1*4742e46bSJacob Faibussowitsch #ifndef PETSCMATMPIDENSECUPM_HPP 2*4742e46bSJacob Faibussowitsch #define PETSCMATMPIDENSECUPM_HPP 3*4742e46bSJacob Faibussowitsch 4*4742e46bSJacob Faibussowitsch #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/ 5*4742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/mpi/mpidense.h> 6*4742e46bSJacob Faibussowitsch 7*4742e46bSJacob Faibussowitsch #ifdef __cplusplus 8*4742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp> 9*4742e46bSJacob Faibussowitsch #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp> 10*4742e46bSJacob Faibussowitsch 11*4742e46bSJacob Faibussowitsch namespace Petsc 12*4742e46bSJacob Faibussowitsch { 13*4742e46bSJacob Faibussowitsch 14*4742e46bSJacob Faibussowitsch namespace mat 15*4742e46bSJacob Faibussowitsch { 16*4742e46bSJacob Faibussowitsch 17*4742e46bSJacob Faibussowitsch namespace cupm 18*4742e46bSJacob Faibussowitsch { 19*4742e46bSJacob Faibussowitsch 20*4742e46bSJacob Faibussowitsch namespace impl 21*4742e46bSJacob Faibussowitsch { 22*4742e46bSJacob Faibussowitsch 23*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 24*4742e46bSJacob Faibussowitsch class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> { 25*4742e46bSJacob Faibussowitsch public: 26*4742e46bSJacob Faibussowitsch MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>); 27*4742e46bSJacob Faibussowitsch 28*4742e46bSJacob Faibussowitsch private: 29*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept; 30*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr MatType MATIMPLCUPM_() noexcept; 31*4742e46bSJacob Faibussowitsch 32*4742e46bSJacob Faibussowitsch static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept; 33*4742e46bSJacob Faibussowitsch 34*4742e46bSJacob Faibussowitsch template <bool to_host> 35*4742e46bSJacob Faibussowitsch static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept; 36*4742e46bSJacob Faibussowitsch 37*4742e46bSJacob Faibussowitsch public: 38*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept; 39*4742e46bSJacob Faibussowitsch 40*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept; 41*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept; 42*4742e46bSJacob Faibussowitsch 43*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept; 44*4742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept; 45*4742e46bSJacob Faibussowitsch 46*4742e46bSJacob Faibussowitsch static PetscErrorCode Create(Mat) noexcept; 47*4742e46bSJacob Faibussowitsch 48*4742e46bSJacob Faibussowitsch static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept; 49*4742e46bSJacob Faibussowitsch static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept; 50*4742e46bSJacob Faibussowitsch static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept; 51*4742e46bSJacob Faibussowitsch 52*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 53*4742e46bSJacob Faibussowitsch static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 54*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode> 55*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 56*4742e46bSJacob Faibussowitsch 57*4742e46bSJacob Faibussowitsch private: 58*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 59*4742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept 60*4742e46bSJacob Faibussowitsch { 61*4742e46bSJacob Faibussowitsch return GetArray<mtype, mode>(m, p); 62*4742e46bSJacob Faibussowitsch } 63*4742e46bSJacob Faibussowitsch 64*4742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode> 65*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept 66*4742e46bSJacob Faibussowitsch { 67*4742e46bSJacob Faibussowitsch return RestoreArray<mtype, mode>(m, p); 68*4742e46bSJacob Faibussowitsch } 69*4742e46bSJacob Faibussowitsch 70*4742e46bSJacob Faibussowitsch public: 71*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 72*4742e46bSJacob Faibussowitsch static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept; 73*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode> 74*4742e46bSJacob Faibussowitsch static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept; 75*4742e46bSJacob Faibussowitsch 76*4742e46bSJacob Faibussowitsch static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept; 77*4742e46bSJacob Faibussowitsch static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept; 78*4742e46bSJacob Faibussowitsch static PetscErrorCode ResetArray(Mat) noexcept; 79*4742e46bSJacob Faibussowitsch 80*4742e46bSJacob Faibussowitsch static PetscErrorCode Shift(Mat, PetscScalar) noexcept; 81*4742e46bSJacob Faibussowitsch }; 82*4742e46bSJacob Faibussowitsch 83*4742e46bSJacob Faibussowitsch } // namespace impl 84*4742e46bSJacob Faibussowitsch 85*4742e46bSJacob Faibussowitsch namespace 86*4742e46bSJacob Faibussowitsch { 87*4742e46bSJacob Faibussowitsch 88*4742e46bSJacob Faibussowitsch // Declare this here so that the functions below can make use of it 89*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 90*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept 91*4742e46bSJacob Faibussowitsch { 92*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 93*4742e46bSJacob Faibussowitsch PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate)); 94*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 95*4742e46bSJacob Faibussowitsch } 96*4742e46bSJacob Faibussowitsch 97*4742e46bSJacob Faibussowitsch } // anonymous namespace 98*4742e46bSJacob Faibussowitsch 99*4742e46bSJacob Faibussowitsch namespace impl 100*4742e46bSJacob Faibussowitsch { 101*4742e46bSJacob Faibussowitsch 102*4742e46bSJacob Faibussowitsch // ========================================================================================== 103*4742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Private API 104*4742e46bSJacob Faibussowitsch // ========================================================================================== 105*4742e46bSJacob Faibussowitsch 106*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 107*4742e46bSJacob Faibussowitsch inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept 108*4742e46bSJacob Faibussowitsch { 109*4742e46bSJacob Faibussowitsch return static_cast<Mat_MPIDense *>(m->data); 110*4742e46bSJacob Faibussowitsch } 111*4742e46bSJacob Faibussowitsch 112*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 113*4742e46bSJacob Faibussowitsch inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept 114*4742e46bSJacob Faibussowitsch { 115*4742e46bSJacob Faibussowitsch return MATMPIDENSECUPM(); 116*4742e46bSJacob Faibussowitsch } 117*4742e46bSJacob Faibussowitsch 118*4742e46bSJacob Faibussowitsch // ========================================================================================== 119*4742e46bSJacob Faibussowitsch 120*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 121*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept 122*4742e46bSJacob Faibussowitsch { 123*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 124*4742e46bSJacob Faibussowitsch if (auto &mimplA = MatIMPLCast(A)->A) { 125*4742e46bSJacob Faibussowitsch PetscCall(MatSetType(mimplA, MATSEQDENSECUPM())); 126*4742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array)); 127*4742e46bSJacob Faibussowitsch } else { 128*4742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx)); 129*4742e46bSJacob Faibussowitsch } 130*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 131*4742e46bSJacob Faibussowitsch } 132*4742e46bSJacob Faibussowitsch 133*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 134*4742e46bSJacob Faibussowitsch template <bool to_host> 135*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept 136*4742e46bSJacob Faibussowitsch { 137*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 138*4742e46bSJacob Faibussowitsch if (reuse == MAT_INITIAL_MATRIX) { 139*4742e46bSJacob Faibussowitsch PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat)); 140*4742e46bSJacob Faibussowitsch } else if (reuse == MAT_REUSE_MATRIX) { 141*4742e46bSJacob Faibussowitsch PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN)); 142*4742e46bSJacob Faibussowitsch } 143*4742e46bSJacob Faibussowitsch { 144*4742e46bSJacob Faibussowitsch const auto B = *newmat; 145*4742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(B); 146*4742e46bSJacob Faibussowitsch 147*4742e46bSJacob Faibussowitsch if (to_host) { 148*4742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_TRUE)); 149*4742e46bSJacob Faibussowitsch } else { 150*4742e46bSJacob Faibussowitsch PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 151*4742e46bSJacob Faibussowitsch } 152*4742e46bSJacob Faibussowitsch 153*4742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype)); 154*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM())); 155*4742e46bSJacob Faibussowitsch 156*4742e46bSJacob Faibussowitsch // ============================================================ 157*4742e46bSJacob Faibussowitsch // Composed Ops 158*4742e46bSJacob Faibussowitsch // ============================================================ 159*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense); 160*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 161*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 162*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 163*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 164*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 165*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 166*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 167*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 168*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 169*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 170*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray); 171*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray); 172*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray); 173*4742e46bSJacob Faibussowitsch 174*4742e46bSJacob Faibussowitsch if (to_host) { 175*4742e46bSJacob Faibussowitsch if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A)); 176*4742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_CPU; 177*4742e46bSJacob Faibussowitsch } else { 178*4742e46bSJacob Faibussowitsch if (auto &m_A = MatIMPLCast(B)->A) { 179*4742e46bSJacob Faibussowitsch PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A)); 180*4742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_BOTH; 181*4742e46bSJacob Faibussowitsch } else { 182*4742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 183*4742e46bSJacob Faibussowitsch } 184*4742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_FALSE)); 185*4742e46bSJacob Faibussowitsch } 186*4742e46bSJacob Faibussowitsch 187*4742e46bSJacob Faibussowitsch // ============================================================ 188*4742e46bSJacob Faibussowitsch // Function Pointer Ops 189*4742e46bSJacob Faibussowitsch // ============================================================ 190*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU); 191*4742e46bSJacob Faibussowitsch } 192*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 193*4742e46bSJacob Faibussowitsch } 194*4742e46bSJacob Faibussowitsch 195*4742e46bSJacob Faibussowitsch // ========================================================================================== 196*4742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Public API 197*4742e46bSJacob Faibussowitsch // ========================================================================================== 198*4742e46bSJacob Faibussowitsch 199*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 200*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept 201*4742e46bSJacob Faibussowitsch { 202*4742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C"; 203*4742e46bSJacob Faibussowitsch } 204*4742e46bSJacob Faibussowitsch 205*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 206*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept 207*4742e46bSJacob Faibussowitsch { 208*4742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C"; 209*4742e46bSJacob Faibussowitsch } 210*4742e46bSJacob Faibussowitsch 211*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 212*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept 213*4742e46bSJacob Faibussowitsch { 214*4742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C"; 215*4742e46bSJacob Faibussowitsch } 216*4742e46bSJacob Faibussowitsch 217*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 218*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept 219*4742e46bSJacob Faibussowitsch { 220*4742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C"; 221*4742e46bSJacob Faibussowitsch } 222*4742e46bSJacob Faibussowitsch 223*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 224*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept 225*4742e46bSJacob Faibussowitsch { 226*4742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C"; 227*4742e46bSJacob Faibussowitsch } 228*4742e46bSJacob Faibussowitsch 229*4742e46bSJacob Faibussowitsch // ========================================================================================== 230*4742e46bSJacob Faibussowitsch 231*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 232*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept 233*4742e46bSJacob Faibussowitsch { 234*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 235*4742e46bSJacob Faibussowitsch PetscCall(MatCreate_MPIDense(A)); 236*4742e46bSJacob Faibussowitsch PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 237*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 238*4742e46bSJacob Faibussowitsch } 239*4742e46bSJacob Faibussowitsch 240*4742e46bSJacob Faibussowitsch // ========================================================================================== 241*4742e46bSJacob Faibussowitsch 242*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 243*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept 244*4742e46bSJacob Faibussowitsch { 245*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 246*4742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A); 247*4742e46bSJacob Faibussowitsch 248*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 249*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 250*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 251*4742e46bSJacob Faibussowitsch if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost)); 252*4742e46bSJacob Faibussowitsch A->boundtocpu = usehost; 253*4742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype)); 254*4742e46bSJacob Faibussowitsch if (!usehost) { 255*4742e46bSJacob Faibussowitsch PetscBool iscupm; 256*4742e46bSJacob Faibussowitsch 257*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm)); 258*4742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec)); 259*4742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm)); 260*4742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat)); 261*4742e46bSJacob Faibussowitsch } 262*4742e46bSJacob Faibussowitsch 263*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 264*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 265*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>); 266*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>); 267*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 268*4742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 269*4742e46bSJacob Faibussowitsch 270*4742e46bSJacob Faibussowitsch MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift); 271*4742e46bSJacob Faibussowitsch 272*4742e46bSJacob Faibussowitsch if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost)); 273*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 274*4742e46bSJacob Faibussowitsch } 275*4742e46bSJacob Faibussowitsch 276*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 277*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 278*4742e46bSJacob Faibussowitsch { 279*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 280*4742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat)); 281*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 282*4742e46bSJacob Faibussowitsch } 283*4742e46bSJacob Faibussowitsch 284*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 285*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 286*4742e46bSJacob Faibussowitsch { 287*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 288*4742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat)); 289*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 290*4742e46bSJacob Faibussowitsch } 291*4742e46bSJacob Faibussowitsch 292*4742e46bSJacob Faibussowitsch // ========================================================================================== 293*4742e46bSJacob Faibussowitsch 294*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 295*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access> 296*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 297*4742e46bSJacob Faibussowitsch { 298*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 299*4742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array)); 300*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 301*4742e46bSJacob Faibussowitsch } 302*4742e46bSJacob Faibussowitsch 303*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 304*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access> 305*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 306*4742e46bSJacob Faibussowitsch { 307*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 308*4742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array)); 309*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 310*4742e46bSJacob Faibussowitsch } 311*4742e46bSJacob Faibussowitsch 312*4742e46bSJacob Faibussowitsch // ========================================================================================== 313*4742e46bSJacob Faibussowitsch 314*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 315*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 316*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept 317*4742e46bSJacob Faibussowitsch { 318*4742e46bSJacob Faibussowitsch using namespace vec::cupm; 319*4742e46bSJacob Faibussowitsch 320*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 321*4742e46bSJacob Faibussowitsch const auto mimpl_A = mimpl->A; 322*4742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A); 323*4742e46bSJacob Faibussowitsch auto &cvec = mimpl->cvec; 324*4742e46bSJacob Faibussowitsch PetscInt lda; 325*4742e46bSJacob Faibussowitsch 326*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 327*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 328*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 329*4742e46bSJacob Faibussowitsch mimpl->vecinuse = col + 1; 330*4742e46bSJacob Faibussowitsch 331*4742e46bSJacob Faibussowitsch if (!cvec) PetscCall(VecCreateMPICUPMWithArray<T>(PetscObjectComm(pobj), A->rmap->bs, A->rmap->n, A->rmap->N, nullptr, &cvec)); 332*4742e46bSJacob Faibussowitsch 333*4742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(mimpl_A, &lda)); 334*4742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 335*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMPlaceArrayAsync<T>(cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda))); 336*4742e46bSJacob Faibussowitsch 337*4742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(cvec)); 338*4742e46bSJacob Faibussowitsch *v = cvec; 339*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 340*4742e46bSJacob Faibussowitsch } 341*4742e46bSJacob Faibussowitsch 342*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 343*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access> 344*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept 345*4742e46bSJacob Faibussowitsch { 346*4742e46bSJacob Faibussowitsch using namespace vec::cupm; 347*4742e46bSJacob Faibussowitsch 348*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 349*4742e46bSJacob Faibussowitsch const auto cvec = mimpl->cvec; 350*4742e46bSJacob Faibussowitsch 351*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 352*4742e46bSJacob Faibussowitsch PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first"); 353*4742e46bSJacob Faibussowitsch PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector"); 354*4742e46bSJacob Faibussowitsch mimpl->vecinuse = 0; 355*4742e46bSJacob Faibussowitsch 356*4742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 357*4742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec)); 358*4742e46bSJacob Faibussowitsch PetscCall(VecCUPMResetArrayAsync<T>(cvec)); 359*4742e46bSJacob Faibussowitsch 360*4742e46bSJacob Faibussowitsch if (v) *v = nullptr; 361*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 362*4742e46bSJacob Faibussowitsch } 363*4742e46bSJacob Faibussowitsch 364*4742e46bSJacob Faibussowitsch // ========================================================================================== 365*4742e46bSJacob Faibussowitsch 366*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 367*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept 368*4742e46bSJacob Faibussowitsch { 369*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 370*4742e46bSJacob Faibussowitsch 371*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 372*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 373*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 374*4742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array)); 375*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 376*4742e46bSJacob Faibussowitsch } 377*4742e46bSJacob Faibussowitsch 378*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 379*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept 380*4742e46bSJacob Faibussowitsch { 381*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 382*4742e46bSJacob Faibussowitsch 383*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 384*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 385*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 386*4742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array)); 387*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 388*4742e46bSJacob Faibussowitsch } 389*4742e46bSJacob Faibussowitsch 390*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 391*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept 392*4742e46bSJacob Faibussowitsch { 393*4742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A); 394*4742e46bSJacob Faibussowitsch 395*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 396*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 397*4742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 398*4742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMResetArray<T>(mimpl->A)); 399*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 400*4742e46bSJacob Faibussowitsch } 401*4742e46bSJacob Faibussowitsch 402*4742e46bSJacob Faibussowitsch // ========================================================================================== 403*4742e46bSJacob Faibussowitsch 404*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 405*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept 406*4742e46bSJacob Faibussowitsch { 407*4742e46bSJacob Faibussowitsch PetscDeviceContext dctx; 408*4742e46bSJacob Faibussowitsch 409*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 410*4742e46bSJacob Faibussowitsch PetscCall(GetHandles_(&dctx)); 411*4742e46bSJacob Faibussowitsch PetscCall(PetscInfo(A, "Performing Shift on backend\n")); 412*4742e46bSJacob Faibussowitsch PetscCall(PointwiseUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha))); 413*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 414*4742e46bSJacob Faibussowitsch } 415*4742e46bSJacob Faibussowitsch 416*4742e46bSJacob Faibussowitsch } // namespace impl 417*4742e46bSJacob Faibussowitsch 418*4742e46bSJacob Faibussowitsch namespace 419*4742e46bSJacob Faibussowitsch { 420*4742e46bSJacob Faibussowitsch 421*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T> 422*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept 423*4742e46bSJacob Faibussowitsch { 424*4742e46bSJacob Faibussowitsch PetscMPIInt size; 425*4742e46bSJacob Faibussowitsch 426*4742e46bSJacob Faibussowitsch PetscFunctionBegin; 427*4742e46bSJacob Faibussowitsch PetscValidPointer(A, 7); 428*4742e46bSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(comm, &size)); 429*4742e46bSJacob Faibussowitsch if (size > 1) { 430*4742e46bSJacob Faibussowitsch PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx)); 431*4742e46bSJacob Faibussowitsch } else { 432*4742e46bSJacob Faibussowitsch if (n == PETSC_DECIDE) n = N; 433*4742e46bSJacob Faibussowitsch if (m == PETSC_DECIDE) m = M; 434*4742e46bSJacob Faibussowitsch // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down 435*4742e46bSJacob Faibussowitsch // the line 436*4742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx)); 437*4742e46bSJacob Faibussowitsch } 438*4742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 439*4742e46bSJacob Faibussowitsch } 440*4742e46bSJacob Faibussowitsch 441*4742e46bSJacob Faibussowitsch } // anonymous namespace 442*4742e46bSJacob Faibussowitsch 443*4742e46bSJacob Faibussowitsch } // namespace cupm 444*4742e46bSJacob Faibussowitsch 445*4742e46bSJacob Faibussowitsch } // namespace mat 446*4742e46bSJacob Faibussowitsch 447*4742e46bSJacob Faibussowitsch } // namespace Petsc 448*4742e46bSJacob Faibussowitsch 449*4742e46bSJacob Faibussowitsch #endif // __cplusplus 450*4742e46bSJacob Faibussowitsch 451*4742e46bSJacob Faibussowitsch #endif // PETSCMATMPIDENSECUPM_HPP 452