xref: /petsc/src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp (revision 4742e46b56cb5d0762110e30c569ce3737a8e22a)
1*4742e46bSJacob Faibussowitsch #ifndef PETSCMATSEQDENSECUPM_HPP
2*4742e46bSJacob Faibussowitsch #define PETSCMATSEQDENSECUPM_HPP
3*4742e46bSJacob Faibussowitsch 
4*4742e46bSJacob Faibussowitsch #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
5*4742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/seq/dense.h>
6*4742e46bSJacob Faibussowitsch 
7*4742e46bSJacob Faibussowitsch #if defined(__cplusplus)
8*4742e46bSJacob Faibussowitsch   #include <petsc/private/deviceimpl.h> // PetscDeviceContextGetOptionalNullContext_Internal()
9*4742e46bSJacob Faibussowitsch   #include <petsc/private/randomimpl.h> // _p_PetscRandom
10*4742e46bSJacob Faibussowitsch   #include <petsc/private/vecimpl.h>    // _p_Vec
11*4742e46bSJacob Faibussowitsch   #include <petsc/private/cupmobject.hpp>
12*4742e46bSJacob Faibussowitsch   #include <petsc/private/cupmsolverinterface.hpp>
13*4742e46bSJacob Faibussowitsch 
14*4742e46bSJacob Faibussowitsch   #include <petsc/private/cpp/type_traits.hpp> // PetscObjectCast()
15*4742e46bSJacob Faibussowitsch   #include <petsc/private/cpp/utility.hpp>     // util::exchange()
16*4742e46bSJacob Faibussowitsch 
17*4742e46bSJacob Faibussowitsch   #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp> // for VecSeq_CUPM
18*4742e46bSJacob Faibussowitsch 
19*4742e46bSJacob Faibussowitsch namespace Petsc
20*4742e46bSJacob Faibussowitsch {
21*4742e46bSJacob Faibussowitsch 
22*4742e46bSJacob Faibussowitsch namespace mat
23*4742e46bSJacob Faibussowitsch {
24*4742e46bSJacob Faibussowitsch 
25*4742e46bSJacob Faibussowitsch namespace cupm
26*4742e46bSJacob Faibussowitsch {
27*4742e46bSJacob Faibussowitsch 
28*4742e46bSJacob Faibussowitsch namespace impl
29*4742e46bSJacob Faibussowitsch {
30*4742e46bSJacob Faibussowitsch 
31*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
32*4742e46bSJacob Faibussowitsch class MatDense_Seq_CUPM : MatDense_CUPM<T, MatDense_Seq_CUPM<T>> {
33*4742e46bSJacob Faibussowitsch public:
34*4742e46bSJacob Faibussowitsch   MATDENSECUPM_HEADER(T, MatDense_Seq_CUPM<T>);
35*4742e46bSJacob Faibussowitsch 
36*4742e46bSJacob Faibussowitsch private:
37*4742e46bSJacob Faibussowitsch   struct Mat_SeqDenseCUPM {
38*4742e46bSJacob Faibussowitsch     PetscScalar *d_v;           // pointer to the matrix on the GPU
39*4742e46bSJacob Faibussowitsch     PetscScalar *unplacedarray; // if one called MatCUPMDensePlaceArray(), this is where it stashed the original
40*4742e46bSJacob Faibussowitsch     bool         d_user_alloc;
41*4742e46bSJacob Faibussowitsch     bool         d_unplaced_user_alloc;
42*4742e46bSJacob Faibussowitsch     // factorization support
43*4742e46bSJacob Faibussowitsch     cupmBlasInt_t *d_fact_ipiv;  // device pivots
44*4742e46bSJacob Faibussowitsch     cupmScalar_t  *d_fact_tau;   // device QR tau vector
45*4742e46bSJacob Faibussowitsch     cupmBlasInt_t *d_fact_info;  // device info
46*4742e46bSJacob Faibussowitsch     cupmScalar_t  *d_fact_work;  // device workspace
47*4742e46bSJacob Faibussowitsch     cupmBlasInt_t  d_fact_lwork; // size of device workspace
48*4742e46bSJacob Faibussowitsch     // workspace
49*4742e46bSJacob Faibussowitsch     Vec workvec;
50*4742e46bSJacob Faibussowitsch   };
51*4742e46bSJacob Faibussowitsch 
52*4742e46bSJacob Faibussowitsch   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
53*4742e46bSJacob Faibussowitsch 
54*4742e46bSJacob Faibussowitsch   static PetscErrorCode HostToDevice_(Mat, PetscDeviceContext) noexcept;
55*4742e46bSJacob Faibussowitsch   static PetscErrorCode DeviceToHost_(Mat, PetscDeviceContext) noexcept;
56*4742e46bSJacob Faibussowitsch 
57*4742e46bSJacob Faibussowitsch   static PetscErrorCode CheckCUPMSolverInfo_(const cupmBlasInt_t *, cupmStream_t) noexcept;
58*4742e46bSJacob Faibussowitsch 
59*4742e46bSJacob Faibussowitsch   template <typename Derived>
60*4742e46bSJacob Faibussowitsch   struct SolveCommon;
61*4742e46bSJacob Faibussowitsch   struct SolveQR;
62*4742e46bSJacob Faibussowitsch   struct SolveCholesky;
63*4742e46bSJacob Faibussowitsch   struct SolveLU;
64*4742e46bSJacob Faibussowitsch 
65*4742e46bSJacob Faibussowitsch   template <typename Solver, bool transpose>
66*4742e46bSJacob Faibussowitsch   static PetscErrorCode MatSolve_Factored_Dispatch_(Mat, Vec, Vec) noexcept;
67*4742e46bSJacob Faibussowitsch   template <typename Solver, bool transpose>
68*4742e46bSJacob Faibussowitsch   static PetscErrorCode MatMatSolve_Factored_Dispatch_(Mat, Mat, Mat) noexcept;
69*4742e46bSJacob Faibussowitsch   template <bool transpose>
70*4742e46bSJacob Faibussowitsch   static PetscErrorCode MatMultAdd_Dispatch_(Mat, Vec, Vec, Vec) noexcept;
71*4742e46bSJacob Faibussowitsch 
72*4742e46bSJacob Faibussowitsch   template <bool to_host>
73*4742e46bSJacob Faibussowitsch   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
74*4742e46bSJacob Faibussowitsch 
75*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;
76*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr Mat_SeqDense *MatIMPLCast_(Mat) noexcept;
77*4742e46bSJacob Faibussowitsch 
78*4742e46bSJacob Faibussowitsch public:
79*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr Mat_SeqDenseCUPM *MatCUPMCast(Mat) noexcept;
80*4742e46bSJacob Faibussowitsch 
81*4742e46bSJacob Faibussowitsch   // define these by hand since they don't fit the above mold
82*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr const char *MatConvert_seqdensecupm_seqdense_C() noexcept;
83*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_seqaij_seqdensecupm_C() noexcept;
84*4742e46bSJacob Faibussowitsch 
85*4742e46bSJacob Faibussowitsch   static PetscErrorCode Create(Mat) noexcept;
86*4742e46bSJacob Faibussowitsch   static PetscErrorCode Destroy(Mat) noexcept;
87*4742e46bSJacob Faibussowitsch   static PetscErrorCode SetUp(Mat) noexcept;
88*4742e46bSJacob Faibussowitsch   static PetscErrorCode Reset(Mat) noexcept;
89*4742e46bSJacob Faibussowitsch 
90*4742e46bSJacob Faibussowitsch   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
91*4742e46bSJacob Faibussowitsch   static PetscErrorCode Convert_SeqDense_SeqDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
92*4742e46bSJacob Faibussowitsch   static PetscErrorCode Convert_SeqDenseCUPM_SeqDense(Mat, MatType, MatReuse, Mat *) noexcept;
93*4742e46bSJacob Faibussowitsch 
94*4742e46bSJacob Faibussowitsch   template <PetscMemType, PetscMemoryAccessMode>
95*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext) noexcept;
96*4742e46bSJacob Faibussowitsch   template <PetscMemType, PetscMemoryAccessMode>
97*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext) noexcept;
98*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode>
99*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetArrayAndMemType(Mat, PetscScalar **, PetscMemType *, PetscDeviceContext) noexcept;
100*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode>
101*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreArrayAndMemType(Mat, PetscScalar **, PetscDeviceContext) noexcept;
102*4742e46bSJacob Faibussowitsch 
103*4742e46bSJacob Faibussowitsch private:
104*4742e46bSJacob Faibussowitsch   template <PetscMemType mtype, PetscMemoryAccessMode mode>
105*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
106*4742e46bSJacob Faibussowitsch   {
107*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
108*4742e46bSJacob Faibussowitsch 
109*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
110*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
111*4742e46bSJacob Faibussowitsch     PetscCall(GetArray<mtype, mode>(m, p, dctx));
112*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
113*4742e46bSJacob Faibussowitsch   }
114*4742e46bSJacob Faibussowitsch 
115*4742e46bSJacob Faibussowitsch   template <PetscMemType mtype, PetscMemoryAccessMode mode>
116*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
117*4742e46bSJacob Faibussowitsch   {
118*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
119*4742e46bSJacob Faibussowitsch 
120*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
121*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
122*4742e46bSJacob Faibussowitsch     PetscCall(RestoreArray<mtype, mode>(m, p, dctx));
123*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
124*4742e46bSJacob Faibussowitsch   }
125*4742e46bSJacob Faibussowitsch 
126*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode mode>
127*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetArrayAndMemTypeC_(Mat m, PetscScalar **p, PetscMemType *tp) noexcept
128*4742e46bSJacob Faibussowitsch   {
129*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
130*4742e46bSJacob Faibussowitsch 
131*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
132*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
133*4742e46bSJacob Faibussowitsch     PetscCall(GetArrayAndMemType<mode>(m, p, tp, dctx));
134*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
135*4742e46bSJacob Faibussowitsch   }
136*4742e46bSJacob Faibussowitsch 
137*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode mode>
138*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreArrayAndMemTypeC_(Mat m, PetscScalar **p) noexcept
139*4742e46bSJacob Faibussowitsch   {
140*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
141*4742e46bSJacob Faibussowitsch 
142*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
143*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
144*4742e46bSJacob Faibussowitsch     PetscCall(RestoreArrayAndMemType<mode>(m, p, dctx));
145*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
146*4742e46bSJacob Faibussowitsch   }
147*4742e46bSJacob Faibussowitsch 
148*4742e46bSJacob Faibussowitsch public:
149*4742e46bSJacob Faibussowitsch   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
150*4742e46bSJacob Faibussowitsch   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
151*4742e46bSJacob Faibussowitsch   static PetscErrorCode ResetArray(Mat) noexcept;
152*4742e46bSJacob Faibussowitsch 
153*4742e46bSJacob Faibussowitsch   template <bool transpose_A, bool transpose_B>
154*4742e46bSJacob Faibussowitsch   static PetscErrorCode MatMatMult_Numeric_Dispatch(Mat, Mat, Mat) noexcept;
155*4742e46bSJacob Faibussowitsch   static PetscErrorCode Copy(Mat, Mat, MatStructure) noexcept;
156*4742e46bSJacob Faibussowitsch   static PetscErrorCode ZeroEntries(Mat) noexcept;
157*4742e46bSJacob Faibussowitsch   static PetscErrorCode Scale(Mat, PetscScalar) noexcept;
158*4742e46bSJacob Faibussowitsch   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
159*4742e46bSJacob Faibussowitsch   static PetscErrorCode AXPY(Mat, PetscScalar, Mat, MatStructure) noexcept;
160*4742e46bSJacob Faibussowitsch   static PetscErrorCode Duplicate(Mat, MatDuplicateOption, Mat *) noexcept;
161*4742e46bSJacob Faibussowitsch   static PetscErrorCode SetRandom(Mat, PetscRandom) noexcept;
162*4742e46bSJacob Faibussowitsch 
163*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetColumnVector(Mat, Vec, PetscInt) noexcept;
164*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode>
165*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
166*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode>
167*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
168*4742e46bSJacob Faibussowitsch 
169*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetFactor(Mat, MatFactorType, Mat *) noexcept;
170*4742e46bSJacob Faibussowitsch   static PetscErrorCode InvertFactors(Mat) noexcept;
171*4742e46bSJacob Faibussowitsch 
172*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetSubMatrix(Mat, PetscInt, PetscInt, PetscInt, PetscInt, Mat *) noexcept;
173*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreSubMatrix(Mat, Mat *) noexcept;
174*4742e46bSJacob Faibussowitsch };
175*4742e46bSJacob Faibussowitsch 
176*4742e46bSJacob Faibussowitsch } // namespace impl
177*4742e46bSJacob Faibussowitsch 
178*4742e46bSJacob Faibussowitsch namespace
179*4742e46bSJacob Faibussowitsch {
180*4742e46bSJacob Faibussowitsch 
181*4742e46bSJacob Faibussowitsch // Declare this here so that the functions below can make use of it
182*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
183*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateSeqDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
184*4742e46bSJacob Faibussowitsch {
185*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
186*4742e46bSJacob Faibussowitsch   PetscCall(impl::MatDense_Seq_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, m, n, data, A, dctx, preallocate));
187*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
188*4742e46bSJacob Faibussowitsch }
189*4742e46bSJacob Faibussowitsch 
190*4742e46bSJacob Faibussowitsch } // anonymous namespace
191*4742e46bSJacob Faibussowitsch 
192*4742e46bSJacob Faibussowitsch namespace impl
193*4742e46bSJacob Faibussowitsch {
194*4742e46bSJacob Faibussowitsch 
195*4742e46bSJacob Faibussowitsch // ==========================================================================================
196*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Utility
197*4742e46bSJacob Faibussowitsch // ==========================================================================================
198*4742e46bSJacob Faibussowitsch 
199*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
200*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetPreallocation_(Mat m, PetscDeviceContext dctx, PetscScalar *user_device_array) noexcept
201*4742e46bSJacob Faibussowitsch {
202*4742e46bSJacob Faibussowitsch   const auto   mcu   = MatCUPMCast(m);
203*4742e46bSJacob Faibussowitsch   const auto   nrows = m->rmap->n;
204*4742e46bSJacob Faibussowitsch   const auto   ncols = m->cmap->n;
205*4742e46bSJacob Faibussowitsch   auto        &lda   = MatIMPLCast(m)->lda;
206*4742e46bSJacob Faibussowitsch   cupmStream_t stream;
207*4742e46bSJacob Faibussowitsch 
208*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
209*4742e46bSJacob Faibussowitsch   PetscCheckTypeName(m, MATSEQDENSECUPM());
210*4742e46bSJacob Faibussowitsch   PetscValidDeviceContext(dctx, 2);
211*4742e46bSJacob Faibussowitsch   PetscCall(checkCupmBlasIntCast(nrows));
212*4742e46bSJacob Faibussowitsch   PetscCall(checkCupmBlasIntCast(ncols));
213*4742e46bSJacob Faibussowitsch   PetscCall(GetHandlesFrom_(dctx, &stream));
214*4742e46bSJacob Faibussowitsch   if (lda <= 0) lda = nrows;
215*4742e46bSJacob Faibussowitsch   if (!mcu->d_user_alloc) PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream));
216*4742e46bSJacob Faibussowitsch   if (user_device_array) {
217*4742e46bSJacob Faibussowitsch     mcu->d_user_alloc = PETSC_TRUE;
218*4742e46bSJacob Faibussowitsch     mcu->d_v          = user_device_array;
219*4742e46bSJacob Faibussowitsch   } else {
220*4742e46bSJacob Faibussowitsch     PetscInt size;
221*4742e46bSJacob Faibussowitsch 
222*4742e46bSJacob Faibussowitsch     mcu->d_user_alloc = PETSC_FALSE;
223*4742e46bSJacob Faibussowitsch     PetscCall(PetscIntMultError(lda, ncols, &size));
224*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMMallocAsync(&mcu->d_v, size, stream));
225*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMMemsetAsync(mcu->d_v, 0, size, stream));
226*4742e46bSJacob Faibussowitsch   }
227*4742e46bSJacob Faibussowitsch   m->offloadmask = PETSC_OFFLOAD_GPU;
228*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
229*4742e46bSJacob Faibussowitsch }
230*4742e46bSJacob Faibussowitsch 
231*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
232*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::HostToDevice_(Mat m, PetscDeviceContext dctx) noexcept
233*4742e46bSJacob Faibussowitsch {
234*4742e46bSJacob Faibussowitsch   const auto nrows = m->rmap->n;
235*4742e46bSJacob Faibussowitsch   const auto ncols = m->cmap->n;
236*4742e46bSJacob Faibussowitsch   const auto copy  = m->offloadmask == PETSC_OFFLOAD_CPU || m->offloadmask == PETSC_OFFLOAD_UNALLOCATED;
237*4742e46bSJacob Faibussowitsch 
238*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
239*4742e46bSJacob Faibussowitsch   PetscCheckTypeName(m, MATSEQDENSECUPM());
240*4742e46bSJacob Faibussowitsch   if (m->boundtocpu) PetscFunctionReturn(PETSC_SUCCESS);
241*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(m, "%s matrix %" PetscInt_FMT " x %" PetscInt_FMT "\n", copy ? "Copy" : "Reusing", nrows, ncols));
242*4742e46bSJacob Faibussowitsch   if (copy) {
243*4742e46bSJacob Faibussowitsch     const auto   mcu = MatCUPMCast(m);
244*4742e46bSJacob Faibussowitsch     cupmStream_t stream;
245*4742e46bSJacob Faibussowitsch 
246*4742e46bSJacob Faibussowitsch     // Allocate GPU memory if not present
247*4742e46bSJacob Faibussowitsch     if (!mcu->d_v) PetscCall(SetPreallocation(m, dctx));
248*4742e46bSJacob Faibussowitsch     PetscCall(GetHandlesFrom_(dctx, &stream));
249*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogEventBegin(MAT_DenseCopyToGPU, m, 0, 0, 0));
250*4742e46bSJacob Faibussowitsch     {
251*4742e46bSJacob Faibussowitsch       const auto mimpl = MatIMPLCast(m);
252*4742e46bSJacob Faibussowitsch       const auto lda   = mimpl->lda;
253*4742e46bSJacob Faibussowitsch       const auto src   = mimpl->v;
254*4742e46bSJacob Faibussowitsch       const auto dest  = mcu->d_v;
255*4742e46bSJacob Faibussowitsch 
256*4742e46bSJacob Faibussowitsch       if (lda > nrows) {
257*4742e46bSJacob Faibussowitsch         PetscCall(PetscCUPMMemcpy2DAsync(dest, lda, src, lda, nrows, ncols, cupmMemcpyHostToDevice, stream));
258*4742e46bSJacob Faibussowitsch       } else {
259*4742e46bSJacob Faibussowitsch         PetscCall(PetscCUPMMemcpyAsync(dest, src, lda * ncols, cupmMemcpyHostToDevice, stream));
260*4742e46bSJacob Faibussowitsch       }
261*4742e46bSJacob Faibussowitsch     }
262*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogEventEnd(MAT_DenseCopyToGPU, m, 0, 0, 0));
263*4742e46bSJacob Faibussowitsch     // order important, ensure that offloadmask is PETSC_OFFLOAD_BOTH
264*4742e46bSJacob Faibussowitsch     m->offloadmask = PETSC_OFFLOAD_BOTH;
265*4742e46bSJacob Faibussowitsch   }
266*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
267*4742e46bSJacob Faibussowitsch }
268*4742e46bSJacob Faibussowitsch 
269*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
270*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::DeviceToHost_(Mat m, PetscDeviceContext dctx) noexcept
271*4742e46bSJacob Faibussowitsch {
272*4742e46bSJacob Faibussowitsch   const auto nrows = m->rmap->n;
273*4742e46bSJacob Faibussowitsch   const auto ncols = m->cmap->n;
274*4742e46bSJacob Faibussowitsch   const auto copy  = m->offloadmask == PETSC_OFFLOAD_GPU;
275*4742e46bSJacob Faibussowitsch 
276*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
277*4742e46bSJacob Faibussowitsch   PetscCheckTypeName(m, MATSEQDENSECUPM());
278*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(m, "%s matrix %" PetscInt_FMT " x %" PetscInt_FMT "\n", copy ? "Copy" : "Reusing", nrows, ncols));
279*4742e46bSJacob Faibussowitsch   if (copy) {
280*4742e46bSJacob Faibussowitsch     const auto   mimpl = MatIMPLCast(m);
281*4742e46bSJacob Faibussowitsch     cupmStream_t stream;
282*4742e46bSJacob Faibussowitsch 
283*4742e46bSJacob Faibussowitsch     // MatCreateSeqDenseCUPM may not allocate CPU memory. Allocate if needed
284*4742e46bSJacob Faibussowitsch     if (!mimpl->v) PetscCall(MatSeqDenseSetPreallocation(m, nullptr));
285*4742e46bSJacob Faibussowitsch     PetscCall(GetHandlesFrom_(dctx, &stream));
286*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogEventBegin(MAT_DenseCopyFromGPU, m, 0, 0, 0));
287*4742e46bSJacob Faibussowitsch     {
288*4742e46bSJacob Faibussowitsch       const auto lda  = mimpl->lda;
289*4742e46bSJacob Faibussowitsch       const auto dest = mimpl->v;
290*4742e46bSJacob Faibussowitsch       const auto src  = MatCUPMCast(m)->d_v;
291*4742e46bSJacob Faibussowitsch 
292*4742e46bSJacob Faibussowitsch       if (lda > nrows) {
293*4742e46bSJacob Faibussowitsch         PetscCall(PetscCUPMMemcpy2DAsync(dest, lda, src, lda, nrows, ncols, cupmMemcpyDeviceToHost, stream));
294*4742e46bSJacob Faibussowitsch       } else {
295*4742e46bSJacob Faibussowitsch         PetscCall(PetscCUPMMemcpyAsync(dest, src, lda * ncols, cupmMemcpyDeviceToHost, stream));
296*4742e46bSJacob Faibussowitsch       }
297*4742e46bSJacob Faibussowitsch     }
298*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogEventEnd(MAT_DenseCopyFromGPU, m, 0, 0, 0));
299*4742e46bSJacob Faibussowitsch     // order is important, MatSeqDenseSetPreallocation() might set offloadmask
300*4742e46bSJacob Faibussowitsch     m->offloadmask = PETSC_OFFLOAD_BOTH;
301*4742e46bSJacob Faibussowitsch   }
302*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
303*4742e46bSJacob Faibussowitsch }
304*4742e46bSJacob Faibussowitsch 
305*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
306*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::CheckCUPMSolverInfo_(const cupmBlasInt_t *fact_info, cupmStream_t stream) noexcept
307*4742e46bSJacob Faibussowitsch {
308*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
309*4742e46bSJacob Faibussowitsch   if (PetscDefined(USE_DEBUG)) {
310*4742e46bSJacob Faibussowitsch     cupmBlasInt_t info = 0;
311*4742e46bSJacob Faibussowitsch 
312*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMMemcpyAsync(&info, fact_info, 1, cupmMemcpyDeviceToHost, stream));
313*4742e46bSJacob Faibussowitsch     if (stream) PetscCallCUPM(cupmStreamSynchronize(stream));
314*4742e46bSJacob Faibussowitsch     static_assert(std::is_same<decltype(info), int>::value, "");
315*4742e46bSJacob Faibussowitsch     PetscCheck(info <= 0, PETSC_COMM_SELF, PETSC_ERR_MAT_CH_ZRPVT, "Bad factorization: zero pivot in row %d", info - 1);
316*4742e46bSJacob Faibussowitsch     PetscCheck(info >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Wrong argument to cupmSolver %d", -info);
317*4742e46bSJacob Faibussowitsch   }
318*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
319*4742e46bSJacob Faibussowitsch }
320*4742e46bSJacob Faibussowitsch 
321*4742e46bSJacob Faibussowitsch // ==========================================================================================
322*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Solver Dispatch
323*4742e46bSJacob Faibussowitsch // ==========================================================================================
324*4742e46bSJacob Faibussowitsch 
325*4742e46bSJacob Faibussowitsch // specific solvers called through the dispatch_() family of functions
326*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
327*4742e46bSJacob Faibussowitsch template <typename Derived>
328*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveCommon {
329*4742e46bSJacob Faibussowitsch   using derived_type = Derived;
330*4742e46bSJacob Faibussowitsch 
331*4742e46bSJacob Faibussowitsch   template <typename F>
332*4742e46bSJacob Faibussowitsch   static PetscErrorCode ResizeFactLwork(Mat_SeqDenseCUPM *mcu, cupmStream_t stream, F &&cupmSolverComputeFactLwork) noexcept
333*4742e46bSJacob Faibussowitsch   {
334*4742e46bSJacob Faibussowitsch     cupmBlasInt_t lwork;
335*4742e46bSJacob Faibussowitsch 
336*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
337*4742e46bSJacob Faibussowitsch     PetscCallCUPMSOLVER(cupmSolverComputeFactLwork(&lwork));
338*4742e46bSJacob Faibussowitsch     if (lwork > mcu->d_fact_lwork) {
339*4742e46bSJacob Faibussowitsch       mcu->d_fact_lwork = lwork;
340*4742e46bSJacob Faibussowitsch       PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream));
341*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, lwork, stream));
342*4742e46bSJacob Faibussowitsch     }
343*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
344*4742e46bSJacob Faibussowitsch   }
345*4742e46bSJacob Faibussowitsch 
346*4742e46bSJacob Faibussowitsch   static PetscErrorCode FactorPrepare(Mat A, cupmStream_t stream) noexcept
347*4742e46bSJacob Faibussowitsch   {
348*4742e46bSJacob Faibussowitsch     const auto mcu = MatCUPMCast(A);
349*4742e46bSJacob Faibussowitsch 
350*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
351*4742e46bSJacob Faibussowitsch     PetscCall(PetscInfo(A, "%s factor %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", derived_type::NAME(), A->rmap->n, A->cmap->n));
352*4742e46bSJacob Faibussowitsch     A->factortype             = derived_type::MATFACTORTYPE();
353*4742e46bSJacob Faibussowitsch     A->ops->solve             = MatSolve_Factored_Dispatch_<derived_type, false>;
354*4742e46bSJacob Faibussowitsch     A->ops->solvetranspose    = MatSolve_Factored_Dispatch_<derived_type, true>;
355*4742e46bSJacob Faibussowitsch     A->ops->matsolve          = MatMatSolve_Factored_Dispatch_<derived_type, false>;
356*4742e46bSJacob Faibussowitsch     A->ops->matsolvetranspose = MatMatSolve_Factored_Dispatch_<derived_type, true>;
357*4742e46bSJacob Faibussowitsch 
358*4742e46bSJacob Faibussowitsch     PetscCall(PetscStrFreeAllocpy(MATSOLVERCUPM(), &A->solvertype));
359*4742e46bSJacob Faibussowitsch     if (!mcu->d_fact_info) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_info, 1, stream));
360*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
361*4742e46bSJacob Faibussowitsch   }
362*4742e46bSJacob Faibussowitsch };
363*4742e46bSJacob Faibussowitsch 
364*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
365*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveLU : SolveCommon<SolveLU> {
366*4742e46bSJacob Faibussowitsch   using base_type = SolveCommon<SolveLU>;
367*4742e46bSJacob Faibussowitsch 
368*4742e46bSJacob Faibussowitsch   static constexpr const char   *NAME() noexcept { return "LU"; }
369*4742e46bSJacob Faibussowitsch   static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_LU; }
370*4742e46bSJacob Faibussowitsch 
371*4742e46bSJacob Faibussowitsch   static PetscErrorCode Factor(Mat A, IS, IS, const MatFactorInfo *) noexcept
372*4742e46bSJacob Faibussowitsch   {
373*4742e46bSJacob Faibussowitsch     const auto         m = static_cast<cupmBlasInt_t>(A->rmap->n);
374*4742e46bSJacob Faibussowitsch     const auto         n = static_cast<cupmBlasInt_t>(A->cmap->n);
375*4742e46bSJacob Faibussowitsch     cupmStream_t       stream;
376*4742e46bSJacob Faibussowitsch     cupmSolverHandle_t handle;
377*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
378*4742e46bSJacob Faibussowitsch 
379*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
380*4742e46bSJacob Faibussowitsch     if (!m || !n) PetscFunctionReturn(PETSC_SUCCESS);
381*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx, &handle, &stream));
382*4742e46bSJacob Faibussowitsch     PetscCall(base_type::FactorPrepare(A, stream));
383*4742e46bSJacob Faibussowitsch     {
384*4742e46bSJacob Faibussowitsch       const auto mcu = MatCUPMCast(A);
385*4742e46bSJacob Faibussowitsch       const auto da  = DeviceArrayReadWrite(dctx, A);
386*4742e46bSJacob Faibussowitsch       const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda);
387*4742e46bSJacob Faibussowitsch 
388*4742e46bSJacob Faibussowitsch       // clang-format off
389*4742e46bSJacob Faibussowitsch       PetscCall(
390*4742e46bSJacob Faibussowitsch         base_type::ResizeFactLwork(
391*4742e46bSJacob Faibussowitsch           mcu, stream,
392*4742e46bSJacob Faibussowitsch           [&](cupmBlasInt_t *fact_lwork)
393*4742e46bSJacob Faibussowitsch           {
394*4742e46bSJacob Faibussowitsch             return cupmSolverXgetrf_bufferSize(handle, m, n, da.cupmdata(), lda, fact_lwork);
395*4742e46bSJacob Faibussowitsch           }
396*4742e46bSJacob Faibussowitsch         )
397*4742e46bSJacob Faibussowitsch       );
398*4742e46bSJacob Faibussowitsch       // clang-format on
399*4742e46bSJacob Faibussowitsch       if (!mcu->d_fact_ipiv) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_ipiv, n, stream));
400*4742e46bSJacob Faibussowitsch 
401*4742e46bSJacob Faibussowitsch       PetscCall(PetscLogGpuTimeBegin());
402*4742e46bSJacob Faibussowitsch       PetscCallCUPMSOLVER(cupmSolverXgetrf(handle, m, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_ipiv, mcu->d_fact_info));
403*4742e46bSJacob Faibussowitsch       PetscCall(PetscLogGpuTimeEnd());
404*4742e46bSJacob Faibussowitsch       PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream));
405*4742e46bSJacob Faibussowitsch     }
406*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(2.0 * n * n * m / 3.0));
407*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
408*4742e46bSJacob Faibussowitsch   }
409*4742e46bSJacob Faibussowitsch 
410*4742e46bSJacob Faibussowitsch   template <bool transpose>
411*4742e46bSJacob Faibussowitsch   static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept
412*4742e46bSJacob Faibussowitsch   {
413*4742e46bSJacob Faibussowitsch     const auto         mcu       = MatCUPMCast(A);
414*4742e46bSJacob Faibussowitsch     const auto         fact_info = mcu->d_fact_info;
415*4742e46bSJacob Faibussowitsch     const auto         fact_ipiv = mcu->d_fact_ipiv;
416*4742e46bSJacob Faibussowitsch     const auto         lda       = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda);
417*4742e46bSJacob Faibussowitsch     cupmSolverHandle_t handle;
418*4742e46bSJacob Faibussowitsch 
419*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
420*4742e46bSJacob Faibussowitsch     PetscCall(GetHandlesFrom_(dctx, &handle));
421*4742e46bSJacob Faibussowitsch     PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k));
422*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
423*4742e46bSJacob Faibussowitsch     {
424*4742e46bSJacob Faibussowitsch       constexpr auto op  = transpose ? CUPMSOLVER_OP_T : CUPMSOLVER_OP_N;
425*4742e46bSJacob Faibussowitsch       const auto     da  = DeviceArrayRead(dctx, A);
426*4742e46bSJacob Faibussowitsch       const auto     lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda);
427*4742e46bSJacob Faibussowitsch 
428*4742e46bSJacob Faibussowitsch       // clang-format off
429*4742e46bSJacob Faibussowitsch       PetscCall(
430*4742e46bSJacob Faibussowitsch         base_type::ResizeFactLwork(
431*4742e46bSJacob Faibussowitsch           mcu, stream,
432*4742e46bSJacob Faibussowitsch           [&](cupmBlasInt_t *lwork)
433*4742e46bSJacob Faibussowitsch           {
434*4742e46bSJacob Faibussowitsch             return cupmSolverXgetrs_bufferSize(
435*4742e46bSJacob Faibussowitsch               handle, op, m, nrhs, da.cupmdata(), lda, fact_ipiv, x, ldx, lwork
436*4742e46bSJacob Faibussowitsch             );
437*4742e46bSJacob Faibussowitsch           }
438*4742e46bSJacob Faibussowitsch         )
439*4742e46bSJacob Faibussowitsch       );
440*4742e46bSJacob Faibussowitsch       // clang-format on
441*4742e46bSJacob Faibussowitsch       PetscCallCUPMSOLVER(cupmSolverXgetrs(handle, op, m, nrhs, da.cupmdata(), lda, fact_ipiv, x, ldx, mcu->d_fact_work, mcu->d_fact_lwork, fact_info));
442*4742e46bSJacob Faibussowitsch       PetscCall(CheckCUPMSolverInfo_(fact_info, stream));
443*4742e46bSJacob Faibussowitsch     }
444*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
445*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(nrhs * (2.0 * m * m - m)));
446*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
447*4742e46bSJacob Faibussowitsch   }
448*4742e46bSJacob Faibussowitsch };
449*4742e46bSJacob Faibussowitsch 
450*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
451*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveCholesky : SolveCommon<SolveCholesky> {
452*4742e46bSJacob Faibussowitsch   using base_type = SolveCommon<SolveCholesky>;
453*4742e46bSJacob Faibussowitsch 
454*4742e46bSJacob Faibussowitsch   static constexpr const char   *NAME() noexcept { return "Cholesky"; }
455*4742e46bSJacob Faibussowitsch   static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_CHOLESKY; }
456*4742e46bSJacob Faibussowitsch 
457*4742e46bSJacob Faibussowitsch   static PetscErrorCode Factor(Mat A, IS, const MatFactorInfo *) noexcept
458*4742e46bSJacob Faibussowitsch   {
459*4742e46bSJacob Faibussowitsch     const auto         n = static_cast<cupmBlasInt_t>(A->rmap->n);
460*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
461*4742e46bSJacob Faibussowitsch     cupmSolverHandle_t handle;
462*4742e46bSJacob Faibussowitsch     cupmStream_t       stream;
463*4742e46bSJacob Faibussowitsch 
464*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
465*4742e46bSJacob Faibussowitsch     if (!n || !A->cmap->n) PetscFunctionReturn(PETSC_SUCCESS);
466*4742e46bSJacob Faibussowitsch     PetscCheck(A->spd == PETSC_BOOL3_TRUE, PETSC_COMM_SELF, PETSC_ERR_SUP, "%ssytrs unavailable. Use MAT_FACTOR_LU", cupmSolverName());
467*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx, &handle, &stream));
468*4742e46bSJacob Faibussowitsch     PetscCall(base_type::FactorPrepare(A, stream));
469*4742e46bSJacob Faibussowitsch     {
470*4742e46bSJacob Faibussowitsch       const auto mcu = MatCUPMCast(A);
471*4742e46bSJacob Faibussowitsch       const auto da  = DeviceArrayReadWrite(dctx, A);
472*4742e46bSJacob Faibussowitsch       const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda);
473*4742e46bSJacob Faibussowitsch 
474*4742e46bSJacob Faibussowitsch       // clang-format off
475*4742e46bSJacob Faibussowitsch       PetscCall(
476*4742e46bSJacob Faibussowitsch         base_type::ResizeFactLwork(
477*4742e46bSJacob Faibussowitsch           mcu, stream,
478*4742e46bSJacob Faibussowitsch           [&](cupmBlasInt_t *fact_lwork)
479*4742e46bSJacob Faibussowitsch           {
480*4742e46bSJacob Faibussowitsch             return cupmSolverXpotrf_bufferSize(
481*4742e46bSJacob Faibussowitsch               handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, fact_lwork
482*4742e46bSJacob Faibussowitsch             );
483*4742e46bSJacob Faibussowitsch           }
484*4742e46bSJacob Faibussowitsch         )
485*4742e46bSJacob Faibussowitsch       );
486*4742e46bSJacob Faibussowitsch       // clang-format on
487*4742e46bSJacob Faibussowitsch       PetscCall(PetscLogGpuTimeBegin());
488*4742e46bSJacob Faibussowitsch       PetscCallCUPMSOLVER(cupmSolverXpotrf(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info));
489*4742e46bSJacob Faibussowitsch       PetscCall(PetscLogGpuTimeEnd());
490*4742e46bSJacob Faibussowitsch       PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream));
491*4742e46bSJacob Faibussowitsch     }
492*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(1.0 * n * n * n / 3.0));
493*4742e46bSJacob Faibussowitsch 
494*4742e46bSJacob Faibussowitsch   #if 0
495*4742e46bSJacob Faibussowitsch     // At the time of writing this interface (cuda 10.0), cusolverDn does not implement *sytrs
496*4742e46bSJacob Faibussowitsch     // and *hetr* routines. The code below should work, and it can be activated when *sytrs
497*4742e46bSJacob Faibussowitsch     // routines will be available
498*4742e46bSJacob Faibussowitsch     if (!mcu->d_fact_ipiv) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_ipiv, n, stream));
499*4742e46bSJacob Faibussowitsch     if (!mcu->d_fact_lwork) {
500*4742e46bSJacob Faibussowitsch       PetscCallCUPMSOLVER(cupmSolverDnXsytrf_bufferSize(handle, n, da.cupmdata(), lda, &mcu->d_fact_lwork));
501*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, mcu->d_fact_lwork, stream));
502*4742e46bSJacob Faibussowitsch     }
503*4742e46bSJacob Faibussowitsch     if (mcu->d_fact_info) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_info, 1, stream));
504*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
505*4742e46bSJacob Faibussowitsch     PetscCallCUPMSOLVER(cupmSolverXsytrf(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da, lda, mcu->d_fact_ipiv, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info));
506*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
507*4742e46bSJacob Faibussowitsch   #endif
508*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
509*4742e46bSJacob Faibussowitsch   }
510*4742e46bSJacob Faibussowitsch 
511*4742e46bSJacob Faibussowitsch   template <bool transpose>
512*4742e46bSJacob Faibussowitsch   static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept
513*4742e46bSJacob Faibussowitsch   {
514*4742e46bSJacob Faibussowitsch     const auto         mcu       = MatCUPMCast(A);
515*4742e46bSJacob Faibussowitsch     const auto         fact_info = mcu->d_fact_info;
516*4742e46bSJacob Faibussowitsch     cupmSolverHandle_t handle;
517*4742e46bSJacob Faibussowitsch 
518*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
519*4742e46bSJacob Faibussowitsch     PetscAssert(!mcu->d_fact_ipiv, PETSC_COMM_SELF, PETSC_ERR_LIB, "%ssytrs not implemented", cupmSolverName());
520*4742e46bSJacob Faibussowitsch     PetscCall(GetHandlesFrom_(dctx, &handle));
521*4742e46bSJacob Faibussowitsch     PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k));
522*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
523*4742e46bSJacob Faibussowitsch     {
524*4742e46bSJacob Faibussowitsch       const auto da  = DeviceArrayRead(dctx, A);
525*4742e46bSJacob Faibussowitsch       const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda);
526*4742e46bSJacob Faibussowitsch 
527*4742e46bSJacob Faibussowitsch       // clang-format off
528*4742e46bSJacob Faibussowitsch       PetscCall(
529*4742e46bSJacob Faibussowitsch         base_type::ResizeFactLwork(
530*4742e46bSJacob Faibussowitsch           mcu, stream,
531*4742e46bSJacob Faibussowitsch           [&](cupmBlasInt_t *lwork)
532*4742e46bSJacob Faibussowitsch           {
533*4742e46bSJacob Faibussowitsch             return cupmSolverXpotrs_bufferSize(
534*4742e46bSJacob Faibussowitsch               handle, CUPMSOLVER_FILL_MODE_LOWER, m, nrhs, da.cupmdata(), lda, x, ldx, lwork
535*4742e46bSJacob Faibussowitsch             );
536*4742e46bSJacob Faibussowitsch           }
537*4742e46bSJacob Faibussowitsch         )
538*4742e46bSJacob Faibussowitsch       );
539*4742e46bSJacob Faibussowitsch       // clang-format on
540*4742e46bSJacob Faibussowitsch       PetscCallCUPMSOLVER(cupmSolverXpotrs(handle, CUPMSOLVER_FILL_MODE_LOWER, m, nrhs, da.cupmdata(), lda, x, ldx, mcu->d_fact_work, mcu->d_fact_lwork, fact_info));
541*4742e46bSJacob Faibussowitsch       PetscCall(CheckCUPMSolverInfo_(fact_info, stream));
542*4742e46bSJacob Faibussowitsch     }
543*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
544*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(nrhs * (2.0 * m * m - m)));
545*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
546*4742e46bSJacob Faibussowitsch   }
547*4742e46bSJacob Faibussowitsch };
548*4742e46bSJacob Faibussowitsch 
549*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
550*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveQR : SolveCommon<SolveQR> {
551*4742e46bSJacob Faibussowitsch   using base_type = SolveCommon<SolveQR>;
552*4742e46bSJacob Faibussowitsch 
553*4742e46bSJacob Faibussowitsch   static constexpr const char   *NAME() noexcept { return "QR"; }
554*4742e46bSJacob Faibussowitsch   static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_QR; }
555*4742e46bSJacob Faibussowitsch 
556*4742e46bSJacob Faibussowitsch   static PetscErrorCode Factor(Mat A, IS, const MatFactorInfo *) noexcept
557*4742e46bSJacob Faibussowitsch   {
558*4742e46bSJacob Faibussowitsch     const auto         m     = static_cast<cupmBlasInt_t>(A->rmap->n);
559*4742e46bSJacob Faibussowitsch     const auto         n     = static_cast<cupmBlasInt_t>(A->cmap->n);
560*4742e46bSJacob Faibussowitsch     const auto         min   = std::min(m, n);
561*4742e46bSJacob Faibussowitsch     const auto         mimpl = MatIMPLCast(A);
562*4742e46bSJacob Faibussowitsch     cupmStream_t       stream;
563*4742e46bSJacob Faibussowitsch     cupmSolverHandle_t handle;
564*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
565*4742e46bSJacob Faibussowitsch 
566*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
567*4742e46bSJacob Faibussowitsch     if (!m || !n) PetscFunctionReturn(PETSC_SUCCESS);
568*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx, &handle, &stream));
569*4742e46bSJacob Faibussowitsch     PetscCall(base_type::FactorPrepare(A, stream));
570*4742e46bSJacob Faibussowitsch     mimpl->rank = min;
571*4742e46bSJacob Faibussowitsch     {
572*4742e46bSJacob Faibussowitsch       const auto mcu = MatCUPMCast(A);
573*4742e46bSJacob Faibussowitsch       const auto da  = DeviceArrayReadWrite(dctx, A);
574*4742e46bSJacob Faibussowitsch       const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda);
575*4742e46bSJacob Faibussowitsch 
576*4742e46bSJacob Faibussowitsch       if (!mcu->workvec) PetscCall(vec::cupm::VecCreateSeqCUPMAsync<T>(PetscObjectComm(PetscObjectCast(A)), m, &mcu->workvec));
577*4742e46bSJacob Faibussowitsch       if (!mcu->d_fact_tau) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_tau, min, stream));
578*4742e46bSJacob Faibussowitsch       // clang-format off
579*4742e46bSJacob Faibussowitsch       PetscCall(
580*4742e46bSJacob Faibussowitsch         base_type::ResizeFactLwork(
581*4742e46bSJacob Faibussowitsch           mcu, stream,
582*4742e46bSJacob Faibussowitsch           [&](cupmBlasInt_t *fact_lwork)
583*4742e46bSJacob Faibussowitsch           {
584*4742e46bSJacob Faibussowitsch             return cupmSolverXgeqrf_bufferSize(handle, m, n, da.cupmdata(), lda, fact_lwork);
585*4742e46bSJacob Faibussowitsch           }
586*4742e46bSJacob Faibussowitsch         )
587*4742e46bSJacob Faibussowitsch       );
588*4742e46bSJacob Faibussowitsch       // clang-format on
589*4742e46bSJacob Faibussowitsch       PetscCall(PetscLogGpuTimeBegin());
590*4742e46bSJacob Faibussowitsch       PetscCallCUPMSOLVER(cupmSolverXgeqrf(handle, m, n, da.cupmdata(), lda, mcu->d_fact_tau, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info));
591*4742e46bSJacob Faibussowitsch       PetscCall(PetscLogGpuTimeEnd());
592*4742e46bSJacob Faibussowitsch       PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream));
593*4742e46bSJacob Faibussowitsch     }
594*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(2.0 * min * min * (std::max(m, n) - min / 3.0)));
595*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
596*4742e46bSJacob Faibussowitsch   }
597*4742e46bSJacob Faibussowitsch 
598*4742e46bSJacob Faibussowitsch   template <bool transpose>
599*4742e46bSJacob Faibussowitsch   static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept
600*4742e46bSJacob Faibussowitsch   {
601*4742e46bSJacob Faibussowitsch     const auto         mimpl      = MatIMPLCast(A);
602*4742e46bSJacob Faibussowitsch     const auto         rank       = static_cast<cupmBlasInt_t>(mimpl->rank);
603*4742e46bSJacob Faibussowitsch     const auto         mcu        = MatCUPMCast(A);
604*4742e46bSJacob Faibussowitsch     const auto         fact_info  = mcu->d_fact_info;
605*4742e46bSJacob Faibussowitsch     const auto         fact_tau   = mcu->d_fact_tau;
606*4742e46bSJacob Faibussowitsch     const auto         fact_work  = mcu->d_fact_work;
607*4742e46bSJacob Faibussowitsch     const auto         fact_lwork = mcu->d_fact_lwork;
608*4742e46bSJacob Faibussowitsch     cupmSolverHandle_t solver_handle;
609*4742e46bSJacob Faibussowitsch     cupmBlasHandle_t   blas_handle;
610*4742e46bSJacob Faibussowitsch 
611*4742e46bSJacob Faibussowitsch     PetscFunctionBegin;
612*4742e46bSJacob Faibussowitsch     PetscCall(GetHandlesFrom_(dctx, &blas_handle, &solver_handle));
613*4742e46bSJacob Faibussowitsch     PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k));
614*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
615*4742e46bSJacob Faibussowitsch     {
616*4742e46bSJacob Faibussowitsch       const auto da  = DeviceArrayRead(dctx, A);
617*4742e46bSJacob Faibussowitsch       const auto one = cupmScalarCast(1.0);
618*4742e46bSJacob Faibussowitsch       const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda);
619*4742e46bSJacob Faibussowitsch 
620*4742e46bSJacob Faibussowitsch       if (transpose) {
621*4742e46bSJacob Faibussowitsch         PetscCallCUPMBLAS(cupmBlasXtrsm(blas_handle, CUPMBLAS_SIDE_LEFT, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_T, CUPMBLAS_DIAG_NON_UNIT, rank, nrhs, &one, da.cupmdata(), lda, x, ldx));
622*4742e46bSJacob Faibussowitsch         PetscCallCUPMSOLVER(cupmSolverXormqr(solver_handle, CUPMSOLVER_SIDE_LEFT, CUPMSOLVER_OP_N, m, nrhs, rank, da.cupmdata(), lda, fact_tau, x, ldx, fact_work, fact_lwork, fact_info));
623*4742e46bSJacob Faibussowitsch         PetscCall(CheckCUPMSolverInfo_(fact_info, stream));
624*4742e46bSJacob Faibussowitsch       } else {
625*4742e46bSJacob Faibussowitsch         constexpr auto op = PetscDefined(USE_COMPLEX) ? CUPMSOLVER_OP_C : CUPMSOLVER_OP_T;
626*4742e46bSJacob Faibussowitsch 
627*4742e46bSJacob Faibussowitsch         PetscCallCUPMSOLVER(cupmSolverXormqr(solver_handle, CUPMSOLVER_SIDE_LEFT, op, m, nrhs, rank, da.cupmdata(), lda, fact_tau, x, ldx, fact_work, fact_lwork, fact_info));
628*4742e46bSJacob Faibussowitsch         PetscCall(CheckCUPMSolverInfo_(fact_info, stream));
629*4742e46bSJacob Faibussowitsch         PetscCallCUPMBLAS(cupmBlasXtrsm(blas_handle, CUPMBLAS_SIDE_LEFT, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, rank, nrhs, &one, da.cupmdata(), lda, x, ldx));
630*4742e46bSJacob Faibussowitsch       }
631*4742e46bSJacob Faibussowitsch     }
632*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
633*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogFlops(nrhs * (4.0 * m * rank - (rank * rank))));
634*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
635*4742e46bSJacob Faibussowitsch   }
636*4742e46bSJacob Faibussowitsch };
637*4742e46bSJacob Faibussowitsch 
638*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
639*4742e46bSJacob Faibussowitsch template <typename Solver, bool transpose>
640*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatSolve_Factored_Dispatch_(Mat A, Vec x, Vec y) noexcept
641*4742e46bSJacob Faibussowitsch {
642*4742e46bSJacob Faibussowitsch   using namespace vec::cupm;
643*4742e46bSJacob Faibussowitsch   const auto         pobj_A  = PetscObjectCast(A);
644*4742e46bSJacob Faibussowitsch   const auto         m       = static_cast<cupmBlasInt_t>(A->rmap->n);
645*4742e46bSJacob Faibussowitsch   const auto         k       = static_cast<cupmBlasInt_t>(A->cmap->n);
646*4742e46bSJacob Faibussowitsch   auto              &workvec = MatCUPMCast(A)->workvec;
647*4742e46bSJacob Faibussowitsch   PetscScalar       *y_array = nullptr;
648*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
649*4742e46bSJacob Faibussowitsch   PetscBool          xiscupm, yiscupm, aiscupm;
650*4742e46bSJacob Faibussowitsch   bool               use_y_array_directly;
651*4742e46bSJacob Faibussowitsch   cupmStream_t       stream;
652*4742e46bSJacob Faibussowitsch 
653*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
654*4742e46bSJacob Faibussowitsch   PetscCheck(A->factortype != MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix must be factored to solve");
655*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(x), VecSeq_CUPM::VECSEQCUPM(), &xiscupm));
656*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(y), VecSeq_CUPM::VECSEQCUPM(), &yiscupm));
657*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(pobj_A, MATSEQDENSECUPM(), &aiscupm));
658*4742e46bSJacob Faibussowitsch   PetscAssert(aiscupm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Matrix A is somehow not CUPM?????????????????????????????");
659*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &stream));
660*4742e46bSJacob Faibussowitsch   use_y_array_directly = yiscupm && (k >= m);
661*4742e46bSJacob Faibussowitsch   {
662*4742e46bSJacob Faibussowitsch     const PetscScalar *x_array;
663*4742e46bSJacob Faibussowitsch     const auto         xisdevice = xiscupm && PetscOffloadDevice(x->offloadmask);
664*4742e46bSJacob Faibussowitsch     const auto         copy_mode = xisdevice ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
665*4742e46bSJacob Faibussowitsch 
666*4742e46bSJacob Faibussowitsch     if (!use_y_array_directly && !workvec) PetscCall(VecCreateSeqCUPMAsync<T>(PetscObjectComm(pobj_A), m, &workvec));
667*4742e46bSJacob Faibussowitsch     // The logic here is to try to minimize the amount of memory copying:
668*4742e46bSJacob Faibussowitsch     //
669*4742e46bSJacob Faibussowitsch     // If we call VecCUPMGetArrayRead(X, &x) every time xiscupm and the data is not offloaded
670*4742e46bSJacob Faibussowitsch     // to the GPU yet, then the data is copied to the GPU. But we are only trying to get the
671*4742e46bSJacob Faibussowitsch     // data in order to copy it into the y array. So the array x will be wherever the data
672*4742e46bSJacob Faibussowitsch     // already is so that only one memcpy is performed
673*4742e46bSJacob Faibussowitsch     if (xisdevice) {
674*4742e46bSJacob Faibussowitsch       PetscCall(VecCUPMGetArrayReadAsync<T>(x, &x_array, dctx));
675*4742e46bSJacob Faibussowitsch     } else {
676*4742e46bSJacob Faibussowitsch       PetscCall(VecGetArrayRead(x, &x_array));
677*4742e46bSJacob Faibussowitsch     }
678*4742e46bSJacob Faibussowitsch     PetscCall(VecCUPMGetArrayWriteAsync<T>(use_y_array_directly ? y : workvec, &y_array, dctx));
679*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMMemcpyAsync(y_array, x_array, m, copy_mode, stream));
680*4742e46bSJacob Faibussowitsch     if (xisdevice) {
681*4742e46bSJacob Faibussowitsch       PetscCall(VecCUPMRestoreArrayReadAsync<T>(x, &x_array, dctx));
682*4742e46bSJacob Faibussowitsch     } else {
683*4742e46bSJacob Faibussowitsch       PetscCall(VecRestoreArrayRead(x, &x_array));
684*4742e46bSJacob Faibussowitsch     }
685*4742e46bSJacob Faibussowitsch   }
686*4742e46bSJacob Faibussowitsch 
687*4742e46bSJacob Faibussowitsch   if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A));
688*4742e46bSJacob Faibussowitsch   PetscCall(Solver{}.template Solve<transpose>(A, cupmScalarPtrCast(y_array), m, m, 1, k, dctx, stream));
689*4742e46bSJacob Faibussowitsch   if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A));
690*4742e46bSJacob Faibussowitsch 
691*4742e46bSJacob Faibussowitsch   if (use_y_array_directly) {
692*4742e46bSJacob Faibussowitsch     PetscCall(VecCUPMRestoreArrayWriteAsync<T>(y, &y_array, dctx));
693*4742e46bSJacob Faibussowitsch   } else {
694*4742e46bSJacob Faibussowitsch     const auto   copy_mode = yiscupm ? cupmMemcpyDeviceToDevice : cupmMemcpyDeviceToHost;
695*4742e46bSJacob Faibussowitsch     PetscScalar *yv;
696*4742e46bSJacob Faibussowitsch 
697*4742e46bSJacob Faibussowitsch     // The logic here is that the data is not yet in either y's GPU array or its CPU array.
698*4742e46bSJacob Faibussowitsch     // There is nothing in the interface to say where the user would like it to end up. So we
699*4742e46bSJacob Faibussowitsch     // choose the GPU, because it is the faster option
700*4742e46bSJacob Faibussowitsch     if (yiscupm) {
701*4742e46bSJacob Faibussowitsch       PetscCall(VecCUPMGetArrayWriteAsync<T>(y, &yv, dctx));
702*4742e46bSJacob Faibussowitsch     } else {
703*4742e46bSJacob Faibussowitsch       PetscCall(VecGetArray(y, &yv));
704*4742e46bSJacob Faibussowitsch     }
705*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMMemcpyAsync(yv, y_array, k, copy_mode, stream));
706*4742e46bSJacob Faibussowitsch     if (yiscupm) {
707*4742e46bSJacob Faibussowitsch       PetscCall(VecCUPMRestoreArrayWriteAsync<T>(y, &yv, dctx));
708*4742e46bSJacob Faibussowitsch     } else {
709*4742e46bSJacob Faibussowitsch       PetscCall(VecRestoreArray(y, &yv));
710*4742e46bSJacob Faibussowitsch     }
711*4742e46bSJacob Faibussowitsch     PetscCall(VecCUPMRestoreArrayWriteAsync<T>(workvec, &y_array));
712*4742e46bSJacob Faibussowitsch   }
713*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
714*4742e46bSJacob Faibussowitsch }
715*4742e46bSJacob Faibussowitsch 
716*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
717*4742e46bSJacob Faibussowitsch template <typename Solver, bool transpose>
718*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMatSolve_Factored_Dispatch_(Mat A, Mat B, Mat X) noexcept
719*4742e46bSJacob Faibussowitsch {
720*4742e46bSJacob Faibussowitsch   const auto         m = static_cast<cupmBlasInt_t>(A->rmap->n);
721*4742e46bSJacob Faibussowitsch   const auto         k = static_cast<cupmBlasInt_t>(A->cmap->n);
722*4742e46bSJacob Faibussowitsch   cupmBlasInt_t      nrhs, ldb, ldx, ldy;
723*4742e46bSJacob Faibussowitsch   PetscScalar       *y;
724*4742e46bSJacob Faibussowitsch   PetscBool          biscupm, xiscupm, aiscupm;
725*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
726*4742e46bSJacob Faibussowitsch   cupmStream_t       stream;
727*4742e46bSJacob Faibussowitsch 
728*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
729*4742e46bSJacob Faibussowitsch   PetscCheck(A->factortype != MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix must be factored to solve");
730*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(B), MATSEQDENSECUPM(), &biscupm));
731*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(X), MATSEQDENSECUPM(), &xiscupm));
732*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), MATSEQDENSECUPM(), &aiscupm));
733*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &stream));
734*4742e46bSJacob Faibussowitsch   {
735*4742e46bSJacob Faibussowitsch     PetscInt n;
736*4742e46bSJacob Faibussowitsch 
737*4742e46bSJacob Faibussowitsch     PetscCall(MatGetSize(B, nullptr, &n));
738*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMBlasIntCast(n, &nrhs));
739*4742e46bSJacob Faibussowitsch     PetscCall(MatDenseGetLDA(B, &n));
740*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMBlasIntCast(n, &ldb));
741*4742e46bSJacob Faibussowitsch     PetscCall(MatDenseGetLDA(X, &n));
742*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMBlasIntCast(n, &ldx));
743*4742e46bSJacob Faibussowitsch   }
744*4742e46bSJacob Faibussowitsch   {
745*4742e46bSJacob Faibussowitsch     // The logic here is to try to minimize the amount of memory copying:
746*4742e46bSJacob Faibussowitsch     //
747*4742e46bSJacob Faibussowitsch     // If we call MatDenseCUPMGetArrayRead(B, &b) every time biscupm and the data is not
748*4742e46bSJacob Faibussowitsch     // offloaded to the GPU yet, then the data is copied to the GPU. But we are only trying to
749*4742e46bSJacob Faibussowitsch     // get the data in order to copy it into the y array. So the array b will be wherever the
750*4742e46bSJacob Faibussowitsch     // data already is so that only one memcpy is performed
751*4742e46bSJacob Faibussowitsch     const auto         bisdevice = biscupm && PetscOffloadDevice(B->offloadmask);
752*4742e46bSJacob Faibussowitsch     const auto         copy_mode = bisdevice ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
753*4742e46bSJacob Faibussowitsch     const PetscScalar *b;
754*4742e46bSJacob Faibussowitsch 
755*4742e46bSJacob Faibussowitsch     if (bisdevice) {
756*4742e46bSJacob Faibussowitsch       b = DeviceArrayRead(dctx, B);
757*4742e46bSJacob Faibussowitsch     } else if (biscupm) {
758*4742e46bSJacob Faibussowitsch       b = HostArrayRead(dctx, B);
759*4742e46bSJacob Faibussowitsch     } else {
760*4742e46bSJacob Faibussowitsch       PetscCall(MatDenseGetArrayRead(B, &b));
761*4742e46bSJacob Faibussowitsch     }
762*4742e46bSJacob Faibussowitsch 
763*4742e46bSJacob Faibussowitsch     if (ldx < m || !xiscupm) {
764*4742e46bSJacob Faibussowitsch       // X's array cannot serve as the array (too small or not on device), B's array cannot
765*4742e46bSJacob Faibussowitsch       // serve as the array (const), so allocate a new array
766*4742e46bSJacob Faibussowitsch       ldy = m;
767*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMallocAsync(&y, nrhs * m));
768*4742e46bSJacob Faibussowitsch     } else {
769*4742e46bSJacob Faibussowitsch       // X's array should serve as the array
770*4742e46bSJacob Faibussowitsch       ldy = ldx;
771*4742e46bSJacob Faibussowitsch       y   = DeviceArrayWrite(dctx, X);
772*4742e46bSJacob Faibussowitsch     }
773*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMMemcpy2DAsync(y, ldy, b, ldb, m, nrhs, copy_mode, stream));
774*4742e46bSJacob Faibussowitsch     if (!bisdevice && !biscupm) PetscCall(MatDenseRestoreArrayRead(B, &b));
775*4742e46bSJacob Faibussowitsch   }
776*4742e46bSJacob Faibussowitsch 
777*4742e46bSJacob Faibussowitsch   // convert to CUPM twice??????????????????????????????????
778*4742e46bSJacob Faibussowitsch   // but A should already be CUPM??????????????????????????????????????
779*4742e46bSJacob Faibussowitsch   if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A));
780*4742e46bSJacob Faibussowitsch   PetscCall(Solver{}.template Solve<transpose>(A, cupmScalarPtrCast(y), ldy, m, nrhs, k, dctx, stream));
781*4742e46bSJacob Faibussowitsch   if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A));
782*4742e46bSJacob Faibussowitsch 
783*4742e46bSJacob Faibussowitsch   if (ldx < m || !xiscupm) {
784*4742e46bSJacob Faibussowitsch     const auto   copy_mode = xiscupm ? cupmMemcpyDeviceToDevice : cupmMemcpyDeviceToHost;
785*4742e46bSJacob Faibussowitsch     PetscScalar *x;
786*4742e46bSJacob Faibussowitsch 
787*4742e46bSJacob Faibussowitsch     // The logic here is that the data is not yet in either X's GPU array or its CPU
788*4742e46bSJacob Faibussowitsch     // array. There is nothing in the interface to say where the user would like it to end up.
789*4742e46bSJacob Faibussowitsch     // So we choose the GPU, because it is the faster option
790*4742e46bSJacob Faibussowitsch     if (xiscupm) {
791*4742e46bSJacob Faibussowitsch       x = DeviceArrayWrite(dctx, X);
792*4742e46bSJacob Faibussowitsch     } else {
793*4742e46bSJacob Faibussowitsch       PetscCall(MatDenseGetArray(X, &x));
794*4742e46bSJacob Faibussowitsch     }
795*4742e46bSJacob Faibussowitsch     PetscCall(PetscCUPMMemcpy2DAsync(x, ldx, y, ldy, k, nrhs, copy_mode, stream));
796*4742e46bSJacob Faibussowitsch     if (!xiscupm) PetscCall(MatDenseRestoreArray(X, &x));
797*4742e46bSJacob Faibussowitsch     PetscCallCUPM(cupmFreeAsync(y, stream));
798*4742e46bSJacob Faibussowitsch   }
799*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
800*4742e46bSJacob Faibussowitsch }
801*4742e46bSJacob Faibussowitsch 
802*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
803*4742e46bSJacob Faibussowitsch template <bool transpose>
804*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMultAdd_Dispatch_(Mat A, Vec xx, Vec yy, Vec zz) noexcept
805*4742e46bSJacob Faibussowitsch {
806*4742e46bSJacob Faibussowitsch   const auto         m = static_cast<cupmBlasInt_t>(A->rmap->n);
807*4742e46bSJacob Faibussowitsch   const auto         n = static_cast<cupmBlasInt_t>(A->cmap->n);
808*4742e46bSJacob Faibussowitsch   cupmBlasHandle_t   handle;
809*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
810*4742e46bSJacob Faibussowitsch 
811*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
812*4742e46bSJacob Faibussowitsch   if (yy && yy != zz) PetscCall(VecSeq_CUPM::Copy(yy, zz)); // mult add
813*4742e46bSJacob Faibussowitsch   if (!m || !n) {
814*4742e46bSJacob Faibussowitsch     // mult only
815*4742e46bSJacob Faibussowitsch     if (!yy) PetscCall(VecSeq_CUPM::Set(zz, 0.0));
816*4742e46bSJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
817*4742e46bSJacob Faibussowitsch   }
818*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(A, "Matrix-vector product %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " on backend\n", m, n));
819*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &handle));
820*4742e46bSJacob Faibussowitsch   {
821*4742e46bSJacob Faibussowitsch     constexpr auto op   = transpose ? CUPMBLAS_OP_T : CUPMBLAS_OP_N;
822*4742e46bSJacob Faibussowitsch     const auto     one  = cupmScalarCast(1.0);
823*4742e46bSJacob Faibussowitsch     const auto     zero = cupmScalarCast(0.0);
824*4742e46bSJacob Faibussowitsch     const auto     da   = DeviceArrayRead(dctx, A);
825*4742e46bSJacob Faibussowitsch     const auto     dxx  = VecSeq_CUPM::DeviceArrayRead(dctx, xx);
826*4742e46bSJacob Faibussowitsch     const auto     dzz  = VecSeq_CUPM::DeviceArrayReadWrite(dctx, zz);
827*4742e46bSJacob Faibussowitsch 
828*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
829*4742e46bSJacob Faibussowitsch     PetscCallCUPMBLAS(cupmBlasXgemv(handle, op, m, n, &one, da.cupmdata(), static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda), dxx.cupmdata(), 1, (yy ? &one : &zero), dzz.cupmdata(), 1));
830*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
831*4742e46bSJacob Faibussowitsch   }
832*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * m * n - (yy ? 0 : m)));
833*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
834*4742e46bSJacob Faibussowitsch }
835*4742e46bSJacob Faibussowitsch 
836*4742e46bSJacob Faibussowitsch // ==========================================================================================
837*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Conversion Dispatch
838*4742e46bSJacob Faibussowitsch // ==========================================================================================
839*4742e46bSJacob Faibussowitsch 
840*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
841*4742e46bSJacob Faibussowitsch template <bool to_host>
842*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_Dispatch_(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept
843*4742e46bSJacob Faibussowitsch {
844*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
845*4742e46bSJacob Faibussowitsch   if (reuse == MAT_REUSE_MATRIX || reuse == MAT_INITIAL_MATRIX) {
846*4742e46bSJacob Faibussowitsch     // TODO these cases should be optimized
847*4742e46bSJacob Faibussowitsch     PetscCall(MatConvert_Basic(M, type, reuse, newmat));
848*4742e46bSJacob Faibussowitsch   } else {
849*4742e46bSJacob Faibussowitsch     const auto B    = *newmat;
850*4742e46bSJacob Faibussowitsch     const auto pobj = PetscObjectCast(B);
851*4742e46bSJacob Faibussowitsch 
852*4742e46bSJacob Faibussowitsch     if (to_host) {
853*4742e46bSJacob Faibussowitsch       PetscCall(BindToCPU(B, PETSC_TRUE));
854*4742e46bSJacob Faibussowitsch       PetscCall(Reset(B));
855*4742e46bSJacob Faibussowitsch     } else {
856*4742e46bSJacob Faibussowitsch       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
857*4742e46bSJacob Faibussowitsch     }
858*4742e46bSJacob Faibussowitsch 
859*4742e46bSJacob Faibussowitsch     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecSeq_CUPM::VECCUPM(), &B->defaultvectype));
860*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATSEQDENSE : MATSEQDENSECUPM()));
861*4742e46bSJacob Faibussowitsch     // cvec might be the wrong VecType, destroy and rebuild it if necessary
862*4742e46bSJacob Faibussowitsch     // REVIEW ME: this is possibly very inefficient
863*4742e46bSJacob Faibussowitsch     PetscCall(VecDestroy(&MatIMPLCast(B)->cvec));
864*4742e46bSJacob Faibussowitsch 
865*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatConvert_seqdensecupm_seqdense_C(), nullptr, Convert_SeqDenseCUPM_SeqDense);
866*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
867*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
868*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
869*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
870*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
871*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
872*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
873*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
874*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
875*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_seqaij_seqdensecupm_C(), nullptr, MatProductSetFromOptions_SeqAIJ_SeqDense);
876*4742e46bSJacob Faibussowitsch 
877*4742e46bSJacob Faibussowitsch     if (to_host) {
878*4742e46bSJacob Faibussowitsch       B->offloadmask = PETSC_OFFLOAD_CPU;
879*4742e46bSJacob Faibussowitsch     } else {
880*4742e46bSJacob Faibussowitsch       Mat_SeqDenseCUPM *mcu;
881*4742e46bSJacob Faibussowitsch 
882*4742e46bSJacob Faibussowitsch       PetscCall(PetscNew(&mcu));
883*4742e46bSJacob Faibussowitsch       B->spptr       = mcu;
884*4742e46bSJacob Faibussowitsch       B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; // REVIEW ME: why not offload host??
885*4742e46bSJacob Faibussowitsch       PetscCall(BindToCPU(B, PETSC_FALSE));
886*4742e46bSJacob Faibussowitsch     }
887*4742e46bSJacob Faibussowitsch 
888*4742e46bSJacob Faibussowitsch     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
889*4742e46bSJacob Faibussowitsch     MatSetOp_CUPM(to_host, B, destroy, MatDestroy_SeqDense, Destroy);
890*4742e46bSJacob Faibussowitsch   }
891*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
892*4742e46bSJacob Faibussowitsch }
893*4742e46bSJacob Faibussowitsch 
894*4742e46bSJacob Faibussowitsch // ==========================================================================================
895*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Public API
896*4742e46bSJacob Faibussowitsch // ==========================================================================================
897*4742e46bSJacob Faibussowitsch 
898*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
899*4742e46bSJacob Faibussowitsch inline constexpr MatType MatDense_Seq_CUPM<T>::MATIMPLCUPM_() noexcept
900*4742e46bSJacob Faibussowitsch {
901*4742e46bSJacob Faibussowitsch   return MATSEQDENSECUPM();
902*4742e46bSJacob Faibussowitsch }
903*4742e46bSJacob Faibussowitsch 
904*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
905*4742e46bSJacob Faibussowitsch inline constexpr typename MatDense_Seq_CUPM<T>::Mat_SeqDenseCUPM *MatDense_Seq_CUPM<T>::MatCUPMCast(Mat m) noexcept
906*4742e46bSJacob Faibussowitsch {
907*4742e46bSJacob Faibussowitsch   return static_cast<Mat_SeqDenseCUPM *>(m->spptr);
908*4742e46bSJacob Faibussowitsch }
909*4742e46bSJacob Faibussowitsch 
910*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
911*4742e46bSJacob Faibussowitsch inline constexpr Mat_SeqDense *MatDense_Seq_CUPM<T>::MatIMPLCast_(Mat m) noexcept
912*4742e46bSJacob Faibussowitsch {
913*4742e46bSJacob Faibussowitsch   return static_cast<Mat_SeqDense *>(m->data);
914*4742e46bSJacob Faibussowitsch }
915*4742e46bSJacob Faibussowitsch 
916*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
917*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_Seq_CUPM<T>::MatConvert_seqdensecupm_seqdense_C() noexcept
918*4742e46bSJacob Faibussowitsch {
919*4742e46bSJacob Faibussowitsch   return T == device::cupm::DeviceType::CUDA ? "MatConvert_seqdensecuda_seqdense_C" : "MatConvert_seqdensehip_seqdense_C";
920*4742e46bSJacob Faibussowitsch }
921*4742e46bSJacob Faibussowitsch 
922*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
923*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_Seq_CUPM<T>::MatProductSetFromOptions_seqaij_seqdensecupm_C() noexcept
924*4742e46bSJacob Faibussowitsch {
925*4742e46bSJacob Faibussowitsch   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_seqaij_seqdensecuda_C" : "MatProductSetFromOptions_seqaij_seqdensehip_C";
926*4742e46bSJacob Faibussowitsch }
927*4742e46bSJacob Faibussowitsch 
928*4742e46bSJacob Faibussowitsch // ==========================================================================================
929*4742e46bSJacob Faibussowitsch 
930*4742e46bSJacob Faibussowitsch // MatCreate_SeqDenseCUPM()
931*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
932*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Create(Mat A) noexcept
933*4742e46bSJacob Faibussowitsch {
934*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
935*4742e46bSJacob Faibussowitsch   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
936*4742e46bSJacob Faibussowitsch   PetscCall(MatCreate_SeqDense(A));
937*4742e46bSJacob Faibussowitsch   PetscCall(Convert_SeqDense_SeqDenseCUPM(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A));
938*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
939*4742e46bSJacob Faibussowitsch }
940*4742e46bSJacob Faibussowitsch 
941*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
942*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Destroy(Mat A) noexcept
943*4742e46bSJacob Faibussowitsch {
944*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
945*4742e46bSJacob Faibussowitsch   // prevent copying back data if we own the data pointer
946*4742e46bSJacob Faibussowitsch   if (!MatIMPLCast(A)->user_alloc) A->offloadmask = PETSC_OFFLOAD_CPU;
947*4742e46bSJacob Faibussowitsch   PetscCall(Convert_SeqDenseCUPM_SeqDense(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A));
948*4742e46bSJacob Faibussowitsch   PetscCall(MatDestroy_SeqDense(A));
949*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
950*4742e46bSJacob Faibussowitsch }
951*4742e46bSJacob Faibussowitsch 
952*4742e46bSJacob Faibussowitsch // obj->ops->setup()
953*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
954*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetUp(Mat A) noexcept
955*4742e46bSJacob Faibussowitsch {
956*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
957*4742e46bSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(A->rmap));
958*4742e46bSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(A->cmap));
959*4742e46bSJacob Faibussowitsch   if (!A->preallocated) {
960*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
961*4742e46bSJacob Faibussowitsch 
962*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
963*4742e46bSJacob Faibussowitsch     PetscCall(SetPreallocation(A, dctx));
964*4742e46bSJacob Faibussowitsch   }
965*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
966*4742e46bSJacob Faibussowitsch }
967*4742e46bSJacob Faibussowitsch 
968*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
969*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Reset(Mat A) noexcept
970*4742e46bSJacob Faibussowitsch {
971*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
972*4742e46bSJacob Faibussowitsch   if (const auto mcu = MatCUPMCast(A)) {
973*4742e46bSJacob Faibussowitsch     cupmStream_t stream;
974*4742e46bSJacob Faibussowitsch 
975*4742e46bSJacob Faibussowitsch     PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME());
976*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&stream));
977*4742e46bSJacob Faibussowitsch     if (!mcu->d_user_alloc) PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream));
978*4742e46bSJacob Faibussowitsch     PetscCallCUPM(cupmFreeAsync(mcu->d_fact_tau, stream));
979*4742e46bSJacob Faibussowitsch     PetscCallCUPM(cupmFreeAsync(mcu->d_fact_ipiv, stream));
980*4742e46bSJacob Faibussowitsch     PetscCallCUPM(cupmFreeAsync(mcu->d_fact_info, stream));
981*4742e46bSJacob Faibussowitsch     PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream));
982*4742e46bSJacob Faibussowitsch     PetscCall(VecDestroy(&mcu->workvec));
983*4742e46bSJacob Faibussowitsch     PetscCall(PetscFree(A->spptr /* mcu */));
984*4742e46bSJacob Faibussowitsch   }
985*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
986*4742e46bSJacob Faibussowitsch }
987*4742e46bSJacob Faibussowitsch 
988*4742e46bSJacob Faibussowitsch // ==========================================================================================
989*4742e46bSJacob Faibussowitsch 
990*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
991*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::BindToCPU(Mat A, PetscBool to_host) noexcept
992*4742e46bSJacob Faibussowitsch {
993*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
994*4742e46bSJacob Faibussowitsch   const auto pobj  = PetscObjectCast(A);
995*4742e46bSJacob Faibussowitsch 
996*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
997*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
998*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
999*4742e46bSJacob Faibussowitsch   A->boundtocpu = to_host;
1000*4742e46bSJacob Faibussowitsch   PetscCall(PetscStrFreeAllocpy(to_host ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
1001*4742e46bSJacob Faibussowitsch   if (to_host) {
1002*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
1003*4742e46bSJacob Faibussowitsch 
1004*4742e46bSJacob Faibussowitsch     // make sure we have an up-to-date copy on the CPU
1005*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
1006*4742e46bSJacob Faibussowitsch     PetscCall(DeviceToHost_(A, dctx));
1007*4742e46bSJacob Faibussowitsch   } else {
1008*4742e46bSJacob Faibussowitsch     PetscBool iscupm;
1009*4742e46bSJacob Faibussowitsch 
1010*4742e46bSJacob Faibussowitsch     if (auto &cvec = mimpl->cvec) {
1011*4742e46bSJacob Faibussowitsch       PetscCall(PetscObjectTypeCompare(PetscObjectCast(cvec), VecSeq_CUPM::VECSEQCUPM(), &iscupm));
1012*4742e46bSJacob Faibussowitsch       if (!iscupm) PetscCall(VecDestroy(&cvec));
1013*4742e46bSJacob Faibussowitsch     }
1014*4742e46bSJacob Faibussowitsch     if (auto &cmat = mimpl->cmat) {
1015*4742e46bSJacob Faibussowitsch       PetscCall(PetscObjectTypeCompare(PetscObjectCast(cmat), MATSEQDENSECUPM(), &iscupm));
1016*4742e46bSJacob Faibussowitsch       if (!iscupm) PetscCall(MatDestroy(&cmat));
1017*4742e46bSJacob Faibussowitsch     }
1018*4742e46bSJacob Faibussowitsch   }
1019*4742e46bSJacob Faibussowitsch 
1020*4742e46bSJacob Faibussowitsch   // ============================================================
1021*4742e46bSJacob Faibussowitsch   // Composed ops
1022*4742e46bSJacob Faibussowitsch   // ============================================================
1023*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArray_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>);
1024*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayRead_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>);
1025*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayWrite_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>);
1026*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ_WRITE>);
1027*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ_WRITE>);
1028*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayReadAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ>);
1029*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayReadAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ>);
1030*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayWriteAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_WRITE>);
1031*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayWriteAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_WRITE>);
1032*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
1033*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
1034*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
1035*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
1036*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
1037*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
1038*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseGetSubMatrix_C", MatDenseGetSubMatrix_SeqDense, GetSubMatrix);
1039*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreSubMatrix_C", MatDenseRestoreSubMatrix_SeqDense, RestoreSubMatrix);
1040*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(to_host, pobj, "MatQRFactor_C", MatQRFactor_SeqDense, SolveQR::Factor);
1041*4742e46bSJacob Faibussowitsch   // always the same
1042*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction(pobj, "MatDenseSetLDA_C", MatDenseSetLDA_SeqDense));
1043*4742e46bSJacob Faibussowitsch 
1044*4742e46bSJacob Faibussowitsch   // ============================================================
1045*4742e46bSJacob Faibussowitsch   // Function pointer ops
1046*4742e46bSJacob Faibussowitsch   // ============================================================
1047*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, duplicate, MatDuplicate_SeqDense, Duplicate);
1048*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, mult, MatMult_SeqDense, [](Mat A, Vec xx, Vec yy) { return MatMultAdd_Dispatch_</* transpose */ false>(A, xx, nullptr, yy); });
1049*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, multtranspose, MatMultTranspose_SeqDense, [](Mat A, Vec xx, Vec yy) { return MatMultAdd_Dispatch_</* transpose */ true>(A, xx, nullptr, yy); });
1050*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, multadd, MatMultAdd_SeqDense, MatMultAdd_Dispatch_</* transpose */ false>);
1051*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, multtransposeadd, MatMultTransposeAdd_SeqDense, MatMultAdd_Dispatch_</* transpose */ true>);
1052*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, matmultnumeric, MatMatMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ false, /* transpose_B */ false>);
1053*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, mattransposemultnumeric, MatMatTransposeMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ false, /* transpose_B */ true>);
1054*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, transposematmultnumeric, MatTransposeMatMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ true, /* transpose_B */ false>);
1055*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, axpy, MatAXPY_SeqDense, AXPY);
1056*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, choleskyfactor, MatCholeskyFactor_SeqDense, SolveCholesky::Factor);
1057*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, lufactor, MatLUFactor_SeqDense, SolveLU::Factor);
1058*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, getcolumnvector, MatGetColumnVector_SeqDense, GetColumnVector);
1059*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, scale, MatScale_SeqDense, Scale);
1060*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, shift, MatShift_SeqDense, Shift);
1061*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, copy, MatCopy_SeqDense, Copy);
1062*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, zeroentries, MatZeroEntries_SeqDense, ZeroEntries);
1063*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, setup, MatSetUp_SeqDense, SetUp);
1064*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(to_host, A, setrandom, MatSetRandom_SeqDense, SetRandom);
1065*4742e46bSJacob Faibussowitsch   // seemingly always the same
1066*4742e46bSJacob Faibussowitsch   A->ops->productsetfromoptions = MatProductSetFromOptions_SeqDense;
1067*4742e46bSJacob Faibussowitsch 
1068*4742e46bSJacob Faibussowitsch   if (const auto cmat = mimpl->cmat) PetscCall(MatBindToCPU(cmat, to_host));
1069*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1070*4742e46bSJacob Faibussowitsch }
1071*4742e46bSJacob Faibussowitsch 
1072*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1073*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_SeqDenseCUPM_SeqDense(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept
1074*4742e46bSJacob Faibussowitsch {
1075*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1076*4742e46bSJacob Faibussowitsch   PetscCall(Convert_Dispatch_</* to host */ true>(M, type, reuse, newmat));
1077*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1078*4742e46bSJacob Faibussowitsch }
1079*4742e46bSJacob Faibussowitsch 
1080*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1081*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_SeqDense_SeqDenseCUPM(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept
1082*4742e46bSJacob Faibussowitsch {
1083*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1084*4742e46bSJacob Faibussowitsch   PetscCall(Convert_Dispatch_</* to host */ false>(M, type, reuse, newmat));
1085*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1086*4742e46bSJacob Faibussowitsch }
1087*4742e46bSJacob Faibussowitsch 
1088*4742e46bSJacob Faibussowitsch // ==========================================================================================
1089*4742e46bSJacob Faibussowitsch 
1090*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1091*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode access>
1092*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetArray(Mat m, PetscScalar **array, PetscDeviceContext dctx) noexcept
1093*4742e46bSJacob Faibussowitsch {
1094*4742e46bSJacob Faibussowitsch   constexpr auto hostmem     = PetscMemTypeHost(mtype);
1095*4742e46bSJacob Faibussowitsch   constexpr auto read_access = PetscMemoryAccessRead(access);
1096*4742e46bSJacob Faibussowitsch 
1097*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1098*4742e46bSJacob Faibussowitsch   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
1099*4742e46bSJacob Faibussowitsch   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1100*4742e46bSJacob Faibussowitsch   if (hostmem) {
1101*4742e46bSJacob Faibussowitsch     if (read_access) {
1102*4742e46bSJacob Faibussowitsch       PetscCall(DeviceToHost_(m, dctx));
1103*4742e46bSJacob Faibussowitsch     } else if (!MatIMPLCast(m)->v) {
1104*4742e46bSJacob Faibussowitsch       // MatCreateSeqDenseCUPM may not allocate CPU memory. Allocate if needed
1105*4742e46bSJacob Faibussowitsch       PetscCall(MatSeqDenseSetPreallocation(m, nullptr));
1106*4742e46bSJacob Faibussowitsch     }
1107*4742e46bSJacob Faibussowitsch     *array = MatIMPLCast(m)->v;
1108*4742e46bSJacob Faibussowitsch   } else {
1109*4742e46bSJacob Faibussowitsch     if (read_access) {
1110*4742e46bSJacob Faibussowitsch       PetscCall(HostToDevice_(m, dctx));
1111*4742e46bSJacob Faibussowitsch     } else if (!MatCUPMCast(m)->d_v) {
1112*4742e46bSJacob Faibussowitsch       // write-only
1113*4742e46bSJacob Faibussowitsch       PetscCall(SetPreallocation(m, dctx, nullptr));
1114*4742e46bSJacob Faibussowitsch     }
1115*4742e46bSJacob Faibussowitsch     *array = MatCUPMCast(m)->d_v;
1116*4742e46bSJacob Faibussowitsch   }
1117*4742e46bSJacob Faibussowitsch   if (PetscMemoryAccessWrite(access)) {
1118*4742e46bSJacob Faibussowitsch     m->offloadmask = hostmem ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
1119*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectStateIncrease(PetscObjectCast(m)));
1120*4742e46bSJacob Faibussowitsch   }
1121*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1122*4742e46bSJacob Faibussowitsch }
1123*4742e46bSJacob Faibussowitsch 
1124*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1125*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode access>
1126*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreArray(Mat m, PetscScalar **array, PetscDeviceContext) noexcept
1127*4742e46bSJacob Faibussowitsch {
1128*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1129*4742e46bSJacob Faibussowitsch   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
1130*4742e46bSJacob Faibussowitsch   if (PetscMemoryAccessWrite(access)) {
1131*4742e46bSJacob Faibussowitsch     // WRITE or READ_WRITE
1132*4742e46bSJacob Faibussowitsch     m->offloadmask = PetscMemTypeHost(mtype) ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
1133*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectStateIncrease(PetscObjectCast(m)));
1134*4742e46bSJacob Faibussowitsch   }
1135*4742e46bSJacob Faibussowitsch   *array = nullptr;
1136*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1137*4742e46bSJacob Faibussowitsch }
1138*4742e46bSJacob Faibussowitsch 
1139*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1140*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
1141*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetArrayAndMemType(Mat m, PetscScalar **array, PetscMemType *mtype, PetscDeviceContext dctx) noexcept
1142*4742e46bSJacob Faibussowitsch {
1143*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1144*4742e46bSJacob Faibussowitsch   PetscCall(GetArray<PETSC_MEMTYPE_DEVICE, access>(m, array, dctx));
1145*4742e46bSJacob Faibussowitsch   if (mtype) *mtype = PETSC_MEMTYPE_CUPM();
1146*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1147*4742e46bSJacob Faibussowitsch }
1148*4742e46bSJacob Faibussowitsch 
1149*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1150*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
1151*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreArrayAndMemType(Mat m, PetscScalar **array, PetscDeviceContext dctx) noexcept
1152*4742e46bSJacob Faibussowitsch {
1153*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1154*4742e46bSJacob Faibussowitsch   PetscCall(RestoreArray<PETSC_MEMTYPE_DEVICE, access>(m, array, dctx));
1155*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1156*4742e46bSJacob Faibussowitsch }
1157*4742e46bSJacob Faibussowitsch 
1158*4742e46bSJacob Faibussowitsch // ==========================================================================================
1159*4742e46bSJacob Faibussowitsch 
1160*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1161*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
1162*4742e46bSJacob Faibussowitsch {
1163*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
1164*4742e46bSJacob Faibussowitsch   const auto mcu   = MatCUPMCast(A);
1165*4742e46bSJacob Faibussowitsch 
1166*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1167*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
1168*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
1169*4742e46bSJacob Faibussowitsch   PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME());
1170*4742e46bSJacob Faibussowitsch   if (mimpl->v) {
1171*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
1172*4742e46bSJacob Faibussowitsch 
1173*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
1174*4742e46bSJacob Faibussowitsch     PetscCall(HostToDevice_(A, dctx));
1175*4742e46bSJacob Faibussowitsch   }
1176*4742e46bSJacob Faibussowitsch   mcu->unplacedarray         = util::exchange(mcu->d_v, const_cast<PetscScalar *>(array));
1177*4742e46bSJacob Faibussowitsch   mcu->d_unplaced_user_alloc = util::exchange(mcu->d_user_alloc, PETSC_TRUE);
1178*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1179*4742e46bSJacob Faibussowitsch }
1180*4742e46bSJacob Faibussowitsch 
1181*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1182*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
1183*4742e46bSJacob Faibussowitsch {
1184*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
1185*4742e46bSJacob Faibussowitsch   const auto mcu   = MatCUPMCast(A);
1186*4742e46bSJacob Faibussowitsch 
1187*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1188*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
1189*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
1190*4742e46bSJacob Faibussowitsch   PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME());
1191*4742e46bSJacob Faibussowitsch   if (!mcu->d_user_alloc) {
1192*4742e46bSJacob Faibussowitsch     cupmStream_t stream;
1193*4742e46bSJacob Faibussowitsch 
1194*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&stream));
1195*4742e46bSJacob Faibussowitsch     PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream));
1196*4742e46bSJacob Faibussowitsch   }
1197*4742e46bSJacob Faibussowitsch   mcu->d_v          = const_cast<PetscScalar *>(array);
1198*4742e46bSJacob Faibussowitsch   mcu->d_user_alloc = PETSC_FALSE;
1199*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1200*4742e46bSJacob Faibussowitsch }
1201*4742e46bSJacob Faibussowitsch 
1202*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1203*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ResetArray(Mat A) noexcept
1204*4742e46bSJacob Faibussowitsch {
1205*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
1206*4742e46bSJacob Faibussowitsch   const auto mcu   = MatCUPMCast(A);
1207*4742e46bSJacob Faibussowitsch 
1208*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1209*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
1210*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
1211*4742e46bSJacob Faibussowitsch   if (mimpl->v) {
1212*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
1213*4742e46bSJacob Faibussowitsch 
1214*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
1215*4742e46bSJacob Faibussowitsch     PetscCall(HostToDevice_(A, dctx));
1216*4742e46bSJacob Faibussowitsch   }
1217*4742e46bSJacob Faibussowitsch   mcu->d_v          = util::exchange(mcu->unplacedarray, nullptr);
1218*4742e46bSJacob Faibussowitsch   mcu->d_user_alloc = mcu->d_unplaced_user_alloc;
1219*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1220*4742e46bSJacob Faibussowitsch }
1221*4742e46bSJacob Faibussowitsch 
1222*4742e46bSJacob Faibussowitsch // ==========================================================================================
1223*4742e46bSJacob Faibussowitsch 
1224*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1225*4742e46bSJacob Faibussowitsch template <bool transpose_A, bool transpose_B>
1226*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMatMult_Numeric_Dispatch(Mat A, Mat B, Mat C) noexcept
1227*4742e46bSJacob Faibussowitsch {
1228*4742e46bSJacob Faibussowitsch   cupmBlasInt_t      m, n, k;
1229*4742e46bSJacob Faibussowitsch   PetscBool          Aiscupm, Biscupm;
1230*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1231*4742e46bSJacob Faibussowitsch   cupmBlasHandle_t   handle;
1232*4742e46bSJacob Faibussowitsch 
1233*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1234*4742e46bSJacob Faibussowitsch   PetscCall(PetscCUPMBlasIntCast(C->rmap->n, &m));
1235*4742e46bSJacob Faibussowitsch   PetscCall(PetscCUPMBlasIntCast(C->cmap->n, &n));
1236*4742e46bSJacob Faibussowitsch   PetscCall(PetscCUPMBlasIntCast(transpose_A ? A->rmap->n : A->cmap->n, &k));
1237*4742e46bSJacob Faibussowitsch   if (!m || !n || !k) PetscFunctionReturn(PETSC_SUCCESS);
1238*4742e46bSJacob Faibussowitsch 
1239*4742e46bSJacob Faibussowitsch   // we may end up with SEQDENSE as one of the arguments
1240*4742e46bSJacob Faibussowitsch   // REVIEW ME: how? and why is it not B and C????????
1241*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), MATSEQDENSECUPM(), &Aiscupm));
1242*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(B), MATSEQDENSECUPM(), &Biscupm));
1243*4742e46bSJacob Faibussowitsch   if (!Aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A));
1244*4742e46bSJacob Faibussowitsch   if (!Biscupm) PetscCall(MatConvert(B, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &B));
1245*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix-Matrix product %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " on backend\n", m, k, n));
1246*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &handle));
1247*4742e46bSJacob Faibussowitsch 
1248*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1249*4742e46bSJacob Faibussowitsch   {
1250*4742e46bSJacob Faibussowitsch     const auto one  = cupmScalarCast(1.0);
1251*4742e46bSJacob Faibussowitsch     const auto zero = cupmScalarCast(0.0);
1252*4742e46bSJacob Faibussowitsch     const auto da   = DeviceArrayRead(dctx, A);
1253*4742e46bSJacob Faibussowitsch     const auto db   = DeviceArrayRead(dctx, B);
1254*4742e46bSJacob Faibussowitsch     const auto dc   = DeviceArrayWrite(dctx, C);
1255*4742e46bSJacob Faibussowitsch     PetscInt   alda, blda, clda;
1256*4742e46bSJacob Faibussowitsch 
1257*4742e46bSJacob Faibussowitsch     PetscCall(MatDenseGetLDA(A, &alda));
1258*4742e46bSJacob Faibussowitsch     PetscCall(MatDenseGetLDA(B, &blda));
1259*4742e46bSJacob Faibussowitsch     PetscCall(MatDenseGetLDA(C, &clda));
1260*4742e46bSJacob Faibussowitsch     PetscCallCUPMBLAS(cupmBlasXgemm(handle, transpose_A ? CUPMBLAS_OP_T : CUPMBLAS_OP_N, transpose_B ? CUPMBLAS_OP_T : CUPMBLAS_OP_N, m, n, k, &one, da.cupmdata(), alda, db.cupmdata(), blda, &zero, dc.cupmdata(), clda));
1261*4742e46bSJacob Faibussowitsch   }
1262*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
1263*4742e46bSJacob Faibussowitsch 
1264*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(1.0 * m * n * k + 1.0 * m * n * (k - 1)));
1265*4742e46bSJacob Faibussowitsch   if (!Aiscupm) PetscCall(MatConvert(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A));
1266*4742e46bSJacob Faibussowitsch   if (!Biscupm) PetscCall(MatConvert(B, MATSEQDENSE, MAT_INPLACE_MATRIX, &B));
1267*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1268*4742e46bSJacob Faibussowitsch }
1269*4742e46bSJacob Faibussowitsch 
1270*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1271*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Copy(Mat A, Mat B, MatStructure str) noexcept
1272*4742e46bSJacob Faibussowitsch {
1273*4742e46bSJacob Faibussowitsch   const auto m = A->rmap->n;
1274*4742e46bSJacob Faibussowitsch   const auto n = A->cmap->n;
1275*4742e46bSJacob Faibussowitsch 
1276*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1277*4742e46bSJacob Faibussowitsch   PetscAssert(m == B->rmap->n && n == B->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "size(B) != size(A)");
1278*4742e46bSJacob Faibussowitsch   // The two matrices must have the same copy implementation to be eligible for fast copy
1279*4742e46bSJacob Faibussowitsch   if (A->ops->copy == B->ops->copy) {
1280*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
1281*4742e46bSJacob Faibussowitsch     cupmStream_t       stream;
1282*4742e46bSJacob Faibussowitsch 
1283*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx, &stream));
1284*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
1285*4742e46bSJacob Faibussowitsch     {
1286*4742e46bSJacob Faibussowitsch       const auto va = DeviceArrayRead(dctx, A);
1287*4742e46bSJacob Faibussowitsch       const auto vb = DeviceArrayWrite(dctx, B);
1288*4742e46bSJacob Faibussowitsch       // order is important, DeviceArrayRead/Write() might call SetPreallocation() which sets
1289*4742e46bSJacob Faibussowitsch       // lda!
1290*4742e46bSJacob Faibussowitsch       const auto lda_a = MatIMPLCast(A)->lda;
1291*4742e46bSJacob Faibussowitsch       const auto lda_b = MatIMPLCast(B)->lda;
1292*4742e46bSJacob Faibussowitsch 
1293*4742e46bSJacob Faibussowitsch       if (lda_a > m || lda_b > m) {
1294*4742e46bSJacob Faibussowitsch         PetscAssert(lda_b > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "B lda (%" PetscBLASInt_FMT ") must be > 0 at this point, this indicates Mat%sSetPreallocation() was not called when it should have been!", lda_b, cupmNAME());
1295*4742e46bSJacob Faibussowitsch         PetscAssert(lda_a > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A lda (%" PetscBLASInt_FMT ") must be > 0 at this point, this indicates Mat%sSetPreallocation() was not called when it should have been!", lda_a, cupmNAME());
1296*4742e46bSJacob Faibussowitsch         PetscCall(PetscCUPMMemcpy2DAsync(vb.data(), lda_b, va.data(), lda_a, m, n, cupmMemcpyDeviceToDevice, stream));
1297*4742e46bSJacob Faibussowitsch       } else {
1298*4742e46bSJacob Faibussowitsch         PetscCall(PetscCUPMMemcpyAsync(vb.data(), va.data(), m * n, cupmMemcpyDeviceToDevice, stream));
1299*4742e46bSJacob Faibussowitsch       }
1300*4742e46bSJacob Faibussowitsch     }
1301*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
1302*4742e46bSJacob Faibussowitsch   } else {
1303*4742e46bSJacob Faibussowitsch     PetscCall(MatCopy_Basic(A, B, str));
1304*4742e46bSJacob Faibussowitsch   }
1305*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1306*4742e46bSJacob Faibussowitsch }
1307*4742e46bSJacob Faibussowitsch 
1308*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1309*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ZeroEntries(Mat m) noexcept
1310*4742e46bSJacob Faibussowitsch {
1311*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1312*4742e46bSJacob Faibussowitsch   cupmStream_t       stream;
1313*4742e46bSJacob Faibussowitsch 
1314*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1315*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &stream));
1316*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1317*4742e46bSJacob Faibussowitsch   {
1318*4742e46bSJacob Faibussowitsch     const auto va  = DeviceArrayWrite(dctx, m);
1319*4742e46bSJacob Faibussowitsch     const auto lda = MatIMPLCast(m)->lda;
1320*4742e46bSJacob Faibussowitsch     const auto ma  = m->rmap->n;
1321*4742e46bSJacob Faibussowitsch     const auto na  = m->cmap->n;
1322*4742e46bSJacob Faibussowitsch 
1323*4742e46bSJacob Faibussowitsch     if (lda > ma) {
1324*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMemset2DAsync(va.data(), lda, 0, ma, na, stream));
1325*4742e46bSJacob Faibussowitsch     } else {
1326*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMemsetAsync(va.data(), 0, ma * na, stream));
1327*4742e46bSJacob Faibussowitsch     }
1328*4742e46bSJacob Faibussowitsch   }
1329*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
1330*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1331*4742e46bSJacob Faibussowitsch }
1332*4742e46bSJacob Faibussowitsch 
1333*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1334*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Scale(Mat A, PetscScalar alpha) noexcept
1335*4742e46bSJacob Faibussowitsch {
1336*4742e46bSJacob Faibussowitsch   const auto         m = static_cast<cupmBlasInt_t>(A->rmap->n);
1337*4742e46bSJacob Faibussowitsch   const auto         n = static_cast<cupmBlasInt_t>(A->cmap->n);
1338*4742e46bSJacob Faibussowitsch   const auto         N = m * n;
1339*4742e46bSJacob Faibussowitsch   cupmBlasHandle_t   handle;
1340*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1341*4742e46bSJacob Faibussowitsch 
1342*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1343*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(A, "Performing Scale %d x %d on backend\n", m, n));
1344*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &handle));
1345*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1346*4742e46bSJacob Faibussowitsch   {
1347*4742e46bSJacob Faibussowitsch     const auto cu_alpha = cupmScalarCast(alpha);
1348*4742e46bSJacob Faibussowitsch     const auto da       = DeviceArrayReadWrite(dctx, A);
1349*4742e46bSJacob Faibussowitsch     const auto lda      = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda);
1350*4742e46bSJacob Faibussowitsch 
1351*4742e46bSJacob Faibussowitsch     if (lda > m) {
1352*4742e46bSJacob Faibussowitsch       for (cupmBlasInt_t j = 0; j < n; ++j) PetscCallCUPMBLAS(cupmBlasXscal(handle, m, &cu_alpha, da.cupmdata() + lda * j, 1));
1353*4742e46bSJacob Faibussowitsch     } else {
1354*4742e46bSJacob Faibussowitsch       PetscCallCUPMBLAS(cupmBlasXscal(handle, N, &cu_alpha, da.cupmdata(), 1));
1355*4742e46bSJacob Faibussowitsch     }
1356*4742e46bSJacob Faibussowitsch   }
1357*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
1358*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(N));
1359*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1360*4742e46bSJacob Faibussowitsch }
1361*4742e46bSJacob Faibussowitsch 
1362*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1363*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept
1364*4742e46bSJacob Faibussowitsch {
1365*4742e46bSJacob Faibussowitsch   const auto         m = A->rmap->n;
1366*4742e46bSJacob Faibussowitsch   const auto         n = A->cmap->n;
1367*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1368*4742e46bSJacob Faibussowitsch 
1369*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1370*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx));
1371*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(A, "Performing Shift %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", m, n));
1372*4742e46bSJacob Faibussowitsch   PetscCall(PointwiseUnaryTransform(A, 0, m, n, dctx, device::cupm::functors::make_plus_equals(alpha)));
1373*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1374*4742e46bSJacob Faibussowitsch }
1375*4742e46bSJacob Faibussowitsch 
1376*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1377*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::AXPY(Mat Y, PetscScalar alpha, Mat X, MatStructure) noexcept
1378*4742e46bSJacob Faibussowitsch {
1379*4742e46bSJacob Faibussowitsch   const auto         m_x = X->rmap->n, m_y = Y->rmap->n;
1380*4742e46bSJacob Faibussowitsch   const auto         n_x = X->cmap->n, n_y = Y->cmap->n;
1381*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1382*4742e46bSJacob Faibussowitsch   cupmBlasHandle_t   handle;
1383*4742e46bSJacob Faibussowitsch 
1384*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1385*4742e46bSJacob Faibussowitsch   if (!m_x || !n_x) PetscFunctionReturn(PETSC_SUCCESS);
1386*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(Y, "Performing AXPY %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", m_y, n_y));
1387*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &handle));
1388*4742e46bSJacob Faibussowitsch   {
1389*4742e46bSJacob Faibussowitsch     const auto N        = m_x * n_x;
1390*4742e46bSJacob Faibussowitsch     const auto dx       = DeviceArrayRead(dctx, X);
1391*4742e46bSJacob Faibussowitsch     const auto dy       = alpha == 0.0 ? DeviceArrayWrite(dctx, Y).cupmdata() : DeviceArrayReadWrite(dctx, Y).cupmdata();
1392*4742e46bSJacob Faibussowitsch     const auto ldax     = static_cast<cupmBlasInt_t>(MatIMPLCast(X)->lda);
1393*4742e46bSJacob Faibussowitsch     const auto lday     = static_cast<cupmBlasInt_t>(MatIMPLCast(Y)->lda);
1394*4742e46bSJacob Faibussowitsch     const auto cu_alpha = cupmScalarCast(alpha);
1395*4742e46bSJacob Faibussowitsch 
1396*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
1397*4742e46bSJacob Faibussowitsch     if (ldax > m_x || lday > m_x) {
1398*4742e46bSJacob Faibussowitsch       for (cupmBlasInt_t j = 0; j < n_x; j++) PetscCallCUPMBLAS(cupmBlasXaxpy(handle, m_x, &cu_alpha, dx.cupmdata() + j * ldax, 1, dy + j * lday, 1));
1399*4742e46bSJacob Faibussowitsch     } else {
1400*4742e46bSJacob Faibussowitsch       PetscCallCUPMBLAS(cupmBlasXaxpy(handle, N, &cu_alpha, dx.cupmdata(), 1, dy, 1));
1401*4742e46bSJacob Faibussowitsch     }
1402*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
1403*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(PetscMax(2 * N - 1, 0)));
1404*4742e46bSJacob Faibussowitsch   }
1405*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1406*4742e46bSJacob Faibussowitsch }
1407*4742e46bSJacob Faibussowitsch 
1408*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1409*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Duplicate(Mat A, MatDuplicateOption opt, Mat *B) noexcept
1410*4742e46bSJacob Faibussowitsch {
1411*4742e46bSJacob Faibussowitsch   const auto         hopt = (opt == MAT_COPY_VALUES && A->offloadmask != PETSC_OFFLOAD_CPU) ? MAT_DO_NOT_COPY_VALUES : opt;
1412*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1413*4742e46bSJacob Faibussowitsch 
1414*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1415*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx));
1416*4742e46bSJacob Faibussowitsch   // do not call SetPreallocation() yet, we call it afterwards??
1417*4742e46bSJacob Faibussowitsch   PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->n, A->cmap->n, nullptr, B, dctx, /* preallocate */ false));
1418*4742e46bSJacob Faibussowitsch   PetscCall(MatDuplicateNoCreate_SeqDense(*B, A, hopt));
1419*4742e46bSJacob Faibussowitsch   if (opt == MAT_COPY_VALUES && hopt != MAT_COPY_VALUES) PetscCall(Copy(A, *B, SAME_NONZERO_PATTERN));
1420*4742e46bSJacob Faibussowitsch   // allocate memory if needed
1421*4742e46bSJacob Faibussowitsch   if (opt != MAT_COPY_VALUES && !MatCUPMCast(*B)->d_v) PetscCall(SetPreallocation(*B, dctx));
1422*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1423*4742e46bSJacob Faibussowitsch }
1424*4742e46bSJacob Faibussowitsch 
1425*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1426*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetRandom(Mat A, PetscRandom rng) noexcept
1427*4742e46bSJacob Faibussowitsch {
1428*4742e46bSJacob Faibussowitsch   PetscBool device;
1429*4742e46bSJacob Faibussowitsch 
1430*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1431*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare(PetscObjectCast(rng), PETSCDEVICERAND(), &device));
1432*4742e46bSJacob Faibussowitsch   if (device) {
1433*4742e46bSJacob Faibussowitsch     const auto         m = A->rmap->n;
1434*4742e46bSJacob Faibussowitsch     const auto         n = A->cmap->n;
1435*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
1436*4742e46bSJacob Faibussowitsch 
1437*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
1438*4742e46bSJacob Faibussowitsch     {
1439*4742e46bSJacob Faibussowitsch       const auto a = DeviceArrayWrite(dctx, A);
1440*4742e46bSJacob Faibussowitsch       PetscInt   lda;
1441*4742e46bSJacob Faibussowitsch 
1442*4742e46bSJacob Faibussowitsch       PetscCall(MatDenseGetLDA(A, &lda));
1443*4742e46bSJacob Faibussowitsch       if (lda > m) {
1444*4742e46bSJacob Faibussowitsch         for (PetscInt i = 0; i < n; i++) PetscCall(PetscRandomGetValues(rng, m, a.data() + i * lda));
1445*4742e46bSJacob Faibussowitsch       } else {
1446*4742e46bSJacob Faibussowitsch         PetscInt mn;
1447*4742e46bSJacob Faibussowitsch 
1448*4742e46bSJacob Faibussowitsch         PetscCall(PetscIntMultError(m, n, &mn));
1449*4742e46bSJacob Faibussowitsch         PetscCall(PetscRandomGetValues(rng, mn, a));
1450*4742e46bSJacob Faibussowitsch       }
1451*4742e46bSJacob Faibussowitsch     }
1452*4742e46bSJacob Faibussowitsch   } else {
1453*4742e46bSJacob Faibussowitsch     PetscCall(MatSetRandom_SeqDense(A, rng));
1454*4742e46bSJacob Faibussowitsch   }
1455*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1456*4742e46bSJacob Faibussowitsch }
1457*4742e46bSJacob Faibussowitsch 
1458*4742e46bSJacob Faibussowitsch // ==========================================================================================
1459*4742e46bSJacob Faibussowitsch 
1460*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1461*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetColumnVector(Mat A, Vec v, PetscInt col) noexcept
1462*4742e46bSJacob Faibussowitsch {
1463*4742e46bSJacob Faibussowitsch   const auto         offloadmask = A->offloadmask;
1464*4742e46bSJacob Faibussowitsch   const auto         n           = A->rmap->n;
1465*4742e46bSJacob Faibussowitsch   const auto         col_offset  = [&](const PetscScalar *ptr) { return ptr + col * MatIMPLCast(A)->lda; };
1466*4742e46bSJacob Faibussowitsch   PetscBool          viscupm;
1467*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1468*4742e46bSJacob Faibussowitsch   cupmStream_t       stream;
1469*4742e46bSJacob Faibussowitsch 
1470*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1471*4742e46bSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(v), &viscupm, VecSeq_CUPM::VECSEQCUPM(), VecSeq_CUPM::VECMPICUPM(), VecSeq_CUPM::VECCUPM(), ""));
1472*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &stream));
1473*4742e46bSJacob Faibussowitsch   if (viscupm && !v->boundtocpu) {
1474*4742e46bSJacob Faibussowitsch     const auto x = VecSeq_CUPM::DeviceArrayWrite(dctx, v);
1475*4742e46bSJacob Faibussowitsch 
1476*4742e46bSJacob Faibussowitsch     // update device data
1477*4742e46bSJacob Faibussowitsch     if (PetscOffloadDevice(offloadmask)) {
1478*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMemcpyAsync(x.data(), col_offset(DeviceArrayRead(dctx, A)), n, cupmMemcpyDeviceToDevice, stream));
1479*4742e46bSJacob Faibussowitsch     } else {
1480*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMemcpyAsync(x.data(), col_offset(HostArrayRead(dctx, A)), n, cupmMemcpyHostToDevice, stream));
1481*4742e46bSJacob Faibussowitsch     }
1482*4742e46bSJacob Faibussowitsch   } else {
1483*4742e46bSJacob Faibussowitsch     PetscScalar *x;
1484*4742e46bSJacob Faibussowitsch 
1485*4742e46bSJacob Faibussowitsch     // update host data
1486*4742e46bSJacob Faibussowitsch     PetscCall(VecGetArrayWrite(v, &x));
1487*4742e46bSJacob Faibussowitsch     if (PetscOffloadUnallocated(offloadmask) || PetscOffloadHost(offloadmask)) {
1488*4742e46bSJacob Faibussowitsch       PetscCall(PetscArraycpy(x, col_offset(HostArrayRead(dctx, A)), n));
1489*4742e46bSJacob Faibussowitsch     } else if (PetscOffloadDevice(offloadmask)) {
1490*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMemcpyAsync(x, col_offset(DeviceArrayRead(dctx, A)), n, cupmMemcpyDeviceToHost, stream));
1491*4742e46bSJacob Faibussowitsch     }
1492*4742e46bSJacob Faibussowitsch     PetscCall(VecRestoreArrayWrite(v, &x));
1493*4742e46bSJacob Faibussowitsch   }
1494*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1495*4742e46bSJacob Faibussowitsch }
1496*4742e46bSJacob Faibussowitsch 
1497*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1498*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
1499*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
1500*4742e46bSJacob Faibussowitsch {
1501*4742e46bSJacob Faibussowitsch   using namespace vec::cupm;
1502*4742e46bSJacob Faibussowitsch   const auto         mimpl = MatIMPLCast(A);
1503*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1504*4742e46bSJacob Faibussowitsch 
1505*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1506*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
1507*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
1508*4742e46bSJacob Faibussowitsch   mimpl->vecinuse = col + 1;
1509*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx));
1510*4742e46bSJacob Faibussowitsch   PetscCall(GetArray<PETSC_MEMTYPE_DEVICE, access>(A, const_cast<PetscScalar **>(&mimpl->ptrinuse), dctx));
1511*4742e46bSJacob Faibussowitsch   if (!mimpl->cvec) {
1512*4742e46bSJacob Faibussowitsch     // we pass the data of A, to prevent allocating needless GPU memory the first time
1513*4742e46bSJacob Faibussowitsch     // VecCUPMPlaceArray is called
1514*4742e46bSJacob Faibussowitsch     PetscCall(VecCreateSeqCUPMWithArraysAsync<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->bs, A->rmap->n, nullptr, mimpl->ptrinuse, &mimpl->cvec));
1515*4742e46bSJacob Faibussowitsch   }
1516*4742e46bSJacob Faibussowitsch   PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(mimpl->lda)));
1517*4742e46bSJacob Faibussowitsch   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
1518*4742e46bSJacob Faibussowitsch   *v = mimpl->cvec;
1519*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1520*4742e46bSJacob Faibussowitsch }
1521*4742e46bSJacob Faibussowitsch 
1522*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1523*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
1524*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
1525*4742e46bSJacob Faibussowitsch {
1526*4742e46bSJacob Faibussowitsch   using namespace vec::cupm;
1527*4742e46bSJacob Faibussowitsch   const auto         mimpl = MatIMPLCast(A);
1528*4742e46bSJacob Faibussowitsch   const auto         cvec  = mimpl->cvec;
1529*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1530*4742e46bSJacob Faibussowitsch 
1531*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1532*4742e46bSJacob Faibussowitsch   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
1533*4742e46bSJacob Faibussowitsch   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
1534*4742e46bSJacob Faibussowitsch   mimpl->vecinuse = 0;
1535*4742e46bSJacob Faibussowitsch   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
1536*4742e46bSJacob Faibussowitsch   PetscCall(VecCUPMResetArrayAsync<T>(cvec));
1537*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx));
1538*4742e46bSJacob Faibussowitsch   PetscCall(RestoreArray<PETSC_MEMTYPE_DEVICE, access>(A, const_cast<PetscScalar **>(&mimpl->ptrinuse), dctx));
1539*4742e46bSJacob Faibussowitsch   if (v) *v = nullptr;
1540*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1541*4742e46bSJacob Faibussowitsch }
1542*4742e46bSJacob Faibussowitsch 
1543*4742e46bSJacob Faibussowitsch // ==========================================================================================
1544*4742e46bSJacob Faibussowitsch 
1545*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1546*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetFactor(Mat A, MatFactorType ftype, Mat *fact_out) noexcept
1547*4742e46bSJacob Faibussowitsch {
1548*4742e46bSJacob Faibussowitsch   Mat                fact;
1549*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1550*4742e46bSJacob Faibussowitsch 
1551*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1552*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx));
1553*4742e46bSJacob Faibussowitsch   PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->n, A->cmap->n, nullptr, &fact, dctx, /* preallocate */ false));
1554*4742e46bSJacob Faibussowitsch   fact->factortype = ftype;
1555*4742e46bSJacob Faibussowitsch   switch (ftype) {
1556*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_LU:
1557*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_ILU: // fall-through
1558*4742e46bSJacob Faibussowitsch     fact->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqDense;
1559*4742e46bSJacob Faibussowitsch     fact->ops->ilufactorsymbolic = MatLUFactorSymbolic_SeqDense;
1560*4742e46bSJacob Faibussowitsch     break;
1561*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_CHOLESKY:
1562*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_ICC: // fall-through
1563*4742e46bSJacob Faibussowitsch     fact->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqDense;
1564*4742e46bSJacob Faibussowitsch     break;
1565*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_QR: {
1566*4742e46bSJacob Faibussowitsch     const auto pobj = PetscObjectCast(fact);
1567*4742e46bSJacob Faibussowitsch 
1568*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectComposeFunction(pobj, "MatQRFactor_C", MatQRFactor_SeqDense));
1569*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectComposeFunction(pobj, "MatQRFactorSymbolic_C", MatQRFactorSymbolic_SeqDense));
1570*4742e46bSJacob Faibussowitsch   } break;
1571*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_NONE:
1572*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_ILUDT:     // fall-through
1573*4742e46bSJacob Faibussowitsch   case MAT_FACTOR_NUM_TYPES: // fall-through
1574*4742e46bSJacob Faibussowitsch     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s not supported", MatFactorTypes[ftype]);
1575*4742e46bSJacob Faibussowitsch   }
1576*4742e46bSJacob Faibussowitsch   PetscCall(PetscStrFreeAllocpy(MATSOLVERCUPM(), &fact->solvertype));
1577*4742e46bSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_LU));
1578*4742e46bSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_ILU));
1579*4742e46bSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_CHOLESKY));
1580*4742e46bSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_ICC));
1581*4742e46bSJacob Faibussowitsch   *fact_out = fact;
1582*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1583*4742e46bSJacob Faibussowitsch }
1584*4742e46bSJacob Faibussowitsch 
1585*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1586*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::InvertFactors(Mat A) noexcept
1587*4742e46bSJacob Faibussowitsch {
1588*4742e46bSJacob Faibussowitsch   const auto         mimpl = MatIMPLCast(A);
1589*4742e46bSJacob Faibussowitsch   const auto         mcu   = MatCUPMCast(A);
1590*4742e46bSJacob Faibussowitsch   const auto         n     = static_cast<cupmBlasInt_t>(A->cmap->n);
1591*4742e46bSJacob Faibussowitsch   cupmSolverHandle_t handle;
1592*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1593*4742e46bSJacob Faibussowitsch   cupmStream_t       stream;
1594*4742e46bSJacob Faibussowitsch 
1595*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1596*4742e46bSJacob Faibussowitsch   #if PetscDefined(HAVE_CUDA) && PetscDefined(USING_NVCC)
1597*4742e46bSJacob Faibussowitsch   // HIP appears to have this by default??
1598*4742e46bSJacob Faibussowitsch   PetscCheck(PETSC_PKG_CUDA_VERSION_GE(10, 1, 0), PETSC_COMM_SELF, PETSC_ERR_SUP, "Upgrade to CUDA version 10.1.0 or higher");
1599*4742e46bSJacob Faibussowitsch   #endif
1600*4742e46bSJacob Faibussowitsch   if (!n || !A->rmap->n) PetscFunctionReturn(PETSC_SUCCESS);
1601*4742e46bSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_CHOLESKY, PETSC_COMM_SELF, PETSC_ERR_LIB, "Factor type %s not implemented", MatFactorTypes[A->factortype]);
1602*4742e46bSJacob Faibussowitsch   // spd
1603*4742e46bSJacob Faibussowitsch   PetscCheck(!mcu->d_fact_ipiv, PETSC_COMM_SELF, PETSC_ERR_LIB, "%sDnsytri not implemented", cupmSolverName());
1604*4742e46bSJacob Faibussowitsch 
1605*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx, &handle, &stream));
1606*4742e46bSJacob Faibussowitsch   {
1607*4742e46bSJacob Faibussowitsch     const auto    da  = DeviceArrayReadWrite(dctx, A);
1608*4742e46bSJacob Faibussowitsch     const auto    lda = static_cast<cupmBlasInt_t>(mimpl->lda);
1609*4742e46bSJacob Faibussowitsch     cupmBlasInt_t il;
1610*4742e46bSJacob Faibussowitsch 
1611*4742e46bSJacob Faibussowitsch     PetscCallCUPMSOLVER(cupmSolverXpotri_bufferSize(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, &il));
1612*4742e46bSJacob Faibussowitsch     if (il > mcu->d_fact_lwork) {
1613*4742e46bSJacob Faibussowitsch       mcu->d_fact_lwork = il;
1614*4742e46bSJacob Faibussowitsch       PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream));
1615*4742e46bSJacob Faibussowitsch       PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, il, stream));
1616*4742e46bSJacob Faibussowitsch     }
1617*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeBegin());
1618*4742e46bSJacob Faibussowitsch     PetscCallCUPMSOLVER(cupmSolverXpotri(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info));
1619*4742e46bSJacob Faibussowitsch     PetscCall(PetscLogGpuTimeEnd());
1620*4742e46bSJacob Faibussowitsch   }
1621*4742e46bSJacob Faibussowitsch   PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream));
1622*4742e46bSJacob Faibussowitsch   // TODO (write cuda kernel)
1623*4742e46bSJacob Faibussowitsch   PetscCall(MatSeqDenseSymmetrize_Private(A, PETSC_TRUE));
1624*4742e46bSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(1.0 * n * n * n / 3.0));
1625*4742e46bSJacob Faibussowitsch 
1626*4742e46bSJacob Faibussowitsch   A->ops->solve          = nullptr;
1627*4742e46bSJacob Faibussowitsch   A->ops->solvetranspose = nullptr;
1628*4742e46bSJacob Faibussowitsch   A->ops->matsolve       = nullptr;
1629*4742e46bSJacob Faibussowitsch   A->factortype          = MAT_FACTOR_NONE;
1630*4742e46bSJacob Faibussowitsch 
1631*4742e46bSJacob Faibussowitsch   PetscCall(PetscFree(A->solvertype));
1632*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1633*4742e46bSJacob Faibussowitsch }
1634*4742e46bSJacob Faibussowitsch 
1635*4742e46bSJacob Faibussowitsch // ==========================================================================================
1636*4742e46bSJacob Faibussowitsch 
1637*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1638*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetSubMatrix(Mat A, PetscInt rbegin, PetscInt rend, PetscInt cbegin, PetscInt cend, Mat *mat) noexcept
1639*4742e46bSJacob Faibussowitsch {
1640*4742e46bSJacob Faibussowitsch   const auto         mimpl        = MatIMPLCast(A);
1641*4742e46bSJacob Faibussowitsch   const auto         array_offset = [&](PetscScalar *ptr) { return ptr + rbegin + static_cast<std::size_t>(cbegin) * mimpl->lda; };
1642*4742e46bSJacob Faibussowitsch   const auto         n            = rend - rbegin;
1643*4742e46bSJacob Faibussowitsch   const auto         m            = cend - cbegin;
1644*4742e46bSJacob Faibussowitsch   auto              &cmat         = mimpl->cmat;
1645*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
1646*4742e46bSJacob Faibussowitsch 
1647*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1648*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
1649*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
1650*4742e46bSJacob Faibussowitsch   mimpl->matinuse = cbegin + 1;
1651*4742e46bSJacob Faibussowitsch 
1652*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx));
1653*4742e46bSJacob Faibussowitsch   PetscCall(HostToDevice_(A, dctx));
1654*4742e46bSJacob Faibussowitsch 
1655*4742e46bSJacob Faibussowitsch   if (cmat && ((m != cmat->cmap->N) || (n != cmat->rmap->N))) PetscCall(MatDestroy(&cmat));
1656*4742e46bSJacob Faibussowitsch   {
1657*4742e46bSJacob Faibussowitsch     const auto device_array = array_offset(MatCUPMCast(A)->d_v);
1658*4742e46bSJacob Faibussowitsch 
1659*4742e46bSJacob Faibussowitsch     if (cmat) {
1660*4742e46bSJacob Faibussowitsch       PetscCall(PlaceArray(cmat, device_array));
1661*4742e46bSJacob Faibussowitsch     } else {
1662*4742e46bSJacob Faibussowitsch       PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), n, m, device_array, &cmat, dctx));
1663*4742e46bSJacob Faibussowitsch     }
1664*4742e46bSJacob Faibussowitsch   }
1665*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseSetLDA(cmat, mimpl->lda));
1666*4742e46bSJacob Faibussowitsch   // place CPU array if present but do not copy any data
1667*4742e46bSJacob Faibussowitsch   if (const auto host_array = mimpl->v) {
1668*4742e46bSJacob Faibussowitsch     cmat->offloadmask = PETSC_OFFLOAD_GPU;
1669*4742e46bSJacob Faibussowitsch     PetscCall(MatDensePlaceArray(cmat, array_offset(host_array)));
1670*4742e46bSJacob Faibussowitsch   }
1671*4742e46bSJacob Faibussowitsch 
1672*4742e46bSJacob Faibussowitsch   cmat->offloadmask = A->offloadmask;
1673*4742e46bSJacob Faibussowitsch   *mat              = cmat;
1674*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1675*4742e46bSJacob Faibussowitsch }
1676*4742e46bSJacob Faibussowitsch 
1677*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1678*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreSubMatrix(Mat A, Mat *m) noexcept
1679*4742e46bSJacob Faibussowitsch {
1680*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
1681*4742e46bSJacob Faibussowitsch   const auto cmat  = mimpl->cmat;
1682*4742e46bSJacob Faibussowitsch   const auto reset = static_cast<bool>(mimpl->v);
1683*4742e46bSJacob Faibussowitsch   bool       copy, was_offload_host;
1684*4742e46bSJacob Faibussowitsch 
1685*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1686*4742e46bSJacob Faibussowitsch   PetscCheck(mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetSubMatrix() first");
1687*4742e46bSJacob Faibussowitsch   PetscCheck(cmat, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column matrix");
1688*4742e46bSJacob Faibussowitsch   PetscCheck(*m == cmat, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Not the matrix obtained from MatDenseGetSubMatrix()");
1689*4742e46bSJacob Faibussowitsch   mimpl->matinuse = 0;
1690*4742e46bSJacob Faibussowitsch 
1691*4742e46bSJacob Faibussowitsch   // calls to ResetArray may change it, so save it here
1692*4742e46bSJacob Faibussowitsch   was_offload_host = cmat->offloadmask == PETSC_OFFLOAD_CPU;
1693*4742e46bSJacob Faibussowitsch   if (was_offload_host && !reset) {
1694*4742e46bSJacob Faibussowitsch     copy = true;
1695*4742e46bSJacob Faibussowitsch     PetscCall(MatSeqDenseSetPreallocation(A, nullptr));
1696*4742e46bSJacob Faibussowitsch   } else {
1697*4742e46bSJacob Faibussowitsch     copy = false;
1698*4742e46bSJacob Faibussowitsch   }
1699*4742e46bSJacob Faibussowitsch 
1700*4742e46bSJacob Faibussowitsch   PetscCall(ResetArray(cmat));
1701*4742e46bSJacob Faibussowitsch   if (reset) PetscCall(MatDenseResetArray(cmat));
1702*4742e46bSJacob Faibussowitsch   if (copy) {
1703*4742e46bSJacob Faibussowitsch     PetscDeviceContext dctx;
1704*4742e46bSJacob Faibussowitsch 
1705*4742e46bSJacob Faibussowitsch     PetscCall(GetHandles_(&dctx));
1706*4742e46bSJacob Faibussowitsch     PetscCall(DeviceToHost_(A, dctx));
1707*4742e46bSJacob Faibussowitsch   } else {
1708*4742e46bSJacob Faibussowitsch     A->offloadmask = was_offload_host ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
1709*4742e46bSJacob Faibussowitsch   }
1710*4742e46bSJacob Faibussowitsch 
1711*4742e46bSJacob Faibussowitsch   cmat->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
1712*4742e46bSJacob Faibussowitsch   *m                = nullptr;
1713*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1714*4742e46bSJacob Faibussowitsch }
1715*4742e46bSJacob Faibussowitsch 
1716*4742e46bSJacob Faibussowitsch // ==========================================================================================
1717*4742e46bSJacob Faibussowitsch 
1718*4742e46bSJacob Faibussowitsch namespace
1719*4742e46bSJacob Faibussowitsch {
1720*4742e46bSJacob Faibussowitsch 
1721*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1722*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatMatMultNumeric_SeqDenseCUPM_SeqDenseCUPM(Mat A, Mat B, Mat C, PetscBool TA, PetscBool TB) noexcept
1723*4742e46bSJacob Faibussowitsch {
1724*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1725*4742e46bSJacob Faibussowitsch   if (TA) {
1726*4742e46bSJacob Faibussowitsch     if (TB) {
1727*4742e46bSJacob Faibussowitsch       PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<true, true>(A, B, C));
1728*4742e46bSJacob Faibussowitsch     } else {
1729*4742e46bSJacob Faibussowitsch       PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<true, false>(A, B, C));
1730*4742e46bSJacob Faibussowitsch     }
1731*4742e46bSJacob Faibussowitsch   } else {
1732*4742e46bSJacob Faibussowitsch     if (TB) {
1733*4742e46bSJacob Faibussowitsch       PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<false, true>(A, B, C));
1734*4742e46bSJacob Faibussowitsch     } else {
1735*4742e46bSJacob Faibussowitsch       PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<false, false>(A, B, C));
1736*4742e46bSJacob Faibussowitsch     }
1737*4742e46bSJacob Faibussowitsch   }
1738*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1739*4742e46bSJacob Faibussowitsch }
1740*4742e46bSJacob Faibussowitsch 
1741*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1742*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatSolverTypeRegister_DENSECUPM() noexcept
1743*4742e46bSJacob Faibussowitsch {
1744*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
1745*4742e46bSJacob Faibussowitsch   for (auto ftype : util::make_array(MAT_FACTOR_LU, MAT_FACTOR_CHOLESKY, MAT_FACTOR_QR)) {
1746*4742e46bSJacob Faibussowitsch     PetscCall(MatSolverTypeRegister(MatDense_Seq_CUPM<T>::MATSOLVERCUPM(), MATSEQDENSE, ftype, MatDense_Seq_CUPM<T>::GetFactor));
1747*4742e46bSJacob Faibussowitsch     PetscCall(MatSolverTypeRegister(MatDense_Seq_CUPM<T>::MATSOLVERCUPM(), MatDense_Seq_CUPM<T>::MATSEQDENSECUPM(), ftype, MatDense_Seq_CUPM<T>::GetFactor));
1748*4742e46bSJacob Faibussowitsch   }
1749*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1750*4742e46bSJacob Faibussowitsch }
1751*4742e46bSJacob Faibussowitsch 
1752*4742e46bSJacob Faibussowitsch } // anonymous namespace
1753*4742e46bSJacob Faibussowitsch 
1754*4742e46bSJacob Faibussowitsch } // namespace impl
1755*4742e46bSJacob Faibussowitsch 
1756*4742e46bSJacob Faibussowitsch } // namespace cupm
1757*4742e46bSJacob Faibussowitsch 
1758*4742e46bSJacob Faibussowitsch } // namespace mat
1759*4742e46bSJacob Faibussowitsch 
1760*4742e46bSJacob Faibussowitsch } // namespace Petsc
1761*4742e46bSJacob Faibussowitsch 
1762*4742e46bSJacob Faibussowitsch #endif // __cplusplus
1763*4742e46bSJacob Faibussowitsch 
1764*4742e46bSJacob Faibussowitsch #endif // PETSCMATSEQDENSECUPM_HPP
1765