14742e46bSJacob Faibussowitsch #ifndef PETSCMATMPIDENSECUPM_HPP 24742e46bSJacob Faibussowitsch #define PETSCMATMPIDENSECUPM_HPP 34742e46bSJacob Faibussowitsch 44742e46bSJacob Faibussowitsch #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/ 54742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/mpi/mpidense.h> 64742e46bSJacob Faibussowitsch 74742e46bSJacob Faibussowitsch #ifdef __cplusplus 84742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp> 94742e46bSJacob Faibussowitsch #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp> 104742e46bSJacob Faibussowitsch 114742e46bSJacob Faibussowitsch namespace Petsc 124742e46bSJacob Faibussowitsch { 134742e46bSJacob Faibussowitsch 144742e46bSJacob Faibussowitsch namespace mat 154742e46bSJacob Faibussowitsch { 164742e46bSJacob Faibussowitsch 174742e46bSJacob Faibussowitsch namespace cupm 184742e46bSJacob Faibussowitsch { 194742e46bSJacob Faibussowitsch 204742e46bSJacob Faibussowitsch namespace impl 214742e46bSJacob Faibussowitsch { 224742e46bSJacob Faibussowitsch 234742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 244742e46bSJacob Faibussowitsch class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> { 254742e46bSJacob Faibussowitsch public: 264742e46bSJacob Faibussowitsch MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>); 274742e46bSJacob Faibussowitsch 284742e46bSJacob Faibussowitsch private: 294742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept; 304742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr MatType MATIMPLCUPM_() noexcept; 314742e46bSJacob Faibussowitsch 324742e46bSJacob Faibussowitsch static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept; 334742e46bSJacob Faibussowitsch 344742e46bSJacob Faibussowitsch template <bool to_host> 354742e46bSJacob Faibussowitsch static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept; 364742e46bSJacob Faibussowitsch 374742e46bSJacob Faibussowitsch public: 384742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept; 394742e46bSJacob Faibussowitsch 404742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept; 414742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept; 424742e46bSJacob Faibussowitsch 434742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept; 444742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept; 454742e46bSJacob Faibussowitsch 464742e46bSJacob Faibussowitsch static PetscErrorCode Create(Mat) noexcept; 474742e46bSJacob Faibussowitsch 484742e46bSJacob Faibussowitsch static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept; 494742e46bSJacob Faibussowitsch static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept; 504742e46bSJacob Faibussowitsch static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept; 514742e46bSJacob Faibussowitsch 524742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 534742e46bSJacob Faibussowitsch static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 544742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 554742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 564742e46bSJacob Faibussowitsch 574742e46bSJacob Faibussowitsch private: 584742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 594742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept 604742e46bSJacob Faibussowitsch { 614742e46bSJacob Faibussowitsch return GetArray<mtype, mode>(m, p); 624742e46bSJacob Faibussowitsch } 634742e46bSJacob Faibussowitsch 644742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 654742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept 664742e46bSJacob Faibussowitsch { 674742e46bSJacob Faibussowitsch return RestoreArray<mtype, mode>(m, p); 684742e46bSJacob Faibussowitsch } 694742e46bSJacob Faibussowitsch 704742e46bSJacob Faibussowitsch public: 714742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 724742e46bSJacob Faibussowitsch static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept; 734742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 744742e46bSJacob Faibussowitsch static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept; 754742e46bSJacob Faibussowitsch 764742e46bSJacob Faibussowitsch static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept; 774742e46bSJacob Faibussowitsch static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept; 784742e46bSJacob Faibussowitsch static PetscErrorCode ResetArray(Mat) noexcept; 794742e46bSJacob Faibussowitsch 804742e46bSJacob Faibussowitsch static PetscErrorCode Shift(Mat, PetscScalar) noexcept; 814742e46bSJacob Faibussowitsch }; 824742e46bSJacob Faibussowitsch 834742e46bSJacob Faibussowitsch } // namespace impl 844742e46bSJacob Faibussowitsch 854742e46bSJacob Faibussowitsch namespace 864742e46bSJacob Faibussowitsch { 874742e46bSJacob Faibussowitsch 884742e46bSJacob Faibussowitsch // Declare this here so that the functions below can make use of it 894742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 904742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept 914742e46bSJacob Faibussowitsch { 924742e46bSJacob Faibussowitsch PetscFunctionBegin; 934742e46bSJacob Faibussowitsch PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate)); 944742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 954742e46bSJacob Faibussowitsch } 964742e46bSJacob Faibussowitsch 974742e46bSJacob Faibussowitsch } // anonymous namespace 984742e46bSJacob Faibussowitsch 994742e46bSJacob Faibussowitsch namespace impl 1004742e46bSJacob Faibussowitsch { 1014742e46bSJacob Faibussowitsch 1024742e46bSJacob Faibussowitsch // ========================================================================================== 1034742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Private API 1044742e46bSJacob Faibussowitsch // ========================================================================================== 1054742e46bSJacob Faibussowitsch 1064742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1074742e46bSJacob Faibussowitsch inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept 1084742e46bSJacob Faibussowitsch { 1094742e46bSJacob Faibussowitsch return static_cast<Mat_MPIDense *>(m->data); 1104742e46bSJacob Faibussowitsch } 1114742e46bSJacob Faibussowitsch 1124742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1134742e46bSJacob Faibussowitsch inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept 1144742e46bSJacob Faibussowitsch { 1154742e46bSJacob Faibussowitsch return MATMPIDENSECUPM(); 1164742e46bSJacob Faibussowitsch } 1174742e46bSJacob Faibussowitsch 1184742e46bSJacob Faibussowitsch // ========================================================================================== 1194742e46bSJacob Faibussowitsch 1204742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1214742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept 1224742e46bSJacob Faibussowitsch { 1234742e46bSJacob Faibussowitsch PetscFunctionBegin; 1244742e46bSJacob Faibussowitsch if (auto &mimplA = MatIMPLCast(A)->A) { 1254742e46bSJacob Faibussowitsch PetscCall(MatSetType(mimplA, MATSEQDENSECUPM())); 1264742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array)); 1274742e46bSJacob Faibussowitsch } else { 1284742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx)); 1294742e46bSJacob Faibussowitsch } 1304742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1314742e46bSJacob Faibussowitsch } 1324742e46bSJacob Faibussowitsch 1334742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 1344742e46bSJacob Faibussowitsch template <bool to_host> 1354742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept 1364742e46bSJacob Faibussowitsch { 1374742e46bSJacob Faibussowitsch PetscFunctionBegin; 1384742e46bSJacob Faibussowitsch if (reuse == MAT_INITIAL_MATRIX) { 1394742e46bSJacob Faibussowitsch PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat)); 1404742e46bSJacob Faibussowitsch } else if (reuse == MAT_REUSE_MATRIX) { 1414742e46bSJacob Faibussowitsch PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN)); 1424742e46bSJacob Faibussowitsch } 1434742e46bSJacob Faibussowitsch { 1444742e46bSJacob Faibussowitsch const auto B = *newmat; 1454742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(B); 1464742e46bSJacob Faibussowitsch 1474742e46bSJacob Faibussowitsch if (to_host) { 1484742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_TRUE)); 1494742e46bSJacob Faibussowitsch } else { 1504742e46bSJacob Faibussowitsch PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 1514742e46bSJacob Faibussowitsch } 1524742e46bSJacob Faibussowitsch 1534742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype)); 1544742e46bSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM())); 1554742e46bSJacob Faibussowitsch 1564742e46bSJacob Faibussowitsch // ============================================================ 1574742e46bSJacob Faibussowitsch // Composed Ops 1584742e46bSJacob Faibussowitsch // ============================================================ 1594742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense); 1604742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 1614742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 1624742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 1634742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 1644742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 1654742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 1664742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 1674742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 1684742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 1694742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 1704742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray); 1714742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray); 1724742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray); 1734742e46bSJacob Faibussowitsch 1744742e46bSJacob Faibussowitsch if (to_host) { 1754742e46bSJacob Faibussowitsch if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A)); 1764742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_CPU; 1774742e46bSJacob Faibussowitsch } else { 1784742e46bSJacob Faibussowitsch if (auto &m_A = MatIMPLCast(B)->A) { 1794742e46bSJacob Faibussowitsch PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A)); 1804742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_BOTH; 1814742e46bSJacob Faibussowitsch } else { 1824742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 1834742e46bSJacob Faibussowitsch } 1844742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_FALSE)); 1854742e46bSJacob Faibussowitsch } 1864742e46bSJacob Faibussowitsch 1874742e46bSJacob Faibussowitsch // ============================================================ 1884742e46bSJacob Faibussowitsch // Function Pointer Ops 1894742e46bSJacob Faibussowitsch // ============================================================ 1904742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU); 1914742e46bSJacob Faibussowitsch } 1924742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1934742e46bSJacob Faibussowitsch } 1944742e46bSJacob Faibussowitsch 1954742e46bSJacob Faibussowitsch // ========================================================================================== 1964742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Public API 1974742e46bSJacob Faibussowitsch // ========================================================================================== 1984742e46bSJacob Faibussowitsch 1994742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2004742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept 2014742e46bSJacob Faibussowitsch { 2024742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C"; 2034742e46bSJacob Faibussowitsch } 2044742e46bSJacob Faibussowitsch 2054742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2064742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept 2074742e46bSJacob Faibussowitsch { 2084742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C"; 2094742e46bSJacob Faibussowitsch } 2104742e46bSJacob Faibussowitsch 2114742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2124742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept 2134742e46bSJacob Faibussowitsch { 2144742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C"; 2154742e46bSJacob Faibussowitsch } 2164742e46bSJacob Faibussowitsch 2174742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2184742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept 2194742e46bSJacob Faibussowitsch { 2204742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C"; 2214742e46bSJacob Faibussowitsch } 2224742e46bSJacob Faibussowitsch 2234742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2244742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept 2254742e46bSJacob Faibussowitsch { 2264742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C"; 2274742e46bSJacob Faibussowitsch } 2284742e46bSJacob Faibussowitsch 2294742e46bSJacob Faibussowitsch // ========================================================================================== 2304742e46bSJacob Faibussowitsch 2314742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2324742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept 2334742e46bSJacob Faibussowitsch { 2344742e46bSJacob Faibussowitsch PetscFunctionBegin; 2354742e46bSJacob Faibussowitsch PetscCall(MatCreate_MPIDense(A)); 2364742e46bSJacob Faibussowitsch PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 2374742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2384742e46bSJacob Faibussowitsch } 2394742e46bSJacob Faibussowitsch 2404742e46bSJacob Faibussowitsch // ========================================================================================== 2414742e46bSJacob Faibussowitsch 2424742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2434742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept 2444742e46bSJacob Faibussowitsch { 2454742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 2464742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A); 2474742e46bSJacob Faibussowitsch 2484742e46bSJacob Faibussowitsch PetscFunctionBegin; 2494742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 2504742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 2514742e46bSJacob Faibussowitsch if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost)); 2524742e46bSJacob Faibussowitsch A->boundtocpu = usehost; 2534742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype)); 2544742e46bSJacob Faibussowitsch if (!usehost) { 2554742e46bSJacob Faibussowitsch PetscBool iscupm; 2564742e46bSJacob Faibussowitsch 2574742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm)); 2584742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec)); 2594742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm)); 2604742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat)); 2614742e46bSJacob Faibussowitsch } 2624742e46bSJacob Faibussowitsch 2634742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 2644742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 2654742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>); 2664742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>); 2674742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 2684742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 2694742e46bSJacob Faibussowitsch 2704742e46bSJacob Faibussowitsch MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift); 2714742e46bSJacob Faibussowitsch 2724742e46bSJacob Faibussowitsch if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost)); 2734742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2744742e46bSJacob Faibussowitsch } 2754742e46bSJacob Faibussowitsch 2764742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2774742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 2784742e46bSJacob Faibussowitsch { 2794742e46bSJacob Faibussowitsch PetscFunctionBegin; 2804742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat)); 2814742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2824742e46bSJacob Faibussowitsch } 2834742e46bSJacob Faibussowitsch 2844742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2854742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 2864742e46bSJacob Faibussowitsch { 2874742e46bSJacob Faibussowitsch PetscFunctionBegin; 2884742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat)); 2894742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2904742e46bSJacob Faibussowitsch } 2914742e46bSJacob Faibussowitsch 2924742e46bSJacob Faibussowitsch // ========================================================================================== 2934742e46bSJacob Faibussowitsch 2944742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 2954742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access> 2964742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 2974742e46bSJacob Faibussowitsch { 2984742e46bSJacob Faibussowitsch PetscFunctionBegin; 2994742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array)); 3004742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3014742e46bSJacob Faibussowitsch } 3024742e46bSJacob Faibussowitsch 3034742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3044742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access> 3054742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 3064742e46bSJacob Faibussowitsch { 3074742e46bSJacob Faibussowitsch PetscFunctionBegin; 3084742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array)); 3094742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3104742e46bSJacob Faibussowitsch } 3114742e46bSJacob Faibussowitsch 3124742e46bSJacob Faibussowitsch // ========================================================================================== 3134742e46bSJacob Faibussowitsch 3144742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3154742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 3164742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept 3174742e46bSJacob Faibussowitsch { 3184742e46bSJacob Faibussowitsch using namespace vec::cupm; 3194742e46bSJacob Faibussowitsch 3204742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 3214742e46bSJacob Faibussowitsch const auto mimpl_A = mimpl->A; 3224742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A); 3234742e46bSJacob Faibussowitsch auto &cvec = mimpl->cvec; 3244742e46bSJacob Faibussowitsch PetscInt lda; 3254742e46bSJacob Faibussowitsch 3264742e46bSJacob Faibussowitsch PetscFunctionBegin; 3274742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 3284742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 3294742e46bSJacob Faibussowitsch mimpl->vecinuse = col + 1; 3304742e46bSJacob Faibussowitsch 3314742e46bSJacob Faibussowitsch if (!cvec) PetscCall(VecCreateMPICUPMWithArray<T>(PetscObjectComm(pobj), A->rmap->bs, A->rmap->n, A->rmap->N, nullptr, &cvec)); 3324742e46bSJacob Faibussowitsch 3334742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(mimpl_A, &lda)); 3344742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 3354742e46bSJacob Faibussowitsch PetscCall(VecCUPMPlaceArrayAsync<T>(cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda))); 3364742e46bSJacob Faibussowitsch 3374742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(cvec)); 3384742e46bSJacob Faibussowitsch *v = cvec; 3394742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3404742e46bSJacob Faibussowitsch } 3414742e46bSJacob Faibussowitsch 3424742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3434742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 3444742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept 3454742e46bSJacob Faibussowitsch { 3464742e46bSJacob Faibussowitsch using namespace vec::cupm; 3474742e46bSJacob Faibussowitsch 3484742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 3494742e46bSJacob Faibussowitsch const auto cvec = mimpl->cvec; 3504742e46bSJacob Faibussowitsch 3514742e46bSJacob Faibussowitsch PetscFunctionBegin; 3524742e46bSJacob Faibussowitsch PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first"); 3534742e46bSJacob Faibussowitsch PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector"); 3544742e46bSJacob Faibussowitsch mimpl->vecinuse = 0; 3554742e46bSJacob Faibussowitsch 3564742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 3574742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec)); 3584742e46bSJacob Faibussowitsch PetscCall(VecCUPMResetArrayAsync<T>(cvec)); 3594742e46bSJacob Faibussowitsch 3604742e46bSJacob Faibussowitsch if (v) *v = nullptr; 3614742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3624742e46bSJacob Faibussowitsch } 3634742e46bSJacob Faibussowitsch 3644742e46bSJacob Faibussowitsch // ========================================================================================== 3654742e46bSJacob Faibussowitsch 3664742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3674742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept 3684742e46bSJacob Faibussowitsch { 3694742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 3704742e46bSJacob Faibussowitsch 3714742e46bSJacob Faibussowitsch PetscFunctionBegin; 3724742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 3734742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 3744742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array)); 3754742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3764742e46bSJacob Faibussowitsch } 3774742e46bSJacob Faibussowitsch 3784742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3794742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept 3804742e46bSJacob Faibussowitsch { 3814742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 3824742e46bSJacob Faibussowitsch 3834742e46bSJacob Faibussowitsch PetscFunctionBegin; 3844742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 3854742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 3864742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array)); 3874742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 3884742e46bSJacob Faibussowitsch } 3894742e46bSJacob Faibussowitsch 3904742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 3914742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept 3924742e46bSJacob Faibussowitsch { 3934742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 3944742e46bSJacob Faibussowitsch 3954742e46bSJacob Faibussowitsch PetscFunctionBegin; 3964742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 3974742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 3984742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMResetArray<T>(mimpl->A)); 3994742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 4004742e46bSJacob Faibussowitsch } 4014742e46bSJacob Faibussowitsch 4024742e46bSJacob Faibussowitsch // ========================================================================================== 4034742e46bSJacob Faibussowitsch 4044742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 4054742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept 4064742e46bSJacob Faibussowitsch { 4074742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 4084742e46bSJacob Faibussowitsch 4094742e46bSJacob Faibussowitsch PetscFunctionBegin; 4104742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 4114742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "Performing Shift on backend\n")); 412*2ea277ceSJacob Faibussowitsch PetscCall(DiagonalUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha))); 4134742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 4144742e46bSJacob Faibussowitsch } 4154742e46bSJacob Faibussowitsch 4164742e46bSJacob Faibussowitsch } // namespace impl 4174742e46bSJacob Faibussowitsch 4184742e46bSJacob Faibussowitsch namespace 4194742e46bSJacob Faibussowitsch { 4204742e46bSJacob Faibussowitsch 4214742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 4224742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept 4234742e46bSJacob Faibussowitsch { 4244742e46bSJacob Faibussowitsch PetscMPIInt size; 4254742e46bSJacob Faibussowitsch 4264742e46bSJacob Faibussowitsch PetscFunctionBegin; 4274742e46bSJacob Faibussowitsch PetscValidPointer(A, 7); 4284742e46bSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(comm, &size)); 4294742e46bSJacob Faibussowitsch if (size > 1) { 4304742e46bSJacob Faibussowitsch PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx)); 4314742e46bSJacob Faibussowitsch } else { 4324742e46bSJacob Faibussowitsch if (n == PETSC_DECIDE) n = N; 4334742e46bSJacob Faibussowitsch if (m == PETSC_DECIDE) m = M; 4344742e46bSJacob Faibussowitsch // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down 4354742e46bSJacob Faibussowitsch // the line 4364742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx)); 4374742e46bSJacob Faibussowitsch } 4384742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 4394742e46bSJacob Faibussowitsch } 4404742e46bSJacob Faibussowitsch 4414742e46bSJacob Faibussowitsch } // anonymous namespace 4424742e46bSJacob Faibussowitsch 4434742e46bSJacob Faibussowitsch } // namespace cupm 4444742e46bSJacob Faibussowitsch 4454742e46bSJacob Faibussowitsch } // namespace mat 4464742e46bSJacob Faibussowitsch 4474742e46bSJacob Faibussowitsch } // namespace Petsc 4484742e46bSJacob Faibussowitsch 4494742e46bSJacob Faibussowitsch #endif // __cplusplus 4504742e46bSJacob Faibussowitsch 4514742e46bSJacob Faibussowitsch #endif // PETSCMATMPIDENSECUPM_HPP 452