1*4742e46bSJacob Faibussowitsch #ifndef PETSCMATSEQDENSECUPM_HPP 2*4742e46bSJacob Faibussowitsch #define PETSCMATSEQDENSECUPM_HPP 3*4742e46bSJacob Faibussowitsch 4*4742e46bSJacob Faibussowitsch #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/ 5*4742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/seq/dense.h> 6*4742e46bSJacob Faibussowitsch 7*4742e46bSJacob Faibussowitsch #if defined(__cplusplus) 8*4742e46bSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> // PetscDeviceContextGetOptionalNullContext_Internal() 9*4742e46bSJacob Faibussowitsch #include <petsc/private/randomimpl.h> // _p_PetscRandom 10*4742e46bSJacob Faibussowitsch #include <petsc/private/vecimpl.h> // _p_Vec 11*4742e46bSJacob Faibussowitsch #include <petsc/private/cupmobject.hpp> 12*4742e46bSJacob Faibussowitsch #include <petsc/private/cupmsolverinterface.hpp> 13*4742e46bSJacob Faibussowitsch 14*4742e46bSJacob Faibussowitsch #include <petsc/private/cpp/type_traits.hpp> // PetscObjectCast() 15*4742e46bSJacob Faibussowitsch #include <petsc/private/cpp/utility.hpp> // util::exchange() 16*4742e46bSJacob Faibussowitsch 17*4742e46bSJacob Faibussowitsch #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp> // for VecSeq_CUPM 18*4742e46bSJacob Faibussowitsch 19*4742e46bSJacob Faibussowitsch namespace Petsc 20*4742e46bSJacob Faibussowitsch { 21*4742e46bSJacob Faibussowitsch 22*4742e46bSJacob Faibussowitsch namespace mat 23*4742e46bSJacob Faibussowitsch { 24*4742e46bSJacob Faibussowitsch 25*4742e46bSJacob Faibussowitsch namespace cupm 26*4742e46bSJacob Faibussowitsch { 27*4742e46bSJacob Faibussowitsch 28*4742e46bSJacob Faibussowitsch namespace impl 29*4742e46bSJacob Faibussowitsch { 30*4742e46bSJacob Faibussowitsch 31*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 32*4742e46bSJacob Faibussowitsch class MatDense_Seq_CUPM : MatDense_CUPM<T, MatDense_Seq_CUPM<T>> { 33*4742e46bSJacob Faibussowitsch public: 34*4742e46bSJacob Faibussowitsch MATDENSECUPM_HEADER(T, MatDense_Seq_CUPM<T>); 35*4742e46bSJacob Faibussowitsch 36*4742e46bSJacob Faibussowitsch private: 37*4742e46bSJacob Faibussowitsch struct Mat_SeqDenseCUPM { 38*4742e46bSJacob Faibussowitsch PetscScalar *d_v; // pointer to the matrix on the GPU 39*4742e46bSJacob Faibussowitsch PetscScalar *unplacedarray; // if one called MatCUPMDensePlaceArray(), this is where it stashed the original 40*4742e46bSJacob Faibussowitsch bool d_user_alloc; 41*4742e46bSJacob Faibussowitsch bool d_unplaced_user_alloc; 42*4742e46bSJacob Faibussowitsch // factorization support 43*4742e46bSJacob Faibussowitsch cupmBlasInt_t *d_fact_ipiv; // device pivots 44*4742e46bSJacob Faibussowitsch cupmScalar_t *d_fact_tau; // device QR tau vector 45*4742e46bSJacob Faibussowitsch cupmBlasInt_t *d_fact_info; // device info 46*4742e46bSJacob Faibussowitsch cupmScalar_t *d_fact_work; // device workspace 47*4742e46bSJacob Faibussowitsch cupmBlasInt_t d_fact_lwork; // size of device workspace 48*4742e46bSJacob Faibussowitsch // workspace 49*4742e46bSJacob Faibussowitsch Vec workvec; 50*4742e46bSJacob Faibussowitsch }; 51*4742e46bSJacob Faibussowitsch 52*4742e46bSJacob Faibussowitsch static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept; 53*4742e46bSJacob Faibussowitsch 54*4742e46bSJacob Faibussowitsch static PetscErrorCode HostToDevice_(Mat, PetscDeviceContext) noexcept; 55*4742e46bSJacob Faibussowitsch static PetscErrorCode DeviceToHost_(Mat, PetscDeviceContext) noexcept; 56*4742e46bSJacob Faibussowitsch 57*4742e46bSJacob Faibussowitsch static PetscErrorCode CheckCUPMSolverInfo_(const cupmBlasInt_t *, cupmStream_t) noexcept; 58*4742e46bSJacob Faibussowitsch 59*4742e46bSJacob Faibussowitsch template <typename Derived> 60*4742e46bSJacob Faibussowitsch struct SolveCommon; 61*4742e46bSJacob Faibussowitsch struct SolveQR; 62*4742e46bSJacob Faibussowitsch struct SolveCholesky; 63*4742e46bSJacob Faibussowitsch struct SolveLU; 64*4742e46bSJacob Faibussowitsch 65*4742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 66*4742e46bSJacob Faibussowitsch static PetscErrorCode MatSolve_Factored_Dispatch_(Mat, Vec, Vec) noexcept; 67*4742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 68*4742e46bSJacob Faibussowitsch static PetscErrorCode MatMatSolve_Factored_Dispatch_(Mat, Mat, Mat) noexcept; 69*4742e46bSJacob Faibussowitsch template <bool transpose> 70*4742e46bSJacob Faibussowitsch static PetscErrorCode MatMultAdd_Dispatch_(Mat, Vec, Vec, Vec) noexcept; 71*4742e46bSJacob Faibussowitsch 72*4742e46bSJacob Faibussowitsch template <bool to_host> 73*4742e46bSJacob Faibussowitsch static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept; 74*4742e46bSJacob Faibussowitsch 75*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr MatType MATIMPLCUPM_() noexcept; 76*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr Mat_SeqDense *MatIMPLCast_(Mat) noexcept; 77*4742e46bSJacob Faibussowitsch 78*4742e46bSJacob Faibussowitsch public: 79*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr Mat_SeqDenseCUPM *MatCUPMCast(Mat) noexcept; 80*4742e46bSJacob Faibussowitsch 81*4742e46bSJacob Faibussowitsch // define these by hand since they don't fit the above mold 82*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatConvert_seqdensecupm_seqdense_C() noexcept; 83*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_seqaij_seqdensecupm_C() noexcept; 84*4742e46bSJacob Faibussowitsch 85*4742e46bSJacob Faibussowitsch static PetscErrorCode Create(Mat) noexcept; 86*4742e46bSJacob Faibussowitsch static PetscErrorCode Destroy(Mat) noexcept; 87*4742e46bSJacob Faibussowitsch static PetscErrorCode SetUp(Mat) noexcept; 88*4742e46bSJacob Faibussowitsch static PetscErrorCode Reset(Mat) noexcept; 89*4742e46bSJacob Faibussowitsch 90*4742e46bSJacob Faibussowitsch static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept; 91*4742e46bSJacob Faibussowitsch static PetscErrorCode Convert_SeqDense_SeqDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept; 92*4742e46bSJacob Faibussowitsch static PetscErrorCode Convert_SeqDenseCUPM_SeqDense(Mat, MatType, MatReuse, Mat *) noexcept; 93*4742e46bSJacob Faibussowitsch 94*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 95*4742e46bSJacob Faibussowitsch static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext) noexcept; 96*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 97*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext) noexcept; 98*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 99*4742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayAndMemType(Mat, PetscScalar **, PetscMemType *, PetscDeviceContext) noexcept; 100*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 101*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayAndMemType(Mat, PetscScalar **, PetscDeviceContext) noexcept; 102*4742e46bSJacob Faibussowitsch 103*4742e46bSJacob Faibussowitsch private: 104*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 105*4742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept 106*4742e46bSJacob Faibussowitsch { 107*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 108*4742e46bSJacob Faibussowitsch 109*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 110*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 111*4742e46bSJacob Faibussowitsch PetscCall(GetArray<mtype, mode>(m, p, dctx)); 112*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 113*4742e46bSJacob Faibussowitsch } 114*4742e46bSJacob Faibussowitsch 115*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 116*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept 117*4742e46bSJacob Faibussowitsch { 118*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 119*4742e46bSJacob Faibussowitsch 120*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 121*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 122*4742e46bSJacob Faibussowitsch PetscCall(RestoreArray<mtype, mode>(m, p, dctx)); 123*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 124*4742e46bSJacob Faibussowitsch } 125*4742e46bSJacob Faibussowitsch 126*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode mode> 127*4742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayAndMemTypeC_(Mat m, PetscScalar **p, PetscMemType *tp) noexcept 128*4742e46bSJacob Faibussowitsch { 129*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 130*4742e46bSJacob Faibussowitsch 131*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 132*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 133*4742e46bSJacob Faibussowitsch PetscCall(GetArrayAndMemType<mode>(m, p, tp, dctx)); 134*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 135*4742e46bSJacob Faibussowitsch } 136*4742e46bSJacob Faibussowitsch 137*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode mode> 138*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayAndMemTypeC_(Mat m, PetscScalar **p) noexcept 139*4742e46bSJacob Faibussowitsch { 140*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 141*4742e46bSJacob Faibussowitsch 142*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 143*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 144*4742e46bSJacob Faibussowitsch PetscCall(RestoreArrayAndMemType<mode>(m, p, dctx)); 145*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 146*4742e46bSJacob Faibussowitsch } 147*4742e46bSJacob Faibussowitsch 148*4742e46bSJacob Faibussowitsch public: 149*4742e46bSJacob Faibussowitsch static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept; 150*4742e46bSJacob Faibussowitsch static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept; 151*4742e46bSJacob Faibussowitsch static PetscErrorCode ResetArray(Mat) noexcept; 152*4742e46bSJacob Faibussowitsch 153*4742e46bSJacob Faibussowitsch template <bool transpose_A, bool transpose_B> 154*4742e46bSJacob Faibussowitsch static PetscErrorCode MatMatMult_Numeric_Dispatch(Mat, Mat, Mat) noexcept; 155*4742e46bSJacob Faibussowitsch static PetscErrorCode Copy(Mat, Mat, MatStructure) noexcept; 156*4742e46bSJacob Faibussowitsch static PetscErrorCode ZeroEntries(Mat) noexcept; 157*4742e46bSJacob Faibussowitsch static PetscErrorCode Scale(Mat, PetscScalar) noexcept; 158*4742e46bSJacob Faibussowitsch static PetscErrorCode Shift(Mat, PetscScalar) noexcept; 159*4742e46bSJacob Faibussowitsch static PetscErrorCode AXPY(Mat, PetscScalar, Mat, MatStructure) noexcept; 160*4742e46bSJacob Faibussowitsch static PetscErrorCode Duplicate(Mat, MatDuplicateOption, Mat *) noexcept; 161*4742e46bSJacob Faibussowitsch static PetscErrorCode SetRandom(Mat, PetscRandom) noexcept; 162*4742e46bSJacob Faibussowitsch 163*4742e46bSJacob Faibussowitsch static PetscErrorCode GetColumnVector(Mat, Vec, PetscInt) noexcept; 164*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 165*4742e46bSJacob Faibussowitsch static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept; 166*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 167*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept; 168*4742e46bSJacob Faibussowitsch 169*4742e46bSJacob Faibussowitsch static PetscErrorCode GetFactor(Mat, MatFactorType, Mat *) noexcept; 170*4742e46bSJacob Faibussowitsch static PetscErrorCode InvertFactors(Mat) noexcept; 171*4742e46bSJacob Faibussowitsch 172*4742e46bSJacob Faibussowitsch static PetscErrorCode GetSubMatrix(Mat, PetscInt, PetscInt, PetscInt, PetscInt, Mat *) noexcept; 173*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreSubMatrix(Mat, Mat *) noexcept; 174*4742e46bSJacob Faibussowitsch }; 175*4742e46bSJacob Faibussowitsch 176*4742e46bSJacob Faibussowitsch } // namespace impl 177*4742e46bSJacob Faibussowitsch 178*4742e46bSJacob Faibussowitsch namespace 179*4742e46bSJacob Faibussowitsch { 180*4742e46bSJacob Faibussowitsch 181*4742e46bSJacob Faibussowitsch // Declare this here so that the functions below can make use of it 182*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 183*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateSeqDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept 184*4742e46bSJacob Faibussowitsch { 185*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 186*4742e46bSJacob Faibussowitsch PetscCall(impl::MatDense_Seq_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, m, n, data, A, dctx, preallocate)); 187*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 188*4742e46bSJacob Faibussowitsch } 189*4742e46bSJacob Faibussowitsch 190*4742e46bSJacob Faibussowitsch } // anonymous namespace 191*4742e46bSJacob Faibussowitsch 192*4742e46bSJacob Faibussowitsch namespace impl 193*4742e46bSJacob Faibussowitsch { 194*4742e46bSJacob Faibussowitsch 195*4742e46bSJacob Faibussowitsch // ========================================================================================== 196*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Utility 197*4742e46bSJacob Faibussowitsch // ========================================================================================== 198*4742e46bSJacob Faibussowitsch 199*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 200*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetPreallocation_(Mat m, PetscDeviceContext dctx, PetscScalar *user_device_array) noexcept 201*4742e46bSJacob Faibussowitsch { 202*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(m); 203*4742e46bSJacob Faibussowitsch const auto nrows = m->rmap->n; 204*4742e46bSJacob Faibussowitsch const auto ncols = m->cmap->n; 205*4742e46bSJacob Faibussowitsch auto &lda = MatIMPLCast(m)->lda; 206*4742e46bSJacob Faibussowitsch cupmStream_t stream; 207*4742e46bSJacob Faibussowitsch 208*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 209*4742e46bSJacob Faibussowitsch PetscCheckTypeName(m, MATSEQDENSECUPM()); 210*4742e46bSJacob Faibussowitsch PetscValidDeviceContext(dctx, 2); 211*4742e46bSJacob Faibussowitsch PetscCall(checkCupmBlasIntCast(nrows)); 212*4742e46bSJacob Faibussowitsch PetscCall(checkCupmBlasIntCast(ncols)); 213*4742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 214*4742e46bSJacob Faibussowitsch if (lda <= 0) lda = nrows; 215*4742e46bSJacob Faibussowitsch if (!mcu->d_user_alloc) PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream)); 216*4742e46bSJacob Faibussowitsch if (user_device_array) { 217*4742e46bSJacob Faibussowitsch mcu->d_user_alloc = PETSC_TRUE; 218*4742e46bSJacob Faibussowitsch mcu->d_v = user_device_array; 219*4742e46bSJacob Faibussowitsch } else { 220*4742e46bSJacob Faibussowitsch PetscInt size; 221*4742e46bSJacob Faibussowitsch 222*4742e46bSJacob Faibussowitsch mcu->d_user_alloc = PETSC_FALSE; 223*4742e46bSJacob Faibussowitsch PetscCall(PetscIntMultError(lda, ncols, &size)); 224*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_v, size, stream)); 225*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemsetAsync(mcu->d_v, 0, size, stream)); 226*4742e46bSJacob Faibussowitsch } 227*4742e46bSJacob Faibussowitsch m->offloadmask = PETSC_OFFLOAD_GPU; 228*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 229*4742e46bSJacob Faibussowitsch } 230*4742e46bSJacob Faibussowitsch 231*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 232*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::HostToDevice_(Mat m, PetscDeviceContext dctx) noexcept 233*4742e46bSJacob Faibussowitsch { 234*4742e46bSJacob Faibussowitsch const auto nrows = m->rmap->n; 235*4742e46bSJacob Faibussowitsch const auto ncols = m->cmap->n; 236*4742e46bSJacob Faibussowitsch const auto copy = m->offloadmask == PETSC_OFFLOAD_CPU || m->offloadmask == PETSC_OFFLOAD_UNALLOCATED; 237*4742e46bSJacob Faibussowitsch 238*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 239*4742e46bSJacob Faibussowitsch PetscCheckTypeName(m, MATSEQDENSECUPM()); 240*4742e46bSJacob Faibussowitsch if (m->boundtocpu) PetscFunctionReturn(PETSC_SUCCESS); 241*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(m, "%s matrix %" PetscInt_FMT " x %" PetscInt_FMT "\n", copy ? "Copy" : "Reusing", nrows, ncols)); 242*4742e46bSJacob Faibussowitsch if (copy) { 243*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(m); 244*4742e46bSJacob Faibussowitsch cupmStream_t stream; 245*4742e46bSJacob Faibussowitsch 246*4742e46bSJacob Faibussowitsch // Allocate GPU memory if not present 247*4742e46bSJacob Faibussowitsch if (!mcu->d_v) PetscCall(SetPreallocation(m, dctx)); 248*4742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 249*4742e46bSJacob Faibussowitsch PetscCall(PetscLogEventBegin(MAT_DenseCopyToGPU, m, 0, 0, 0)); 250*4742e46bSJacob Faibussowitsch { 251*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(m); 252*4742e46bSJacob Faibussowitsch const auto lda = mimpl->lda; 253*4742e46bSJacob Faibussowitsch const auto src = mimpl->v; 254*4742e46bSJacob Faibussowitsch const auto dest = mcu->d_v; 255*4742e46bSJacob Faibussowitsch 256*4742e46bSJacob Faibussowitsch if (lda > nrows) { 257*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(dest, lda, src, lda, nrows, ncols, cupmMemcpyHostToDevice, stream)); 258*4742e46bSJacob Faibussowitsch } else { 259*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(dest, src, lda * ncols, cupmMemcpyHostToDevice, stream)); 260*4742e46bSJacob Faibussowitsch } 261*4742e46bSJacob Faibussowitsch } 262*4742e46bSJacob Faibussowitsch PetscCall(PetscLogEventEnd(MAT_DenseCopyToGPU, m, 0, 0, 0)); 263*4742e46bSJacob Faibussowitsch // order important, ensure that offloadmask is PETSC_OFFLOAD_BOTH 264*4742e46bSJacob Faibussowitsch m->offloadmask = PETSC_OFFLOAD_BOTH; 265*4742e46bSJacob Faibussowitsch } 266*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 267*4742e46bSJacob Faibussowitsch } 268*4742e46bSJacob Faibussowitsch 269*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 270*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::DeviceToHost_(Mat m, PetscDeviceContext dctx) noexcept 271*4742e46bSJacob Faibussowitsch { 272*4742e46bSJacob Faibussowitsch const auto nrows = m->rmap->n; 273*4742e46bSJacob Faibussowitsch const auto ncols = m->cmap->n; 274*4742e46bSJacob Faibussowitsch const auto copy = m->offloadmask == PETSC_OFFLOAD_GPU; 275*4742e46bSJacob Faibussowitsch 276*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 277*4742e46bSJacob Faibussowitsch PetscCheckTypeName(m, MATSEQDENSECUPM()); 278*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(m, "%s matrix %" PetscInt_FMT " x %" PetscInt_FMT "\n", copy ? "Copy" : "Reusing", nrows, ncols)); 279*4742e46bSJacob Faibussowitsch if (copy) { 280*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(m); 281*4742e46bSJacob Faibussowitsch cupmStream_t stream; 282*4742e46bSJacob Faibussowitsch 283*4742e46bSJacob Faibussowitsch // MatCreateSeqDenseCUPM may not allocate CPU memory. Allocate if needed 284*4742e46bSJacob Faibussowitsch if (!mimpl->v) PetscCall(MatSeqDenseSetPreallocation(m, nullptr)); 285*4742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 286*4742e46bSJacob Faibussowitsch PetscCall(PetscLogEventBegin(MAT_DenseCopyFromGPU, m, 0, 0, 0)); 287*4742e46bSJacob Faibussowitsch { 288*4742e46bSJacob Faibussowitsch const auto lda = mimpl->lda; 289*4742e46bSJacob Faibussowitsch const auto dest = mimpl->v; 290*4742e46bSJacob Faibussowitsch const auto src = MatCUPMCast(m)->d_v; 291*4742e46bSJacob Faibussowitsch 292*4742e46bSJacob Faibussowitsch if (lda > nrows) { 293*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(dest, lda, src, lda, nrows, ncols, cupmMemcpyDeviceToHost, stream)); 294*4742e46bSJacob Faibussowitsch } else { 295*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(dest, src, lda * ncols, cupmMemcpyDeviceToHost, stream)); 296*4742e46bSJacob Faibussowitsch } 297*4742e46bSJacob Faibussowitsch } 298*4742e46bSJacob Faibussowitsch PetscCall(PetscLogEventEnd(MAT_DenseCopyFromGPU, m, 0, 0, 0)); 299*4742e46bSJacob Faibussowitsch // order is important, MatSeqDenseSetPreallocation() might set offloadmask 300*4742e46bSJacob Faibussowitsch m->offloadmask = PETSC_OFFLOAD_BOTH; 301*4742e46bSJacob Faibussowitsch } 302*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 303*4742e46bSJacob Faibussowitsch } 304*4742e46bSJacob Faibussowitsch 305*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 306*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::CheckCUPMSolverInfo_(const cupmBlasInt_t *fact_info, cupmStream_t stream) noexcept 307*4742e46bSJacob Faibussowitsch { 308*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 309*4742e46bSJacob Faibussowitsch if (PetscDefined(USE_DEBUG)) { 310*4742e46bSJacob Faibussowitsch cupmBlasInt_t info = 0; 311*4742e46bSJacob Faibussowitsch 312*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(&info, fact_info, 1, cupmMemcpyDeviceToHost, stream)); 313*4742e46bSJacob Faibussowitsch if (stream) PetscCallCUPM(cupmStreamSynchronize(stream)); 314*4742e46bSJacob Faibussowitsch static_assert(std::is_same<decltype(info), int>::value, ""); 315*4742e46bSJacob Faibussowitsch PetscCheck(info <= 0, PETSC_COMM_SELF, PETSC_ERR_MAT_CH_ZRPVT, "Bad factorization: zero pivot in row %d", info - 1); 316*4742e46bSJacob Faibussowitsch PetscCheck(info >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Wrong argument to cupmSolver %d", -info); 317*4742e46bSJacob Faibussowitsch } 318*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 319*4742e46bSJacob Faibussowitsch } 320*4742e46bSJacob Faibussowitsch 321*4742e46bSJacob Faibussowitsch // ========================================================================================== 322*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Solver Dispatch 323*4742e46bSJacob Faibussowitsch // ========================================================================================== 324*4742e46bSJacob Faibussowitsch 325*4742e46bSJacob Faibussowitsch // specific solvers called through the dispatch_() family of functions 326*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 327*4742e46bSJacob Faibussowitsch template <typename Derived> 328*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveCommon { 329*4742e46bSJacob Faibussowitsch using derived_type = Derived; 330*4742e46bSJacob Faibussowitsch 331*4742e46bSJacob Faibussowitsch template <typename F> 332*4742e46bSJacob Faibussowitsch static PetscErrorCode ResizeFactLwork(Mat_SeqDenseCUPM *mcu, cupmStream_t stream, F &&cupmSolverComputeFactLwork) noexcept 333*4742e46bSJacob Faibussowitsch { 334*4742e46bSJacob Faibussowitsch cupmBlasInt_t lwork; 335*4742e46bSJacob Faibussowitsch 336*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 337*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverComputeFactLwork(&lwork)); 338*4742e46bSJacob Faibussowitsch if (lwork > mcu->d_fact_lwork) { 339*4742e46bSJacob Faibussowitsch mcu->d_fact_lwork = lwork; 340*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream)); 341*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, lwork, stream)); 342*4742e46bSJacob Faibussowitsch } 343*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 344*4742e46bSJacob Faibussowitsch } 345*4742e46bSJacob Faibussowitsch 346*4742e46bSJacob Faibussowitsch static PetscErrorCode FactorPrepare(Mat A, cupmStream_t stream) noexcept 347*4742e46bSJacob Faibussowitsch { 348*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 349*4742e46bSJacob Faibussowitsch 350*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 351*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s factor %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", derived_type::NAME(), A->rmap->n, A->cmap->n)); 352*4742e46bSJacob Faibussowitsch A->factortype = derived_type::MATFACTORTYPE(); 353*4742e46bSJacob Faibussowitsch A->ops->solve = MatSolve_Factored_Dispatch_<derived_type, false>; 354*4742e46bSJacob Faibussowitsch A->ops->solvetranspose = MatSolve_Factored_Dispatch_<derived_type, true>; 355*4742e46bSJacob Faibussowitsch A->ops->matsolve = MatMatSolve_Factored_Dispatch_<derived_type, false>; 356*4742e46bSJacob Faibussowitsch A->ops->matsolvetranspose = MatMatSolve_Factored_Dispatch_<derived_type, true>; 357*4742e46bSJacob Faibussowitsch 358*4742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(MATSOLVERCUPM(), &A->solvertype)); 359*4742e46bSJacob Faibussowitsch if (!mcu->d_fact_info) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_info, 1, stream)); 360*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 361*4742e46bSJacob Faibussowitsch } 362*4742e46bSJacob Faibussowitsch }; 363*4742e46bSJacob Faibussowitsch 364*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 365*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveLU : SolveCommon<SolveLU> { 366*4742e46bSJacob Faibussowitsch using base_type = SolveCommon<SolveLU>; 367*4742e46bSJacob Faibussowitsch 368*4742e46bSJacob Faibussowitsch static constexpr const char *NAME() noexcept { return "LU"; } 369*4742e46bSJacob Faibussowitsch static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_LU; } 370*4742e46bSJacob Faibussowitsch 371*4742e46bSJacob Faibussowitsch static PetscErrorCode Factor(Mat A, IS, IS, const MatFactorInfo *) noexcept 372*4742e46bSJacob Faibussowitsch { 373*4742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 374*4742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 375*4742e46bSJacob Faibussowitsch cupmStream_t stream; 376*4742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 377*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 378*4742e46bSJacob Faibussowitsch 379*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 380*4742e46bSJacob Faibussowitsch if (!m || !n) PetscFunctionReturn(PETSC_SUCCESS); 381*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 382*4742e46bSJacob Faibussowitsch PetscCall(base_type::FactorPrepare(A, stream)); 383*4742e46bSJacob Faibussowitsch { 384*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 385*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 386*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 387*4742e46bSJacob Faibussowitsch 388*4742e46bSJacob Faibussowitsch // clang-format off 389*4742e46bSJacob Faibussowitsch PetscCall( 390*4742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 391*4742e46bSJacob Faibussowitsch mcu, stream, 392*4742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *fact_lwork) 393*4742e46bSJacob Faibussowitsch { 394*4742e46bSJacob Faibussowitsch return cupmSolverXgetrf_bufferSize(handle, m, n, da.cupmdata(), lda, fact_lwork); 395*4742e46bSJacob Faibussowitsch } 396*4742e46bSJacob Faibussowitsch ) 397*4742e46bSJacob Faibussowitsch ); 398*4742e46bSJacob Faibussowitsch // clang-format on 399*4742e46bSJacob Faibussowitsch if (!mcu->d_fact_ipiv) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_ipiv, n, stream)); 400*4742e46bSJacob Faibussowitsch 401*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 402*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXgetrf(handle, m, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_ipiv, mcu->d_fact_info)); 403*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 404*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 405*4742e46bSJacob Faibussowitsch } 406*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(2.0 * n * n * m / 3.0)); 407*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 408*4742e46bSJacob Faibussowitsch } 409*4742e46bSJacob Faibussowitsch 410*4742e46bSJacob Faibussowitsch template <bool transpose> 411*4742e46bSJacob Faibussowitsch static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept 412*4742e46bSJacob Faibussowitsch { 413*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 414*4742e46bSJacob Faibussowitsch const auto fact_info = mcu->d_fact_info; 415*4742e46bSJacob Faibussowitsch const auto fact_ipiv = mcu->d_fact_ipiv; 416*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 417*4742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 418*4742e46bSJacob Faibussowitsch 419*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 420*4742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &handle)); 421*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k)); 422*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 423*4742e46bSJacob Faibussowitsch { 424*4742e46bSJacob Faibussowitsch constexpr auto op = transpose ? CUPMSOLVER_OP_T : CUPMSOLVER_OP_N; 425*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 426*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 427*4742e46bSJacob Faibussowitsch 428*4742e46bSJacob Faibussowitsch // clang-format off 429*4742e46bSJacob Faibussowitsch PetscCall( 430*4742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 431*4742e46bSJacob Faibussowitsch mcu, stream, 432*4742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *lwork) 433*4742e46bSJacob Faibussowitsch { 434*4742e46bSJacob Faibussowitsch return cupmSolverXgetrs_bufferSize( 435*4742e46bSJacob Faibussowitsch handle, op, m, nrhs, da.cupmdata(), lda, fact_ipiv, x, ldx, lwork 436*4742e46bSJacob Faibussowitsch ); 437*4742e46bSJacob Faibussowitsch } 438*4742e46bSJacob Faibussowitsch ) 439*4742e46bSJacob Faibussowitsch ); 440*4742e46bSJacob Faibussowitsch // clang-format on 441*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXgetrs(handle, op, m, nrhs, da.cupmdata(), lda, fact_ipiv, x, ldx, mcu->d_fact_work, mcu->d_fact_lwork, fact_info)); 442*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 443*4742e46bSJacob Faibussowitsch } 444*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 445*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(nrhs * (2.0 * m * m - m))); 446*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 447*4742e46bSJacob Faibussowitsch } 448*4742e46bSJacob Faibussowitsch }; 449*4742e46bSJacob Faibussowitsch 450*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 451*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveCholesky : SolveCommon<SolveCholesky> { 452*4742e46bSJacob Faibussowitsch using base_type = SolveCommon<SolveCholesky>; 453*4742e46bSJacob Faibussowitsch 454*4742e46bSJacob Faibussowitsch static constexpr const char *NAME() noexcept { return "Cholesky"; } 455*4742e46bSJacob Faibussowitsch static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_CHOLESKY; } 456*4742e46bSJacob Faibussowitsch 457*4742e46bSJacob Faibussowitsch static PetscErrorCode Factor(Mat A, IS, const MatFactorInfo *) noexcept 458*4742e46bSJacob Faibussowitsch { 459*4742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->rmap->n); 460*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 461*4742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 462*4742e46bSJacob Faibussowitsch cupmStream_t stream; 463*4742e46bSJacob Faibussowitsch 464*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 465*4742e46bSJacob Faibussowitsch if (!n || !A->cmap->n) PetscFunctionReturn(PETSC_SUCCESS); 466*4742e46bSJacob Faibussowitsch PetscCheck(A->spd == PETSC_BOOL3_TRUE, PETSC_COMM_SELF, PETSC_ERR_SUP, "%ssytrs unavailable. Use MAT_FACTOR_LU", cupmSolverName()); 467*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 468*4742e46bSJacob Faibussowitsch PetscCall(base_type::FactorPrepare(A, stream)); 469*4742e46bSJacob Faibussowitsch { 470*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 471*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 472*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 473*4742e46bSJacob Faibussowitsch 474*4742e46bSJacob Faibussowitsch // clang-format off 475*4742e46bSJacob Faibussowitsch PetscCall( 476*4742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 477*4742e46bSJacob Faibussowitsch mcu, stream, 478*4742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *fact_lwork) 479*4742e46bSJacob Faibussowitsch { 480*4742e46bSJacob Faibussowitsch return cupmSolverXpotrf_bufferSize( 481*4742e46bSJacob Faibussowitsch handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, fact_lwork 482*4742e46bSJacob Faibussowitsch ); 483*4742e46bSJacob Faibussowitsch } 484*4742e46bSJacob Faibussowitsch ) 485*4742e46bSJacob Faibussowitsch ); 486*4742e46bSJacob Faibussowitsch // clang-format on 487*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 488*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotrf(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 489*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 490*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 491*4742e46bSJacob Faibussowitsch } 492*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(1.0 * n * n * n / 3.0)); 493*4742e46bSJacob Faibussowitsch 494*4742e46bSJacob Faibussowitsch #if 0 495*4742e46bSJacob Faibussowitsch // At the time of writing this interface (cuda 10.0), cusolverDn does not implement *sytrs 496*4742e46bSJacob Faibussowitsch // and *hetr* routines. The code below should work, and it can be activated when *sytrs 497*4742e46bSJacob Faibussowitsch // routines will be available 498*4742e46bSJacob Faibussowitsch if (!mcu->d_fact_ipiv) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_ipiv, n, stream)); 499*4742e46bSJacob Faibussowitsch if (!mcu->d_fact_lwork) { 500*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverDnXsytrf_bufferSize(handle, n, da.cupmdata(), lda, &mcu->d_fact_lwork)); 501*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, mcu->d_fact_lwork, stream)); 502*4742e46bSJacob Faibussowitsch } 503*4742e46bSJacob Faibussowitsch if (mcu->d_fact_info) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_info, 1, stream)); 504*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 505*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXsytrf(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da, lda, mcu->d_fact_ipiv, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 506*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 507*4742e46bSJacob Faibussowitsch #endif 508*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 509*4742e46bSJacob Faibussowitsch } 510*4742e46bSJacob Faibussowitsch 511*4742e46bSJacob Faibussowitsch template <bool transpose> 512*4742e46bSJacob Faibussowitsch static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept 513*4742e46bSJacob Faibussowitsch { 514*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 515*4742e46bSJacob Faibussowitsch const auto fact_info = mcu->d_fact_info; 516*4742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 517*4742e46bSJacob Faibussowitsch 518*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 519*4742e46bSJacob Faibussowitsch PetscAssert(!mcu->d_fact_ipiv, PETSC_COMM_SELF, PETSC_ERR_LIB, "%ssytrs not implemented", cupmSolverName()); 520*4742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &handle)); 521*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k)); 522*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 523*4742e46bSJacob Faibussowitsch { 524*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 525*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 526*4742e46bSJacob Faibussowitsch 527*4742e46bSJacob Faibussowitsch // clang-format off 528*4742e46bSJacob Faibussowitsch PetscCall( 529*4742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 530*4742e46bSJacob Faibussowitsch mcu, stream, 531*4742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *lwork) 532*4742e46bSJacob Faibussowitsch { 533*4742e46bSJacob Faibussowitsch return cupmSolverXpotrs_bufferSize( 534*4742e46bSJacob Faibussowitsch handle, CUPMSOLVER_FILL_MODE_LOWER, m, nrhs, da.cupmdata(), lda, x, ldx, lwork 535*4742e46bSJacob Faibussowitsch ); 536*4742e46bSJacob Faibussowitsch } 537*4742e46bSJacob Faibussowitsch ) 538*4742e46bSJacob Faibussowitsch ); 539*4742e46bSJacob Faibussowitsch // clang-format on 540*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotrs(handle, CUPMSOLVER_FILL_MODE_LOWER, m, nrhs, da.cupmdata(), lda, x, ldx, mcu->d_fact_work, mcu->d_fact_lwork, fact_info)); 541*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 542*4742e46bSJacob Faibussowitsch } 543*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 544*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(nrhs * (2.0 * m * m - m))); 545*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 546*4742e46bSJacob Faibussowitsch } 547*4742e46bSJacob Faibussowitsch }; 548*4742e46bSJacob Faibussowitsch 549*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 550*4742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveQR : SolveCommon<SolveQR> { 551*4742e46bSJacob Faibussowitsch using base_type = SolveCommon<SolveQR>; 552*4742e46bSJacob Faibussowitsch 553*4742e46bSJacob Faibussowitsch static constexpr const char *NAME() noexcept { return "QR"; } 554*4742e46bSJacob Faibussowitsch static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_QR; } 555*4742e46bSJacob Faibussowitsch 556*4742e46bSJacob Faibussowitsch static PetscErrorCode Factor(Mat A, IS, const MatFactorInfo *) noexcept 557*4742e46bSJacob Faibussowitsch { 558*4742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 559*4742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 560*4742e46bSJacob Faibussowitsch const auto min = std::min(m, n); 561*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 562*4742e46bSJacob Faibussowitsch cupmStream_t stream; 563*4742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 564*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 565*4742e46bSJacob Faibussowitsch 566*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 567*4742e46bSJacob Faibussowitsch if (!m || !n) PetscFunctionReturn(PETSC_SUCCESS); 568*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 569*4742e46bSJacob Faibussowitsch PetscCall(base_type::FactorPrepare(A, stream)); 570*4742e46bSJacob Faibussowitsch mimpl->rank = min; 571*4742e46bSJacob Faibussowitsch { 572*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 573*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 574*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda); 575*4742e46bSJacob Faibussowitsch 576*4742e46bSJacob Faibussowitsch if (!mcu->workvec) PetscCall(vec::cupm::VecCreateSeqCUPMAsync<T>(PetscObjectComm(PetscObjectCast(A)), m, &mcu->workvec)); 577*4742e46bSJacob Faibussowitsch if (!mcu->d_fact_tau) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_tau, min, stream)); 578*4742e46bSJacob Faibussowitsch // clang-format off 579*4742e46bSJacob Faibussowitsch PetscCall( 580*4742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 581*4742e46bSJacob Faibussowitsch mcu, stream, 582*4742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *fact_lwork) 583*4742e46bSJacob Faibussowitsch { 584*4742e46bSJacob Faibussowitsch return cupmSolverXgeqrf_bufferSize(handle, m, n, da.cupmdata(), lda, fact_lwork); 585*4742e46bSJacob Faibussowitsch } 586*4742e46bSJacob Faibussowitsch ) 587*4742e46bSJacob Faibussowitsch ); 588*4742e46bSJacob Faibussowitsch // clang-format on 589*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 590*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXgeqrf(handle, m, n, da.cupmdata(), lda, mcu->d_fact_tau, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 591*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 592*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 593*4742e46bSJacob Faibussowitsch } 594*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(2.0 * min * min * (std::max(m, n) - min / 3.0))); 595*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 596*4742e46bSJacob Faibussowitsch } 597*4742e46bSJacob Faibussowitsch 598*4742e46bSJacob Faibussowitsch template <bool transpose> 599*4742e46bSJacob Faibussowitsch static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept 600*4742e46bSJacob Faibussowitsch { 601*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 602*4742e46bSJacob Faibussowitsch const auto rank = static_cast<cupmBlasInt_t>(mimpl->rank); 603*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 604*4742e46bSJacob Faibussowitsch const auto fact_info = mcu->d_fact_info; 605*4742e46bSJacob Faibussowitsch const auto fact_tau = mcu->d_fact_tau; 606*4742e46bSJacob Faibussowitsch const auto fact_work = mcu->d_fact_work; 607*4742e46bSJacob Faibussowitsch const auto fact_lwork = mcu->d_fact_lwork; 608*4742e46bSJacob Faibussowitsch cupmSolverHandle_t solver_handle; 609*4742e46bSJacob Faibussowitsch cupmBlasHandle_t blas_handle; 610*4742e46bSJacob Faibussowitsch 611*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 612*4742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &blas_handle, &solver_handle)); 613*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k)); 614*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 615*4742e46bSJacob Faibussowitsch { 616*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 617*4742e46bSJacob Faibussowitsch const auto one = cupmScalarCast(1.0); 618*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda); 619*4742e46bSJacob Faibussowitsch 620*4742e46bSJacob Faibussowitsch if (transpose) { 621*4742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXtrsm(blas_handle, CUPMBLAS_SIDE_LEFT, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_T, CUPMBLAS_DIAG_NON_UNIT, rank, nrhs, &one, da.cupmdata(), lda, x, ldx)); 622*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXormqr(solver_handle, CUPMSOLVER_SIDE_LEFT, CUPMSOLVER_OP_N, m, nrhs, rank, da.cupmdata(), lda, fact_tau, x, ldx, fact_work, fact_lwork, fact_info)); 623*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 624*4742e46bSJacob Faibussowitsch } else { 625*4742e46bSJacob Faibussowitsch constexpr auto op = PetscDefined(USE_COMPLEX) ? CUPMSOLVER_OP_C : CUPMSOLVER_OP_T; 626*4742e46bSJacob Faibussowitsch 627*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXormqr(solver_handle, CUPMSOLVER_SIDE_LEFT, op, m, nrhs, rank, da.cupmdata(), lda, fact_tau, x, ldx, fact_work, fact_lwork, fact_info)); 628*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 629*4742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXtrsm(blas_handle, CUPMBLAS_SIDE_LEFT, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, rank, nrhs, &one, da.cupmdata(), lda, x, ldx)); 630*4742e46bSJacob Faibussowitsch } 631*4742e46bSJacob Faibussowitsch } 632*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 633*4742e46bSJacob Faibussowitsch PetscCall(PetscLogFlops(nrhs * (4.0 * m * rank - (rank * rank)))); 634*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 635*4742e46bSJacob Faibussowitsch } 636*4742e46bSJacob Faibussowitsch }; 637*4742e46bSJacob Faibussowitsch 638*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 639*4742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 640*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatSolve_Factored_Dispatch_(Mat A, Vec x, Vec y) noexcept 641*4742e46bSJacob Faibussowitsch { 642*4742e46bSJacob Faibussowitsch using namespace vec::cupm; 643*4742e46bSJacob Faibussowitsch const auto pobj_A = PetscObjectCast(A); 644*4742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 645*4742e46bSJacob Faibussowitsch const auto k = static_cast<cupmBlasInt_t>(A->cmap->n); 646*4742e46bSJacob Faibussowitsch auto &workvec = MatCUPMCast(A)->workvec; 647*4742e46bSJacob Faibussowitsch PetscScalar *y_array = nullptr; 648*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 649*4742e46bSJacob Faibussowitsch PetscBool xiscupm, yiscupm, aiscupm; 650*4742e46bSJacob Faibussowitsch bool use_y_array_directly; 651*4742e46bSJacob Faibussowitsch cupmStream_t stream; 652*4742e46bSJacob Faibussowitsch 653*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 654*4742e46bSJacob Faibussowitsch PetscCheck(A->factortype != MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix must be factored to solve"); 655*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(x), VecSeq_CUPM::VECSEQCUPM(), &xiscupm)); 656*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(y), VecSeq_CUPM::VECSEQCUPM(), &yiscupm)); 657*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(pobj_A, MATSEQDENSECUPM(), &aiscupm)); 658*4742e46bSJacob Faibussowitsch PetscAssert(aiscupm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Matrix A is somehow not CUPM?????????????????????????????"); 659*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 660*4742e46bSJacob Faibussowitsch use_y_array_directly = yiscupm && (k >= m); 661*4742e46bSJacob Faibussowitsch { 662*4742e46bSJacob Faibussowitsch const PetscScalar *x_array; 663*4742e46bSJacob Faibussowitsch const auto xisdevice = xiscupm && PetscOffloadDevice(x->offloadmask); 664*4742e46bSJacob Faibussowitsch const auto copy_mode = xisdevice ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 665*4742e46bSJacob Faibussowitsch 666*4742e46bSJacob Faibussowitsch if (!use_y_array_directly && !workvec) PetscCall(VecCreateSeqCUPMAsync<T>(PetscObjectComm(pobj_A), m, &workvec)); 667*4742e46bSJacob Faibussowitsch // The logic here is to try to minimize the amount of memory copying: 668*4742e46bSJacob Faibussowitsch // 669*4742e46bSJacob Faibussowitsch // If we call VecCUPMGetArrayRead(X, &x) every time xiscupm and the data is not offloaded 670*4742e46bSJacob Faibussowitsch // to the GPU yet, then the data is copied to the GPU. But we are only trying to get the 671*4742e46bSJacob Faibussowitsch // data in order to copy it into the y array. So the array x will be wherever the data 672*4742e46bSJacob Faibussowitsch // already is so that only one memcpy is performed 673*4742e46bSJacob Faibussowitsch if (xisdevice) { 674*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMGetArrayReadAsync<T>(x, &x_array, dctx)); 675*4742e46bSJacob Faibussowitsch } else { 676*4742e46bSJacob Faibussowitsch PetscCall(VecGetArrayRead(x, &x_array)); 677*4742e46bSJacob Faibussowitsch } 678*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMGetArrayWriteAsync<T>(use_y_array_directly ? y : workvec, &y_array, dctx)); 679*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(y_array, x_array, m, copy_mode, stream)); 680*4742e46bSJacob Faibussowitsch if (xisdevice) { 681*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayReadAsync<T>(x, &x_array, dctx)); 682*4742e46bSJacob Faibussowitsch } else { 683*4742e46bSJacob Faibussowitsch PetscCall(VecRestoreArrayRead(x, &x_array)); 684*4742e46bSJacob Faibussowitsch } 685*4742e46bSJacob Faibussowitsch } 686*4742e46bSJacob Faibussowitsch 687*4742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 688*4742e46bSJacob Faibussowitsch PetscCall(Solver{}.template Solve<transpose>(A, cupmScalarPtrCast(y_array), m, m, 1, k, dctx, stream)); 689*4742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A)); 690*4742e46bSJacob Faibussowitsch 691*4742e46bSJacob Faibussowitsch if (use_y_array_directly) { 692*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayWriteAsync<T>(y, &y_array, dctx)); 693*4742e46bSJacob Faibussowitsch } else { 694*4742e46bSJacob Faibussowitsch const auto copy_mode = yiscupm ? cupmMemcpyDeviceToDevice : cupmMemcpyDeviceToHost; 695*4742e46bSJacob Faibussowitsch PetscScalar *yv; 696*4742e46bSJacob Faibussowitsch 697*4742e46bSJacob Faibussowitsch // The logic here is that the data is not yet in either y's GPU array or its CPU array. 698*4742e46bSJacob Faibussowitsch // There is nothing in the interface to say where the user would like it to end up. So we 699*4742e46bSJacob Faibussowitsch // choose the GPU, because it is the faster option 700*4742e46bSJacob Faibussowitsch if (yiscupm) { 701*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMGetArrayWriteAsync<T>(y, &yv, dctx)); 702*4742e46bSJacob Faibussowitsch } else { 703*4742e46bSJacob Faibussowitsch PetscCall(VecGetArray(y, &yv)); 704*4742e46bSJacob Faibussowitsch } 705*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(yv, y_array, k, copy_mode, stream)); 706*4742e46bSJacob Faibussowitsch if (yiscupm) { 707*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayWriteAsync<T>(y, &yv, dctx)); 708*4742e46bSJacob Faibussowitsch } else { 709*4742e46bSJacob Faibussowitsch PetscCall(VecRestoreArray(y, &yv)); 710*4742e46bSJacob Faibussowitsch } 711*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayWriteAsync<T>(workvec, &y_array)); 712*4742e46bSJacob Faibussowitsch } 713*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 714*4742e46bSJacob Faibussowitsch } 715*4742e46bSJacob Faibussowitsch 716*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 717*4742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 718*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMatSolve_Factored_Dispatch_(Mat A, Mat B, Mat X) noexcept 719*4742e46bSJacob Faibussowitsch { 720*4742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 721*4742e46bSJacob Faibussowitsch const auto k = static_cast<cupmBlasInt_t>(A->cmap->n); 722*4742e46bSJacob Faibussowitsch cupmBlasInt_t nrhs, ldb, ldx, ldy; 723*4742e46bSJacob Faibussowitsch PetscScalar *y; 724*4742e46bSJacob Faibussowitsch PetscBool biscupm, xiscupm, aiscupm; 725*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 726*4742e46bSJacob Faibussowitsch cupmStream_t stream; 727*4742e46bSJacob Faibussowitsch 728*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 729*4742e46bSJacob Faibussowitsch PetscCheck(A->factortype != MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix must be factored to solve"); 730*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(B), MATSEQDENSECUPM(), &biscupm)); 731*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(X), MATSEQDENSECUPM(), &xiscupm)); 732*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), MATSEQDENSECUPM(), &aiscupm)); 733*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 734*4742e46bSJacob Faibussowitsch { 735*4742e46bSJacob Faibussowitsch PetscInt n; 736*4742e46bSJacob Faibussowitsch 737*4742e46bSJacob Faibussowitsch PetscCall(MatGetSize(B, nullptr, &n)); 738*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(n, &nrhs)); 739*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(B, &n)); 740*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(n, &ldb)); 741*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(X, &n)); 742*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(n, &ldx)); 743*4742e46bSJacob Faibussowitsch } 744*4742e46bSJacob Faibussowitsch { 745*4742e46bSJacob Faibussowitsch // The logic here is to try to minimize the amount of memory copying: 746*4742e46bSJacob Faibussowitsch // 747*4742e46bSJacob Faibussowitsch // If we call MatDenseCUPMGetArrayRead(B, &b) every time biscupm and the data is not 748*4742e46bSJacob Faibussowitsch // offloaded to the GPU yet, then the data is copied to the GPU. But we are only trying to 749*4742e46bSJacob Faibussowitsch // get the data in order to copy it into the y array. So the array b will be wherever the 750*4742e46bSJacob Faibussowitsch // data already is so that only one memcpy is performed 751*4742e46bSJacob Faibussowitsch const auto bisdevice = biscupm && PetscOffloadDevice(B->offloadmask); 752*4742e46bSJacob Faibussowitsch const auto copy_mode = bisdevice ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 753*4742e46bSJacob Faibussowitsch const PetscScalar *b; 754*4742e46bSJacob Faibussowitsch 755*4742e46bSJacob Faibussowitsch if (bisdevice) { 756*4742e46bSJacob Faibussowitsch b = DeviceArrayRead(dctx, B); 757*4742e46bSJacob Faibussowitsch } else if (biscupm) { 758*4742e46bSJacob Faibussowitsch b = HostArrayRead(dctx, B); 759*4742e46bSJacob Faibussowitsch } else { 760*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetArrayRead(B, &b)); 761*4742e46bSJacob Faibussowitsch } 762*4742e46bSJacob Faibussowitsch 763*4742e46bSJacob Faibussowitsch if (ldx < m || !xiscupm) { 764*4742e46bSJacob Faibussowitsch // X's array cannot serve as the array (too small or not on device), B's array cannot 765*4742e46bSJacob Faibussowitsch // serve as the array (const), so allocate a new array 766*4742e46bSJacob Faibussowitsch ldy = m; 767*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&y, nrhs * m)); 768*4742e46bSJacob Faibussowitsch } else { 769*4742e46bSJacob Faibussowitsch // X's array should serve as the array 770*4742e46bSJacob Faibussowitsch ldy = ldx; 771*4742e46bSJacob Faibussowitsch y = DeviceArrayWrite(dctx, X); 772*4742e46bSJacob Faibussowitsch } 773*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(y, ldy, b, ldb, m, nrhs, copy_mode, stream)); 774*4742e46bSJacob Faibussowitsch if (!bisdevice && !biscupm) PetscCall(MatDenseRestoreArrayRead(B, &b)); 775*4742e46bSJacob Faibussowitsch } 776*4742e46bSJacob Faibussowitsch 777*4742e46bSJacob Faibussowitsch // convert to CUPM twice?????????????????????????????????? 778*4742e46bSJacob Faibussowitsch // but A should already be CUPM?????????????????????????????????????? 779*4742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 780*4742e46bSJacob Faibussowitsch PetscCall(Solver{}.template Solve<transpose>(A, cupmScalarPtrCast(y), ldy, m, nrhs, k, dctx, stream)); 781*4742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 782*4742e46bSJacob Faibussowitsch 783*4742e46bSJacob Faibussowitsch if (ldx < m || !xiscupm) { 784*4742e46bSJacob Faibussowitsch const auto copy_mode = xiscupm ? cupmMemcpyDeviceToDevice : cupmMemcpyDeviceToHost; 785*4742e46bSJacob Faibussowitsch PetscScalar *x; 786*4742e46bSJacob Faibussowitsch 787*4742e46bSJacob Faibussowitsch // The logic here is that the data is not yet in either X's GPU array or its CPU 788*4742e46bSJacob Faibussowitsch // array. There is nothing in the interface to say where the user would like it to end up. 789*4742e46bSJacob Faibussowitsch // So we choose the GPU, because it is the faster option 790*4742e46bSJacob Faibussowitsch if (xiscupm) { 791*4742e46bSJacob Faibussowitsch x = DeviceArrayWrite(dctx, X); 792*4742e46bSJacob Faibussowitsch } else { 793*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetArray(X, &x)); 794*4742e46bSJacob Faibussowitsch } 795*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(x, ldx, y, ldy, k, nrhs, copy_mode, stream)); 796*4742e46bSJacob Faibussowitsch if (!xiscupm) PetscCall(MatDenseRestoreArray(X, &x)); 797*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(y, stream)); 798*4742e46bSJacob Faibussowitsch } 799*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 800*4742e46bSJacob Faibussowitsch } 801*4742e46bSJacob Faibussowitsch 802*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 803*4742e46bSJacob Faibussowitsch template <bool transpose> 804*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMultAdd_Dispatch_(Mat A, Vec xx, Vec yy, Vec zz) noexcept 805*4742e46bSJacob Faibussowitsch { 806*4742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 807*4742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 808*4742e46bSJacob Faibussowitsch cupmBlasHandle_t handle; 809*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 810*4742e46bSJacob Faibussowitsch 811*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 812*4742e46bSJacob Faibussowitsch if (yy && yy != zz) PetscCall(VecSeq_CUPM::Copy(yy, zz)); // mult add 813*4742e46bSJacob Faibussowitsch if (!m || !n) { 814*4742e46bSJacob Faibussowitsch // mult only 815*4742e46bSJacob Faibussowitsch if (!yy) PetscCall(VecSeq_CUPM::Set(zz, 0.0)); 816*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 817*4742e46bSJacob Faibussowitsch } 818*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "Matrix-vector product %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " on backend\n", m, n)); 819*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle)); 820*4742e46bSJacob Faibussowitsch { 821*4742e46bSJacob Faibussowitsch constexpr auto op = transpose ? CUPMBLAS_OP_T : CUPMBLAS_OP_N; 822*4742e46bSJacob Faibussowitsch const auto one = cupmScalarCast(1.0); 823*4742e46bSJacob Faibussowitsch const auto zero = cupmScalarCast(0.0); 824*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 825*4742e46bSJacob Faibussowitsch const auto dxx = VecSeq_CUPM::DeviceArrayRead(dctx, xx); 826*4742e46bSJacob Faibussowitsch const auto dzz = VecSeq_CUPM::DeviceArrayReadWrite(dctx, zz); 827*4742e46bSJacob Faibussowitsch 828*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 829*4742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXgemv(handle, op, m, n, &one, da.cupmdata(), static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda), dxx.cupmdata(), 1, (yy ? &one : &zero), dzz.cupmdata(), 1)); 830*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 831*4742e46bSJacob Faibussowitsch } 832*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(2.0 * m * n - (yy ? 0 : m))); 833*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 834*4742e46bSJacob Faibussowitsch } 835*4742e46bSJacob Faibussowitsch 836*4742e46bSJacob Faibussowitsch // ========================================================================================== 837*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Conversion Dispatch 838*4742e46bSJacob Faibussowitsch // ========================================================================================== 839*4742e46bSJacob Faibussowitsch 840*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 841*4742e46bSJacob Faibussowitsch template <bool to_host> 842*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_Dispatch_(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept 843*4742e46bSJacob Faibussowitsch { 844*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 845*4742e46bSJacob Faibussowitsch if (reuse == MAT_REUSE_MATRIX || reuse == MAT_INITIAL_MATRIX) { 846*4742e46bSJacob Faibussowitsch // TODO these cases should be optimized 847*4742e46bSJacob Faibussowitsch PetscCall(MatConvert_Basic(M, type, reuse, newmat)); 848*4742e46bSJacob Faibussowitsch } else { 849*4742e46bSJacob Faibussowitsch const auto B = *newmat; 850*4742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(B); 851*4742e46bSJacob Faibussowitsch 852*4742e46bSJacob Faibussowitsch if (to_host) { 853*4742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_TRUE)); 854*4742e46bSJacob Faibussowitsch PetscCall(Reset(B)); 855*4742e46bSJacob Faibussowitsch } else { 856*4742e46bSJacob Faibussowitsch PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 857*4742e46bSJacob Faibussowitsch } 858*4742e46bSJacob Faibussowitsch 859*4742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecSeq_CUPM::VECCUPM(), &B->defaultvectype)); 860*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATSEQDENSE : MATSEQDENSECUPM())); 861*4742e46bSJacob Faibussowitsch // cvec might be the wrong VecType, destroy and rebuild it if necessary 862*4742e46bSJacob Faibussowitsch // REVIEW ME: this is possibly very inefficient 863*4742e46bSJacob Faibussowitsch PetscCall(VecDestroy(&MatIMPLCast(B)->cvec)); 864*4742e46bSJacob Faibussowitsch 865*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatConvert_seqdensecupm_seqdense_C(), nullptr, Convert_SeqDenseCUPM_SeqDense); 866*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 867*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 868*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 869*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 870*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 871*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 872*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray); 873*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray); 874*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray); 875*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_seqaij_seqdensecupm_C(), nullptr, MatProductSetFromOptions_SeqAIJ_SeqDense); 876*4742e46bSJacob Faibussowitsch 877*4742e46bSJacob Faibussowitsch if (to_host) { 878*4742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_CPU; 879*4742e46bSJacob Faibussowitsch } else { 880*4742e46bSJacob Faibussowitsch Mat_SeqDenseCUPM *mcu; 881*4742e46bSJacob Faibussowitsch 882*4742e46bSJacob Faibussowitsch PetscCall(PetscNew(&mcu)); 883*4742e46bSJacob Faibussowitsch B->spptr = mcu; 884*4742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; // REVIEW ME: why not offload host?? 885*4742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_FALSE)); 886*4742e46bSJacob Faibussowitsch } 887*4742e46bSJacob Faibussowitsch 888*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU); 889*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, B, destroy, MatDestroy_SeqDense, Destroy); 890*4742e46bSJacob Faibussowitsch } 891*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 892*4742e46bSJacob Faibussowitsch } 893*4742e46bSJacob Faibussowitsch 894*4742e46bSJacob Faibussowitsch // ========================================================================================== 895*4742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Public API 896*4742e46bSJacob Faibussowitsch // ========================================================================================== 897*4742e46bSJacob Faibussowitsch 898*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 899*4742e46bSJacob Faibussowitsch inline constexpr MatType MatDense_Seq_CUPM<T>::MATIMPLCUPM_() noexcept 900*4742e46bSJacob Faibussowitsch { 901*4742e46bSJacob Faibussowitsch return MATSEQDENSECUPM(); 902*4742e46bSJacob Faibussowitsch } 903*4742e46bSJacob Faibussowitsch 904*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 905*4742e46bSJacob Faibussowitsch inline constexpr typename MatDense_Seq_CUPM<T>::Mat_SeqDenseCUPM *MatDense_Seq_CUPM<T>::MatCUPMCast(Mat m) noexcept 906*4742e46bSJacob Faibussowitsch { 907*4742e46bSJacob Faibussowitsch return static_cast<Mat_SeqDenseCUPM *>(m->spptr); 908*4742e46bSJacob Faibussowitsch } 909*4742e46bSJacob Faibussowitsch 910*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 911*4742e46bSJacob Faibussowitsch inline constexpr Mat_SeqDense *MatDense_Seq_CUPM<T>::MatIMPLCast_(Mat m) noexcept 912*4742e46bSJacob Faibussowitsch { 913*4742e46bSJacob Faibussowitsch return static_cast<Mat_SeqDense *>(m->data); 914*4742e46bSJacob Faibussowitsch } 915*4742e46bSJacob Faibussowitsch 916*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 917*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_Seq_CUPM<T>::MatConvert_seqdensecupm_seqdense_C() noexcept 918*4742e46bSJacob Faibussowitsch { 919*4742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatConvert_seqdensecuda_seqdense_C" : "MatConvert_seqdensehip_seqdense_C"; 920*4742e46bSJacob Faibussowitsch } 921*4742e46bSJacob Faibussowitsch 922*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 923*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_Seq_CUPM<T>::MatProductSetFromOptions_seqaij_seqdensecupm_C() noexcept 924*4742e46bSJacob Faibussowitsch { 925*4742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_seqaij_seqdensecuda_C" : "MatProductSetFromOptions_seqaij_seqdensehip_C"; 926*4742e46bSJacob Faibussowitsch } 927*4742e46bSJacob Faibussowitsch 928*4742e46bSJacob Faibussowitsch // ========================================================================================== 929*4742e46bSJacob Faibussowitsch 930*4742e46bSJacob Faibussowitsch // MatCreate_SeqDenseCUPM() 931*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 932*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Create(Mat A) noexcept 933*4742e46bSJacob Faibussowitsch { 934*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 935*4742e46bSJacob Faibussowitsch PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 936*4742e46bSJacob Faibussowitsch PetscCall(MatCreate_SeqDense(A)); 937*4742e46bSJacob Faibussowitsch PetscCall(Convert_SeqDense_SeqDenseCUPM(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 938*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 939*4742e46bSJacob Faibussowitsch } 940*4742e46bSJacob Faibussowitsch 941*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 942*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Destroy(Mat A) noexcept 943*4742e46bSJacob Faibussowitsch { 944*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 945*4742e46bSJacob Faibussowitsch // prevent copying back data if we own the data pointer 946*4742e46bSJacob Faibussowitsch if (!MatIMPLCast(A)->user_alloc) A->offloadmask = PETSC_OFFLOAD_CPU; 947*4742e46bSJacob Faibussowitsch PetscCall(Convert_SeqDenseCUPM_SeqDense(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A)); 948*4742e46bSJacob Faibussowitsch PetscCall(MatDestroy_SeqDense(A)); 949*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 950*4742e46bSJacob Faibussowitsch } 951*4742e46bSJacob Faibussowitsch 952*4742e46bSJacob Faibussowitsch // obj->ops->setup() 953*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 954*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetUp(Mat A) noexcept 955*4742e46bSJacob Faibussowitsch { 956*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 957*4742e46bSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(A->rmap)); 958*4742e46bSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(A->cmap)); 959*4742e46bSJacob Faibussowitsch if (!A->preallocated) { 960*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 961*4742e46bSJacob Faibussowitsch 962*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 963*4742e46bSJacob Faibussowitsch PetscCall(SetPreallocation(A, dctx)); 964*4742e46bSJacob Faibussowitsch } 965*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 966*4742e46bSJacob Faibussowitsch } 967*4742e46bSJacob Faibussowitsch 968*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 969*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Reset(Mat A) noexcept 970*4742e46bSJacob Faibussowitsch { 971*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 972*4742e46bSJacob Faibussowitsch if (const auto mcu = MatCUPMCast(A)) { 973*4742e46bSJacob Faibussowitsch cupmStream_t stream; 974*4742e46bSJacob Faibussowitsch 975*4742e46bSJacob Faibussowitsch PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME()); 976*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&stream)); 977*4742e46bSJacob Faibussowitsch if (!mcu->d_user_alloc) PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream)); 978*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_tau, stream)); 979*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_ipiv, stream)); 980*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_info, stream)); 981*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream)); 982*4742e46bSJacob Faibussowitsch PetscCall(VecDestroy(&mcu->workvec)); 983*4742e46bSJacob Faibussowitsch PetscCall(PetscFree(A->spptr /* mcu */)); 984*4742e46bSJacob Faibussowitsch } 985*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 986*4742e46bSJacob Faibussowitsch } 987*4742e46bSJacob Faibussowitsch 988*4742e46bSJacob Faibussowitsch // ========================================================================================== 989*4742e46bSJacob Faibussowitsch 990*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 991*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::BindToCPU(Mat A, PetscBool to_host) noexcept 992*4742e46bSJacob Faibussowitsch { 993*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 994*4742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A); 995*4742e46bSJacob Faibussowitsch 996*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 997*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 998*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 999*4742e46bSJacob Faibussowitsch A->boundtocpu = to_host; 1000*4742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(to_host ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype)); 1001*4742e46bSJacob Faibussowitsch if (to_host) { 1002*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1003*4742e46bSJacob Faibussowitsch 1004*4742e46bSJacob Faibussowitsch // make sure we have an up-to-date copy on the CPU 1005*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1006*4742e46bSJacob Faibussowitsch PetscCall(DeviceToHost_(A, dctx)); 1007*4742e46bSJacob Faibussowitsch } else { 1008*4742e46bSJacob Faibussowitsch PetscBool iscupm; 1009*4742e46bSJacob Faibussowitsch 1010*4742e46bSJacob Faibussowitsch if (auto &cvec = mimpl->cvec) { 1011*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(cvec), VecSeq_CUPM::VECSEQCUPM(), &iscupm)); 1012*4742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(VecDestroy(&cvec)); 1013*4742e46bSJacob Faibussowitsch } 1014*4742e46bSJacob Faibussowitsch if (auto &cmat = mimpl->cmat) { 1015*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(cmat), MATSEQDENSECUPM(), &iscupm)); 1016*4742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(MatDestroy(&cmat)); 1017*4742e46bSJacob Faibussowitsch } 1018*4742e46bSJacob Faibussowitsch } 1019*4742e46bSJacob Faibussowitsch 1020*4742e46bSJacob Faibussowitsch // ============================================================ 1021*4742e46bSJacob Faibussowitsch // Composed ops 1022*4742e46bSJacob Faibussowitsch // ============================================================ 1023*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArray_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>); 1024*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayRead_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>); 1025*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayWrite_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>); 1026*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ_WRITE>); 1027*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ_WRITE>); 1028*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayReadAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ>); 1029*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayReadAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ>); 1030*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayWriteAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_WRITE>); 1031*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayWriteAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_WRITE>); 1032*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 1033*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 1034*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>); 1035*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>); 1036*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 1037*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 1038*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetSubMatrix_C", MatDenseGetSubMatrix_SeqDense, GetSubMatrix); 1039*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreSubMatrix_C", MatDenseRestoreSubMatrix_SeqDense, RestoreSubMatrix); 1040*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatQRFactor_C", MatQRFactor_SeqDense, SolveQR::Factor); 1041*4742e46bSJacob Faibussowitsch // always the same 1042*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction(pobj, "MatDenseSetLDA_C", MatDenseSetLDA_SeqDense)); 1043*4742e46bSJacob Faibussowitsch 1044*4742e46bSJacob Faibussowitsch // ============================================================ 1045*4742e46bSJacob Faibussowitsch // Function pointer ops 1046*4742e46bSJacob Faibussowitsch // ============================================================ 1047*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, duplicate, MatDuplicate_SeqDense, Duplicate); 1048*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, mult, MatMult_SeqDense, [](Mat A, Vec xx, Vec yy) { return MatMultAdd_Dispatch_</* transpose */ false>(A, xx, nullptr, yy); }); 1049*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, multtranspose, MatMultTranspose_SeqDense, [](Mat A, Vec xx, Vec yy) { return MatMultAdd_Dispatch_</* transpose */ true>(A, xx, nullptr, yy); }); 1050*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, multadd, MatMultAdd_SeqDense, MatMultAdd_Dispatch_</* transpose */ false>); 1051*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, multtransposeadd, MatMultTransposeAdd_SeqDense, MatMultAdd_Dispatch_</* transpose */ true>); 1052*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, matmultnumeric, MatMatMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ false, /* transpose_B */ false>); 1053*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, mattransposemultnumeric, MatMatTransposeMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ false, /* transpose_B */ true>); 1054*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, transposematmultnumeric, MatTransposeMatMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ true, /* transpose_B */ false>); 1055*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, axpy, MatAXPY_SeqDense, AXPY); 1056*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, choleskyfactor, MatCholeskyFactor_SeqDense, SolveCholesky::Factor); 1057*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, lufactor, MatLUFactor_SeqDense, SolveLU::Factor); 1058*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, getcolumnvector, MatGetColumnVector_SeqDense, GetColumnVector); 1059*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, scale, MatScale_SeqDense, Scale); 1060*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, shift, MatShift_SeqDense, Shift); 1061*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, copy, MatCopy_SeqDense, Copy); 1062*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, zeroentries, MatZeroEntries_SeqDense, ZeroEntries); 1063*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, setup, MatSetUp_SeqDense, SetUp); 1064*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, setrandom, MatSetRandom_SeqDense, SetRandom); 1065*4742e46bSJacob Faibussowitsch // seemingly always the same 1066*4742e46bSJacob Faibussowitsch A->ops->productsetfromoptions = MatProductSetFromOptions_SeqDense; 1067*4742e46bSJacob Faibussowitsch 1068*4742e46bSJacob Faibussowitsch if (const auto cmat = mimpl->cmat) PetscCall(MatBindToCPU(cmat, to_host)); 1069*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1070*4742e46bSJacob Faibussowitsch } 1071*4742e46bSJacob Faibussowitsch 1072*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1073*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_SeqDenseCUPM_SeqDense(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept 1074*4742e46bSJacob Faibussowitsch { 1075*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1076*4742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ true>(M, type, reuse, newmat)); 1077*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1078*4742e46bSJacob Faibussowitsch } 1079*4742e46bSJacob Faibussowitsch 1080*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1081*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_SeqDense_SeqDenseCUPM(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept 1082*4742e46bSJacob Faibussowitsch { 1083*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1084*4742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ false>(M, type, reuse, newmat)); 1085*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1086*4742e46bSJacob Faibussowitsch } 1087*4742e46bSJacob Faibussowitsch 1088*4742e46bSJacob Faibussowitsch // ========================================================================================== 1089*4742e46bSJacob Faibussowitsch 1090*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1091*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode access> 1092*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetArray(Mat m, PetscScalar **array, PetscDeviceContext dctx) noexcept 1093*4742e46bSJacob Faibussowitsch { 1094*4742e46bSJacob Faibussowitsch constexpr auto hostmem = PetscMemTypeHost(mtype); 1095*4742e46bSJacob Faibussowitsch constexpr auto read_access = PetscMemoryAccessRead(access); 1096*4742e46bSJacob Faibussowitsch 1097*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1098*4742e46bSJacob Faibussowitsch static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), ""); 1099*4742e46bSJacob Faibussowitsch PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1100*4742e46bSJacob Faibussowitsch if (hostmem) { 1101*4742e46bSJacob Faibussowitsch if (read_access) { 1102*4742e46bSJacob Faibussowitsch PetscCall(DeviceToHost_(m, dctx)); 1103*4742e46bSJacob Faibussowitsch } else if (!MatIMPLCast(m)->v) { 1104*4742e46bSJacob Faibussowitsch // MatCreateSeqDenseCUPM may not allocate CPU memory. Allocate if needed 1105*4742e46bSJacob Faibussowitsch PetscCall(MatSeqDenseSetPreallocation(m, nullptr)); 1106*4742e46bSJacob Faibussowitsch } 1107*4742e46bSJacob Faibussowitsch *array = MatIMPLCast(m)->v; 1108*4742e46bSJacob Faibussowitsch } else { 1109*4742e46bSJacob Faibussowitsch if (read_access) { 1110*4742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(m, dctx)); 1111*4742e46bSJacob Faibussowitsch } else if (!MatCUPMCast(m)->d_v) { 1112*4742e46bSJacob Faibussowitsch // write-only 1113*4742e46bSJacob Faibussowitsch PetscCall(SetPreallocation(m, dctx, nullptr)); 1114*4742e46bSJacob Faibussowitsch } 1115*4742e46bSJacob Faibussowitsch *array = MatCUPMCast(m)->d_v; 1116*4742e46bSJacob Faibussowitsch } 1117*4742e46bSJacob Faibussowitsch if (PetscMemoryAccessWrite(access)) { 1118*4742e46bSJacob Faibussowitsch m->offloadmask = hostmem ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU; 1119*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectStateIncrease(PetscObjectCast(m))); 1120*4742e46bSJacob Faibussowitsch } 1121*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1122*4742e46bSJacob Faibussowitsch } 1123*4742e46bSJacob Faibussowitsch 1124*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1125*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode access> 1126*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreArray(Mat m, PetscScalar **array, PetscDeviceContext) noexcept 1127*4742e46bSJacob Faibussowitsch { 1128*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1129*4742e46bSJacob Faibussowitsch static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), ""); 1130*4742e46bSJacob Faibussowitsch if (PetscMemoryAccessWrite(access)) { 1131*4742e46bSJacob Faibussowitsch // WRITE or READ_WRITE 1132*4742e46bSJacob Faibussowitsch m->offloadmask = PetscMemTypeHost(mtype) ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU; 1133*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectStateIncrease(PetscObjectCast(m))); 1134*4742e46bSJacob Faibussowitsch } 1135*4742e46bSJacob Faibussowitsch *array = nullptr; 1136*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1137*4742e46bSJacob Faibussowitsch } 1138*4742e46bSJacob Faibussowitsch 1139*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1140*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 1141*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetArrayAndMemType(Mat m, PetscScalar **array, PetscMemType *mtype, PetscDeviceContext dctx) noexcept 1142*4742e46bSJacob Faibussowitsch { 1143*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1144*4742e46bSJacob Faibussowitsch PetscCall(GetArray<PETSC_MEMTYPE_DEVICE, access>(m, array, dctx)); 1145*4742e46bSJacob Faibussowitsch if (mtype) *mtype = PETSC_MEMTYPE_CUPM(); 1146*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1147*4742e46bSJacob Faibussowitsch } 1148*4742e46bSJacob Faibussowitsch 1149*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1150*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 1151*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreArrayAndMemType(Mat m, PetscScalar **array, PetscDeviceContext dctx) noexcept 1152*4742e46bSJacob Faibussowitsch { 1153*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1154*4742e46bSJacob Faibussowitsch PetscCall(RestoreArray<PETSC_MEMTYPE_DEVICE, access>(m, array, dctx)); 1155*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1156*4742e46bSJacob Faibussowitsch } 1157*4742e46bSJacob Faibussowitsch 1158*4742e46bSJacob Faibussowitsch // ========================================================================================== 1159*4742e46bSJacob Faibussowitsch 1160*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1161*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept 1162*4742e46bSJacob Faibussowitsch { 1163*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1164*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 1165*4742e46bSJacob Faibussowitsch 1166*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1167*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 1168*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 1169*4742e46bSJacob Faibussowitsch PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME()); 1170*4742e46bSJacob Faibussowitsch if (mimpl->v) { 1171*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1172*4742e46bSJacob Faibussowitsch 1173*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1174*4742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(A, dctx)); 1175*4742e46bSJacob Faibussowitsch } 1176*4742e46bSJacob Faibussowitsch mcu->unplacedarray = util::exchange(mcu->d_v, const_cast<PetscScalar *>(array)); 1177*4742e46bSJacob Faibussowitsch mcu->d_unplaced_user_alloc = util::exchange(mcu->d_user_alloc, PETSC_TRUE); 1178*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1179*4742e46bSJacob Faibussowitsch } 1180*4742e46bSJacob Faibussowitsch 1181*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1182*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept 1183*4742e46bSJacob Faibussowitsch { 1184*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1185*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 1186*4742e46bSJacob Faibussowitsch 1187*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1188*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 1189*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 1190*4742e46bSJacob Faibussowitsch PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME()); 1191*4742e46bSJacob Faibussowitsch if (!mcu->d_user_alloc) { 1192*4742e46bSJacob Faibussowitsch cupmStream_t stream; 1193*4742e46bSJacob Faibussowitsch 1194*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&stream)); 1195*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream)); 1196*4742e46bSJacob Faibussowitsch } 1197*4742e46bSJacob Faibussowitsch mcu->d_v = const_cast<PetscScalar *>(array); 1198*4742e46bSJacob Faibussowitsch mcu->d_user_alloc = PETSC_FALSE; 1199*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1200*4742e46bSJacob Faibussowitsch } 1201*4742e46bSJacob Faibussowitsch 1202*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1203*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ResetArray(Mat A) noexcept 1204*4742e46bSJacob Faibussowitsch { 1205*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1206*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 1207*4742e46bSJacob Faibussowitsch 1208*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1209*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 1210*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 1211*4742e46bSJacob Faibussowitsch if (mimpl->v) { 1212*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1213*4742e46bSJacob Faibussowitsch 1214*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1215*4742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(A, dctx)); 1216*4742e46bSJacob Faibussowitsch } 1217*4742e46bSJacob Faibussowitsch mcu->d_v = util::exchange(mcu->unplacedarray, nullptr); 1218*4742e46bSJacob Faibussowitsch mcu->d_user_alloc = mcu->d_unplaced_user_alloc; 1219*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1220*4742e46bSJacob Faibussowitsch } 1221*4742e46bSJacob Faibussowitsch 1222*4742e46bSJacob Faibussowitsch // ========================================================================================== 1223*4742e46bSJacob Faibussowitsch 1224*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1225*4742e46bSJacob Faibussowitsch template <bool transpose_A, bool transpose_B> 1226*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMatMult_Numeric_Dispatch(Mat A, Mat B, Mat C) noexcept 1227*4742e46bSJacob Faibussowitsch { 1228*4742e46bSJacob Faibussowitsch cupmBlasInt_t m, n, k; 1229*4742e46bSJacob Faibussowitsch PetscBool Aiscupm, Biscupm; 1230*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1231*4742e46bSJacob Faibussowitsch cupmBlasHandle_t handle; 1232*4742e46bSJacob Faibussowitsch 1233*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1234*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(C->rmap->n, &m)); 1235*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(C->cmap->n, &n)); 1236*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(transpose_A ? A->rmap->n : A->cmap->n, &k)); 1237*4742e46bSJacob Faibussowitsch if (!m || !n || !k) PetscFunctionReturn(PETSC_SUCCESS); 1238*4742e46bSJacob Faibussowitsch 1239*4742e46bSJacob Faibussowitsch // we may end up with SEQDENSE as one of the arguments 1240*4742e46bSJacob Faibussowitsch // REVIEW ME: how? and why is it not B and C???????? 1241*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), MATSEQDENSECUPM(), &Aiscupm)); 1242*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(B), MATSEQDENSECUPM(), &Biscupm)); 1243*4742e46bSJacob Faibussowitsch if (!Aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 1244*4742e46bSJacob Faibussowitsch if (!Biscupm) PetscCall(MatConvert(B, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &B)); 1245*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(C, "Matrix-Matrix product %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " on backend\n", m, k, n)); 1246*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle)); 1247*4742e46bSJacob Faibussowitsch 1248*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 1249*4742e46bSJacob Faibussowitsch { 1250*4742e46bSJacob Faibussowitsch const auto one = cupmScalarCast(1.0); 1251*4742e46bSJacob Faibussowitsch const auto zero = cupmScalarCast(0.0); 1252*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 1253*4742e46bSJacob Faibussowitsch const auto db = DeviceArrayRead(dctx, B); 1254*4742e46bSJacob Faibussowitsch const auto dc = DeviceArrayWrite(dctx, C); 1255*4742e46bSJacob Faibussowitsch PetscInt alda, blda, clda; 1256*4742e46bSJacob Faibussowitsch 1257*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(A, &alda)); 1258*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(B, &blda)); 1259*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(C, &clda)); 1260*4742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXgemm(handle, transpose_A ? CUPMBLAS_OP_T : CUPMBLAS_OP_N, transpose_B ? CUPMBLAS_OP_T : CUPMBLAS_OP_N, m, n, k, &one, da.cupmdata(), alda, db.cupmdata(), blda, &zero, dc.cupmdata(), clda)); 1261*4742e46bSJacob Faibussowitsch } 1262*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 1263*4742e46bSJacob Faibussowitsch 1264*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(1.0 * m * n * k + 1.0 * m * n * (k - 1))); 1265*4742e46bSJacob Faibussowitsch if (!Aiscupm) PetscCall(MatConvert(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A)); 1266*4742e46bSJacob Faibussowitsch if (!Biscupm) PetscCall(MatConvert(B, MATSEQDENSE, MAT_INPLACE_MATRIX, &B)); 1267*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1268*4742e46bSJacob Faibussowitsch } 1269*4742e46bSJacob Faibussowitsch 1270*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1271*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Copy(Mat A, Mat B, MatStructure str) noexcept 1272*4742e46bSJacob Faibussowitsch { 1273*4742e46bSJacob Faibussowitsch const auto m = A->rmap->n; 1274*4742e46bSJacob Faibussowitsch const auto n = A->cmap->n; 1275*4742e46bSJacob Faibussowitsch 1276*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1277*4742e46bSJacob Faibussowitsch PetscAssert(m == B->rmap->n && n == B->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "size(B) != size(A)"); 1278*4742e46bSJacob Faibussowitsch // The two matrices must have the same copy implementation to be eligible for fast copy 1279*4742e46bSJacob Faibussowitsch if (A->ops->copy == B->ops->copy) { 1280*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1281*4742e46bSJacob Faibussowitsch cupmStream_t stream; 1282*4742e46bSJacob Faibussowitsch 1283*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 1284*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 1285*4742e46bSJacob Faibussowitsch { 1286*4742e46bSJacob Faibussowitsch const auto va = DeviceArrayRead(dctx, A); 1287*4742e46bSJacob Faibussowitsch const auto vb = DeviceArrayWrite(dctx, B); 1288*4742e46bSJacob Faibussowitsch // order is important, DeviceArrayRead/Write() might call SetPreallocation() which sets 1289*4742e46bSJacob Faibussowitsch // lda! 1290*4742e46bSJacob Faibussowitsch const auto lda_a = MatIMPLCast(A)->lda; 1291*4742e46bSJacob Faibussowitsch const auto lda_b = MatIMPLCast(B)->lda; 1292*4742e46bSJacob Faibussowitsch 1293*4742e46bSJacob Faibussowitsch if (lda_a > m || lda_b > m) { 1294*4742e46bSJacob Faibussowitsch PetscAssert(lda_b > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "B lda (%" PetscBLASInt_FMT ") must be > 0 at this point, this indicates Mat%sSetPreallocation() was not called when it should have been!", lda_b, cupmNAME()); 1295*4742e46bSJacob Faibussowitsch PetscAssert(lda_a > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A lda (%" PetscBLASInt_FMT ") must be > 0 at this point, this indicates Mat%sSetPreallocation() was not called when it should have been!", lda_a, cupmNAME()); 1296*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(vb.data(), lda_b, va.data(), lda_a, m, n, cupmMemcpyDeviceToDevice, stream)); 1297*4742e46bSJacob Faibussowitsch } else { 1298*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(vb.data(), va.data(), m * n, cupmMemcpyDeviceToDevice, stream)); 1299*4742e46bSJacob Faibussowitsch } 1300*4742e46bSJacob Faibussowitsch } 1301*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 1302*4742e46bSJacob Faibussowitsch } else { 1303*4742e46bSJacob Faibussowitsch PetscCall(MatCopy_Basic(A, B, str)); 1304*4742e46bSJacob Faibussowitsch } 1305*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1306*4742e46bSJacob Faibussowitsch } 1307*4742e46bSJacob Faibussowitsch 1308*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1309*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ZeroEntries(Mat m) noexcept 1310*4742e46bSJacob Faibussowitsch { 1311*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1312*4742e46bSJacob Faibussowitsch cupmStream_t stream; 1313*4742e46bSJacob Faibussowitsch 1314*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1315*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 1316*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 1317*4742e46bSJacob Faibussowitsch { 1318*4742e46bSJacob Faibussowitsch const auto va = DeviceArrayWrite(dctx, m); 1319*4742e46bSJacob Faibussowitsch const auto lda = MatIMPLCast(m)->lda; 1320*4742e46bSJacob Faibussowitsch const auto ma = m->rmap->n; 1321*4742e46bSJacob Faibussowitsch const auto na = m->cmap->n; 1322*4742e46bSJacob Faibussowitsch 1323*4742e46bSJacob Faibussowitsch if (lda > ma) { 1324*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemset2DAsync(va.data(), lda, 0, ma, na, stream)); 1325*4742e46bSJacob Faibussowitsch } else { 1326*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemsetAsync(va.data(), 0, ma * na, stream)); 1327*4742e46bSJacob Faibussowitsch } 1328*4742e46bSJacob Faibussowitsch } 1329*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 1330*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1331*4742e46bSJacob Faibussowitsch } 1332*4742e46bSJacob Faibussowitsch 1333*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1334*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Scale(Mat A, PetscScalar alpha) noexcept 1335*4742e46bSJacob Faibussowitsch { 1336*4742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 1337*4742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 1338*4742e46bSJacob Faibussowitsch const auto N = m * n; 1339*4742e46bSJacob Faibussowitsch cupmBlasHandle_t handle; 1340*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1341*4742e46bSJacob Faibussowitsch 1342*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1343*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "Performing Scale %d x %d on backend\n", m, n)); 1344*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle)); 1345*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 1346*4742e46bSJacob Faibussowitsch { 1347*4742e46bSJacob Faibussowitsch const auto cu_alpha = cupmScalarCast(alpha); 1348*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 1349*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 1350*4742e46bSJacob Faibussowitsch 1351*4742e46bSJacob Faibussowitsch if (lda > m) { 1352*4742e46bSJacob Faibussowitsch for (cupmBlasInt_t j = 0; j < n; ++j) PetscCallCUPMBLAS(cupmBlasXscal(handle, m, &cu_alpha, da.cupmdata() + lda * j, 1)); 1353*4742e46bSJacob Faibussowitsch } else { 1354*4742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXscal(handle, N, &cu_alpha, da.cupmdata(), 1)); 1355*4742e46bSJacob Faibussowitsch } 1356*4742e46bSJacob Faibussowitsch } 1357*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 1358*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(N)); 1359*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1360*4742e46bSJacob Faibussowitsch } 1361*4742e46bSJacob Faibussowitsch 1362*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1363*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept 1364*4742e46bSJacob Faibussowitsch { 1365*4742e46bSJacob Faibussowitsch const auto m = A->rmap->n; 1366*4742e46bSJacob Faibussowitsch const auto n = A->cmap->n; 1367*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1368*4742e46bSJacob Faibussowitsch 1369*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1370*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1371*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "Performing Shift %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", m, n)); 1372*4742e46bSJacob Faibussowitsch PetscCall(PointwiseUnaryTransform(A, 0, m, n, dctx, device::cupm::functors::make_plus_equals(alpha))); 1373*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1374*4742e46bSJacob Faibussowitsch } 1375*4742e46bSJacob Faibussowitsch 1376*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1377*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::AXPY(Mat Y, PetscScalar alpha, Mat X, MatStructure) noexcept 1378*4742e46bSJacob Faibussowitsch { 1379*4742e46bSJacob Faibussowitsch const auto m_x = X->rmap->n, m_y = Y->rmap->n; 1380*4742e46bSJacob Faibussowitsch const auto n_x = X->cmap->n, n_y = Y->cmap->n; 1381*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1382*4742e46bSJacob Faibussowitsch cupmBlasHandle_t handle; 1383*4742e46bSJacob Faibussowitsch 1384*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1385*4742e46bSJacob Faibussowitsch if (!m_x || !n_x) PetscFunctionReturn(PETSC_SUCCESS); 1386*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(Y, "Performing AXPY %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", m_y, n_y)); 1387*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle)); 1388*4742e46bSJacob Faibussowitsch { 1389*4742e46bSJacob Faibussowitsch const auto N = m_x * n_x; 1390*4742e46bSJacob Faibussowitsch const auto dx = DeviceArrayRead(dctx, X); 1391*4742e46bSJacob Faibussowitsch const auto dy = alpha == 0.0 ? DeviceArrayWrite(dctx, Y).cupmdata() : DeviceArrayReadWrite(dctx, Y).cupmdata(); 1392*4742e46bSJacob Faibussowitsch const auto ldax = static_cast<cupmBlasInt_t>(MatIMPLCast(X)->lda); 1393*4742e46bSJacob Faibussowitsch const auto lday = static_cast<cupmBlasInt_t>(MatIMPLCast(Y)->lda); 1394*4742e46bSJacob Faibussowitsch const auto cu_alpha = cupmScalarCast(alpha); 1395*4742e46bSJacob Faibussowitsch 1396*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 1397*4742e46bSJacob Faibussowitsch if (ldax > m_x || lday > m_x) { 1398*4742e46bSJacob Faibussowitsch for (cupmBlasInt_t j = 0; j < n_x; j++) PetscCallCUPMBLAS(cupmBlasXaxpy(handle, m_x, &cu_alpha, dx.cupmdata() + j * ldax, 1, dy + j * lday, 1)); 1399*4742e46bSJacob Faibussowitsch } else { 1400*4742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXaxpy(handle, N, &cu_alpha, dx.cupmdata(), 1, dy, 1)); 1401*4742e46bSJacob Faibussowitsch } 1402*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 1403*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(PetscMax(2 * N - 1, 0))); 1404*4742e46bSJacob Faibussowitsch } 1405*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1406*4742e46bSJacob Faibussowitsch } 1407*4742e46bSJacob Faibussowitsch 1408*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1409*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Duplicate(Mat A, MatDuplicateOption opt, Mat *B) noexcept 1410*4742e46bSJacob Faibussowitsch { 1411*4742e46bSJacob Faibussowitsch const auto hopt = (opt == MAT_COPY_VALUES && A->offloadmask != PETSC_OFFLOAD_CPU) ? MAT_DO_NOT_COPY_VALUES : opt; 1412*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1413*4742e46bSJacob Faibussowitsch 1414*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1415*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1416*4742e46bSJacob Faibussowitsch // do not call SetPreallocation() yet, we call it afterwards?? 1417*4742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->n, A->cmap->n, nullptr, B, dctx, /* preallocate */ false)); 1418*4742e46bSJacob Faibussowitsch PetscCall(MatDuplicateNoCreate_SeqDense(*B, A, hopt)); 1419*4742e46bSJacob Faibussowitsch if (opt == MAT_COPY_VALUES && hopt != MAT_COPY_VALUES) PetscCall(Copy(A, *B, SAME_NONZERO_PATTERN)); 1420*4742e46bSJacob Faibussowitsch // allocate memory if needed 1421*4742e46bSJacob Faibussowitsch if (opt != MAT_COPY_VALUES && !MatCUPMCast(*B)->d_v) PetscCall(SetPreallocation(*B, dctx)); 1422*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1423*4742e46bSJacob Faibussowitsch } 1424*4742e46bSJacob Faibussowitsch 1425*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1426*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetRandom(Mat A, PetscRandom rng) noexcept 1427*4742e46bSJacob Faibussowitsch { 1428*4742e46bSJacob Faibussowitsch PetscBool device; 1429*4742e46bSJacob Faibussowitsch 1430*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1431*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(rng), PETSCDEVICERAND(), &device)); 1432*4742e46bSJacob Faibussowitsch if (device) { 1433*4742e46bSJacob Faibussowitsch const auto m = A->rmap->n; 1434*4742e46bSJacob Faibussowitsch const auto n = A->cmap->n; 1435*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1436*4742e46bSJacob Faibussowitsch 1437*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1438*4742e46bSJacob Faibussowitsch { 1439*4742e46bSJacob Faibussowitsch const auto a = DeviceArrayWrite(dctx, A); 1440*4742e46bSJacob Faibussowitsch PetscInt lda; 1441*4742e46bSJacob Faibussowitsch 1442*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(A, &lda)); 1443*4742e46bSJacob Faibussowitsch if (lda > m) { 1444*4742e46bSJacob Faibussowitsch for (PetscInt i = 0; i < n; i++) PetscCall(PetscRandomGetValues(rng, m, a.data() + i * lda)); 1445*4742e46bSJacob Faibussowitsch } else { 1446*4742e46bSJacob Faibussowitsch PetscInt mn; 1447*4742e46bSJacob Faibussowitsch 1448*4742e46bSJacob Faibussowitsch PetscCall(PetscIntMultError(m, n, &mn)); 1449*4742e46bSJacob Faibussowitsch PetscCall(PetscRandomGetValues(rng, mn, a)); 1450*4742e46bSJacob Faibussowitsch } 1451*4742e46bSJacob Faibussowitsch } 1452*4742e46bSJacob Faibussowitsch } else { 1453*4742e46bSJacob Faibussowitsch PetscCall(MatSetRandom_SeqDense(A, rng)); 1454*4742e46bSJacob Faibussowitsch } 1455*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1456*4742e46bSJacob Faibussowitsch } 1457*4742e46bSJacob Faibussowitsch 1458*4742e46bSJacob Faibussowitsch // ========================================================================================== 1459*4742e46bSJacob Faibussowitsch 1460*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1461*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetColumnVector(Mat A, Vec v, PetscInt col) noexcept 1462*4742e46bSJacob Faibussowitsch { 1463*4742e46bSJacob Faibussowitsch const auto offloadmask = A->offloadmask; 1464*4742e46bSJacob Faibussowitsch const auto n = A->rmap->n; 1465*4742e46bSJacob Faibussowitsch const auto col_offset = [&](const PetscScalar *ptr) { return ptr + col * MatIMPLCast(A)->lda; }; 1466*4742e46bSJacob Faibussowitsch PetscBool viscupm; 1467*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1468*4742e46bSJacob Faibussowitsch cupmStream_t stream; 1469*4742e46bSJacob Faibussowitsch 1470*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1471*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(v), &viscupm, VecSeq_CUPM::VECSEQCUPM(), VecSeq_CUPM::VECMPICUPM(), VecSeq_CUPM::VECCUPM(), "")); 1472*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 1473*4742e46bSJacob Faibussowitsch if (viscupm && !v->boundtocpu) { 1474*4742e46bSJacob Faibussowitsch const auto x = VecSeq_CUPM::DeviceArrayWrite(dctx, v); 1475*4742e46bSJacob Faibussowitsch 1476*4742e46bSJacob Faibussowitsch // update device data 1477*4742e46bSJacob Faibussowitsch if (PetscOffloadDevice(offloadmask)) { 1478*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(x.data(), col_offset(DeviceArrayRead(dctx, A)), n, cupmMemcpyDeviceToDevice, stream)); 1479*4742e46bSJacob Faibussowitsch } else { 1480*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(x.data(), col_offset(HostArrayRead(dctx, A)), n, cupmMemcpyHostToDevice, stream)); 1481*4742e46bSJacob Faibussowitsch } 1482*4742e46bSJacob Faibussowitsch } else { 1483*4742e46bSJacob Faibussowitsch PetscScalar *x; 1484*4742e46bSJacob Faibussowitsch 1485*4742e46bSJacob Faibussowitsch // update host data 1486*4742e46bSJacob Faibussowitsch PetscCall(VecGetArrayWrite(v, &x)); 1487*4742e46bSJacob Faibussowitsch if (PetscOffloadUnallocated(offloadmask) || PetscOffloadHost(offloadmask)) { 1488*4742e46bSJacob Faibussowitsch PetscCall(PetscArraycpy(x, col_offset(HostArrayRead(dctx, A)), n)); 1489*4742e46bSJacob Faibussowitsch } else if (PetscOffloadDevice(offloadmask)) { 1490*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(x, col_offset(DeviceArrayRead(dctx, A)), n, cupmMemcpyDeviceToHost, stream)); 1491*4742e46bSJacob Faibussowitsch } 1492*4742e46bSJacob Faibussowitsch PetscCall(VecRestoreArrayWrite(v, &x)); 1493*4742e46bSJacob Faibussowitsch } 1494*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1495*4742e46bSJacob Faibussowitsch } 1496*4742e46bSJacob Faibussowitsch 1497*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1498*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 1499*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept 1500*4742e46bSJacob Faibussowitsch { 1501*4742e46bSJacob Faibussowitsch using namespace vec::cupm; 1502*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1503*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1504*4742e46bSJacob Faibussowitsch 1505*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1506*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 1507*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 1508*4742e46bSJacob Faibussowitsch mimpl->vecinuse = col + 1; 1509*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1510*4742e46bSJacob Faibussowitsch PetscCall(GetArray<PETSC_MEMTYPE_DEVICE, access>(A, const_cast<PetscScalar **>(&mimpl->ptrinuse), dctx)); 1511*4742e46bSJacob Faibussowitsch if (!mimpl->cvec) { 1512*4742e46bSJacob Faibussowitsch // we pass the data of A, to prevent allocating needless GPU memory the first time 1513*4742e46bSJacob Faibussowitsch // VecCUPMPlaceArray is called 1514*4742e46bSJacob Faibussowitsch PetscCall(VecCreateSeqCUPMWithArraysAsync<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->bs, A->rmap->n, nullptr, mimpl->ptrinuse, &mimpl->cvec)); 1515*4742e46bSJacob Faibussowitsch } 1516*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(mimpl->lda))); 1517*4742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec)); 1518*4742e46bSJacob Faibussowitsch *v = mimpl->cvec; 1519*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1520*4742e46bSJacob Faibussowitsch } 1521*4742e46bSJacob Faibussowitsch 1522*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1523*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 1524*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept 1525*4742e46bSJacob Faibussowitsch { 1526*4742e46bSJacob Faibussowitsch using namespace vec::cupm; 1527*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1528*4742e46bSJacob Faibussowitsch const auto cvec = mimpl->cvec; 1529*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1530*4742e46bSJacob Faibussowitsch 1531*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1532*4742e46bSJacob Faibussowitsch PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first"); 1533*4742e46bSJacob Faibussowitsch PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector"); 1534*4742e46bSJacob Faibussowitsch mimpl->vecinuse = 0; 1535*4742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec)); 1536*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMResetArrayAsync<T>(cvec)); 1537*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1538*4742e46bSJacob Faibussowitsch PetscCall(RestoreArray<PETSC_MEMTYPE_DEVICE, access>(A, const_cast<PetscScalar **>(&mimpl->ptrinuse), dctx)); 1539*4742e46bSJacob Faibussowitsch if (v) *v = nullptr; 1540*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1541*4742e46bSJacob Faibussowitsch } 1542*4742e46bSJacob Faibussowitsch 1543*4742e46bSJacob Faibussowitsch // ========================================================================================== 1544*4742e46bSJacob Faibussowitsch 1545*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1546*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetFactor(Mat A, MatFactorType ftype, Mat *fact_out) noexcept 1547*4742e46bSJacob Faibussowitsch { 1548*4742e46bSJacob Faibussowitsch Mat fact; 1549*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1550*4742e46bSJacob Faibussowitsch 1551*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1552*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1553*4742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->n, A->cmap->n, nullptr, &fact, dctx, /* preallocate */ false)); 1554*4742e46bSJacob Faibussowitsch fact->factortype = ftype; 1555*4742e46bSJacob Faibussowitsch switch (ftype) { 1556*4742e46bSJacob Faibussowitsch case MAT_FACTOR_LU: 1557*4742e46bSJacob Faibussowitsch case MAT_FACTOR_ILU: // fall-through 1558*4742e46bSJacob Faibussowitsch fact->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqDense; 1559*4742e46bSJacob Faibussowitsch fact->ops->ilufactorsymbolic = MatLUFactorSymbolic_SeqDense; 1560*4742e46bSJacob Faibussowitsch break; 1561*4742e46bSJacob Faibussowitsch case MAT_FACTOR_CHOLESKY: 1562*4742e46bSJacob Faibussowitsch case MAT_FACTOR_ICC: // fall-through 1563*4742e46bSJacob Faibussowitsch fact->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqDense; 1564*4742e46bSJacob Faibussowitsch break; 1565*4742e46bSJacob Faibussowitsch case MAT_FACTOR_QR: { 1566*4742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(fact); 1567*4742e46bSJacob Faibussowitsch 1568*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction(pobj, "MatQRFactor_C", MatQRFactor_SeqDense)); 1569*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction(pobj, "MatQRFactorSymbolic_C", MatQRFactorSymbolic_SeqDense)); 1570*4742e46bSJacob Faibussowitsch } break; 1571*4742e46bSJacob Faibussowitsch case MAT_FACTOR_NONE: 1572*4742e46bSJacob Faibussowitsch case MAT_FACTOR_ILUDT: // fall-through 1573*4742e46bSJacob Faibussowitsch case MAT_FACTOR_NUM_TYPES: // fall-through 1574*4742e46bSJacob Faibussowitsch SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s not supported", MatFactorTypes[ftype]); 1575*4742e46bSJacob Faibussowitsch } 1576*4742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(MATSOLVERCUPM(), &fact->solvertype)); 1577*4742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_LU)); 1578*4742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_ILU)); 1579*4742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_CHOLESKY)); 1580*4742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_ICC)); 1581*4742e46bSJacob Faibussowitsch *fact_out = fact; 1582*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1583*4742e46bSJacob Faibussowitsch } 1584*4742e46bSJacob Faibussowitsch 1585*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1586*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::InvertFactors(Mat A) noexcept 1587*4742e46bSJacob Faibussowitsch { 1588*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1589*4742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 1590*4742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 1591*4742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 1592*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1593*4742e46bSJacob Faibussowitsch cupmStream_t stream; 1594*4742e46bSJacob Faibussowitsch 1595*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1596*4742e46bSJacob Faibussowitsch #if PetscDefined(HAVE_CUDA) && PetscDefined(USING_NVCC) 1597*4742e46bSJacob Faibussowitsch // HIP appears to have this by default?? 1598*4742e46bSJacob Faibussowitsch PetscCheck(PETSC_PKG_CUDA_VERSION_GE(10, 1, 0), PETSC_COMM_SELF, PETSC_ERR_SUP, "Upgrade to CUDA version 10.1.0 or higher"); 1599*4742e46bSJacob Faibussowitsch #endif 1600*4742e46bSJacob Faibussowitsch if (!n || !A->rmap->n) PetscFunctionReturn(PETSC_SUCCESS); 1601*4742e46bSJacob Faibussowitsch PetscCheck(A->factortype == MAT_FACTOR_CHOLESKY, PETSC_COMM_SELF, PETSC_ERR_LIB, "Factor type %s not implemented", MatFactorTypes[A->factortype]); 1602*4742e46bSJacob Faibussowitsch // spd 1603*4742e46bSJacob Faibussowitsch PetscCheck(!mcu->d_fact_ipiv, PETSC_COMM_SELF, PETSC_ERR_LIB, "%sDnsytri not implemented", cupmSolverName()); 1604*4742e46bSJacob Faibussowitsch 1605*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 1606*4742e46bSJacob Faibussowitsch { 1607*4742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 1608*4742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda); 1609*4742e46bSJacob Faibussowitsch cupmBlasInt_t il; 1610*4742e46bSJacob Faibussowitsch 1611*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotri_bufferSize(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, &il)); 1612*4742e46bSJacob Faibussowitsch if (il > mcu->d_fact_lwork) { 1613*4742e46bSJacob Faibussowitsch mcu->d_fact_lwork = il; 1614*4742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream)); 1615*4742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, il, stream)); 1616*4742e46bSJacob Faibussowitsch } 1617*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 1618*4742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotri(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 1619*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 1620*4742e46bSJacob Faibussowitsch } 1621*4742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 1622*4742e46bSJacob Faibussowitsch // TODO (write cuda kernel) 1623*4742e46bSJacob Faibussowitsch PetscCall(MatSeqDenseSymmetrize_Private(A, PETSC_TRUE)); 1624*4742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(1.0 * n * n * n / 3.0)); 1625*4742e46bSJacob Faibussowitsch 1626*4742e46bSJacob Faibussowitsch A->ops->solve = nullptr; 1627*4742e46bSJacob Faibussowitsch A->ops->solvetranspose = nullptr; 1628*4742e46bSJacob Faibussowitsch A->ops->matsolve = nullptr; 1629*4742e46bSJacob Faibussowitsch A->factortype = MAT_FACTOR_NONE; 1630*4742e46bSJacob Faibussowitsch 1631*4742e46bSJacob Faibussowitsch PetscCall(PetscFree(A->solvertype)); 1632*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1633*4742e46bSJacob Faibussowitsch } 1634*4742e46bSJacob Faibussowitsch 1635*4742e46bSJacob Faibussowitsch // ========================================================================================== 1636*4742e46bSJacob Faibussowitsch 1637*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1638*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetSubMatrix(Mat A, PetscInt rbegin, PetscInt rend, PetscInt cbegin, PetscInt cend, Mat *mat) noexcept 1639*4742e46bSJacob Faibussowitsch { 1640*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1641*4742e46bSJacob Faibussowitsch const auto array_offset = [&](PetscScalar *ptr) { return ptr + rbegin + static_cast<std::size_t>(cbegin) * mimpl->lda; }; 1642*4742e46bSJacob Faibussowitsch const auto n = rend - rbegin; 1643*4742e46bSJacob Faibussowitsch const auto m = cend - cbegin; 1644*4742e46bSJacob Faibussowitsch auto &cmat = mimpl->cmat; 1645*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1646*4742e46bSJacob Faibussowitsch 1647*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1648*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 1649*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 1650*4742e46bSJacob Faibussowitsch mimpl->matinuse = cbegin + 1; 1651*4742e46bSJacob Faibussowitsch 1652*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1653*4742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(A, dctx)); 1654*4742e46bSJacob Faibussowitsch 1655*4742e46bSJacob Faibussowitsch if (cmat && ((m != cmat->cmap->N) || (n != cmat->rmap->N))) PetscCall(MatDestroy(&cmat)); 1656*4742e46bSJacob Faibussowitsch { 1657*4742e46bSJacob Faibussowitsch const auto device_array = array_offset(MatCUPMCast(A)->d_v); 1658*4742e46bSJacob Faibussowitsch 1659*4742e46bSJacob Faibussowitsch if (cmat) { 1660*4742e46bSJacob Faibussowitsch PetscCall(PlaceArray(cmat, device_array)); 1661*4742e46bSJacob Faibussowitsch } else { 1662*4742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), n, m, device_array, &cmat, dctx)); 1663*4742e46bSJacob Faibussowitsch } 1664*4742e46bSJacob Faibussowitsch } 1665*4742e46bSJacob Faibussowitsch PetscCall(MatDenseSetLDA(cmat, mimpl->lda)); 1666*4742e46bSJacob Faibussowitsch // place CPU array if present but do not copy any data 1667*4742e46bSJacob Faibussowitsch if (const auto host_array = mimpl->v) { 1668*4742e46bSJacob Faibussowitsch cmat->offloadmask = PETSC_OFFLOAD_GPU; 1669*4742e46bSJacob Faibussowitsch PetscCall(MatDensePlaceArray(cmat, array_offset(host_array))); 1670*4742e46bSJacob Faibussowitsch } 1671*4742e46bSJacob Faibussowitsch 1672*4742e46bSJacob Faibussowitsch cmat->offloadmask = A->offloadmask; 1673*4742e46bSJacob Faibussowitsch *mat = cmat; 1674*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1675*4742e46bSJacob Faibussowitsch } 1676*4742e46bSJacob Faibussowitsch 1677*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1678*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreSubMatrix(Mat A, Mat *m) noexcept 1679*4742e46bSJacob Faibussowitsch { 1680*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 1681*4742e46bSJacob Faibussowitsch const auto cmat = mimpl->cmat; 1682*4742e46bSJacob Faibussowitsch const auto reset = static_cast<bool>(mimpl->v); 1683*4742e46bSJacob Faibussowitsch bool copy, was_offload_host; 1684*4742e46bSJacob Faibussowitsch 1685*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1686*4742e46bSJacob Faibussowitsch PetscCheck(mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetSubMatrix() first"); 1687*4742e46bSJacob Faibussowitsch PetscCheck(cmat, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column matrix"); 1688*4742e46bSJacob Faibussowitsch PetscCheck(*m == cmat, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Not the matrix obtained from MatDenseGetSubMatrix()"); 1689*4742e46bSJacob Faibussowitsch mimpl->matinuse = 0; 1690*4742e46bSJacob Faibussowitsch 1691*4742e46bSJacob Faibussowitsch // calls to ResetArray may change it, so save it here 1692*4742e46bSJacob Faibussowitsch was_offload_host = cmat->offloadmask == PETSC_OFFLOAD_CPU; 1693*4742e46bSJacob Faibussowitsch if (was_offload_host && !reset) { 1694*4742e46bSJacob Faibussowitsch copy = true; 1695*4742e46bSJacob Faibussowitsch PetscCall(MatSeqDenseSetPreallocation(A, nullptr)); 1696*4742e46bSJacob Faibussowitsch } else { 1697*4742e46bSJacob Faibussowitsch copy = false; 1698*4742e46bSJacob Faibussowitsch } 1699*4742e46bSJacob Faibussowitsch 1700*4742e46bSJacob Faibussowitsch PetscCall(ResetArray(cmat)); 1701*4742e46bSJacob Faibussowitsch if (reset) PetscCall(MatDenseResetArray(cmat)); 1702*4742e46bSJacob Faibussowitsch if (copy) { 1703*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1704*4742e46bSJacob Faibussowitsch 1705*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1706*4742e46bSJacob Faibussowitsch PetscCall(DeviceToHost_(A, dctx)); 1707*4742e46bSJacob Faibussowitsch } else { 1708*4742e46bSJacob Faibussowitsch A->offloadmask = was_offload_host ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU; 1709*4742e46bSJacob Faibussowitsch } 1710*4742e46bSJacob Faibussowitsch 1711*4742e46bSJacob Faibussowitsch cmat->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 1712*4742e46bSJacob Faibussowitsch *m = nullptr; 1713*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1714*4742e46bSJacob Faibussowitsch } 1715*4742e46bSJacob Faibussowitsch 1716*4742e46bSJacob Faibussowitsch // ========================================================================================== 1717*4742e46bSJacob Faibussowitsch 1718*4742e46bSJacob Faibussowitsch namespace 1719*4742e46bSJacob Faibussowitsch { 1720*4742e46bSJacob Faibussowitsch 1721*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1722*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatMatMultNumeric_SeqDenseCUPM_SeqDenseCUPM(Mat A, Mat B, Mat C, PetscBool TA, PetscBool TB) noexcept 1723*4742e46bSJacob Faibussowitsch { 1724*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1725*4742e46bSJacob Faibussowitsch if (TA) { 1726*4742e46bSJacob Faibussowitsch if (TB) { 1727*4742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<true, true>(A, B, C)); 1728*4742e46bSJacob Faibussowitsch } else { 1729*4742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<true, false>(A, B, C)); 1730*4742e46bSJacob Faibussowitsch } 1731*4742e46bSJacob Faibussowitsch } else { 1732*4742e46bSJacob Faibussowitsch if (TB) { 1733*4742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<false, true>(A, B, C)); 1734*4742e46bSJacob Faibussowitsch } else { 1735*4742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<false, false>(A, B, C)); 1736*4742e46bSJacob Faibussowitsch } 1737*4742e46bSJacob Faibussowitsch } 1738*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1739*4742e46bSJacob Faibussowitsch } 1740*4742e46bSJacob Faibussowitsch 1741*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1742*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatSolverTypeRegister_DENSECUPM() noexcept 1743*4742e46bSJacob Faibussowitsch { 1744*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 1745*4742e46bSJacob Faibussowitsch for (auto ftype : util::make_array(MAT_FACTOR_LU, MAT_FACTOR_CHOLESKY, MAT_FACTOR_QR)) { 1746*4742e46bSJacob Faibussowitsch PetscCall(MatSolverTypeRegister(MatDense_Seq_CUPM<T>::MATSOLVERCUPM(), MATSEQDENSE, ftype, MatDense_Seq_CUPM<T>::GetFactor)); 1747*4742e46bSJacob Faibussowitsch PetscCall(MatSolverTypeRegister(MatDense_Seq_CUPM<T>::MATSOLVERCUPM(), MatDense_Seq_CUPM<T>::MATSEQDENSECUPM(), ftype, MatDense_Seq_CUPM<T>::GetFactor)); 1748*4742e46bSJacob Faibussowitsch } 1749*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1750*4742e46bSJacob Faibussowitsch } 1751*4742e46bSJacob Faibussowitsch 1752*4742e46bSJacob Faibussowitsch } // anonymous namespace 1753*4742e46bSJacob Faibussowitsch 1754*4742e46bSJacob Faibussowitsch } // namespace impl 1755*4742e46bSJacob Faibussowitsch 1756*4742e46bSJacob Faibussowitsch } // namespace cupm 1757*4742e46bSJacob Faibussowitsch 1758*4742e46bSJacob Faibussowitsch } // namespace mat 1759*4742e46bSJacob Faibussowitsch 1760*4742e46bSJacob Faibussowitsch } // namespace Petsc 1761*4742e46bSJacob Faibussowitsch 1762*4742e46bSJacob Faibussowitsch #endif // __cplusplus 1763*4742e46bSJacob Faibussowitsch 1764*4742e46bSJacob Faibussowitsch #endif // PETSCMATSEQDENSECUPM_HPP 1765