1a4963045SJacob Faibussowitsch #pragma once 24742e46bSJacob Faibussowitsch 34742e46bSJacob Faibussowitsch #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/ 44742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/seq/dense.h> 54742e46bSJacob Faibussowitsch 64742e46bSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> // PetscDeviceContextGetOptionalNullContext_Internal() 74742e46bSJacob Faibussowitsch #include <petsc/private/randomimpl.h> // _p_PetscRandom 84742e46bSJacob Faibussowitsch #include <petsc/private/vecimpl.h> // _p_Vec 94742e46bSJacob Faibussowitsch #include <petsc/private/cupmobject.hpp> 104742e46bSJacob Faibussowitsch #include <petsc/private/cupmsolverinterface.hpp> 114742e46bSJacob Faibussowitsch 124742e46bSJacob Faibussowitsch #include <petsc/private/cpp/type_traits.hpp> // PetscObjectCast() 134742e46bSJacob Faibussowitsch #include <petsc/private/cpp/utility.hpp> // util::exchange() 144742e46bSJacob Faibussowitsch 154742e46bSJacob Faibussowitsch #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp> // for VecSeq_CUPM 164742e46bSJacob Faibussowitsch 174742e46bSJacob Faibussowitsch namespace Petsc 184742e46bSJacob Faibussowitsch { 194742e46bSJacob Faibussowitsch 204742e46bSJacob Faibussowitsch namespace mat 214742e46bSJacob Faibussowitsch { 224742e46bSJacob Faibussowitsch 234742e46bSJacob Faibussowitsch namespace cupm 244742e46bSJacob Faibussowitsch { 254742e46bSJacob Faibussowitsch 264742e46bSJacob Faibussowitsch namespace impl 274742e46bSJacob Faibussowitsch { 284742e46bSJacob Faibussowitsch 294742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3085f25e71SJed Brown class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL MatDense_Seq_CUPM : MatDense_CUPM<T, MatDense_Seq_CUPM<T>> { 314742e46bSJacob Faibussowitsch public: 324742e46bSJacob Faibussowitsch MATDENSECUPM_HEADER(T, MatDense_Seq_CUPM<T>); 334742e46bSJacob Faibussowitsch 344742e46bSJacob Faibussowitsch private: 354742e46bSJacob Faibussowitsch struct Mat_SeqDenseCUPM { 364742e46bSJacob Faibussowitsch PetscScalar *d_v; // pointer to the matrix on the GPU 374742e46bSJacob Faibussowitsch PetscScalar *unplacedarray; // if one called MatCUPMDensePlaceArray(), this is where it stashed the original 384742e46bSJacob Faibussowitsch bool d_user_alloc; 394742e46bSJacob Faibussowitsch bool d_unplaced_user_alloc; 404742e46bSJacob Faibussowitsch // factorization support 414742e46bSJacob Faibussowitsch cupmBlasInt_t *d_fact_ipiv; // device pivots 424742e46bSJacob Faibussowitsch cupmScalar_t *d_fact_tau; // device QR tau vector 434742e46bSJacob Faibussowitsch cupmBlasInt_t *d_fact_info; // device info 444742e46bSJacob Faibussowitsch cupmScalar_t *d_fact_work; // device workspace 454742e46bSJacob Faibussowitsch cupmBlasInt_t d_fact_lwork; // size of device workspace 464742e46bSJacob Faibussowitsch // workspace 474742e46bSJacob Faibussowitsch Vec workvec; 484742e46bSJacob Faibussowitsch }; 494742e46bSJacob Faibussowitsch 504742e46bSJacob Faibussowitsch static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept; 514742e46bSJacob Faibussowitsch 524742e46bSJacob Faibussowitsch static PetscErrorCode HostToDevice_(Mat, PetscDeviceContext) noexcept; 534742e46bSJacob Faibussowitsch static PetscErrorCode DeviceToHost_(Mat, PetscDeviceContext) noexcept; 544742e46bSJacob Faibussowitsch 554742e46bSJacob Faibussowitsch static PetscErrorCode CheckCUPMSolverInfo_(const cupmBlasInt_t *, cupmStream_t) noexcept; 564742e46bSJacob Faibussowitsch 574742e46bSJacob Faibussowitsch template <typename Derived> 584742e46bSJacob Faibussowitsch struct SolveCommon; 594742e46bSJacob Faibussowitsch struct SolveQR; 604742e46bSJacob Faibussowitsch struct SolveCholesky; 614742e46bSJacob Faibussowitsch struct SolveLU; 624742e46bSJacob Faibussowitsch 634742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 644742e46bSJacob Faibussowitsch static PetscErrorCode MatSolve_Factored_Dispatch_(Mat, Vec, Vec) noexcept; 654742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 664742e46bSJacob Faibussowitsch static PetscErrorCode MatMatSolve_Factored_Dispatch_(Mat, Mat, Mat) noexcept; 6795571869SBlanca Mellado Pinto template <bool transpose, bool hermitian> 680be0d8bdSHansol Suh static PetscErrorCode MatMultAddColumnRange_Dispatch_(Mat, Vec, Vec, Vec, PetscInt, PetscInt) noexcept; 690be0d8bdSHansol Suh template <bool transpose, bool hermitian> 700be0d8bdSHansol Suh static PetscErrorCode MatMultColumnRange_Dispatch_(Mat, Vec, Vec, PetscInt, PetscInt) noexcept; 710be0d8bdSHansol Suh template <bool transpose, bool hermitian> 724742e46bSJacob Faibussowitsch static PetscErrorCode MatMultAdd_Dispatch_(Mat, Vec, Vec, Vec) noexcept; 734742e46bSJacob Faibussowitsch 744742e46bSJacob Faibussowitsch template <bool to_host> 754742e46bSJacob Faibussowitsch static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept; 764742e46bSJacob Faibussowitsch 774742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr MatType MATIMPLCUPM_() noexcept; 784742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr Mat_SeqDense *MatIMPLCast_(Mat) noexcept; 794742e46bSJacob Faibussowitsch 804742e46bSJacob Faibussowitsch public: 814742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr Mat_SeqDenseCUPM *MatCUPMCast(Mat) noexcept; 824742e46bSJacob Faibussowitsch 834742e46bSJacob Faibussowitsch // define these by hand since they don't fit the above mold 844742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatConvert_seqdensecupm_seqdense_C() noexcept; 854742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_seqaij_seqdensecupm_C() noexcept; 864742e46bSJacob Faibussowitsch 874742e46bSJacob Faibussowitsch static PetscErrorCode Create(Mat) noexcept; 884742e46bSJacob Faibussowitsch static PetscErrorCode Destroy(Mat) noexcept; 894742e46bSJacob Faibussowitsch static PetscErrorCode SetUp(Mat) noexcept; 904742e46bSJacob Faibussowitsch static PetscErrorCode Reset(Mat) noexcept; 914742e46bSJacob Faibussowitsch 924742e46bSJacob Faibussowitsch static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept; 934742e46bSJacob Faibussowitsch static PetscErrorCode Convert_SeqDense_SeqDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept; 944742e46bSJacob Faibussowitsch static PetscErrorCode Convert_SeqDenseCUPM_SeqDense(Mat, MatType, MatReuse, Mat *) noexcept; 954742e46bSJacob Faibussowitsch 964742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 974742e46bSJacob Faibussowitsch static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext) noexcept; 984742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 994742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext) noexcept; 1004742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 1014742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayAndMemType(Mat, PetscScalar **, PetscMemType *, PetscDeviceContext) noexcept; 1024742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 1034742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayAndMemType(Mat, PetscScalar **, PetscDeviceContext) noexcept; 1044742e46bSJacob Faibussowitsch 1054742e46bSJacob Faibussowitsch private: 1064742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 1074742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept 1084742e46bSJacob Faibussowitsch { 1094742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1104742e46bSJacob Faibussowitsch 1114742e46bSJacob Faibussowitsch PetscFunctionBegin; 1124742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1134742e46bSJacob Faibussowitsch PetscCall(GetArray<mtype, mode>(m, p, dctx)); 1144742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1154742e46bSJacob Faibussowitsch } 1164742e46bSJacob Faibussowitsch 1174742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 1184742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept 1194742e46bSJacob Faibussowitsch { 1204742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1214742e46bSJacob Faibussowitsch 1224742e46bSJacob Faibussowitsch PetscFunctionBegin; 1234742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1244742e46bSJacob Faibussowitsch PetscCall(RestoreArray<mtype, mode>(m, p, dctx)); 1254742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1264742e46bSJacob Faibussowitsch } 1274742e46bSJacob Faibussowitsch 1284742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode mode> 1294742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayAndMemTypeC_(Mat m, PetscScalar **p, PetscMemType *tp) noexcept 1304742e46bSJacob Faibussowitsch { 1314742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1324742e46bSJacob Faibussowitsch 1334742e46bSJacob Faibussowitsch PetscFunctionBegin; 1344742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1354742e46bSJacob Faibussowitsch PetscCall(GetArrayAndMemType<mode>(m, p, tp, dctx)); 1364742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1374742e46bSJacob Faibussowitsch } 1384742e46bSJacob Faibussowitsch 1394742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode mode> 1404742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayAndMemTypeC_(Mat m, PetscScalar **p) noexcept 1414742e46bSJacob Faibussowitsch { 1424742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 1434742e46bSJacob Faibussowitsch 1444742e46bSJacob Faibussowitsch PetscFunctionBegin; 1454742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 1464742e46bSJacob Faibussowitsch PetscCall(RestoreArrayAndMemType<mode>(m, p, dctx)); 1474742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1484742e46bSJacob Faibussowitsch } 1494742e46bSJacob Faibussowitsch 1504742e46bSJacob Faibussowitsch public: 1514742e46bSJacob Faibussowitsch static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept; 1524742e46bSJacob Faibussowitsch static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept; 1534742e46bSJacob Faibussowitsch static PetscErrorCode ResetArray(Mat) noexcept; 1544742e46bSJacob Faibussowitsch 1554742e46bSJacob Faibussowitsch template <bool transpose_A, bool transpose_B> 1564742e46bSJacob Faibussowitsch static PetscErrorCode MatMatMult_Numeric_Dispatch(Mat, Mat, Mat) noexcept; 1574742e46bSJacob Faibussowitsch static PetscErrorCode Copy(Mat, Mat, MatStructure) noexcept; 1584742e46bSJacob Faibussowitsch static PetscErrorCode ZeroEntries(Mat) noexcept; 1593853def2SToby Isaac static PetscErrorCode Conjugate(Mat) noexcept; 1604742e46bSJacob Faibussowitsch static PetscErrorCode Scale(Mat, PetscScalar) noexcept; 1614742e46bSJacob Faibussowitsch static PetscErrorCode AXPY(Mat, PetscScalar, Mat, MatStructure) noexcept; 1624742e46bSJacob Faibussowitsch static PetscErrorCode Duplicate(Mat, MatDuplicateOption, Mat *) noexcept; 1634742e46bSJacob Faibussowitsch static PetscErrorCode SetRandom(Mat, PetscRandom) noexcept; 1644742e46bSJacob Faibussowitsch 1654742e46bSJacob Faibussowitsch static PetscErrorCode GetColumnVector(Mat, Vec, PetscInt) noexcept; 1664742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 1674742e46bSJacob Faibussowitsch static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept; 1684742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 1694742e46bSJacob Faibussowitsch static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept; 1704742e46bSJacob Faibussowitsch 1714742e46bSJacob Faibussowitsch static PetscErrorCode GetFactor(Mat, MatFactorType, Mat *) noexcept; 1724742e46bSJacob Faibussowitsch static PetscErrorCode InvertFactors(Mat) noexcept; 1734742e46bSJacob Faibussowitsch 1744742e46bSJacob Faibussowitsch static PetscErrorCode GetSubMatrix(Mat, PetscInt, PetscInt, PetscInt, PetscInt, Mat *) noexcept; 1754742e46bSJacob Faibussowitsch static PetscErrorCode RestoreSubMatrix(Mat, Mat *) noexcept; 1764742e46bSJacob Faibussowitsch }; 1774742e46bSJacob Faibussowitsch 1784742e46bSJacob Faibussowitsch } // namespace impl 1794742e46bSJacob Faibussowitsch 1804742e46bSJacob Faibussowitsch namespace 1814742e46bSJacob Faibussowitsch { 1824742e46bSJacob Faibussowitsch 1834742e46bSJacob Faibussowitsch // Declare this here so that the functions below can make use of it 1844742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1854742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateSeqDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept 1864742e46bSJacob Faibussowitsch { 1874742e46bSJacob Faibussowitsch PetscFunctionBegin; 1884742e46bSJacob Faibussowitsch PetscCall(impl::MatDense_Seq_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, m, n, data, A, dctx, preallocate)); 1894742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1904742e46bSJacob Faibussowitsch } 1914742e46bSJacob Faibussowitsch 1924742e46bSJacob Faibussowitsch } // anonymous namespace 1934742e46bSJacob Faibussowitsch 1944742e46bSJacob Faibussowitsch namespace impl 1954742e46bSJacob Faibussowitsch { 1964742e46bSJacob Faibussowitsch 1974742e46bSJacob Faibussowitsch // ========================================================================================== 1984742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Utility 1994742e46bSJacob Faibussowitsch // ========================================================================================== 2004742e46bSJacob Faibussowitsch 2014742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2024742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetPreallocation_(Mat m, PetscDeviceContext dctx, PetscScalar *user_device_array) noexcept 2034742e46bSJacob Faibussowitsch { 2044742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(m); 2054742e46bSJacob Faibussowitsch const auto nrows = m->rmap->n; 2064742e46bSJacob Faibussowitsch const auto ncols = m->cmap->n; 2074742e46bSJacob Faibussowitsch auto &lda = MatIMPLCast(m)->lda; 2084742e46bSJacob Faibussowitsch cupmStream_t stream; 2094742e46bSJacob Faibussowitsch 2104742e46bSJacob Faibussowitsch PetscFunctionBegin; 2114742e46bSJacob Faibussowitsch PetscCheckTypeName(m, MATSEQDENSECUPM()); 2124742e46bSJacob Faibussowitsch PetscValidDeviceContext(dctx, 2); 2134742e46bSJacob Faibussowitsch PetscCall(checkCupmBlasIntCast(nrows)); 2144742e46bSJacob Faibussowitsch PetscCall(checkCupmBlasIntCast(ncols)); 2154742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 2164742e46bSJacob Faibussowitsch if (lda <= 0) lda = nrows; 2174742e46bSJacob Faibussowitsch if (!mcu->d_user_alloc) PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream)); 2184742e46bSJacob Faibussowitsch if (user_device_array) { 2194742e46bSJacob Faibussowitsch mcu->d_user_alloc = PETSC_TRUE; 2204742e46bSJacob Faibussowitsch mcu->d_v = user_device_array; 2214742e46bSJacob Faibussowitsch } else { 22216130775SJose E. Roman std::size_t size; 2234742e46bSJacob Faibussowitsch 2244742e46bSJacob Faibussowitsch mcu->d_user_alloc = PETSC_FALSE; 22516130775SJose E. Roman size = lda * ncols; 2264742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_v, size, stream)); 2274742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemsetAsync(mcu->d_v, 0, size, stream)); 2284742e46bSJacob Faibussowitsch } 2294742e46bSJacob Faibussowitsch m->offloadmask = PETSC_OFFLOAD_GPU; 2304742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2314742e46bSJacob Faibussowitsch } 2324742e46bSJacob Faibussowitsch 2334742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2344742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::HostToDevice_(Mat m, PetscDeviceContext dctx) noexcept 2354742e46bSJacob Faibussowitsch { 2364742e46bSJacob Faibussowitsch const auto nrows = m->rmap->n; 2374742e46bSJacob Faibussowitsch const auto ncols = m->cmap->n; 2384742e46bSJacob Faibussowitsch const auto copy = m->offloadmask == PETSC_OFFLOAD_CPU || m->offloadmask == PETSC_OFFLOAD_UNALLOCATED; 2394742e46bSJacob Faibussowitsch 2404742e46bSJacob Faibussowitsch PetscFunctionBegin; 2414742e46bSJacob Faibussowitsch PetscCheckTypeName(m, MATSEQDENSECUPM()); 2424742e46bSJacob Faibussowitsch if (m->boundtocpu) PetscFunctionReturn(PETSC_SUCCESS); 2434742e46bSJacob Faibussowitsch PetscCall(PetscInfo(m, "%s matrix %" PetscInt_FMT " x %" PetscInt_FMT "\n", copy ? "Copy" : "Reusing", nrows, ncols)); 2444742e46bSJacob Faibussowitsch if (copy) { 2454742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(m); 2464742e46bSJacob Faibussowitsch cupmStream_t stream; 2474742e46bSJacob Faibussowitsch 2484742e46bSJacob Faibussowitsch // Allocate GPU memory if not present 2493d9668e3SJacob Faibussowitsch if (!mcu->d_v) PetscCall(SetPreallocation(m, dctx, nullptr)); 2504742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 2514742e46bSJacob Faibussowitsch PetscCall(PetscLogEventBegin(MAT_DenseCopyToGPU, m, 0, 0, 0)); 2524742e46bSJacob Faibussowitsch { 2534742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(m); 2544742e46bSJacob Faibussowitsch const auto lda = mimpl->lda; 2554742e46bSJacob Faibussowitsch const auto src = mimpl->v; 2564742e46bSJacob Faibussowitsch const auto dest = mcu->d_v; 2574742e46bSJacob Faibussowitsch 2584742e46bSJacob Faibussowitsch if (lda > nrows) { 2594742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(dest, lda, src, lda, nrows, ncols, cupmMemcpyHostToDevice, stream)); 2604742e46bSJacob Faibussowitsch } else { 2614742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(dest, src, lda * ncols, cupmMemcpyHostToDevice, stream)); 2624742e46bSJacob Faibussowitsch } 2634742e46bSJacob Faibussowitsch } 2644742e46bSJacob Faibussowitsch PetscCall(PetscLogEventEnd(MAT_DenseCopyToGPU, m, 0, 0, 0)); 2654742e46bSJacob Faibussowitsch // order important, ensure that offloadmask is PETSC_OFFLOAD_BOTH 2664742e46bSJacob Faibussowitsch m->offloadmask = PETSC_OFFLOAD_BOTH; 2674742e46bSJacob Faibussowitsch } 2684742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2694742e46bSJacob Faibussowitsch } 2704742e46bSJacob Faibussowitsch 2714742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2724742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::DeviceToHost_(Mat m, PetscDeviceContext dctx) noexcept 2734742e46bSJacob Faibussowitsch { 2744742e46bSJacob Faibussowitsch const auto nrows = m->rmap->n; 2754742e46bSJacob Faibussowitsch const auto ncols = m->cmap->n; 2764742e46bSJacob Faibussowitsch const auto copy = m->offloadmask == PETSC_OFFLOAD_GPU; 2774742e46bSJacob Faibussowitsch 2784742e46bSJacob Faibussowitsch PetscFunctionBegin; 2794742e46bSJacob Faibussowitsch PetscCheckTypeName(m, MATSEQDENSECUPM()); 2804742e46bSJacob Faibussowitsch PetscCall(PetscInfo(m, "%s matrix %" PetscInt_FMT " x %" PetscInt_FMT "\n", copy ? "Copy" : "Reusing", nrows, ncols)); 2814742e46bSJacob Faibussowitsch if (copy) { 2824742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(m); 2834742e46bSJacob Faibussowitsch cupmStream_t stream; 2844742e46bSJacob Faibussowitsch 2854742e46bSJacob Faibussowitsch // MatCreateSeqDenseCUPM may not allocate CPU memory. Allocate if needed 2864742e46bSJacob Faibussowitsch if (!mimpl->v) PetscCall(MatSeqDenseSetPreallocation(m, nullptr)); 2874742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 2884742e46bSJacob Faibussowitsch PetscCall(PetscLogEventBegin(MAT_DenseCopyFromGPU, m, 0, 0, 0)); 2894742e46bSJacob Faibussowitsch { 2904742e46bSJacob Faibussowitsch const auto lda = mimpl->lda; 2914742e46bSJacob Faibussowitsch const auto dest = mimpl->v; 2924742e46bSJacob Faibussowitsch const auto src = MatCUPMCast(m)->d_v; 2934742e46bSJacob Faibussowitsch 2944742e46bSJacob Faibussowitsch if (lda > nrows) { 2954742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(dest, lda, src, lda, nrows, ncols, cupmMemcpyDeviceToHost, stream)); 2964742e46bSJacob Faibussowitsch } else { 2974742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(dest, src, lda * ncols, cupmMemcpyDeviceToHost, stream)); 2984742e46bSJacob Faibussowitsch } 2994742e46bSJacob Faibussowitsch } 3004742e46bSJacob Faibussowitsch PetscCall(PetscLogEventEnd(MAT_DenseCopyFromGPU, m, 0, 0, 0)); 3014742e46bSJacob Faibussowitsch // order is important, MatSeqDenseSetPreallocation() might set offloadmask 3024742e46bSJacob Faibussowitsch m->offloadmask = PETSC_OFFLOAD_BOTH; 3034742e46bSJacob Faibussowitsch } 3044742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3054742e46bSJacob Faibussowitsch } 3064742e46bSJacob Faibussowitsch 3074742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3084742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::CheckCUPMSolverInfo_(const cupmBlasInt_t *fact_info, cupmStream_t stream) noexcept 3094742e46bSJacob Faibussowitsch { 3104742e46bSJacob Faibussowitsch PetscFunctionBegin; 3114742e46bSJacob Faibussowitsch if (PetscDefined(USE_DEBUG)) { 3124742e46bSJacob Faibussowitsch cupmBlasInt_t info = 0; 3134742e46bSJacob Faibussowitsch 3144742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(&info, fact_info, 1, cupmMemcpyDeviceToHost, stream)); 3154742e46bSJacob Faibussowitsch if (stream) PetscCallCUPM(cupmStreamSynchronize(stream)); 3164742e46bSJacob Faibussowitsch static_assert(std::is_same<decltype(info), int>::value, ""); 3174742e46bSJacob Faibussowitsch PetscCheck(info <= 0, PETSC_COMM_SELF, PETSC_ERR_MAT_CH_ZRPVT, "Bad factorization: zero pivot in row %d", info - 1); 3184742e46bSJacob Faibussowitsch PetscCheck(info >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Wrong argument to cupmSolver %d", -info); 3194742e46bSJacob Faibussowitsch } 3204742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3214742e46bSJacob Faibussowitsch } 3224742e46bSJacob Faibussowitsch 3234742e46bSJacob Faibussowitsch // ========================================================================================== 3244742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Solver Dispatch 3254742e46bSJacob Faibussowitsch // ========================================================================================== 3264742e46bSJacob Faibussowitsch 3274742e46bSJacob Faibussowitsch // specific solvers called through the dispatch_() family of functions 3284742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3294742e46bSJacob Faibussowitsch template <typename Derived> 3304742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveCommon { 3314742e46bSJacob Faibussowitsch using derived_type = Derived; 3324742e46bSJacob Faibussowitsch 3334742e46bSJacob Faibussowitsch template <typename F> 3344742e46bSJacob Faibussowitsch static PetscErrorCode ResizeFactLwork(Mat_SeqDenseCUPM *mcu, cupmStream_t stream, F &&cupmSolverComputeFactLwork) noexcept 3354742e46bSJacob Faibussowitsch { 3364742e46bSJacob Faibussowitsch cupmBlasInt_t lwork; 3374742e46bSJacob Faibussowitsch 3384742e46bSJacob Faibussowitsch PetscFunctionBegin; 3394742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverComputeFactLwork(&lwork)); 3404742e46bSJacob Faibussowitsch if (lwork > mcu->d_fact_lwork) { 3414742e46bSJacob Faibussowitsch mcu->d_fact_lwork = lwork; 3424742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream)); 3434742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, lwork, stream)); 3444742e46bSJacob Faibussowitsch } 3454742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3464742e46bSJacob Faibussowitsch } 3474742e46bSJacob Faibussowitsch 3484742e46bSJacob Faibussowitsch static PetscErrorCode FactorPrepare(Mat A, cupmStream_t stream) noexcept 3494742e46bSJacob Faibussowitsch { 3504742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 3514742e46bSJacob Faibussowitsch 3524742e46bSJacob Faibussowitsch PetscFunctionBegin; 3534742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s factor %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", derived_type::NAME(), A->rmap->n, A->cmap->n)); 3544742e46bSJacob Faibussowitsch A->factortype = derived_type::MATFACTORTYPE(); 3554742e46bSJacob Faibussowitsch A->ops->solve = MatSolve_Factored_Dispatch_<derived_type, false>; 3564742e46bSJacob Faibussowitsch A->ops->solvetranspose = MatSolve_Factored_Dispatch_<derived_type, true>; 3574742e46bSJacob Faibussowitsch A->ops->matsolve = MatMatSolve_Factored_Dispatch_<derived_type, false>; 3584742e46bSJacob Faibussowitsch A->ops->matsolvetranspose = MatMatSolve_Factored_Dispatch_<derived_type, true>; 3594742e46bSJacob Faibussowitsch 3604742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(MATSOLVERCUPM(), &A->solvertype)); 3614742e46bSJacob Faibussowitsch if (!mcu->d_fact_info) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_info, 1, stream)); 3624742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3634742e46bSJacob Faibussowitsch } 3644742e46bSJacob Faibussowitsch }; 3654742e46bSJacob Faibussowitsch 3664742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3674742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveLU : SolveCommon<SolveLU> { 3684742e46bSJacob Faibussowitsch using base_type = SolveCommon<SolveLU>; 3694742e46bSJacob Faibussowitsch 3704742e46bSJacob Faibussowitsch static constexpr const char *NAME() noexcept { return "LU"; } 3714742e46bSJacob Faibussowitsch static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_LU; } 3724742e46bSJacob Faibussowitsch 3734742e46bSJacob Faibussowitsch static PetscErrorCode Factor(Mat A, IS, IS, const MatFactorInfo *) noexcept 3744742e46bSJacob Faibussowitsch { 3754742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 3764742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 3774742e46bSJacob Faibussowitsch cupmStream_t stream; 3784742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 3794742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 3804742e46bSJacob Faibussowitsch 3814742e46bSJacob Faibussowitsch PetscFunctionBegin; 3824742e46bSJacob Faibussowitsch if (!m || !n) PetscFunctionReturn(PETSC_SUCCESS); 3834742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 3844742e46bSJacob Faibussowitsch PetscCall(base_type::FactorPrepare(A, stream)); 3854742e46bSJacob Faibussowitsch { 3864742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 3874742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 3884742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 3894742e46bSJacob Faibussowitsch 3904742e46bSJacob Faibussowitsch // clang-format off 3914742e46bSJacob Faibussowitsch PetscCall( 3924742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 3934742e46bSJacob Faibussowitsch mcu, stream, 3944742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *fact_lwork) 3954742e46bSJacob Faibussowitsch { 3964742e46bSJacob Faibussowitsch return cupmSolverXgetrf_bufferSize(handle, m, n, da.cupmdata(), lda, fact_lwork); 3974742e46bSJacob Faibussowitsch } 3984742e46bSJacob Faibussowitsch ) 3994742e46bSJacob Faibussowitsch ); 4004742e46bSJacob Faibussowitsch // clang-format on 4014742e46bSJacob Faibussowitsch if (!mcu->d_fact_ipiv) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_ipiv, n, stream)); 4024742e46bSJacob Faibussowitsch 4034742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 4044742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXgetrf(handle, m, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_ipiv, mcu->d_fact_info)); 4054742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 4064742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 4074742e46bSJacob Faibussowitsch } 4084742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(2.0 * n * n * m / 3.0)); 4094742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 4104742e46bSJacob Faibussowitsch } 4114742e46bSJacob Faibussowitsch 4124742e46bSJacob Faibussowitsch template <bool transpose> 4134742e46bSJacob Faibussowitsch static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept 4144742e46bSJacob Faibussowitsch { 4154742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 4164742e46bSJacob Faibussowitsch const auto fact_info = mcu->d_fact_info; 4174742e46bSJacob Faibussowitsch const auto fact_ipiv = mcu->d_fact_ipiv; 4184742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 4194742e46bSJacob Faibussowitsch 4204742e46bSJacob Faibussowitsch PetscFunctionBegin; 4214742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &handle)); 4224742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k)); 4234742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 4244742e46bSJacob Faibussowitsch { 4254742e46bSJacob Faibussowitsch constexpr auto op = transpose ? CUPMSOLVER_OP_T : CUPMSOLVER_OP_N; 4264742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 4274742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 4284742e46bSJacob Faibussowitsch 4294742e46bSJacob Faibussowitsch // clang-format off 4304742e46bSJacob Faibussowitsch PetscCall( 4314742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 4324742e46bSJacob Faibussowitsch mcu, stream, 4334742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *lwork) 4344742e46bSJacob Faibussowitsch { 4354742e46bSJacob Faibussowitsch return cupmSolverXgetrs_bufferSize( 4364742e46bSJacob Faibussowitsch handle, op, m, nrhs, da.cupmdata(), lda, fact_ipiv, x, ldx, lwork 4374742e46bSJacob Faibussowitsch ); 4384742e46bSJacob Faibussowitsch } 4394742e46bSJacob Faibussowitsch ) 4404742e46bSJacob Faibussowitsch ); 4414742e46bSJacob Faibussowitsch // clang-format on 4424742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXgetrs(handle, op, m, nrhs, da.cupmdata(), lda, fact_ipiv, x, ldx, mcu->d_fact_work, mcu->d_fact_lwork, fact_info)); 4434742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 4444742e46bSJacob Faibussowitsch } 4454742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 4464742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(nrhs * (2.0 * m * m - m))); 4474742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 4484742e46bSJacob Faibussowitsch } 4494742e46bSJacob Faibussowitsch }; 4504742e46bSJacob Faibussowitsch 4514742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 4524742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveCholesky : SolveCommon<SolveCholesky> { 4534742e46bSJacob Faibussowitsch using base_type = SolveCommon<SolveCholesky>; 4544742e46bSJacob Faibussowitsch 4554742e46bSJacob Faibussowitsch static constexpr const char *NAME() noexcept { return "Cholesky"; } 4564742e46bSJacob Faibussowitsch static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_CHOLESKY; } 4574742e46bSJacob Faibussowitsch 4584742e46bSJacob Faibussowitsch static PetscErrorCode Factor(Mat A, IS, const MatFactorInfo *) noexcept 4594742e46bSJacob Faibussowitsch { 4604742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->rmap->n); 4614742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 4624742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 4634742e46bSJacob Faibussowitsch cupmStream_t stream; 4644742e46bSJacob Faibussowitsch 4654742e46bSJacob Faibussowitsch PetscFunctionBegin; 4664742e46bSJacob Faibussowitsch if (!n || !A->cmap->n) PetscFunctionReturn(PETSC_SUCCESS); 4674742e46bSJacob Faibussowitsch PetscCheck(A->spd == PETSC_BOOL3_TRUE, PETSC_COMM_SELF, PETSC_ERR_SUP, "%ssytrs unavailable. Use MAT_FACTOR_LU", cupmSolverName()); 4684742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 4694742e46bSJacob Faibussowitsch PetscCall(base_type::FactorPrepare(A, stream)); 4704742e46bSJacob Faibussowitsch { 4714742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 4724742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 4734742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 4744742e46bSJacob Faibussowitsch 4754742e46bSJacob Faibussowitsch // clang-format off 4764742e46bSJacob Faibussowitsch PetscCall( 4774742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 4784742e46bSJacob Faibussowitsch mcu, stream, 4794742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *fact_lwork) 4804742e46bSJacob Faibussowitsch { 4814742e46bSJacob Faibussowitsch return cupmSolverXpotrf_bufferSize( 4824742e46bSJacob Faibussowitsch handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, fact_lwork 4834742e46bSJacob Faibussowitsch ); 4844742e46bSJacob Faibussowitsch } 4854742e46bSJacob Faibussowitsch ) 4864742e46bSJacob Faibussowitsch ); 4874742e46bSJacob Faibussowitsch // clang-format on 4884742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 4894742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotrf(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 4904742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 4914742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 4924742e46bSJacob Faibussowitsch } 4934742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(1.0 * n * n * n / 3.0)); 4944742e46bSJacob Faibussowitsch 4954742e46bSJacob Faibussowitsch #if 0 4964742e46bSJacob Faibussowitsch // At the time of writing this interface (cuda 10.0), cusolverDn does not implement *sytrs 4974742e46bSJacob Faibussowitsch // and *hetr* routines. The code below should work, and it can be activated when *sytrs 4984742e46bSJacob Faibussowitsch // routines will be available 4994742e46bSJacob Faibussowitsch if (!mcu->d_fact_ipiv) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_ipiv, n, stream)); 5004742e46bSJacob Faibussowitsch if (!mcu->d_fact_lwork) { 5014742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverDnXsytrf_bufferSize(handle, n, da.cupmdata(), lda, &mcu->d_fact_lwork)); 5024742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, mcu->d_fact_lwork, stream)); 5034742e46bSJacob Faibussowitsch } 5044742e46bSJacob Faibussowitsch if (mcu->d_fact_info) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_info, 1, stream)); 5054742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 5064742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXsytrf(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da, lda, mcu->d_fact_ipiv, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 5074742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 5084742e46bSJacob Faibussowitsch #endif 5094742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5104742e46bSJacob Faibussowitsch } 5114742e46bSJacob Faibussowitsch 5124742e46bSJacob Faibussowitsch template <bool transpose> 5134742e46bSJacob Faibussowitsch static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept 5144742e46bSJacob Faibussowitsch { 5154742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 5164742e46bSJacob Faibussowitsch const auto fact_info = mcu->d_fact_info; 5174742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 5184742e46bSJacob Faibussowitsch 5194742e46bSJacob Faibussowitsch PetscFunctionBegin; 5204742e46bSJacob Faibussowitsch PetscAssert(!mcu->d_fact_ipiv, PETSC_COMM_SELF, PETSC_ERR_LIB, "%ssytrs not implemented", cupmSolverName()); 5214742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &handle)); 5224742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k)); 5234742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 5244742e46bSJacob Faibussowitsch { 5254742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 5264742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 5274742e46bSJacob Faibussowitsch 5284742e46bSJacob Faibussowitsch // clang-format off 5294742e46bSJacob Faibussowitsch PetscCall( 5304742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 5314742e46bSJacob Faibussowitsch mcu, stream, 5324742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *lwork) 5334742e46bSJacob Faibussowitsch { 5344742e46bSJacob Faibussowitsch return cupmSolverXpotrs_bufferSize( 5354742e46bSJacob Faibussowitsch handle, CUPMSOLVER_FILL_MODE_LOWER, m, nrhs, da.cupmdata(), lda, x, ldx, lwork 5364742e46bSJacob Faibussowitsch ); 5374742e46bSJacob Faibussowitsch } 5384742e46bSJacob Faibussowitsch ) 5394742e46bSJacob Faibussowitsch ); 5404742e46bSJacob Faibussowitsch // clang-format on 5414742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotrs(handle, CUPMSOLVER_FILL_MODE_LOWER, m, nrhs, da.cupmdata(), lda, x, ldx, mcu->d_fact_work, mcu->d_fact_lwork, fact_info)); 5424742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 5434742e46bSJacob Faibussowitsch } 5444742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 5454742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(nrhs * (2.0 * m * m - m))); 5464742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5474742e46bSJacob Faibussowitsch } 5484742e46bSJacob Faibussowitsch }; 5494742e46bSJacob Faibussowitsch 5504742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 5514742e46bSJacob Faibussowitsch struct MatDense_Seq_CUPM<T>::SolveQR : SolveCommon<SolveQR> { 5524742e46bSJacob Faibussowitsch using base_type = SolveCommon<SolveQR>; 5534742e46bSJacob Faibussowitsch 5544742e46bSJacob Faibussowitsch static constexpr const char *NAME() noexcept { return "QR"; } 5554742e46bSJacob Faibussowitsch static constexpr MatFactorType MATFACTORTYPE() noexcept { return MAT_FACTOR_QR; } 5564742e46bSJacob Faibussowitsch 5574742e46bSJacob Faibussowitsch static PetscErrorCode Factor(Mat A, IS, const MatFactorInfo *) noexcept 5584742e46bSJacob Faibussowitsch { 5594742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 5604742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 5614742e46bSJacob Faibussowitsch const auto min = std::min(m, n); 5624742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 5634742e46bSJacob Faibussowitsch cupmStream_t stream; 5644742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 5654742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 5664742e46bSJacob Faibussowitsch 5674742e46bSJacob Faibussowitsch PetscFunctionBegin; 5684742e46bSJacob Faibussowitsch if (!m || !n) PetscFunctionReturn(PETSC_SUCCESS); 5694742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 5704742e46bSJacob Faibussowitsch PetscCall(base_type::FactorPrepare(A, stream)); 5714742e46bSJacob Faibussowitsch mimpl->rank = min; 5724742e46bSJacob Faibussowitsch { 5734742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 5744742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 5754742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda); 5764742e46bSJacob Faibussowitsch 5774742e46bSJacob Faibussowitsch if (!mcu->workvec) PetscCall(vec::cupm::VecCreateSeqCUPMAsync<T>(PetscObjectComm(PetscObjectCast(A)), m, &mcu->workvec)); 5784742e46bSJacob Faibussowitsch if (!mcu->d_fact_tau) PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_tau, min, stream)); 5794742e46bSJacob Faibussowitsch // clang-format off 5804742e46bSJacob Faibussowitsch PetscCall( 5814742e46bSJacob Faibussowitsch base_type::ResizeFactLwork( 5824742e46bSJacob Faibussowitsch mcu, stream, 5834742e46bSJacob Faibussowitsch [&](cupmBlasInt_t *fact_lwork) 5844742e46bSJacob Faibussowitsch { 5854742e46bSJacob Faibussowitsch return cupmSolverXgeqrf_bufferSize(handle, m, n, da.cupmdata(), lda, fact_lwork); 5864742e46bSJacob Faibussowitsch } 5874742e46bSJacob Faibussowitsch ) 5884742e46bSJacob Faibussowitsch ); 5894742e46bSJacob Faibussowitsch // clang-format on 5904742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 5914742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXgeqrf(handle, m, n, da.cupmdata(), lda, mcu->d_fact_tau, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 5924742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 5934742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 5944742e46bSJacob Faibussowitsch } 5954742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(2.0 * min * min * (std::max(m, n) - min / 3.0))); 5964742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 5974742e46bSJacob Faibussowitsch } 5984742e46bSJacob Faibussowitsch 5994742e46bSJacob Faibussowitsch template <bool transpose> 6004742e46bSJacob Faibussowitsch static PetscErrorCode Solve(Mat A, cupmScalar_t *x, cupmBlasInt_t ldx, cupmBlasInt_t m, cupmBlasInt_t nrhs, cupmBlasInt_t k, PetscDeviceContext dctx, cupmStream_t stream) noexcept 6014742e46bSJacob Faibussowitsch { 6024742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 6034742e46bSJacob Faibussowitsch const auto rank = static_cast<cupmBlasInt_t>(mimpl->rank); 6044742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 6054742e46bSJacob Faibussowitsch const auto fact_info = mcu->d_fact_info; 6064742e46bSJacob Faibussowitsch const auto fact_tau = mcu->d_fact_tau; 6074742e46bSJacob Faibussowitsch const auto fact_work = mcu->d_fact_work; 6084742e46bSJacob Faibussowitsch const auto fact_lwork = mcu->d_fact_lwork; 6094742e46bSJacob Faibussowitsch cupmSolverHandle_t solver_handle; 6104742e46bSJacob Faibussowitsch cupmBlasHandle_t blas_handle; 6114742e46bSJacob Faibussowitsch 6124742e46bSJacob Faibussowitsch PetscFunctionBegin; 6134742e46bSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &blas_handle, &solver_handle)); 6144742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "%s solve %d x %d on backend\n", NAME(), m, k)); 6154742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 6164742e46bSJacob Faibussowitsch { 6174742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 6184742e46bSJacob Faibussowitsch const auto one = cupmScalarCast(1.0); 6194742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda); 6204742e46bSJacob Faibussowitsch 6214742e46bSJacob Faibussowitsch if (transpose) { 6224742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXtrsm(blas_handle, CUPMBLAS_SIDE_LEFT, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_T, CUPMBLAS_DIAG_NON_UNIT, rank, nrhs, &one, da.cupmdata(), lda, x, ldx)); 6234742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXormqr(solver_handle, CUPMSOLVER_SIDE_LEFT, CUPMSOLVER_OP_N, m, nrhs, rank, da.cupmdata(), lda, fact_tau, x, ldx, fact_work, fact_lwork, fact_info)); 6244742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 6254742e46bSJacob Faibussowitsch } else { 6264742e46bSJacob Faibussowitsch constexpr auto op = PetscDefined(USE_COMPLEX) ? CUPMSOLVER_OP_C : CUPMSOLVER_OP_T; 6274742e46bSJacob Faibussowitsch 6284742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXormqr(solver_handle, CUPMSOLVER_SIDE_LEFT, op, m, nrhs, rank, da.cupmdata(), lda, fact_tau, x, ldx, fact_work, fact_lwork, fact_info)); 6294742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(fact_info, stream)); 6304742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXtrsm(blas_handle, CUPMBLAS_SIDE_LEFT, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, rank, nrhs, &one, da.cupmdata(), lda, x, ldx)); 6314742e46bSJacob Faibussowitsch } 6324742e46bSJacob Faibussowitsch } 6334742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 6344742e46bSJacob Faibussowitsch PetscCall(PetscLogFlops(nrhs * (4.0 * m * rank - (rank * rank)))); 6354742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 6364742e46bSJacob Faibussowitsch } 6374742e46bSJacob Faibussowitsch }; 6384742e46bSJacob Faibussowitsch 6394742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 6404742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 6414742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatSolve_Factored_Dispatch_(Mat A, Vec x, Vec y) noexcept 6424742e46bSJacob Faibussowitsch { 6434742e46bSJacob Faibussowitsch using namespace vec::cupm; 6444742e46bSJacob Faibussowitsch const auto pobj_A = PetscObjectCast(A); 6454742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 6464742e46bSJacob Faibussowitsch const auto k = static_cast<cupmBlasInt_t>(A->cmap->n); 6474742e46bSJacob Faibussowitsch auto &workvec = MatCUPMCast(A)->workvec; 6484742e46bSJacob Faibussowitsch PetscScalar *y_array = nullptr; 6494742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 6504742e46bSJacob Faibussowitsch PetscBool xiscupm, yiscupm, aiscupm; 6514742e46bSJacob Faibussowitsch bool use_y_array_directly; 6524742e46bSJacob Faibussowitsch cupmStream_t stream; 6534742e46bSJacob Faibussowitsch 6544742e46bSJacob Faibussowitsch PetscFunctionBegin; 6554742e46bSJacob Faibussowitsch PetscCheck(A->factortype != MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix must be factored to solve"); 6564742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(x), VecSeq_CUPM::VECSEQCUPM(), &xiscupm)); 6574742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(y), VecSeq_CUPM::VECSEQCUPM(), &yiscupm)); 6584742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(pobj_A, MATSEQDENSECUPM(), &aiscupm)); 6594742e46bSJacob Faibussowitsch PetscAssert(aiscupm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Matrix A is somehow not CUPM?????????????????????????????"); 6604742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 6614742e46bSJacob Faibussowitsch use_y_array_directly = yiscupm && (k >= m); 6624742e46bSJacob Faibussowitsch { 6634742e46bSJacob Faibussowitsch const PetscScalar *x_array; 6644742e46bSJacob Faibussowitsch const auto xisdevice = xiscupm && PetscOffloadDevice(x->offloadmask); 6654742e46bSJacob Faibussowitsch const auto copy_mode = xisdevice ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 6664742e46bSJacob Faibussowitsch 6674742e46bSJacob Faibussowitsch if (!use_y_array_directly && !workvec) PetscCall(VecCreateSeqCUPMAsync<T>(PetscObjectComm(pobj_A), m, &workvec)); 6684742e46bSJacob Faibussowitsch // The logic here is to try to minimize the amount of memory copying: 6694742e46bSJacob Faibussowitsch // 6704742e46bSJacob Faibussowitsch // If we call VecCUPMGetArrayRead(X, &x) every time xiscupm and the data is not offloaded 6714742e46bSJacob Faibussowitsch // to the GPU yet, then the data is copied to the GPU. But we are only trying to get the 6724742e46bSJacob Faibussowitsch // data in order to copy it into the y array. So the array x will be wherever the data 6734742e46bSJacob Faibussowitsch // already is so that only one memcpy is performed 6744742e46bSJacob Faibussowitsch if (xisdevice) { 6754742e46bSJacob Faibussowitsch PetscCall(VecCUPMGetArrayReadAsync<T>(x, &x_array, dctx)); 6764742e46bSJacob Faibussowitsch } else { 6774742e46bSJacob Faibussowitsch PetscCall(VecGetArrayRead(x, &x_array)); 6784742e46bSJacob Faibussowitsch } 6794742e46bSJacob Faibussowitsch PetscCall(VecCUPMGetArrayWriteAsync<T>(use_y_array_directly ? y : workvec, &y_array, dctx)); 6804742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(y_array, x_array, m, copy_mode, stream)); 6814742e46bSJacob Faibussowitsch if (xisdevice) { 6824742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayReadAsync<T>(x, &x_array, dctx)); 6834742e46bSJacob Faibussowitsch } else { 6844742e46bSJacob Faibussowitsch PetscCall(VecRestoreArrayRead(x, &x_array)); 6854742e46bSJacob Faibussowitsch } 6864742e46bSJacob Faibussowitsch } 6874742e46bSJacob Faibussowitsch 6884742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 6894742e46bSJacob Faibussowitsch PetscCall(Solver{}.template Solve<transpose>(A, cupmScalarPtrCast(y_array), m, m, 1, k, dctx, stream)); 6904742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A)); 6914742e46bSJacob Faibussowitsch 6924742e46bSJacob Faibussowitsch if (use_y_array_directly) { 6934742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayWriteAsync<T>(y, &y_array, dctx)); 6944742e46bSJacob Faibussowitsch } else { 6954742e46bSJacob Faibussowitsch const auto copy_mode = yiscupm ? cupmMemcpyDeviceToDevice : cupmMemcpyDeviceToHost; 6964742e46bSJacob Faibussowitsch PetscScalar *yv; 6974742e46bSJacob Faibussowitsch 6984742e46bSJacob Faibussowitsch // The logic here is that the data is not yet in either y's GPU array or its CPU array. 6994742e46bSJacob Faibussowitsch // There is nothing in the interface to say where the user would like it to end up. So we 7004742e46bSJacob Faibussowitsch // choose the GPU, because it is the faster option 7014742e46bSJacob Faibussowitsch if (yiscupm) { 7024742e46bSJacob Faibussowitsch PetscCall(VecCUPMGetArrayWriteAsync<T>(y, &yv, dctx)); 7034742e46bSJacob Faibussowitsch } else { 7044742e46bSJacob Faibussowitsch PetscCall(VecGetArray(y, &yv)); 7054742e46bSJacob Faibussowitsch } 7064742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(yv, y_array, k, copy_mode, stream)); 7074742e46bSJacob Faibussowitsch if (yiscupm) { 7084742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayWriteAsync<T>(y, &yv, dctx)); 7094742e46bSJacob Faibussowitsch } else { 7104742e46bSJacob Faibussowitsch PetscCall(VecRestoreArray(y, &yv)); 7114742e46bSJacob Faibussowitsch } 7124742e46bSJacob Faibussowitsch PetscCall(VecCUPMRestoreArrayWriteAsync<T>(workvec, &y_array)); 7134742e46bSJacob Faibussowitsch } 7144742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 7154742e46bSJacob Faibussowitsch } 7164742e46bSJacob Faibussowitsch 7174742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 7184742e46bSJacob Faibussowitsch template <typename Solver, bool transpose> 7194742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMatSolve_Factored_Dispatch_(Mat A, Mat B, Mat X) noexcept 7204742e46bSJacob Faibussowitsch { 7214742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 7224742e46bSJacob Faibussowitsch const auto k = static_cast<cupmBlasInt_t>(A->cmap->n); 7234742e46bSJacob Faibussowitsch cupmBlasInt_t nrhs, ldb, ldx, ldy; 7244742e46bSJacob Faibussowitsch PetscScalar *y; 7254742e46bSJacob Faibussowitsch PetscBool biscupm, xiscupm, aiscupm; 7264742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 7274742e46bSJacob Faibussowitsch cupmStream_t stream; 7284742e46bSJacob Faibussowitsch 7294742e46bSJacob Faibussowitsch PetscFunctionBegin; 7304742e46bSJacob Faibussowitsch PetscCheck(A->factortype != MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix must be factored to solve"); 7314742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(B), MATSEQDENSECUPM(), &biscupm)); 7324742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(X), MATSEQDENSECUPM(), &xiscupm)); 7334742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), MATSEQDENSECUPM(), &aiscupm)); 7344742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 7354742e46bSJacob Faibussowitsch { 7364742e46bSJacob Faibussowitsch PetscInt n; 7374742e46bSJacob Faibussowitsch 7384742e46bSJacob Faibussowitsch PetscCall(MatGetSize(B, nullptr, &n)); 7394742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(n, &nrhs)); 7404742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(B, &n)); 7414742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(n, &ldb)); 7424742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(X, &n)); 7434742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(n, &ldx)); 7444742e46bSJacob Faibussowitsch } 7454742e46bSJacob Faibussowitsch { 7464742e46bSJacob Faibussowitsch // The logic here is to try to minimize the amount of memory copying: 7474742e46bSJacob Faibussowitsch // 7484742e46bSJacob Faibussowitsch // If we call MatDenseCUPMGetArrayRead(B, &b) every time biscupm and the data is not 7494742e46bSJacob Faibussowitsch // offloaded to the GPU yet, then the data is copied to the GPU. But we are only trying to 7504742e46bSJacob Faibussowitsch // get the data in order to copy it into the y array. So the array b will be wherever the 7514742e46bSJacob Faibussowitsch // data already is so that only one memcpy is performed 7524742e46bSJacob Faibussowitsch const auto bisdevice = biscupm && PetscOffloadDevice(B->offloadmask); 7534742e46bSJacob Faibussowitsch const auto copy_mode = bisdevice ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 7544742e46bSJacob Faibussowitsch const PetscScalar *b; 7554742e46bSJacob Faibussowitsch 7564742e46bSJacob Faibussowitsch if (bisdevice) { 7574742e46bSJacob Faibussowitsch b = DeviceArrayRead(dctx, B); 7584742e46bSJacob Faibussowitsch } else if (biscupm) { 7594742e46bSJacob Faibussowitsch b = HostArrayRead(dctx, B); 7604742e46bSJacob Faibussowitsch } else { 7614742e46bSJacob Faibussowitsch PetscCall(MatDenseGetArrayRead(B, &b)); 7624742e46bSJacob Faibussowitsch } 7634742e46bSJacob Faibussowitsch 7644742e46bSJacob Faibussowitsch if (ldx < m || !xiscupm) { 7654742e46bSJacob Faibussowitsch // X's array cannot serve as the array (too small or not on device), B's array cannot 7664742e46bSJacob Faibussowitsch // serve as the array (const), so allocate a new array 7674742e46bSJacob Faibussowitsch ldy = m; 7684742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&y, nrhs * m)); 7694742e46bSJacob Faibussowitsch } else { 7704742e46bSJacob Faibussowitsch // X's array should serve as the array 7714742e46bSJacob Faibussowitsch ldy = ldx; 7724742e46bSJacob Faibussowitsch y = DeviceArrayWrite(dctx, X); 7734742e46bSJacob Faibussowitsch } 7744742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(y, ldy, b, ldb, m, nrhs, copy_mode, stream)); 7754742e46bSJacob Faibussowitsch if (!bisdevice && !biscupm) PetscCall(MatDenseRestoreArrayRead(B, &b)); 7764742e46bSJacob Faibussowitsch } 7774742e46bSJacob Faibussowitsch 7784742e46bSJacob Faibussowitsch // convert to CUPM twice?????????????????????????????????? 7794742e46bSJacob Faibussowitsch // but A should already be CUPM?????????????????????????????????????? 7804742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 7814742e46bSJacob Faibussowitsch PetscCall(Solver{}.template Solve<transpose>(A, cupmScalarPtrCast(y), ldy, m, nrhs, k, dctx, stream)); 7824742e46bSJacob Faibussowitsch if (!aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 7834742e46bSJacob Faibussowitsch 7844742e46bSJacob Faibussowitsch if (ldx < m || !xiscupm) { 7854742e46bSJacob Faibussowitsch const auto copy_mode = xiscupm ? cupmMemcpyDeviceToDevice : cupmMemcpyDeviceToHost; 7864742e46bSJacob Faibussowitsch PetscScalar *x; 7874742e46bSJacob Faibussowitsch 7884742e46bSJacob Faibussowitsch // The logic here is that the data is not yet in either X's GPU array or its CPU 7894742e46bSJacob Faibussowitsch // array. There is nothing in the interface to say where the user would like it to end up. 7904742e46bSJacob Faibussowitsch // So we choose the GPU, because it is the faster option 7914742e46bSJacob Faibussowitsch if (xiscupm) { 7924742e46bSJacob Faibussowitsch x = DeviceArrayWrite(dctx, X); 7934742e46bSJacob Faibussowitsch } else { 7944742e46bSJacob Faibussowitsch PetscCall(MatDenseGetArray(X, &x)); 7954742e46bSJacob Faibussowitsch } 7964742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(x, ldx, y, ldy, k, nrhs, copy_mode, stream)); 7974742e46bSJacob Faibussowitsch if (!xiscupm) PetscCall(MatDenseRestoreArray(X, &x)); 7984742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(y, stream)); 7994742e46bSJacob Faibussowitsch } 8004742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 8014742e46bSJacob Faibussowitsch } 8024742e46bSJacob Faibussowitsch 8034742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 80495571869SBlanca Mellado Pinto template <bool transpose, bool hermitian> 8050be0d8bdSHansol Suh inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMultAddColumnRange_Dispatch_(Mat A, Vec xx, Vec yy, Vec zz, PetscInt c_start, PetscInt c_end) noexcept 8064742e46bSJacob Faibussowitsch { 8074742e46bSJacob Faibussowitsch const auto m = static_cast<cupmBlasInt_t>(A->rmap->n); 8080be0d8bdSHansol Suh const auto n = static_cast<cupmBlasInt_t>(c_end - c_start); 8090be0d8bdSHansol Suh const auto lda = static_cast<cupmBlasInt_t>(MatIMPLCast(A)->lda); 8103853def2SToby Isaac PetscBool xiscupm, yiscupm, ziscupm; 8114742e46bSJacob Faibussowitsch cupmBlasHandle_t handle; 8123853def2SToby Isaac Vec x = xx, y = yy, z = zz; 8134742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 8144742e46bSJacob Faibussowitsch 8154742e46bSJacob Faibussowitsch PetscFunctionBegin; 8163853def2SToby Isaac PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xx), &xiscupm, VecSeq_CUPM::VECSEQCUPM(), VecSeq_CUPM::VECMPICUPM(), VecSeq_CUPM::VECCUPM(), "")); 8173853def2SToby Isaac if (!xiscupm || xx->boundtocpu) { 8183853def2SToby Isaac PetscCall(VecCreate(PetscObjectComm(PetscObjectCast(xx)), &x)); 8193853def2SToby Isaac PetscCall(VecSetLayout(x, xx->map)); 8203853def2SToby Isaac PetscCall(VecSetType(x, VecSeq_CUPM::VECCUPM())); 8213853def2SToby Isaac PetscCall(VecCopy(xx, x)); 8223853def2SToby Isaac } 8233853def2SToby Isaac 8243853def2SToby Isaac if (yy) { 8253853def2SToby Isaac PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yy), &yiscupm, VecSeq_CUPM::VECSEQCUPM(), VecSeq_CUPM::VECMPICUPM(), VecSeq_CUPM::VECCUPM(), "")); 8263853def2SToby Isaac if (!yiscupm || yy->boundtocpu) { 8273853def2SToby Isaac PetscCall(VecCreate(PetscObjectComm(PetscObjectCast(yy)), &y)); 8283853def2SToby Isaac PetscCall(VecSetLayout(y, yy->map)); 8293853def2SToby Isaac PetscCall(VecSetType(y, VecSeq_CUPM::VECCUPM())); 8303853def2SToby Isaac PetscCall(VecCopy(yy, y)); 8313853def2SToby Isaac } 8323853def2SToby Isaac } 8333853def2SToby Isaac 8343853def2SToby Isaac if (zz != yy) { 8353853def2SToby Isaac PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(zz), &ziscupm, VecSeq_CUPM::VECSEQCUPM(), VecSeq_CUPM::VECMPICUPM(), VecSeq_CUPM::VECCUPM(), "")); 8363853def2SToby Isaac if (!ziscupm || zz->boundtocpu) { 8373853def2SToby Isaac PetscCall(VecCreate(PetscObjectComm(PetscObjectCast(zz)), &z)); 8383853def2SToby Isaac PetscCall(VecSetLayout(z, zz->map)); 8393853def2SToby Isaac PetscCall(VecSetType(z, VecSeq_CUPM::VECCUPM())); 8403853def2SToby Isaac } 8413853def2SToby Isaac } else { 8423853def2SToby Isaac z = y; 8433853def2SToby Isaac } 8443853def2SToby Isaac 8453853def2SToby Isaac if (y && y != z) PetscCall(VecSeq_CUPM::Copy(y, z)); // mult add 8464742e46bSJacob Faibussowitsch if (!m || !n) { 8474742e46bSJacob Faibussowitsch // mult only 8483853def2SToby Isaac if (!y) PetscCall(VecSeq_CUPM::Set(z, 0.0)); 8494742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 8504742e46bSJacob Faibussowitsch } 8514742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "Matrix-vector product %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " on backend\n", m, n)); 8524742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle)); 8534742e46bSJacob Faibussowitsch { 85495571869SBlanca Mellado Pinto constexpr auto op = transpose ? (hermitian ? CUPMBLAS_OP_C : CUPMBLAS_OP_T) : CUPMBLAS_OP_N; 8554742e46bSJacob Faibussowitsch const auto one = cupmScalarCast(1.0); 8564742e46bSJacob Faibussowitsch const auto zero = cupmScalarCast(0.0); 8574742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 8583853def2SToby Isaac const auto dxx = VecSeq_CUPM::DeviceArrayRead(dctx, x); 8593853def2SToby Isaac const auto dzz = VecSeq_CUPM::DeviceArrayReadWrite(dctx, z); 8604742e46bSJacob Faibussowitsch 8614742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 8623853def2SToby Isaac PetscCallCUPMBLAS(cupmBlasXgemv(handle, op, m, n, &one, da.cupmdata() + c_start * lda, lda, dxx.cupmdata() + (transpose ? 0 : c_start), 1, y ? &one : &zero, dzz.cupmdata() + (transpose ? c_start : 0), 1)); 8634742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 8644742e46bSJacob Faibussowitsch } 8654742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(2.0 * m * n - (yy ? 0 : m))); 8663853def2SToby Isaac if (z != zz) { 8673853def2SToby Isaac PetscCall(VecCopy(z, zz)); 8683853def2SToby Isaac if (z != y) PetscCall(VecDestroy(&z)); 8693853def2SToby Isaac } 8703853def2SToby Isaac if (y != yy) PetscCall(VecDestroy(&y)); 8713853def2SToby Isaac if (x != xx) PetscCall(VecDestroy(&x)); 8724742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 8734742e46bSJacob Faibussowitsch } 8744742e46bSJacob Faibussowitsch 8750be0d8bdSHansol Suh template <device::cupm::DeviceType T> 8760be0d8bdSHansol Suh template <bool transpose, bool hermitian> 8770be0d8bdSHansol Suh inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMultColumnRange_Dispatch_(Mat A, Vec xx, Vec yy, PetscInt c_start, PetscInt c_end) noexcept 8780be0d8bdSHansol Suh { 8790be0d8bdSHansol Suh PetscFunctionBegin; 8800be0d8bdSHansol Suh PetscCall(MatMultAddColumnRange_Dispatch_<transpose, hermitian>(A, xx, nullptr, yy, c_start, c_end)); 8810be0d8bdSHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 8820be0d8bdSHansol Suh } 8830be0d8bdSHansol Suh 8840be0d8bdSHansol Suh template <device::cupm::DeviceType T> 8850be0d8bdSHansol Suh template <bool transpose, bool hermitian> 8860be0d8bdSHansol Suh inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMultAdd_Dispatch_(Mat A, Vec xx, Vec yy, Vec zz) noexcept 8870be0d8bdSHansol Suh { 8880be0d8bdSHansol Suh PetscFunctionBegin; 8890be0d8bdSHansol Suh PetscCall(MatMultAddColumnRange_Dispatch_<transpose, hermitian>(A, xx, yy, zz, 0, A->cmap->n)); 8900be0d8bdSHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 8910be0d8bdSHansol Suh } 8920be0d8bdSHansol Suh 8934742e46bSJacob Faibussowitsch // ========================================================================================== 8944742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Private API - Conversion Dispatch 8954742e46bSJacob Faibussowitsch // ========================================================================================== 8964742e46bSJacob Faibussowitsch 8974742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 8984742e46bSJacob Faibussowitsch template <bool to_host> 8994742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_Dispatch_(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept 9004742e46bSJacob Faibussowitsch { 9014742e46bSJacob Faibussowitsch PetscFunctionBegin; 9024742e46bSJacob Faibussowitsch if (reuse == MAT_REUSE_MATRIX || reuse == MAT_INITIAL_MATRIX) { 9034742e46bSJacob Faibussowitsch // TODO these cases should be optimized 9044742e46bSJacob Faibussowitsch PetscCall(MatConvert_Basic(M, type, reuse, newmat)); 9054742e46bSJacob Faibussowitsch } else { 9064742e46bSJacob Faibussowitsch const auto B = *newmat; 9074742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(B); 9084742e46bSJacob Faibussowitsch 9094742e46bSJacob Faibussowitsch if (to_host) { 9104742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_TRUE)); 9114742e46bSJacob Faibussowitsch PetscCall(Reset(B)); 9124742e46bSJacob Faibussowitsch } else { 9134742e46bSJacob Faibussowitsch PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 9144742e46bSJacob Faibussowitsch } 9154742e46bSJacob Faibussowitsch 9164742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecSeq_CUPM::VECCUPM(), &B->defaultvectype)); 9174742e46bSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATSEQDENSE : MATSEQDENSECUPM())); 9184742e46bSJacob Faibussowitsch // cvec might be the wrong VecType, destroy and rebuild it if necessary 9194742e46bSJacob Faibussowitsch // REVIEW ME: this is possibly very inefficient 9204742e46bSJacob Faibussowitsch PetscCall(VecDestroy(&MatIMPLCast(B)->cvec)); 9214742e46bSJacob Faibussowitsch 9224742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatConvert_seqdensecupm_seqdense_C(), nullptr, Convert_SeqDenseCUPM_SeqDense); 9234742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 9244742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 9254742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 9264742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 9274742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 9284742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 9294742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray); 9304742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray); 9314742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray); 9324742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_seqaij_seqdensecupm_C(), nullptr, MatProductSetFromOptions_SeqAIJ_SeqDense); 9333d9668e3SJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation); 9344742e46bSJacob Faibussowitsch 9354742e46bSJacob Faibussowitsch if (to_host) { 9364742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_CPU; 9374742e46bSJacob Faibussowitsch } else { 9384742e46bSJacob Faibussowitsch Mat_SeqDenseCUPM *mcu; 9394742e46bSJacob Faibussowitsch 9404742e46bSJacob Faibussowitsch PetscCall(PetscNew(&mcu)); 9414742e46bSJacob Faibussowitsch B->spptr = mcu; 9424742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; // REVIEW ME: why not offload host?? 9434742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_FALSE)); 9444742e46bSJacob Faibussowitsch } 9454742e46bSJacob Faibussowitsch 9464742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU); 9474742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, B, destroy, MatDestroy_SeqDense, Destroy); 9484742e46bSJacob Faibussowitsch } 9494742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 9504742e46bSJacob Faibussowitsch } 9514742e46bSJacob Faibussowitsch 9524742e46bSJacob Faibussowitsch // ========================================================================================== 9534742e46bSJacob Faibussowitsch // MatDense_Seq_CUPM - Public API 9544742e46bSJacob Faibussowitsch // ========================================================================================== 9554742e46bSJacob Faibussowitsch 9564742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 9574742e46bSJacob Faibussowitsch inline constexpr MatType MatDense_Seq_CUPM<T>::MATIMPLCUPM_() noexcept 9584742e46bSJacob Faibussowitsch { 9594742e46bSJacob Faibussowitsch return MATSEQDENSECUPM(); 9604742e46bSJacob Faibussowitsch } 9614742e46bSJacob Faibussowitsch 9624742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 9634742e46bSJacob Faibussowitsch inline constexpr typename MatDense_Seq_CUPM<T>::Mat_SeqDenseCUPM *MatDense_Seq_CUPM<T>::MatCUPMCast(Mat m) noexcept 9644742e46bSJacob Faibussowitsch { 9654742e46bSJacob Faibussowitsch return static_cast<Mat_SeqDenseCUPM *>(m->spptr); 9664742e46bSJacob Faibussowitsch } 9674742e46bSJacob Faibussowitsch 9684742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 9694742e46bSJacob Faibussowitsch inline constexpr Mat_SeqDense *MatDense_Seq_CUPM<T>::MatIMPLCast_(Mat m) noexcept 9704742e46bSJacob Faibussowitsch { 9714742e46bSJacob Faibussowitsch return static_cast<Mat_SeqDense *>(m->data); 9724742e46bSJacob Faibussowitsch } 9734742e46bSJacob Faibussowitsch 9744742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 9754742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_Seq_CUPM<T>::MatConvert_seqdensecupm_seqdense_C() noexcept 9764742e46bSJacob Faibussowitsch { 9774742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatConvert_seqdensecuda_seqdense_C" : "MatConvert_seqdensehip_seqdense_C"; 9784742e46bSJacob Faibussowitsch } 9794742e46bSJacob Faibussowitsch 9804742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 9814742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_Seq_CUPM<T>::MatProductSetFromOptions_seqaij_seqdensecupm_C() noexcept 9824742e46bSJacob Faibussowitsch { 9834742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_seqaij_seqdensecuda_C" : "MatProductSetFromOptions_seqaij_seqdensehip_C"; 9844742e46bSJacob Faibussowitsch } 9854742e46bSJacob Faibussowitsch 9864742e46bSJacob Faibussowitsch // ========================================================================================== 9874742e46bSJacob Faibussowitsch 9884742e46bSJacob Faibussowitsch // MatCreate_SeqDenseCUPM() 9894742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 9904742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Create(Mat A) noexcept 9914742e46bSJacob Faibussowitsch { 9924742e46bSJacob Faibussowitsch PetscFunctionBegin; 9934742e46bSJacob Faibussowitsch PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 9944742e46bSJacob Faibussowitsch PetscCall(MatCreate_SeqDense(A)); 9954742e46bSJacob Faibussowitsch PetscCall(Convert_SeqDense_SeqDenseCUPM(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 9964742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 9974742e46bSJacob Faibussowitsch } 9984742e46bSJacob Faibussowitsch 9994742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 10004742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Destroy(Mat A) noexcept 10014742e46bSJacob Faibussowitsch { 10024742e46bSJacob Faibussowitsch PetscFunctionBegin; 10034742e46bSJacob Faibussowitsch // prevent copying back data if we own the data pointer 10044742e46bSJacob Faibussowitsch if (!MatIMPLCast(A)->user_alloc) A->offloadmask = PETSC_OFFLOAD_CPU; 10054742e46bSJacob Faibussowitsch PetscCall(Convert_SeqDenseCUPM_SeqDense(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A)); 10064742e46bSJacob Faibussowitsch PetscCall(MatDestroy_SeqDense(A)); 10074742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 10084742e46bSJacob Faibussowitsch } 10094742e46bSJacob Faibussowitsch 10104742e46bSJacob Faibussowitsch // obj->ops->setup() 10114742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 10124742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetUp(Mat A) noexcept 10134742e46bSJacob Faibussowitsch { 10144742e46bSJacob Faibussowitsch PetscFunctionBegin; 10154742e46bSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(A->rmap)); 10164742e46bSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(A->cmap)); 10174742e46bSJacob Faibussowitsch if (!A->preallocated) { 10184742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 10194742e46bSJacob Faibussowitsch 10204742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 10213d9668e3SJacob Faibussowitsch PetscCall(SetPreallocation(A, dctx, nullptr)); 10224742e46bSJacob Faibussowitsch } 10234742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 10244742e46bSJacob Faibussowitsch } 10254742e46bSJacob Faibussowitsch 10264742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 10274742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Reset(Mat A) noexcept 10284742e46bSJacob Faibussowitsch { 10294742e46bSJacob Faibussowitsch PetscFunctionBegin; 10304742e46bSJacob Faibussowitsch if (const auto mcu = MatCUPMCast(A)) { 10314742e46bSJacob Faibussowitsch cupmStream_t stream; 10324742e46bSJacob Faibussowitsch 10334742e46bSJacob Faibussowitsch PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME()); 10344742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&stream)); 10354742e46bSJacob Faibussowitsch if (!mcu->d_user_alloc) PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream)); 10364742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_tau, stream)); 10374742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_ipiv, stream)); 10384742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_info, stream)); 10394742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream)); 10404742e46bSJacob Faibussowitsch PetscCall(VecDestroy(&mcu->workvec)); 10414742e46bSJacob Faibussowitsch PetscCall(PetscFree(A->spptr /* mcu */)); 10424742e46bSJacob Faibussowitsch } 10434742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 10444742e46bSJacob Faibussowitsch } 10454742e46bSJacob Faibussowitsch 10464742e46bSJacob Faibussowitsch // ========================================================================================== 10474742e46bSJacob Faibussowitsch 10484742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 10494742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::BindToCPU(Mat A, PetscBool to_host) noexcept 10504742e46bSJacob Faibussowitsch { 10514742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 10524742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A); 10534742e46bSJacob Faibussowitsch 10544742e46bSJacob Faibussowitsch PetscFunctionBegin; 10554742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 10564742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 10574742e46bSJacob Faibussowitsch A->boundtocpu = to_host; 10584742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(to_host ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype)); 10594742e46bSJacob Faibussowitsch if (to_host) { 10604742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 10614742e46bSJacob Faibussowitsch 10624742e46bSJacob Faibussowitsch // make sure we have an up-to-date copy on the CPU 10634742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 10644742e46bSJacob Faibussowitsch PetscCall(DeviceToHost_(A, dctx)); 10654742e46bSJacob Faibussowitsch } else { 10664742e46bSJacob Faibussowitsch PetscBool iscupm; 10674742e46bSJacob Faibussowitsch 10684742e46bSJacob Faibussowitsch if (auto &cvec = mimpl->cvec) { 10694742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(cvec), VecSeq_CUPM::VECSEQCUPM(), &iscupm)); 10704742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(VecDestroy(&cvec)); 10714742e46bSJacob Faibussowitsch } 10724742e46bSJacob Faibussowitsch if (auto &cmat = mimpl->cmat) { 10734742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(cmat), MATSEQDENSECUPM(), &iscupm)); 10744742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(MatDestroy(&cmat)); 10754742e46bSJacob Faibussowitsch } 10764742e46bSJacob Faibussowitsch } 10774742e46bSJacob Faibussowitsch 10784742e46bSJacob Faibussowitsch // ============================================================ 10794742e46bSJacob Faibussowitsch // Composed ops 10804742e46bSJacob Faibussowitsch // ============================================================ 10814742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArray_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>); 10824742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayRead_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>); 10834742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayWrite_C", MatDenseGetArray_SeqDense, GetArrayC_<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>); 10844742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ_WRITE>); 10854742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ_WRITE>); 10864742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayReadAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ>); 10874742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayReadAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_READ>); 10884742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetArrayWriteAndMemType_C", nullptr, GetArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_WRITE>); 10894742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreArrayWriteAndMemType_C", nullptr, RestoreArrayAndMemTypeC_<PETSC_MEMORY_ACCESS_WRITE>); 10904742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 10914742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 10924742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>); 10934742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>); 10944742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_SeqDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 10954742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_SeqDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 10964742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseGetSubMatrix_C", MatDenseGetSubMatrix_SeqDense, GetSubMatrix); 10974742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatDenseRestoreSubMatrix_C", MatDenseRestoreSubMatrix_SeqDense, RestoreSubMatrix); 10984742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, "MatQRFactor_C", MatQRFactor_SeqDense, SolveQR::Factor); 1099*d016bddeSToby Isaac MatComposeOp_CUPM(to_host, pobj, "MatMultColumnRange_C", MatMultColumnRange_SeqDense, MatMultColumnRange_Dispatch_</* transpose */ false, /* hermitian */ false>); 11000be0d8bdSHansol Suh MatComposeOp_CUPM(to_host, pobj, "MatMultAddColumnRange_C", MatMultAddColumnRange_SeqDense, MatMultAddColumnRange_Dispatch_</* transpose */ false, /* hermitian */ false>); 11010be0d8bdSHansol Suh MatComposeOp_CUPM(to_host, pobj, "MatMultHermitianTransposeColumnRange_C", MatMultHermitianTransposeColumnRange_SeqDense, MatMultColumnRange_Dispatch_</* transpose */ true, /* hermitian */ true>); 11020be0d8bdSHansol Suh MatComposeOp_CUPM(to_host, pobj, "MatMultHermitianTransposeAddColumnRange_C", MatMultHermitianTransposeAddColumnRange_SeqDense, MatMultAddColumnRange_Dispatch_</* transpose */ true, /* hermitian */ true>); 11034742e46bSJacob Faibussowitsch // always the same 11044742e46bSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction(pobj, "MatDenseSetLDA_C", MatDenseSetLDA_SeqDense)); 11054742e46bSJacob Faibussowitsch 11064742e46bSJacob Faibussowitsch // ============================================================ 11074742e46bSJacob Faibussowitsch // Function pointer ops 11084742e46bSJacob Faibussowitsch // ============================================================ 11094742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, duplicate, MatDuplicate_SeqDense, Duplicate); 111095571869SBlanca Mellado Pinto MatSetOp_CUPM(to_host, A, mult, MatMult_SeqDense, [](Mat A, Vec xx, Vec yy) { return MatMultAdd_Dispatch_</* transpose */ false, /* hermitian */ false>(A, xx, nullptr, yy); }); 111195571869SBlanca Mellado Pinto MatSetOp_CUPM(to_host, A, multtranspose, MatMultTranspose_SeqDense, [](Mat A, Vec xx, Vec yy) { return MatMultAdd_Dispatch_</* transpose */ true, /* hermitian */ false>(A, xx, nullptr, yy); }); 111295571869SBlanca Mellado Pinto MatSetOp_CUPM(to_host, A, multhermitiantranspose, MatMultTranspose_SeqDense, [](Mat A, Vec xx, Vec yy) { return MatMultAdd_Dispatch_</* transpose */ true, /* hermitian */ true>(A, xx, nullptr, yy); }); 111395571869SBlanca Mellado Pinto MatSetOp_CUPM(to_host, A, multadd, MatMultAdd_SeqDense, MatMultAdd_Dispatch_</* transpose */ false, /* hermitian */ false>); 111495571869SBlanca Mellado Pinto MatSetOp_CUPM(to_host, A, multtransposeadd, MatMultTransposeAdd_SeqDense, MatMultAdd_Dispatch_</* transpose */ true, /* hermitian */ false>); 111595571869SBlanca Mellado Pinto MatSetOp_CUPM(to_host, A, multhermitiantransposeadd, MatMultHermitianTransposeAdd_SeqDense, MatMultAdd_Dispatch_</* transpose */ true, /* hermitian */ true>); 11164742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, matmultnumeric, MatMatMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ false, /* transpose_B */ false>); 11174742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, mattransposemultnumeric, MatMatTransposeMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ false, /* transpose_B */ true>); 11184742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, transposematmultnumeric, MatTransposeMatMultNumeric_SeqDense_SeqDense, MatMatMult_Numeric_Dispatch</* transpose_A */ true, /* transpose_B */ false>); 11194742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, axpy, MatAXPY_SeqDense, AXPY); 11204742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, choleskyfactor, MatCholeskyFactor_SeqDense, SolveCholesky::Factor); 11214742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, lufactor, MatLUFactor_SeqDense, SolveLU::Factor); 11224742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, getcolumnvector, MatGetColumnVector_SeqDense, GetColumnVector); 11233853def2SToby Isaac MatSetOp_CUPM(to_host, A, conjugate, MatConjugate_SeqDense, Conjugate); 11244742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, scale, MatScale_SeqDense, Scale); 11254742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, shift, MatShift_SeqDense, Shift); 11264742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, copy, MatCopy_SeqDense, Copy); 11274742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, zeroentries, MatZeroEntries_SeqDense, ZeroEntries); 11284742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, setup, MatSetUp_SeqDense, SetUp); 11294742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, A, setrandom, MatSetRandom_SeqDense, SetRandom); 113014277c92SJacob Faibussowitsch MatSetOp_CUPM(to_host, A, getdiagonal, MatGetDiagonal_SeqDense, GetDiagonal); 11314742e46bSJacob Faibussowitsch // seemingly always the same 11324742e46bSJacob Faibussowitsch A->ops->productsetfromoptions = MatProductSetFromOptions_SeqDense; 11334742e46bSJacob Faibussowitsch 11344742e46bSJacob Faibussowitsch if (const auto cmat = mimpl->cmat) PetscCall(MatBindToCPU(cmat, to_host)); 11354742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 11364742e46bSJacob Faibussowitsch } 11374742e46bSJacob Faibussowitsch 11384742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 11394742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_SeqDenseCUPM_SeqDense(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept 11404742e46bSJacob Faibussowitsch { 11414742e46bSJacob Faibussowitsch PetscFunctionBegin; 11424742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ true>(M, type, reuse, newmat)); 11434742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 11444742e46bSJacob Faibussowitsch } 11454742e46bSJacob Faibussowitsch 11464742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 11474742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Convert_SeqDense_SeqDenseCUPM(Mat M, MatType type, MatReuse reuse, Mat *newmat) noexcept 11484742e46bSJacob Faibussowitsch { 11494742e46bSJacob Faibussowitsch PetscFunctionBegin; 11504742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ false>(M, type, reuse, newmat)); 11514742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 11524742e46bSJacob Faibussowitsch } 11534742e46bSJacob Faibussowitsch 11544742e46bSJacob Faibussowitsch // ========================================================================================== 11554742e46bSJacob Faibussowitsch 11564742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 11574742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode access> 11584742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetArray(Mat m, PetscScalar **array, PetscDeviceContext dctx) noexcept 11594742e46bSJacob Faibussowitsch { 11604742e46bSJacob Faibussowitsch constexpr auto hostmem = PetscMemTypeHost(mtype); 11614742e46bSJacob Faibussowitsch constexpr auto read_access = PetscMemoryAccessRead(access); 11624742e46bSJacob Faibussowitsch 11634742e46bSJacob Faibussowitsch PetscFunctionBegin; 11644742e46bSJacob Faibussowitsch static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), ""); 11654742e46bSJacob Faibussowitsch PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 11664742e46bSJacob Faibussowitsch if (hostmem) { 11674742e46bSJacob Faibussowitsch if (read_access) { 11684742e46bSJacob Faibussowitsch PetscCall(DeviceToHost_(m, dctx)); 11694742e46bSJacob Faibussowitsch } else if (!MatIMPLCast(m)->v) { 11704742e46bSJacob Faibussowitsch // MatCreateSeqDenseCUPM may not allocate CPU memory. Allocate if needed 11714742e46bSJacob Faibussowitsch PetscCall(MatSeqDenseSetPreallocation(m, nullptr)); 11724742e46bSJacob Faibussowitsch } 11734742e46bSJacob Faibussowitsch *array = MatIMPLCast(m)->v; 11744742e46bSJacob Faibussowitsch } else { 11754742e46bSJacob Faibussowitsch if (read_access) { 11764742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(m, dctx)); 11774742e46bSJacob Faibussowitsch } else if (!MatCUPMCast(m)->d_v) { 11784742e46bSJacob Faibussowitsch // write-only 11794742e46bSJacob Faibussowitsch PetscCall(SetPreallocation(m, dctx, nullptr)); 11804742e46bSJacob Faibussowitsch } 11814742e46bSJacob Faibussowitsch *array = MatCUPMCast(m)->d_v; 11824742e46bSJacob Faibussowitsch } 11834742e46bSJacob Faibussowitsch if (PetscMemoryAccessWrite(access)) { 11844742e46bSJacob Faibussowitsch m->offloadmask = hostmem ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU; 11854742e46bSJacob Faibussowitsch PetscCall(PetscObjectStateIncrease(PetscObjectCast(m))); 11864742e46bSJacob Faibussowitsch } 11874742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 11884742e46bSJacob Faibussowitsch } 11894742e46bSJacob Faibussowitsch 11904742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 11914742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode access> 11924742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreArray(Mat m, PetscScalar **array, PetscDeviceContext) noexcept 11934742e46bSJacob Faibussowitsch { 11944742e46bSJacob Faibussowitsch PetscFunctionBegin; 11954742e46bSJacob Faibussowitsch static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), ""); 11964742e46bSJacob Faibussowitsch if (PetscMemoryAccessWrite(access)) { 11974742e46bSJacob Faibussowitsch // WRITE or READ_WRITE 11984742e46bSJacob Faibussowitsch m->offloadmask = PetscMemTypeHost(mtype) ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU; 11994742e46bSJacob Faibussowitsch PetscCall(PetscObjectStateIncrease(PetscObjectCast(m))); 12004742e46bSJacob Faibussowitsch } 120101329d7dSJacob Faibussowitsch if (array) { 120201329d7dSJacob Faibussowitsch PetscCall(CheckPointerMatchesMemType_(*array, mtype)); 12034742e46bSJacob Faibussowitsch *array = nullptr; 120401329d7dSJacob Faibussowitsch } 12054742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 12064742e46bSJacob Faibussowitsch } 12074742e46bSJacob Faibussowitsch 12084742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 12094742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 12104742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetArrayAndMemType(Mat m, PetscScalar **array, PetscMemType *mtype, PetscDeviceContext dctx) noexcept 12114742e46bSJacob Faibussowitsch { 12124742e46bSJacob Faibussowitsch PetscFunctionBegin; 12134742e46bSJacob Faibussowitsch PetscCall(GetArray<PETSC_MEMTYPE_DEVICE, access>(m, array, dctx)); 12144742e46bSJacob Faibussowitsch if (mtype) *mtype = PETSC_MEMTYPE_CUPM(); 12154742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 12164742e46bSJacob Faibussowitsch } 12174742e46bSJacob Faibussowitsch 12184742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 12194742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 12204742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreArrayAndMemType(Mat m, PetscScalar **array, PetscDeviceContext dctx) noexcept 12214742e46bSJacob Faibussowitsch { 12224742e46bSJacob Faibussowitsch PetscFunctionBegin; 12234742e46bSJacob Faibussowitsch PetscCall(RestoreArray<PETSC_MEMTYPE_DEVICE, access>(m, array, dctx)); 12244742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 12254742e46bSJacob Faibussowitsch } 12264742e46bSJacob Faibussowitsch 12274742e46bSJacob Faibussowitsch // ========================================================================================== 12284742e46bSJacob Faibussowitsch 12294742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 12304742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept 12314742e46bSJacob Faibussowitsch { 12324742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 12334742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 12344742e46bSJacob Faibussowitsch 12354742e46bSJacob Faibussowitsch PetscFunctionBegin; 12364742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 12374742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 12384742e46bSJacob Faibussowitsch PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME()); 12394742e46bSJacob Faibussowitsch if (mimpl->v) { 12404742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 12414742e46bSJacob Faibussowitsch 12424742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 12434742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(A, dctx)); 12444742e46bSJacob Faibussowitsch } 12454742e46bSJacob Faibussowitsch mcu->unplacedarray = util::exchange(mcu->d_v, const_cast<PetscScalar *>(array)); 12464742e46bSJacob Faibussowitsch mcu->d_unplaced_user_alloc = util::exchange(mcu->d_user_alloc, PETSC_TRUE); 12474742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 12484742e46bSJacob Faibussowitsch } 12494742e46bSJacob Faibussowitsch 12504742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 12514742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept 12524742e46bSJacob Faibussowitsch { 12534742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 12544742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 12554742e46bSJacob Faibussowitsch 12564742e46bSJacob Faibussowitsch PetscFunctionBegin; 12574742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 12584742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 12594742e46bSJacob Faibussowitsch PetscCheck(!mcu->unplacedarray, PETSC_COMM_SELF, PETSC_ERR_ORDER, "MatDense%sResetArray() must be called first", cupmNAME()); 12604742e46bSJacob Faibussowitsch if (!mcu->d_user_alloc) { 12614742e46bSJacob Faibussowitsch cupmStream_t stream; 12624742e46bSJacob Faibussowitsch 12634742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&stream)); 12644742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_v, stream)); 12654742e46bSJacob Faibussowitsch } 12664742e46bSJacob Faibussowitsch mcu->d_v = const_cast<PetscScalar *>(array); 12674742e46bSJacob Faibussowitsch mcu->d_user_alloc = PETSC_FALSE; 12684742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 12694742e46bSJacob Faibussowitsch } 12704742e46bSJacob Faibussowitsch 12714742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 12724742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ResetArray(Mat A) noexcept 12734742e46bSJacob Faibussowitsch { 12744742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 12754742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 12764742e46bSJacob Faibussowitsch 12774742e46bSJacob Faibussowitsch PetscFunctionBegin; 12784742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 12794742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 12804742e46bSJacob Faibussowitsch if (mimpl->v) { 12814742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 12824742e46bSJacob Faibussowitsch 12834742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 12844742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(A, dctx)); 12854742e46bSJacob Faibussowitsch } 12864742e46bSJacob Faibussowitsch mcu->d_v = util::exchange(mcu->unplacedarray, nullptr); 12874742e46bSJacob Faibussowitsch mcu->d_user_alloc = mcu->d_unplaced_user_alloc; 12884742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 12894742e46bSJacob Faibussowitsch } 12904742e46bSJacob Faibussowitsch 12914742e46bSJacob Faibussowitsch // ========================================================================================== 12924742e46bSJacob Faibussowitsch 12934742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 12944742e46bSJacob Faibussowitsch template <bool transpose_A, bool transpose_B> 12954742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::MatMatMult_Numeric_Dispatch(Mat A, Mat B, Mat C) noexcept 12964742e46bSJacob Faibussowitsch { 12974742e46bSJacob Faibussowitsch cupmBlasInt_t m, n, k; 12984742e46bSJacob Faibussowitsch PetscBool Aiscupm, Biscupm; 12994742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 13004742e46bSJacob Faibussowitsch cupmBlasHandle_t handle; 13014742e46bSJacob Faibussowitsch 13024742e46bSJacob Faibussowitsch PetscFunctionBegin; 13034742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(C->rmap->n, &m)); 13044742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(C->cmap->n, &n)); 13054742e46bSJacob Faibussowitsch PetscCall(PetscCUPMBlasIntCast(transpose_A ? A->rmap->n : A->cmap->n, &k)); 13064742e46bSJacob Faibussowitsch if (!m || !n || !k) PetscFunctionReturn(PETSC_SUCCESS); 13074742e46bSJacob Faibussowitsch 13084742e46bSJacob Faibussowitsch // we may end up with SEQDENSE as one of the arguments 13094742e46bSJacob Faibussowitsch // REVIEW ME: how? and why is it not B and C???????? 13104742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), MATSEQDENSECUPM(), &Aiscupm)); 13114742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(B), MATSEQDENSECUPM(), &Biscupm)); 13124742e46bSJacob Faibussowitsch if (!Aiscupm) PetscCall(MatConvert(A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 13134742e46bSJacob Faibussowitsch if (!Biscupm) PetscCall(MatConvert(B, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &B)); 13144742e46bSJacob Faibussowitsch PetscCall(PetscInfo(C, "Matrix-Matrix product %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " x %" PetscBLASInt_FMT " on backend\n", m, k, n)); 13154742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle)); 13164742e46bSJacob Faibussowitsch 13174742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 13184742e46bSJacob Faibussowitsch { 13194742e46bSJacob Faibussowitsch const auto one = cupmScalarCast(1.0); 13204742e46bSJacob Faibussowitsch const auto zero = cupmScalarCast(0.0); 13214742e46bSJacob Faibussowitsch const auto da = DeviceArrayRead(dctx, A); 13224742e46bSJacob Faibussowitsch const auto db = DeviceArrayRead(dctx, B); 13234742e46bSJacob Faibussowitsch const auto dc = DeviceArrayWrite(dctx, C); 13244742e46bSJacob Faibussowitsch PetscInt alda, blda, clda; 13254742e46bSJacob Faibussowitsch 13264742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(A, &alda)); 13274742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(B, &blda)); 13284742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(C, &clda)); 13294742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXgemm(handle, transpose_A ? CUPMBLAS_OP_T : CUPMBLAS_OP_N, transpose_B ? CUPMBLAS_OP_T : CUPMBLAS_OP_N, m, n, k, &one, da.cupmdata(), alda, db.cupmdata(), blda, &zero, dc.cupmdata(), clda)); 13304742e46bSJacob Faibussowitsch } 13314742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 13324742e46bSJacob Faibussowitsch 13334742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(1.0 * m * n * k + 1.0 * m * n * (k - 1))); 13344742e46bSJacob Faibussowitsch if (!Aiscupm) PetscCall(MatConvert(A, MATSEQDENSE, MAT_INPLACE_MATRIX, &A)); 13354742e46bSJacob Faibussowitsch if (!Biscupm) PetscCall(MatConvert(B, MATSEQDENSE, MAT_INPLACE_MATRIX, &B)); 13364742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 13374742e46bSJacob Faibussowitsch } 13384742e46bSJacob Faibussowitsch 13394742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 13404742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Copy(Mat A, Mat B, MatStructure str) noexcept 13414742e46bSJacob Faibussowitsch { 13424742e46bSJacob Faibussowitsch const auto m = A->rmap->n; 13434742e46bSJacob Faibussowitsch const auto n = A->cmap->n; 13444742e46bSJacob Faibussowitsch 13454742e46bSJacob Faibussowitsch PetscFunctionBegin; 13464742e46bSJacob Faibussowitsch PetscAssert(m == B->rmap->n && n == B->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "size(B) != size(A)"); 13474742e46bSJacob Faibussowitsch // The two matrices must have the same copy implementation to be eligible for fast copy 13484742e46bSJacob Faibussowitsch if (A->ops->copy == B->ops->copy) { 13494742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 13504742e46bSJacob Faibussowitsch cupmStream_t stream; 13514742e46bSJacob Faibussowitsch 13524742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 13534742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 13544742e46bSJacob Faibussowitsch { 13554742e46bSJacob Faibussowitsch const auto va = DeviceArrayRead(dctx, A); 13564742e46bSJacob Faibussowitsch const auto vb = DeviceArrayWrite(dctx, B); 13574742e46bSJacob Faibussowitsch // order is important, DeviceArrayRead/Write() might call SetPreallocation() which sets 13584742e46bSJacob Faibussowitsch // lda! 13594742e46bSJacob Faibussowitsch const auto lda_a = MatIMPLCast(A)->lda; 13604742e46bSJacob Faibussowitsch const auto lda_b = MatIMPLCast(B)->lda; 13614742e46bSJacob Faibussowitsch 13624742e46bSJacob Faibussowitsch if (lda_a > m || lda_b > m) { 13634742e46bSJacob Faibussowitsch PetscAssert(lda_b > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "B lda (%" PetscBLASInt_FMT ") must be > 0 at this point, this indicates Mat%sSetPreallocation() was not called when it should have been!", lda_b, cupmNAME()); 13644742e46bSJacob Faibussowitsch PetscAssert(lda_a > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A lda (%" PetscBLASInt_FMT ") must be > 0 at this point, this indicates Mat%sSetPreallocation() was not called when it should have been!", lda_a, cupmNAME()); 13654742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpy2DAsync(vb.data(), lda_b, va.data(), lda_a, m, n, cupmMemcpyDeviceToDevice, stream)); 13664742e46bSJacob Faibussowitsch } else { 13674742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(vb.data(), va.data(), m * n, cupmMemcpyDeviceToDevice, stream)); 13684742e46bSJacob Faibussowitsch } 13694742e46bSJacob Faibussowitsch } 13704742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 13714742e46bSJacob Faibussowitsch } else { 13724742e46bSJacob Faibussowitsch PetscCall(MatCopy_Basic(A, B, str)); 13734742e46bSJacob Faibussowitsch } 13744742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 13754742e46bSJacob Faibussowitsch } 13764742e46bSJacob Faibussowitsch 13774742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 13784742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::ZeroEntries(Mat m) noexcept 13794742e46bSJacob Faibussowitsch { 13804742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 13814742e46bSJacob Faibussowitsch cupmStream_t stream; 13824742e46bSJacob Faibussowitsch 13834742e46bSJacob Faibussowitsch PetscFunctionBegin; 13844742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 13854742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 13864742e46bSJacob Faibussowitsch { 13874742e46bSJacob Faibussowitsch const auto va = DeviceArrayWrite(dctx, m); 13884742e46bSJacob Faibussowitsch const auto lda = MatIMPLCast(m)->lda; 13894742e46bSJacob Faibussowitsch const auto ma = m->rmap->n; 13904742e46bSJacob Faibussowitsch const auto na = m->cmap->n; 13914742e46bSJacob Faibussowitsch 13924742e46bSJacob Faibussowitsch if (lda > ma) { 13934742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemset2DAsync(va.data(), lda, 0, ma, na, stream)); 13944742e46bSJacob Faibussowitsch } else { 13954742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemsetAsync(va.data(), 0, ma * na, stream)); 13964742e46bSJacob Faibussowitsch } 13974742e46bSJacob Faibussowitsch } 13984742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 13994742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 14004742e46bSJacob Faibussowitsch } 14014742e46bSJacob Faibussowitsch 14022ea277ceSJacob Faibussowitsch namespace detail 14032ea277ceSJacob Faibussowitsch { 14042ea277ceSJacob Faibussowitsch 14052ea277ceSJacob Faibussowitsch // ========================================================================================== 14062ea277ceSJacob Faibussowitsch // SubMatIndexFunctor 14072ea277ceSJacob Faibussowitsch // 14082ea277ceSJacob Faibussowitsch // Iterator which permutes a linear index range into matrix indices for am nrows x ncols 14092ea277ceSJacob Faibussowitsch // submat with leading dimension lda. Essentially SubMatIndexFunctor(i) returns the index for 14102ea277ceSJacob Faibussowitsch // the i'th sequential entry in the matrix. 14112ea277ceSJacob Faibussowitsch // ========================================================================================== 14122ea277ceSJacob Faibussowitsch template <typename T> 14132ea277ceSJacob Faibussowitsch struct SubMatIndexFunctor { 14142ea277ceSJacob Faibussowitsch PETSC_HOSTDEVICE_INLINE_DECL T operator()(T x) const noexcept { return ((x / nrows) * lda) + (x % nrows); } 14152ea277ceSJacob Faibussowitsch 14162ea277ceSJacob Faibussowitsch PetscInt nrows; 14172ea277ceSJacob Faibussowitsch PetscInt ncols; 14182ea277ceSJacob Faibussowitsch PetscInt lda; 14192ea277ceSJacob Faibussowitsch }; 14202ea277ceSJacob Faibussowitsch 14212ea277ceSJacob Faibussowitsch template <typename Iterator> 14222ea277ceSJacob Faibussowitsch struct SubMatrixIterator : MatrixIteratorBase<Iterator, SubMatIndexFunctor<typename thrust::iterator_difference<Iterator>::type>> { 14232ea277ceSJacob Faibussowitsch using base_type = MatrixIteratorBase<Iterator, SubMatIndexFunctor<typename thrust::iterator_difference<Iterator>::type>>; 14242ea277ceSJacob Faibussowitsch 14252ea277ceSJacob Faibussowitsch using iterator = typename base_type::iterator; 14262ea277ceSJacob Faibussowitsch 14272ea277ceSJacob Faibussowitsch constexpr SubMatrixIterator(Iterator first, Iterator last, PetscInt nrows, PetscInt ncols, PetscInt lda) noexcept : 14282ea277ceSJacob Faibussowitsch base_type{ 14292ea277ceSJacob Faibussowitsch std::move(first), std::move(last), {nrows, ncols, lda} 14302ea277ceSJacob Faibussowitsch } 14312ea277ceSJacob Faibussowitsch { 14322ea277ceSJacob Faibussowitsch } 14332ea277ceSJacob Faibussowitsch 14342ea277ceSJacob Faibussowitsch PETSC_NODISCARD iterator end() const noexcept { return this->begin() + (this->func.nrows * this->func.ncols); } 14352ea277ceSJacob Faibussowitsch }; 14362ea277ceSJacob Faibussowitsch 14372ea277ceSJacob Faibussowitsch namespace 14382ea277ceSJacob Faibussowitsch { 14392ea277ceSJacob Faibussowitsch 14402ea277ceSJacob Faibussowitsch template <typename T> 14412ea277ceSJacob Faibussowitsch PETSC_NODISCARD inline SubMatrixIterator<typename thrust::device_vector<T>::iterator> make_submat_iterator(PetscInt rstart, PetscInt rend, PetscInt cstart, PetscInt cend, PetscInt lda, T *ptr) noexcept 14422ea277ceSJacob Faibussowitsch { 14432ea277ceSJacob Faibussowitsch const auto nrows = rend - rstart; 14442ea277ceSJacob Faibussowitsch const auto ncols = cend - cstart; 14452ea277ceSJacob Faibussowitsch const auto dptr = thrust::device_pointer_cast(ptr); 14462ea277ceSJacob Faibussowitsch 14472ea277ceSJacob Faibussowitsch return {dptr + (rstart * lda) + cstart, dptr + ((rstart + nrows) * lda) + cstart, nrows, ncols, lda}; 14482ea277ceSJacob Faibussowitsch } 14492ea277ceSJacob Faibussowitsch 14502ea277ceSJacob Faibussowitsch } // namespace 14512ea277ceSJacob Faibussowitsch 14523853def2SToby Isaac struct conjugate { 14533853def2SToby Isaac PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &x) const noexcept { return PetscConj(x); } 14543853def2SToby Isaac }; 14553853def2SToby Isaac 14562ea277ceSJacob Faibussowitsch } // namespace detail 14572ea277ceSJacob Faibussowitsch 14584742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 14593853def2SToby Isaac inline PetscErrorCode MatDense_Seq_CUPM<T>::Conjugate(Mat A) noexcept 14603853def2SToby Isaac { 14613853def2SToby Isaac const auto m = A->rmap->n; 14623853def2SToby Isaac const auto n = A->cmap->n; 14633853def2SToby Isaac const auto N = m * n; 14643853def2SToby Isaac PetscDeviceContext dctx; 14653853def2SToby Isaac cupmStream_t stream; 14663853def2SToby Isaac 14673853def2SToby Isaac PetscFunctionBegin; 14683853def2SToby Isaac if (PetscDefined(USE_COMPLEX)) { 14693853def2SToby Isaac PetscCall(GetHandles_(&dctx, &stream)); 14703853def2SToby Isaac PetscCall(PetscLogGpuTimeBegin()); 14713853def2SToby Isaac { 14723853def2SToby Isaac const auto da = DeviceArrayReadWrite(dctx, A); 14733853def2SToby Isaac const auto lda = MatIMPLCast(A)->lda; 14743853def2SToby Isaac cupmStream_t stream; 14753853def2SToby Isaac PetscCall(GetHandlesFrom_(dctx, &stream)); 14763853def2SToby Isaac 14773853def2SToby Isaac if (lda > m) { 14783853def2SToby Isaac // clang-format off 14793853def2SToby Isaac PetscCallThrust( 14803853def2SToby Isaac const auto sub_mat = detail::make_submat_iterator(0, m, 0, n, lda, da.data()); 14813853def2SToby Isaac 14823853def2SToby Isaac THRUST_CALL( 14833853def2SToby Isaac thrust::transform, 14843853def2SToby Isaac stream, 14853853def2SToby Isaac sub_mat.begin(), sub_mat.end(), sub_mat.begin(), 14863853def2SToby Isaac detail::conjugate{} 14873853def2SToby Isaac ) 14883853def2SToby Isaac ); 14893853def2SToby Isaac // clang-format on 14903853def2SToby Isaac } else { 14913853def2SToby Isaac // clang-format off 14923853def2SToby Isaac PetscCallThrust( 14933853def2SToby Isaac const auto aptr = thrust::device_pointer_cast(da.data()); 14943853def2SToby Isaac 14953853def2SToby Isaac THRUST_CALL( 14963853def2SToby Isaac thrust::transform, 14973853def2SToby Isaac stream, 14983853def2SToby Isaac aptr, aptr + N, aptr, 14993853def2SToby Isaac detail::conjugate{} 15003853def2SToby Isaac ) 15013853def2SToby Isaac ); 15023853def2SToby Isaac // clang-format on 15033853def2SToby Isaac } 15043853def2SToby Isaac } 15053853def2SToby Isaac PetscCall(PetscLogGpuTimeEnd()); 15063853def2SToby Isaac } 15073853def2SToby Isaac PetscFunctionReturn(PETSC_SUCCESS); 15083853def2SToby Isaac } 15093853def2SToby Isaac 15103853def2SToby Isaac template <device::cupm::DeviceType T> 15114742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Scale(Mat A, PetscScalar alpha) noexcept 15124742e46bSJacob Faibussowitsch { 15132ea277ceSJacob Faibussowitsch const auto m = A->rmap->n; 15142ea277ceSJacob Faibussowitsch const auto n = A->cmap->n; 15154742e46bSJacob Faibussowitsch const auto N = m * n; 15164742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 15174742e46bSJacob Faibussowitsch 15184742e46bSJacob Faibussowitsch PetscFunctionBegin; 15192ea277ceSJacob Faibussowitsch PetscCall(PetscInfo(A, "Performing Scale %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", m, n)); 15202ea277ceSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 15214742e46bSJacob Faibussowitsch { 15224742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 15232ea277ceSJacob Faibussowitsch const auto lda = MatIMPLCast(A)->lda; 15244742e46bSJacob Faibussowitsch 15254742e46bSJacob Faibussowitsch if (lda > m) { 15262ea277ceSJacob Faibussowitsch cupmStream_t stream; 15272ea277ceSJacob Faibussowitsch 15282ea277ceSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 15292ea277ceSJacob Faibussowitsch // clang-format off 15302ea277ceSJacob Faibussowitsch PetscCallThrust( 15312ea277ceSJacob Faibussowitsch const auto sub_mat = detail::make_submat_iterator(0, m, 0, n, lda, da.data()); 15322ea277ceSJacob Faibussowitsch 15332ea277ceSJacob Faibussowitsch THRUST_CALL( 15342ea277ceSJacob Faibussowitsch thrust::transform, 15352ea277ceSJacob Faibussowitsch stream, 15362ea277ceSJacob Faibussowitsch sub_mat.begin(), sub_mat.end(), sub_mat.begin(), 15372ea277ceSJacob Faibussowitsch device::cupm::functors::make_times_equals(alpha) 15382ea277ceSJacob Faibussowitsch ) 15392ea277ceSJacob Faibussowitsch ); 15402ea277ceSJacob Faibussowitsch // clang-format on 15414742e46bSJacob Faibussowitsch } else { 15422ea277ceSJacob Faibussowitsch const auto cu_alpha = cupmScalarCast(alpha); 15432ea277ceSJacob Faibussowitsch cupmBlasHandle_t handle; 15442ea277ceSJacob Faibussowitsch 15452ea277ceSJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &handle)); 15462ea277ceSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 15474742e46bSJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXscal(handle, N, &cu_alpha, da.cupmdata(), 1)); 15484742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 15492ea277ceSJacob Faibussowitsch } 15502ea277ceSJacob Faibussowitsch } 15514742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(N)); 15524742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 15534742e46bSJacob Faibussowitsch } 15544742e46bSJacob Faibussowitsch 15554742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 15564742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::AXPY(Mat Y, PetscScalar alpha, Mat X, MatStructure) noexcept 15574742e46bSJacob Faibussowitsch { 15584742e46bSJacob Faibussowitsch const auto m_x = X->rmap->n, m_y = Y->rmap->n; 15594742e46bSJacob Faibussowitsch const auto n_x = X->cmap->n, n_y = Y->cmap->n; 1560025e0618SJacob Faibussowitsch const auto N = m_x * n_x; 15614742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 15624742e46bSJacob Faibussowitsch 15634742e46bSJacob Faibussowitsch PetscFunctionBegin; 1564025e0618SJacob Faibussowitsch if (!m_x || !n_x || alpha == (PetscScalar)0.0) PetscFunctionReturn(PETSC_SUCCESS); 15654742e46bSJacob Faibussowitsch PetscCall(PetscInfo(Y, "Performing AXPY %" PetscInt_FMT " x %" PetscInt_FMT " on backend\n", m_y, n_y)); 1566025e0618SJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 15674742e46bSJacob Faibussowitsch { 15684742e46bSJacob Faibussowitsch const auto dx = DeviceArrayRead(dctx, X); 1569025e0618SJacob Faibussowitsch const auto dy = DeviceArrayReadWrite(dctx, Y); 1570025e0618SJacob Faibussowitsch const auto lda_x = MatIMPLCast(X)->lda; 1571025e0618SJacob Faibussowitsch const auto lda_y = MatIMPLCast(Y)->lda; 15724742e46bSJacob Faibussowitsch 1573025e0618SJacob Faibussowitsch if (lda_x > m_x || lda_y > m_x) { 1574025e0618SJacob Faibussowitsch cupmStream_t stream; 1575025e0618SJacob Faibussowitsch 1576025e0618SJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &stream)); 1577025e0618SJacob Faibussowitsch // clang-format off 1578025e0618SJacob Faibussowitsch PetscCallThrust( 1579025e0618SJacob Faibussowitsch const auto sub_mat_y = detail::make_submat_iterator(0, m_y, 0, n_y, lda_y, dy.data()); 1580025e0618SJacob Faibussowitsch const auto sub_mat_x = detail::make_submat_iterator(0, m_x, 0, n_x, lda_x, dx.data()); 1581025e0618SJacob Faibussowitsch 1582025e0618SJacob Faibussowitsch THRUST_CALL( 1583025e0618SJacob Faibussowitsch thrust::transform, 1584025e0618SJacob Faibussowitsch stream, 1585025e0618SJacob Faibussowitsch sub_mat_x.begin(), sub_mat_x.end(), sub_mat_y.begin(), sub_mat_y.begin(), 1586025e0618SJacob Faibussowitsch device::cupm::functors::make_axpy(alpha) 1587025e0618SJacob Faibussowitsch ); 1588025e0618SJacob Faibussowitsch ); 1589025e0618SJacob Faibussowitsch // clang-format on 15904742e46bSJacob Faibussowitsch } else { 1591025e0618SJacob Faibussowitsch const auto cu_alpha = cupmScalarCast(alpha); 1592025e0618SJacob Faibussowitsch cupmBlasHandle_t handle; 1593025e0618SJacob Faibussowitsch 1594025e0618SJacob Faibussowitsch PetscCall(GetHandlesFrom_(dctx, &handle)); 1595025e0618SJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 1596025e0618SJacob Faibussowitsch PetscCallCUPMBLAS(cupmBlasXaxpy(handle, N, &cu_alpha, dx.cupmdata(), 1, dy.cupmdata(), 1)); 15974742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 15984742e46bSJacob Faibussowitsch } 1599025e0618SJacob Faibussowitsch } 1600025e0618SJacob Faibussowitsch PetscCall(PetscLogGpuFlops(PetscMax(2 * N - 1, 0))); 16014742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 16024742e46bSJacob Faibussowitsch } 16034742e46bSJacob Faibussowitsch 16044742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 16054742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::Duplicate(Mat A, MatDuplicateOption opt, Mat *B) noexcept 16064742e46bSJacob Faibussowitsch { 16074742e46bSJacob Faibussowitsch const auto hopt = (opt == MAT_COPY_VALUES && A->offloadmask != PETSC_OFFLOAD_CPU) ? MAT_DO_NOT_COPY_VALUES : opt; 16084742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 16094742e46bSJacob Faibussowitsch 16104742e46bSJacob Faibussowitsch PetscFunctionBegin; 16114742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 16124742e46bSJacob Faibussowitsch // do not call SetPreallocation() yet, we call it afterwards?? 16134742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->n, A->cmap->n, nullptr, B, dctx, /* preallocate */ false)); 16144742e46bSJacob Faibussowitsch PetscCall(MatDuplicateNoCreate_SeqDense(*B, A, hopt)); 16154742e46bSJacob Faibussowitsch if (opt == MAT_COPY_VALUES && hopt != MAT_COPY_VALUES) PetscCall(Copy(A, *B, SAME_NONZERO_PATTERN)); 16164742e46bSJacob Faibussowitsch // allocate memory if needed 16173d9668e3SJacob Faibussowitsch if (opt != MAT_COPY_VALUES && !MatCUPMCast(*B)->d_v) PetscCall(SetPreallocation(*B, dctx, nullptr)); 16184742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 16194742e46bSJacob Faibussowitsch } 16204742e46bSJacob Faibussowitsch 16214742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 16224742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::SetRandom(Mat A, PetscRandom rng) noexcept 16234742e46bSJacob Faibussowitsch { 16243853def2SToby Isaac PetscBool device_rand_is_rander48; 16253853def2SToby Isaac PetscBool device = PETSC_FALSE; 16264742e46bSJacob Faibussowitsch 16274742e46bSJacob Faibussowitsch PetscFunctionBegin; 16283853def2SToby Isaac // CUPMObject<hip>::PETSCDEVICERAD() is PETSCRANDER48 until PetscRandom is implemented for hiprand 16293853def2SToby Isaac PetscCall(PetscStrncmp(PETSCDEVICERAND(), PETSCRANDER48, sizeof(PETSCRANDER48), &device_rand_is_rander48)); 16303853def2SToby Isaac if (!device_rand_is_rander48) PetscCall(PetscObjectTypeCompare(PetscObjectCast(rng), PETSCDEVICERAND(), &device)); 16314742e46bSJacob Faibussowitsch if (device) { 16324742e46bSJacob Faibussowitsch const auto m = A->rmap->n; 16334742e46bSJacob Faibussowitsch const auto n = A->cmap->n; 16344742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 16354742e46bSJacob Faibussowitsch 16364742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 16374742e46bSJacob Faibussowitsch { 16384742e46bSJacob Faibussowitsch const auto a = DeviceArrayWrite(dctx, A); 16394742e46bSJacob Faibussowitsch PetscInt lda; 16404742e46bSJacob Faibussowitsch 16414742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(A, &lda)); 16424742e46bSJacob Faibussowitsch if (lda > m) { 16434742e46bSJacob Faibussowitsch for (PetscInt i = 0; i < n; i++) PetscCall(PetscRandomGetValues(rng, m, a.data() + i * lda)); 16444742e46bSJacob Faibussowitsch } else { 16454742e46bSJacob Faibussowitsch PetscInt mn; 16464742e46bSJacob Faibussowitsch 16474742e46bSJacob Faibussowitsch PetscCall(PetscIntMultError(m, n, &mn)); 16484742e46bSJacob Faibussowitsch PetscCall(PetscRandomGetValues(rng, mn, a)); 16494742e46bSJacob Faibussowitsch } 16504742e46bSJacob Faibussowitsch } 16514742e46bSJacob Faibussowitsch } else { 16524742e46bSJacob Faibussowitsch PetscCall(MatSetRandom_SeqDense(A, rng)); 16534742e46bSJacob Faibussowitsch } 16544742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 16554742e46bSJacob Faibussowitsch } 16564742e46bSJacob Faibussowitsch 16574742e46bSJacob Faibussowitsch // ========================================================================================== 16584742e46bSJacob Faibussowitsch 16594742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 16604742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetColumnVector(Mat A, Vec v, PetscInt col) noexcept 16614742e46bSJacob Faibussowitsch { 16624742e46bSJacob Faibussowitsch const auto offloadmask = A->offloadmask; 16634742e46bSJacob Faibussowitsch const auto n = A->rmap->n; 16644742e46bSJacob Faibussowitsch const auto col_offset = [&](const PetscScalar *ptr) { return ptr + col * MatIMPLCast(A)->lda; }; 16654742e46bSJacob Faibussowitsch PetscBool viscupm; 16664742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 16674742e46bSJacob Faibussowitsch cupmStream_t stream; 16684742e46bSJacob Faibussowitsch 16694742e46bSJacob Faibussowitsch PetscFunctionBegin; 16704742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(v), &viscupm, VecSeq_CUPM::VECSEQCUPM(), VecSeq_CUPM::VECMPICUPM(), VecSeq_CUPM::VECCUPM(), "")); 16714742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &stream)); 16724742e46bSJacob Faibussowitsch if (viscupm && !v->boundtocpu) { 16734742e46bSJacob Faibussowitsch const auto x = VecSeq_CUPM::DeviceArrayWrite(dctx, v); 16744742e46bSJacob Faibussowitsch 16754742e46bSJacob Faibussowitsch // update device data 16764742e46bSJacob Faibussowitsch if (PetscOffloadDevice(offloadmask)) { 16774742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(x.data(), col_offset(DeviceArrayRead(dctx, A)), n, cupmMemcpyDeviceToDevice, stream)); 16784742e46bSJacob Faibussowitsch } else { 16794742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(x.data(), col_offset(HostArrayRead(dctx, A)), n, cupmMemcpyHostToDevice, stream)); 16804742e46bSJacob Faibussowitsch } 16814742e46bSJacob Faibussowitsch } else { 16824742e46bSJacob Faibussowitsch PetscScalar *x; 16834742e46bSJacob Faibussowitsch 16844742e46bSJacob Faibussowitsch // update host data 16854742e46bSJacob Faibussowitsch PetscCall(VecGetArrayWrite(v, &x)); 16864742e46bSJacob Faibussowitsch if (PetscOffloadUnallocated(offloadmask) || PetscOffloadHost(offloadmask)) { 16874742e46bSJacob Faibussowitsch PetscCall(PetscArraycpy(x, col_offset(HostArrayRead(dctx, A)), n)); 16884742e46bSJacob Faibussowitsch } else if (PetscOffloadDevice(offloadmask)) { 16894742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMemcpyAsync(x, col_offset(DeviceArrayRead(dctx, A)), n, cupmMemcpyDeviceToHost, stream)); 16904742e46bSJacob Faibussowitsch } 16914742e46bSJacob Faibussowitsch PetscCall(VecRestoreArrayWrite(v, &x)); 16924742e46bSJacob Faibussowitsch } 16934742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 16944742e46bSJacob Faibussowitsch } 16954742e46bSJacob Faibussowitsch 16964742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 16974742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 16984742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept 16994742e46bSJacob Faibussowitsch { 17004742e46bSJacob Faibussowitsch using namespace vec::cupm; 17014742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 17024742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 17034742e46bSJacob Faibussowitsch 17044742e46bSJacob Faibussowitsch PetscFunctionBegin; 17054742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 17064742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 17074742e46bSJacob Faibussowitsch mimpl->vecinuse = col + 1; 1708d16ceb75SStefano Zampini if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec)); 17094742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 17104742e46bSJacob Faibussowitsch PetscCall(GetArray<PETSC_MEMTYPE_DEVICE, access>(A, const_cast<PetscScalar **>(&mimpl->ptrinuse), dctx)); 17114742e46bSJacob Faibussowitsch PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(mimpl->lda))); 17124742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec)); 17134742e46bSJacob Faibussowitsch *v = mimpl->cvec; 17144742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17154742e46bSJacob Faibussowitsch } 17164742e46bSJacob Faibussowitsch 17174742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 17184742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 17194742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept 17204742e46bSJacob Faibussowitsch { 17214742e46bSJacob Faibussowitsch using namespace vec::cupm; 17224742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 17234742e46bSJacob Faibussowitsch const auto cvec = mimpl->cvec; 17244742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 17254742e46bSJacob Faibussowitsch 17264742e46bSJacob Faibussowitsch PetscFunctionBegin; 17274742e46bSJacob Faibussowitsch PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first"); 17284742e46bSJacob Faibussowitsch PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector"); 17294742e46bSJacob Faibussowitsch mimpl->vecinuse = 0; 17304742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec)); 17314742e46bSJacob Faibussowitsch PetscCall(VecCUPMResetArrayAsync<T>(cvec)); 17324742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 17334742e46bSJacob Faibussowitsch PetscCall(RestoreArray<PETSC_MEMTYPE_DEVICE, access>(A, const_cast<PetscScalar **>(&mimpl->ptrinuse), dctx)); 17344742e46bSJacob Faibussowitsch if (v) *v = nullptr; 17354742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17364742e46bSJacob Faibussowitsch } 17374742e46bSJacob Faibussowitsch 17384742e46bSJacob Faibussowitsch // ========================================================================================== 17394742e46bSJacob Faibussowitsch 17404742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 17414742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetFactor(Mat A, MatFactorType ftype, Mat *fact_out) noexcept 17424742e46bSJacob Faibussowitsch { 1743c72d0b4cSJacob Faibussowitsch Mat fact = nullptr; 17444742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 17454742e46bSJacob Faibussowitsch 17464742e46bSJacob Faibussowitsch PetscFunctionBegin; 17474742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 17484742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), A->rmap->n, A->cmap->n, nullptr, &fact, dctx, /* preallocate */ false)); 17494742e46bSJacob Faibussowitsch fact->factortype = ftype; 17504742e46bSJacob Faibussowitsch switch (ftype) { 17514742e46bSJacob Faibussowitsch case MAT_FACTOR_LU: 17524742e46bSJacob Faibussowitsch case MAT_FACTOR_ILU: // fall-through 17534742e46bSJacob Faibussowitsch fact->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqDense; 17544742e46bSJacob Faibussowitsch fact->ops->ilufactorsymbolic = MatLUFactorSymbolic_SeqDense; 17554742e46bSJacob Faibussowitsch break; 17564742e46bSJacob Faibussowitsch case MAT_FACTOR_CHOLESKY: 17574742e46bSJacob Faibussowitsch case MAT_FACTOR_ICC: // fall-through 17584742e46bSJacob Faibussowitsch fact->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqDense; 17594742e46bSJacob Faibussowitsch break; 17604742e46bSJacob Faibussowitsch case MAT_FACTOR_QR: { 17614742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(fact); 17624742e46bSJacob Faibussowitsch 17634742e46bSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction(pobj, "MatQRFactor_C", MatQRFactor_SeqDense)); 17644742e46bSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction(pobj, "MatQRFactorSymbolic_C", MatQRFactorSymbolic_SeqDense)); 17654742e46bSJacob Faibussowitsch } break; 17664742e46bSJacob Faibussowitsch case MAT_FACTOR_NONE: 17674742e46bSJacob Faibussowitsch case MAT_FACTOR_ILUDT: // fall-through 17684742e46bSJacob Faibussowitsch case MAT_FACTOR_NUM_TYPES: // fall-through 17694742e46bSJacob Faibussowitsch SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s not supported", MatFactorTypes[ftype]); 17704742e46bSJacob Faibussowitsch } 17714742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(MATSOLVERCUPM(), &fact->solvertype)); 17724742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_LU)); 17734742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_ILU)); 17744742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_CHOLESKY)); 17754742e46bSJacob Faibussowitsch PetscCall(PetscStrallocpy(MATORDERINGEXTERNAL, const_cast<char **>(fact->preferredordering) + MAT_FACTOR_ICC)); 17764742e46bSJacob Faibussowitsch *fact_out = fact; 17774742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17784742e46bSJacob Faibussowitsch } 17794742e46bSJacob Faibussowitsch 17804742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 17814742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::InvertFactors(Mat A) noexcept 17824742e46bSJacob Faibussowitsch { 17834742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 17844742e46bSJacob Faibussowitsch const auto mcu = MatCUPMCast(A); 17854742e46bSJacob Faibussowitsch const auto n = static_cast<cupmBlasInt_t>(A->cmap->n); 17864742e46bSJacob Faibussowitsch cupmSolverHandle_t handle; 17874742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 17884742e46bSJacob Faibussowitsch cupmStream_t stream; 17894742e46bSJacob Faibussowitsch 17904742e46bSJacob Faibussowitsch PetscFunctionBegin; 17914742e46bSJacob Faibussowitsch #if PetscDefined(HAVE_CUDA) && PetscDefined(USING_NVCC) 17924742e46bSJacob Faibussowitsch // HIP appears to have this by default?? 17934742e46bSJacob Faibussowitsch PetscCheck(PETSC_PKG_CUDA_VERSION_GE(10, 1, 0), PETSC_COMM_SELF, PETSC_ERR_SUP, "Upgrade to CUDA version 10.1.0 or higher"); 17944742e46bSJacob Faibussowitsch #endif 17954742e46bSJacob Faibussowitsch if (!n || !A->rmap->n) PetscFunctionReturn(PETSC_SUCCESS); 17964742e46bSJacob Faibussowitsch PetscCheck(A->factortype == MAT_FACTOR_CHOLESKY, PETSC_COMM_SELF, PETSC_ERR_LIB, "Factor type %s not implemented", MatFactorTypes[A->factortype]); 17974742e46bSJacob Faibussowitsch // spd 17984742e46bSJacob Faibussowitsch PetscCheck(!mcu->d_fact_ipiv, PETSC_COMM_SELF, PETSC_ERR_LIB, "%sDnsytri not implemented", cupmSolverName()); 17994742e46bSJacob Faibussowitsch 18004742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx, &handle, &stream)); 18014742e46bSJacob Faibussowitsch { 18024742e46bSJacob Faibussowitsch const auto da = DeviceArrayReadWrite(dctx, A); 18034742e46bSJacob Faibussowitsch const auto lda = static_cast<cupmBlasInt_t>(mimpl->lda); 18044742e46bSJacob Faibussowitsch cupmBlasInt_t il; 18054742e46bSJacob Faibussowitsch 18064742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotri_bufferSize(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, &il)); 18074742e46bSJacob Faibussowitsch if (il > mcu->d_fact_lwork) { 18084742e46bSJacob Faibussowitsch mcu->d_fact_lwork = il; 18094742e46bSJacob Faibussowitsch PetscCallCUPM(cupmFreeAsync(mcu->d_fact_work, stream)); 18104742e46bSJacob Faibussowitsch PetscCall(PetscCUPMMallocAsync(&mcu->d_fact_work, il, stream)); 18114742e46bSJacob Faibussowitsch } 18124742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeBegin()); 18134742e46bSJacob Faibussowitsch PetscCallCUPMSOLVER(cupmSolverXpotri(handle, CUPMSOLVER_FILL_MODE_LOWER, n, da.cupmdata(), lda, mcu->d_fact_work, mcu->d_fact_lwork, mcu->d_fact_info)); 18144742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuTimeEnd()); 18154742e46bSJacob Faibussowitsch } 18164742e46bSJacob Faibussowitsch PetscCall(CheckCUPMSolverInfo_(mcu->d_fact_info, stream)); 18174742e46bSJacob Faibussowitsch // TODO (write cuda kernel) 18184742e46bSJacob Faibussowitsch PetscCall(MatSeqDenseSymmetrize_Private(A, PETSC_TRUE)); 18194742e46bSJacob Faibussowitsch PetscCall(PetscLogGpuFlops(1.0 * n * n * n / 3.0)); 18204742e46bSJacob Faibussowitsch 18214742e46bSJacob Faibussowitsch A->ops->solve = nullptr; 18224742e46bSJacob Faibussowitsch A->ops->solvetranspose = nullptr; 18234742e46bSJacob Faibussowitsch A->ops->matsolve = nullptr; 18244742e46bSJacob Faibussowitsch A->factortype = MAT_FACTOR_NONE; 18254742e46bSJacob Faibussowitsch 18264742e46bSJacob Faibussowitsch PetscCall(PetscFree(A->solvertype)); 18274742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 18284742e46bSJacob Faibussowitsch } 18294742e46bSJacob Faibussowitsch 18304742e46bSJacob Faibussowitsch // ========================================================================================== 18314742e46bSJacob Faibussowitsch 18324742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 18334742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::GetSubMatrix(Mat A, PetscInt rbegin, PetscInt rend, PetscInt cbegin, PetscInt cend, Mat *mat) noexcept 18344742e46bSJacob Faibussowitsch { 18354742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 18364742e46bSJacob Faibussowitsch const auto array_offset = [&](PetscScalar *ptr) { return ptr + rbegin + static_cast<std::size_t>(cbegin) * mimpl->lda; }; 18374742e46bSJacob Faibussowitsch const auto n = rend - rbegin; 18384742e46bSJacob Faibussowitsch const auto m = cend - cbegin; 18394742e46bSJacob Faibussowitsch auto &cmat = mimpl->cmat; 18404742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 18414742e46bSJacob Faibussowitsch 18424742e46bSJacob Faibussowitsch PetscFunctionBegin; 18434742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 18444742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 18454742e46bSJacob Faibussowitsch mimpl->matinuse = cbegin + 1; 18464742e46bSJacob Faibussowitsch 18474742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 18484742e46bSJacob Faibussowitsch PetscCall(HostToDevice_(A, dctx)); 18494742e46bSJacob Faibussowitsch 18504742e46bSJacob Faibussowitsch if (cmat && ((m != cmat->cmap->N) || (n != cmat->rmap->N))) PetscCall(MatDestroy(&cmat)); 18514742e46bSJacob Faibussowitsch { 18524742e46bSJacob Faibussowitsch const auto device_array = array_offset(MatCUPMCast(A)->d_v); 18534742e46bSJacob Faibussowitsch 18544742e46bSJacob Faibussowitsch if (cmat) { 18554742e46bSJacob Faibussowitsch PetscCall(PlaceArray(cmat, device_array)); 18564742e46bSJacob Faibussowitsch } else { 18574742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PetscObjectComm(PetscObjectCast(A)), n, m, device_array, &cmat, dctx)); 18584742e46bSJacob Faibussowitsch } 18594742e46bSJacob Faibussowitsch } 18604742e46bSJacob Faibussowitsch PetscCall(MatDenseSetLDA(cmat, mimpl->lda)); 18614742e46bSJacob Faibussowitsch // place CPU array if present but do not copy any data 18624742e46bSJacob Faibussowitsch if (const auto host_array = mimpl->v) { 18634742e46bSJacob Faibussowitsch cmat->offloadmask = PETSC_OFFLOAD_GPU; 18644742e46bSJacob Faibussowitsch PetscCall(MatDensePlaceArray(cmat, array_offset(host_array))); 18654742e46bSJacob Faibussowitsch } 18664742e46bSJacob Faibussowitsch 18674742e46bSJacob Faibussowitsch cmat->offloadmask = A->offloadmask; 18684742e46bSJacob Faibussowitsch *mat = cmat; 18694742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 18704742e46bSJacob Faibussowitsch } 18714742e46bSJacob Faibussowitsch 18724742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 18734742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_Seq_CUPM<T>::RestoreSubMatrix(Mat A, Mat *m) noexcept 18744742e46bSJacob Faibussowitsch { 18754742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 18764742e46bSJacob Faibussowitsch const auto cmat = mimpl->cmat; 18774742e46bSJacob Faibussowitsch const auto reset = static_cast<bool>(mimpl->v); 18784742e46bSJacob Faibussowitsch bool copy, was_offload_host; 18794742e46bSJacob Faibussowitsch 18804742e46bSJacob Faibussowitsch PetscFunctionBegin; 18814742e46bSJacob Faibussowitsch PetscCheck(mimpl->matinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetSubMatrix() first"); 18824742e46bSJacob Faibussowitsch PetscCheck(cmat, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column matrix"); 18834742e46bSJacob Faibussowitsch PetscCheck(*m == cmat, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Not the matrix obtained from MatDenseGetSubMatrix()"); 18844742e46bSJacob Faibussowitsch mimpl->matinuse = 0; 18854742e46bSJacob Faibussowitsch 18864742e46bSJacob Faibussowitsch // calls to ResetArray may change it, so save it here 18874742e46bSJacob Faibussowitsch was_offload_host = cmat->offloadmask == PETSC_OFFLOAD_CPU; 18884742e46bSJacob Faibussowitsch if (was_offload_host && !reset) { 18894742e46bSJacob Faibussowitsch copy = true; 18904742e46bSJacob Faibussowitsch PetscCall(MatSeqDenseSetPreallocation(A, nullptr)); 18914742e46bSJacob Faibussowitsch } else { 18924742e46bSJacob Faibussowitsch copy = false; 18934742e46bSJacob Faibussowitsch } 18944742e46bSJacob Faibussowitsch 18954742e46bSJacob Faibussowitsch PetscCall(ResetArray(cmat)); 18964742e46bSJacob Faibussowitsch if (reset) PetscCall(MatDenseResetArray(cmat)); 18974742e46bSJacob Faibussowitsch if (copy) { 18984742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 18994742e46bSJacob Faibussowitsch 19004742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 19014742e46bSJacob Faibussowitsch PetscCall(DeviceToHost_(A, dctx)); 19024742e46bSJacob Faibussowitsch } else { 19034742e46bSJacob Faibussowitsch A->offloadmask = was_offload_host ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU; 19044742e46bSJacob Faibussowitsch } 19054742e46bSJacob Faibussowitsch 19064742e46bSJacob Faibussowitsch cmat->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 19074742e46bSJacob Faibussowitsch *m = nullptr; 19084742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 19094742e46bSJacob Faibussowitsch } 19104742e46bSJacob Faibussowitsch 19114742e46bSJacob Faibussowitsch // ========================================================================================== 19124742e46bSJacob Faibussowitsch 19134742e46bSJacob Faibussowitsch namespace 19144742e46bSJacob Faibussowitsch { 19154742e46bSJacob Faibussowitsch 19164742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 19174742e46bSJacob Faibussowitsch inline PetscErrorCode MatMatMultNumeric_SeqDenseCUPM_SeqDenseCUPM(Mat A, Mat B, Mat C, PetscBool TA, PetscBool TB) noexcept 19184742e46bSJacob Faibussowitsch { 19194742e46bSJacob Faibussowitsch PetscFunctionBegin; 19204742e46bSJacob Faibussowitsch if (TA) { 19214742e46bSJacob Faibussowitsch if (TB) { 19224742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<true, true>(A, B, C)); 19234742e46bSJacob Faibussowitsch } else { 19244742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<true, false>(A, B, C)); 19254742e46bSJacob Faibussowitsch } 19264742e46bSJacob Faibussowitsch } else { 19274742e46bSJacob Faibussowitsch if (TB) { 19284742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<false, true>(A, B, C)); 19294742e46bSJacob Faibussowitsch } else { 19304742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::template MatMatMult_Numeric_Dispatch<false, false>(A, B, C)); 19314742e46bSJacob Faibussowitsch } 19324742e46bSJacob Faibussowitsch } 19334742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 19344742e46bSJacob Faibussowitsch } 19354742e46bSJacob Faibussowitsch 19364742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 19374742e46bSJacob Faibussowitsch inline PetscErrorCode MatSolverTypeRegister_DENSECUPM() noexcept 19384742e46bSJacob Faibussowitsch { 19394742e46bSJacob Faibussowitsch PetscFunctionBegin; 19404742e46bSJacob Faibussowitsch for (auto ftype : util::make_array(MAT_FACTOR_LU, MAT_FACTOR_CHOLESKY, MAT_FACTOR_QR)) { 19414742e46bSJacob Faibussowitsch PetscCall(MatSolverTypeRegister(MatDense_Seq_CUPM<T>::MATSOLVERCUPM(), MATSEQDENSE, ftype, MatDense_Seq_CUPM<T>::GetFactor)); 19424742e46bSJacob Faibussowitsch PetscCall(MatSolverTypeRegister(MatDense_Seq_CUPM<T>::MATSOLVERCUPM(), MatDense_Seq_CUPM<T>::MATSEQDENSECUPM(), ftype, MatDense_Seq_CUPM<T>::GetFactor)); 19434742e46bSJacob Faibussowitsch } 19444742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 19454742e46bSJacob Faibussowitsch } 19464742e46bSJacob Faibussowitsch 19474742e46bSJacob Faibussowitsch } // anonymous namespace 19484742e46bSJacob Faibussowitsch 19494742e46bSJacob Faibussowitsch } // namespace impl 19504742e46bSJacob Faibussowitsch 19514742e46bSJacob Faibussowitsch } // namespace cupm 19524742e46bSJacob Faibussowitsch 19534742e46bSJacob Faibussowitsch } // namespace mat 19544742e46bSJacob Faibussowitsch 19554742e46bSJacob Faibussowitsch } // namespace Petsc 1956