1 #ifndef PETSCMATDENSECUPMIMPL_H 2 #define PETSCMATDENSECUPMIMPL_H 3 4 #define PETSC_SKIP_IMMINTRIN_H_CUDAWORKAROUND 1 5 #include <petsc/private/matimpl.h> /*I <petscmat.h> I*/ 6 7 #ifdef __cplusplus 8 #include <petsc/private/deviceimpl.h> 9 #include <petsc/private/cupmsolverinterface.hpp> 10 #include <petsc/private/cupmobject.hpp> 11 12 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 13 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 14 15 #include <thrust/device_vector.h> 16 #include <thrust/device_ptr.h> 17 #include <thrust/iterator/counting_iterator.h> 18 #include <thrust/iterator/transform_iterator.h> 19 #include <thrust/iterator/permutation_iterator.h> 20 #include <thrust/transform.h> 21 22 namespace Petsc 23 { 24 25 namespace vec 26 { 27 28 namespace cupm 29 { 30 31 namespace impl 32 { 33 34 template <device::cupm::DeviceType> 35 class VecSeq_CUPM; 36 template <device::cupm::DeviceType> 37 class VecMPI_CUPM; 38 39 } // namespace impl 40 41 } // namespace cupm 42 43 } // namespace vec 44 45 namespace mat 46 { 47 48 namespace cupm 49 { 50 51 namespace impl 52 { 53 54 // ========================================================================================== 55 // MatDense_CUPM_Base 56 // 57 // A base class to separate out the CRTP code from the common CUPM stuff (like the composed 58 // function names). 59 // ========================================================================================== 60 61 template <device::cupm::DeviceType T> 62 class MatDense_CUPM_Base : protected device::cupm::impl::CUPMObject<T> { 63 public: 64 PETSC_CUPMOBJECT_HEADER(T); 65 66 #define MatDenseCUPMComposedOpDecl(OP_NAME) \ 67 PETSC_NODISCARD static constexpr const char *PetscConcat(MatDenseCUPM, OP_NAME)() noexcept \ 68 { \ 69 return T == device::cupm::DeviceType::CUDA ? PetscStringize(PetscConcat(MatDenseCUDA, OP_NAME)) : PetscStringize(PetscConcat(MatDenseHIP, OP_NAME)); \ 70 } 71 72 // clang-format off 73 MatDenseCUPMComposedOpDecl(GetArray_C) 74 MatDenseCUPMComposedOpDecl(GetArrayRead_C) 75 MatDenseCUPMComposedOpDecl(GetArrayWrite_C) 76 MatDenseCUPMComposedOpDecl(RestoreArray_C) 77 MatDenseCUPMComposedOpDecl(RestoreArrayRead_C) 78 MatDenseCUPMComposedOpDecl(RestoreArrayWrite_C) 79 MatDenseCUPMComposedOpDecl(PlaceArray_C) 80 MatDenseCUPMComposedOpDecl(ReplaceArray_C) 81 MatDenseCUPMComposedOpDecl(ResetArray_C) 82 // clang-format on 83 84 #undef MatDenseCUPMComposedOpDecl 85 86 PETSC_NODISCARD static constexpr MatType MATSEQDENSECUPM() noexcept; 87 PETSC_NODISCARD static constexpr MatType MATMPIDENSECUPM() noexcept; 88 PETSC_NODISCARD static constexpr MatType MATDENSECUPM() noexcept; 89 PETSC_NODISCARD static constexpr MatSolverType MATSOLVERCUPM() noexcept; 90 }; 91 92 // ========================================================================================== 93 // MatDense_CUPM_Base -- Public API 94 // ========================================================================================== 95 96 template <device::cupm::DeviceType T> 97 inline constexpr MatType MatDense_CUPM_Base<T>::MATSEQDENSECUPM() noexcept 98 { 99 return T == device::cupm::DeviceType::CUDA ? MATSEQDENSECUDA : MATSEQDENSEHIP; 100 } 101 102 template <device::cupm::DeviceType T> 103 inline constexpr MatType MatDense_CUPM_Base<T>::MATMPIDENSECUPM() noexcept 104 { 105 return T == device::cupm::DeviceType::CUDA ? MATMPIDENSECUDA : MATMPIDENSEHIP; 106 } 107 108 template <device::cupm::DeviceType T> 109 inline constexpr MatType MatDense_CUPM_Base<T>::MATDENSECUPM() noexcept 110 { 111 return T == device::cupm::DeviceType::CUDA ? MATDENSECUDA : MATDENSEHIP; 112 } 113 114 template <device::cupm::DeviceType T> 115 inline constexpr MatSolverType MatDense_CUPM_Base<T>::MATSOLVERCUPM() noexcept 116 { 117 return T == device::cupm::DeviceType::CUDA ? MATSOLVERCUDA : MATSOLVERHIP; 118 } 119 120 #define MATDENSECUPM_BASE_HEADER(T) \ 121 PETSC_CUPMOBJECT_HEADER(T); \ 122 using VecSeq_CUPM = ::Petsc::vec::cupm::impl::VecSeq_CUPM<T>; \ 123 using VecMPI_CUPM = ::Petsc::vec::cupm::impl::VecMPI_CUPM<T>; \ 124 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSEQDENSECUPM; \ 125 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATMPIDENSECUPM; \ 126 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATDENSECUPM; \ 127 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSOLVERCUPM; \ 128 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C; \ 129 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C; \ 130 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C; \ 131 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C; \ 132 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C; \ 133 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C; \ 134 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C; \ 135 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C; \ 136 using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C 137 138 // forward declare 139 template <device::cupm::DeviceType> 140 class MatDense_Seq_CUPM; 141 template <device::cupm::DeviceType> 142 class MatDense_MPI_CUPM; 143 144 // ========================================================================================== 145 // MatDense_CUPM 146 // 147 // The true "base" class for MatDenseCUPM. The reason MatDense_CUPM and MatDense_CUPM_Base 148 // exist is to separate out the CRTP code from the non-crtp code so that the generic functions 149 // can be called via templates below. 150 // ========================================================================================== 151 152 template <device::cupm::DeviceType T, typename Derived> 153 class MatDense_CUPM : protected MatDense_CUPM_Base<T> { 154 protected: 155 MATDENSECUPM_BASE_HEADER(T); 156 157 template <PetscMemType, PetscMemoryAccessMode> 158 class MatrixArray; 159 160 // Cast the Mat to its host struct, i.e. return the result of (Mat_SeqDense *)m->data 161 template <typename U = Derived> 162 PETSC_NODISCARD static constexpr auto MatIMPLCast(Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(U::MatIMPLCast_(m)) 163 PETSC_NODISCARD static constexpr MatType MATIMPLCUPM() noexcept; 164 165 static PetscErrorCode CreateIMPLDenseCUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, PetscInt, PetscScalar *, Mat *, PetscDeviceContext, bool) noexcept; 166 static PetscErrorCode SetPreallocation(Mat, PetscDeviceContext, PetscScalar * = nullptr) noexcept; 167 168 template <typename F> 169 static PetscErrorCode DiagonalUnaryTransform(Mat, PetscInt, PetscInt, PetscInt, PetscDeviceContext, F &&) noexcept; 170 171 PETSC_NODISCARD static auto DeviceArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>{dctx, m}) 172 PETSC_NODISCARD static auto DeviceArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>{dctx, m}) 173 PETSC_NODISCARD static auto DeviceArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m}) 174 PETSC_NODISCARD static auto HostArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>{dctx, m}) 175 PETSC_NODISCARD static auto HostArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>{dctx, m}) 176 PETSC_NODISCARD static auto HostArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m}) 177 }; 178 179 // ========================================================================================== 180 // MatDense_CUPM::MatrixArray 181 // ========================================================================================== 182 183 template <device::cupm::DeviceType T, typename D> 184 template <PetscMemType MT, PetscMemoryAccessMode MA> 185 class MatDense_CUPM<T, D>::MatrixArray : public device::cupm::impl::RestoreableArray<T, MT, MA> { 186 using base_type = device::cupm::impl::RestoreableArray<T, MT, MA>; 187 188 public: 189 MatrixArray(PetscDeviceContext, Mat) noexcept; 190 ~MatrixArray() noexcept; 191 192 // must declare move constructor since we declare a destructor 193 constexpr MatrixArray(MatrixArray &&) noexcept; 194 195 private: 196 Mat m_ = nullptr; 197 }; 198 199 // ========================================================================================== 200 // MatDense_CUPM::MatrixArray -- Public API 201 // ========================================================================================== 202 203 template <device::cupm::DeviceType T, typename D> 204 template <PetscMemType MT, PetscMemoryAccessMode MA> 205 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(PetscDeviceContext dctx, Mat m) noexcept : base_type{dctx}, m_{m} 206 { 207 PetscFunctionBegin; 208 PetscCallAbort(PETSC_COMM_SELF, D::template GetArray<MT, MA>(m, &this->ptr_, dctx)); 209 PetscFunctionReturnVoid(); 210 } 211 212 template <device::cupm::DeviceType T, typename D> 213 template <PetscMemType MT, PetscMemoryAccessMode MA> 214 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::~MatrixArray() noexcept 215 { 216 PetscFunctionBegin; 217 PetscCallAbort(PETSC_COMM_SELF, D::template RestoreArray<MT, MA>(m_, &this->ptr_, this->dctx_)); 218 PetscFunctionReturnVoid(); 219 } 220 221 template <device::cupm::DeviceType T, typename D> 222 template <PetscMemType MT, PetscMemoryAccessMode MA> 223 inline constexpr MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(MatrixArray &&other) noexcept : base_type{std::move(other)}, m_{util::exchange(other.m_, nullptr)} 224 { 225 } 226 227 // ========================================================================================== 228 // MatDense_CUPM -- Protected API 229 // ========================================================================================== 230 231 template <device::cupm::DeviceType T, typename D> 232 inline constexpr MatType MatDense_CUPM<T, D>::MATIMPLCUPM() noexcept 233 { 234 return D::MATIMPLCUPM_(); 235 } 236 237 // Common core for MatCreateSeqDenseCUPM() and MatCreateMPIDenseCUPM() 238 template <device::cupm::DeviceType T, typename D> 239 inline PetscErrorCode MatDense_CUPM<T, D>::CreateIMPLDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx, bool preallocate) noexcept 240 { 241 Mat mat; 242 243 PetscFunctionBegin; 244 PetscValidPointer(A, 7); 245 PetscCall(MatCreate(comm, &mat)); 246 PetscCall(MatSetSizes(mat, m, n, M, N)); 247 PetscCall(MatSetType(mat, D::MATIMPLCUPM())); 248 if (preallocate) { 249 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 250 PetscCall(D::SetPreallocation(mat, dctx, data)); 251 } 252 *A = mat; 253 PetscFunctionReturn(PETSC_SUCCESS); 254 } 255 256 template <device::cupm::DeviceType T, typename D> 257 inline PetscErrorCode MatDense_CUPM<T, D>::SetPreallocation(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept 258 { 259 PetscFunctionBegin; 260 // cannot use PetscValidHeaderSpecificType(..., MATIMPLCUPM()) since the incoming matrix 261 // might be the local (sequential) matrix of a MatMPIDense_CUPM. Since this would be called 262 // from the MPI matrix'es impl MATIMPLCUPM() would return MATMPIDENSECUPM(). 263 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 264 PetscCheckTypeNames(A, D::MATSEQDENSECUPM(), D::MATMPIDENSECUPM()); 265 PetscCall(PetscLayoutSetUp(A->rmap)); 266 PetscCall(PetscLayoutSetUp(A->cmap)); 267 PetscCall(D::SetPreallocation_(A, dctx, device_array)); 268 A->preallocated = PETSC_TRUE; 269 A->assembled = PETSC_TRUE; 270 PetscFunctionReturn(PETSC_SUCCESS); 271 } 272 273 namespace detail 274 { 275 276 // ========================================================================================== 277 // MatrixIteratorBase 278 // 279 // A base class for creating thrust iterators over the local sub-matrix. This will set up the 280 // proper iterator definitions so thrust knows how to handle things properly. Template 281 // parameters are as follows: 282 // 283 // - Iterator: 284 // The type of the primary array iterator. Usually this is 285 // thrust::device_pointer<PetscScalar>::iterator. 286 // 287 // - IndexFunctor: 288 // This should be a functor which contains an operator() that when called with an index `i`, 289 // returns the i'th permuted index into the array. For example, it could return the i'th 290 // diagonal entry. 291 // ========================================================================================== 292 template <typename Iterator, typename IndexFunctor> 293 class MatrixIteratorBase { 294 public: 295 using array_iterator_type = Iterator; 296 using index_functor_type = IndexFunctor; 297 298 using difference_type = typename thrust::iterator_difference<array_iterator_type>::type; 299 using CountingIterator = thrust::counting_iterator<difference_type>; 300 using TransformIterator = thrust::transform_iterator<index_functor_type, CountingIterator>; 301 using PermutationIterator = thrust::permutation_iterator<array_iterator_type, TransformIterator>; 302 using iterator = PermutationIterator; // type of the begin/end iterator 303 304 constexpr MatrixIteratorBase(array_iterator_type first, array_iterator_type last, index_functor_type idx_func) noexcept : first{std::move(first)}, last{std::move(last)}, func{std::move(idx_func)} { } 305 306 PETSC_NODISCARD iterator begin() const noexcept 307 { 308 return PermutationIterator{ 309 first, TransformIterator{CountingIterator{0}, func} 310 }; 311 } 312 313 protected: 314 array_iterator_type first; 315 array_iterator_type last; 316 index_functor_type func; 317 }; 318 319 // ========================================================================================== 320 // StridedIndexFunctor 321 // 322 // Iterator which permutes a linear index range into strided matrix indices. Usually used to 323 // get the diagonal. 324 // ========================================================================================== 325 template <typename T> 326 struct StridedIndexFunctor { 327 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &i) const noexcept { return stride * i; } 328 329 T stride; 330 }; 331 332 template <typename Iterator> 333 class DiagonalIterator : public MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>> { 334 public: 335 using base_type = MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>>; 336 337 using difference_type = typename base_type::difference_type; 338 using iterator = typename base_type::iterator; 339 340 constexpr DiagonalIterator(Iterator first, Iterator last, difference_type stride) noexcept : base_type{std::move(first), std::move(last), {stride}} { } 341 342 PETSC_NODISCARD iterator end() const noexcept { return this->begin() + (this->last - this->first + this->func.stride - 1) / this->func.stride; } 343 }; 344 345 } // namespace detail 346 347 template <device::cupm::DeviceType T, typename D> 348 template <typename F> 349 inline PetscErrorCode MatDense_CUPM<T, D>::DiagonalUnaryTransform(Mat A, PetscInt rstart, PetscInt rend, PetscInt cols, PetscDeviceContext dctx, F &&functor) noexcept 350 { 351 const auto rend2 = std::min(rend, cols); 352 353 PetscFunctionBegin; 354 if (rend2 > rstart) { 355 const auto da = D::DeviceArrayReadWrite(dctx, A); 356 PetscInt lda; 357 358 PetscCall(MatDenseGetLDA(A, &lda)); 359 { 360 using DiagonalIterator = detail::DiagonalIterator<thrust::device_vector<PetscScalar>::iterator>; 361 const auto dptr = thrust::device_pointer_cast(da.data()); 362 const std::size_t begin = rstart * lda; 363 const std::size_t end = rend2 - rstart + rend2 * lda; 364 DiagonalIterator diagonal{dptr + begin, dptr + end, lda + 1}; 365 cupmStream_t stream; 366 367 PetscCall(D::GetHandlesFrom_(dctx, &stream)); 368 // clang-format off 369 PetscCallThrust( 370 THRUST_CALL( 371 thrust::transform, 372 stream, 373 diagonal.begin(), diagonal.end(), diagonal.begin(), 374 std::forward<F>(functor) 375 ) 376 ); 377 // clang-format on 378 } 379 PetscCall(PetscLogGpuFlops(rend2 - rstart)); 380 } 381 PetscFunctionReturn(PETSC_SUCCESS); 382 } 383 384 #define MatComposeOp_CUPM(use_host, pobj, op_str, op_host, ...) \ 385 do { \ 386 if (use_host) { \ 387 PetscCall(PetscObjectComposeFunction(pobj, op_str, op_host)); \ 388 } else { \ 389 PetscCall(PetscObjectComposeFunction(pobj, op_str, __VA_ARGS__)); \ 390 } \ 391 } while (0) 392 393 #define MatSetOp_CUPM(use_host, mat, op_name, op_host, ...) \ 394 do { \ 395 if (use_host) { \ 396 (mat)->ops->op_name = op_host; \ 397 } else { \ 398 (mat)->ops->op_name = __VA_ARGS__; \ 399 } \ 400 } while (0) 401 402 #define MATDENSECUPM_HEADER(T, ...) \ 403 MATDENSECUPM_BASE_HEADER(T); \ 404 friend class ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>; \ 405 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MatIMPLCast; \ 406 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MATIMPLCUPM; \ 407 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::CreateIMPLDenseCUPM; \ 408 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::SetPreallocation; \ 409 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayRead; \ 410 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayWrite; \ 411 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayReadWrite; \ 412 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayRead; \ 413 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayWrite; \ 414 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayReadWrite; \ 415 using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DiagonalUnaryTransform 416 417 } // namespace impl 418 419 namespace 420 { 421 422 template <device::cupm::DeviceType T, PetscMemoryAccessMode access> 423 inline PetscErrorCode MatDenseCUPMGetArray_Private(Mat A, PetscScalar **array) noexcept 424 { 425 PetscFunctionBegin; 426 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 427 PetscValidPointer(array, 2); 428 switch (access) { 429 case PETSC_MEMORY_ACCESS_READ: 430 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C(), (Mat, PetscScalar **), (A, array)); 431 break; 432 case PETSC_MEMORY_ACCESS_WRITE: 433 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C(), (Mat, PetscScalar **), (A, array)); 434 break; 435 case PETSC_MEMORY_ACCESS_READ_WRITE: 436 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C(), (Mat, PetscScalar **), (A, array)); 437 break; 438 } 439 if (PetscMemoryAccessWrite(access)) PetscCall(PetscObjectStateIncrease(PetscObjectCast(A))); 440 PetscFunctionReturn(PETSC_SUCCESS); 441 } 442 443 template <device::cupm::DeviceType T, PetscMemoryAccessMode access> 444 inline PetscErrorCode MatDenseCUPMRestoreArray_Private(Mat A, PetscScalar **array) noexcept 445 { 446 PetscFunctionBegin; 447 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 448 if (array) PetscValidPointer(array, 2); 449 switch (access) { 450 case PETSC_MEMORY_ACCESS_READ: 451 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C(), (Mat, PetscScalar **), (A, array)); 452 break; 453 case PETSC_MEMORY_ACCESS_WRITE: 454 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C(), (Mat, PetscScalar **), (A, array)); 455 break; 456 case PETSC_MEMORY_ACCESS_READ_WRITE: 457 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C(), (Mat, PetscScalar **), (A, array)); 458 break; 459 } 460 if (PetscMemoryAccessWrite(access)) { 461 PetscCall(PetscObjectStateIncrease(PetscObjectCast(A))); 462 A->offloadmask = PETSC_OFFLOAD_GPU; 463 } 464 if (array) *array = nullptr; 465 PetscFunctionReturn(PETSC_SUCCESS); 466 } 467 468 template <device::cupm::DeviceType T> 469 inline PetscErrorCode MatDenseCUPMGetArray(Mat A, PetscScalar **array) noexcept 470 { 471 PetscFunctionBegin; 472 PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array)); 473 PetscFunctionReturn(PETSC_SUCCESS); 474 } 475 476 template <device::cupm::DeviceType T> 477 inline PetscErrorCode MatDenseCUPMGetArrayRead(Mat A, const PetscScalar **array) noexcept 478 { 479 PetscFunctionBegin; 480 PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array))); 481 PetscFunctionReturn(PETSC_SUCCESS); 482 } 483 484 template <device::cupm::DeviceType T> 485 inline PetscErrorCode MatDenseCUPMGetArrayWrite(Mat A, PetscScalar **array) noexcept 486 { 487 PetscFunctionBegin; 488 PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array)); 489 PetscFunctionReturn(PETSC_SUCCESS); 490 } 491 492 template <device::cupm::DeviceType T> 493 inline PetscErrorCode MatDenseCUPMRestoreArray(Mat A, PetscScalar **array) noexcept 494 { 495 PetscFunctionBegin; 496 PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array)); 497 PetscFunctionReturn(PETSC_SUCCESS); 498 } 499 500 template <device::cupm::DeviceType T> 501 inline PetscErrorCode MatDenseCUPMRestoreArrayRead(Mat A, const PetscScalar **array) noexcept 502 { 503 PetscFunctionBegin; 504 PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array))); 505 PetscFunctionReturn(PETSC_SUCCESS); 506 } 507 508 template <device::cupm::DeviceType T> 509 inline PetscErrorCode MatDenseCUPMRestoreArrayWrite(Mat A, PetscScalar **array) noexcept 510 { 511 PetscFunctionBegin; 512 PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array)); 513 PetscFunctionReturn(PETSC_SUCCESS); 514 } 515 516 template <device::cupm::DeviceType T> 517 inline PetscErrorCode MatDenseCUPMPlaceArray(Mat A, const PetscScalar *array) noexcept 518 { 519 PetscFunctionBegin; 520 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 521 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C(), (Mat, const PetscScalar *), (A, array)); 522 PetscCall(PetscObjectStateIncrease(PetscObjectCast(A))); 523 A->offloadmask = PETSC_OFFLOAD_GPU; 524 PetscFunctionReturn(PETSC_SUCCESS); 525 } 526 527 template <device::cupm::DeviceType T> 528 inline PetscErrorCode MatDenseCUPMReplaceArray(Mat A, const PetscScalar *array) noexcept 529 { 530 PetscFunctionBegin; 531 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 532 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C(), (Mat, const PetscScalar *), (A, array)); 533 PetscCall(PetscObjectStateIncrease(PetscObjectCast(A))); 534 A->offloadmask = PETSC_OFFLOAD_GPU; 535 PetscFunctionReturn(PETSC_SUCCESS); 536 } 537 538 template <device::cupm::DeviceType T> 539 inline PetscErrorCode MatDenseCUPMResetArray(Mat A) noexcept 540 { 541 PetscFunctionBegin; 542 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 543 PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C(), (Mat), (A)); 544 PetscCall(PetscObjectStateIncrease(PetscObjectCast(A))); 545 PetscFunctionReturn(PETSC_SUCCESS); 546 } 547 548 } // anonymous namespace 549 550 } // namespace cupm 551 552 } // namespace mat 553 554 } // namespace Petsc 555 556 #endif // __cplusplus 557 558 #endif // PETSCMATDENSECUPMIMPL_H 559