xref: /petsc/src/sys/objects/device/impls/cupm/cupmstream.hpp (revision 0e6b6b5985dd9b1172860d21fb88bd3966bf7c54)
1*0e6b6b59SJacob Faibussowitsch #ifndef PETSC_CUPMSTREAM_HPP
2*0e6b6b59SJacob Faibussowitsch #define PETSC_CUPMSTREAM_HPP
3*0e6b6b59SJacob Faibussowitsch 
4*0e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp>
5*0e6b6b59SJacob Faibussowitsch 
6*0e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp"
7*0e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp"
8*0e6b6b59SJacob Faibussowitsch 
9*0e6b6b59SJacob Faibussowitsch #if defined(__cplusplus)
10*0e6b6b59SJacob Faibussowitsch namespace Petsc {
11*0e6b6b59SJacob Faibussowitsch 
12*0e6b6b59SJacob Faibussowitsch namespace device {
13*0e6b6b59SJacob Faibussowitsch 
14*0e6b6b59SJacob Faibussowitsch namespace cupm {
15*0e6b6b59SJacob Faibussowitsch 
16*0e6b6b59SJacob Faibussowitsch // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
17*0e6b6b59SJacob Faibussowitsch // identify separate cupm streams. This is so that the memory pool can accelerate allocation
18*0e6b6b59SJacob Faibussowitsch // calls as it can just pass back a pointer to memory that was used on the same
19*0e6b6b59SJacob Faibussowitsch // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
20*0e6b6b59SJacob Faibussowitsch // Address of the objects does not suffice since cupmStreams are very likely internally reused.
21*0e6b6b59SJacob Faibussowitsch 
22*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
23*0e6b6b59SJacob Faibussowitsch class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
24*0e6b6b59SJacob Faibussowitsch   using crtp_base_type = StreamBase<CUPMStream<T>>;
25*0e6b6b59SJacob Faibussowitsch   friend crtp_base_type;
26*0e6b6b59SJacob Faibussowitsch 
27*0e6b6b59SJacob Faibussowitsch public:
28*0e6b6b59SJacob Faibussowitsch   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
29*0e6b6b59SJacob Faibussowitsch 
30*0e6b6b59SJacob Faibussowitsch   using stream_type = cupmStream_t;
31*0e6b6b59SJacob Faibussowitsch   using id_type     = typename crtp_base_type::id_type;
32*0e6b6b59SJacob Faibussowitsch   using event_type  = CUPMEvent<T>;
33*0e6b6b59SJacob Faibussowitsch   using flag_type   = unsigned int;
34*0e6b6b59SJacob Faibussowitsch 
35*0e6b6b59SJacob Faibussowitsch   CUPMStream() noexcept = default;
36*0e6b6b59SJacob Faibussowitsch 
37*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD PetscErrorCode destroy() noexcept;
38*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD PetscErrorCode create(flag_type) noexcept;
39*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD PetscErrorCode change_type(PetscStreamType) noexcept;
40*0e6b6b59SJacob Faibussowitsch 
41*0e6b6b59SJacob Faibussowitsch private:
42*0e6b6b59SJacob Faibussowitsch   stream_type stream_{};
43*0e6b6b59SJacob Faibussowitsch   id_type     id_ = new_id_();
44*0e6b6b59SJacob Faibussowitsch 
45*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD static id_type new_id_() noexcept;
46*0e6b6b59SJacob Faibussowitsch 
47*0e6b6b59SJacob Faibussowitsch   // CRTP implementations
48*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD stream_type    get_stream_() const noexcept;
49*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD id_type        get_id_() const noexcept;
50*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD PetscErrorCode record_event_(event_type &) const noexcept;
51*0e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD PetscErrorCode wait_for_(event_type &) const noexcept;
52*0e6b6b59SJacob Faibussowitsch };
53*0e6b6b59SJacob Faibussowitsch 
54*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
55*0e6b6b59SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::destroy() noexcept {
56*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
57*0e6b6b59SJacob Faibussowitsch   if (stream_) {
58*0e6b6b59SJacob Faibussowitsch     PetscCallCUPM(cupmStreamDestroy(stream_));
59*0e6b6b59SJacob Faibussowitsch     stream_ = cupmStream_t{};
60*0e6b6b59SJacob Faibussowitsch     id_     = 0;
61*0e6b6b59SJacob Faibussowitsch   }
62*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
63*0e6b6b59SJacob Faibussowitsch }
64*0e6b6b59SJacob Faibussowitsch 
65*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
66*0e6b6b59SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept {
67*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
68*0e6b6b59SJacob Faibussowitsch   if (stream_) {
69*0e6b6b59SJacob Faibussowitsch     if (PetscDefined(USE_DEBUG)) {
70*0e6b6b59SJacob Faibussowitsch       flag_type current_flags;
71*0e6b6b59SJacob Faibussowitsch 
72*0e6b6b59SJacob Faibussowitsch       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
73*0e6b6b59SJacob Faibussowitsch       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
74*0e6b6b59SJacob Faibussowitsch     }
75*0e6b6b59SJacob Faibussowitsch     PetscFunctionReturn(0);
76*0e6b6b59SJacob Faibussowitsch   }
77*0e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
78*0e6b6b59SJacob Faibussowitsch   id_ = new_id_();
79*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
80*0e6b6b59SJacob Faibussowitsch }
81*0e6b6b59SJacob Faibussowitsch 
82*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
83*0e6b6b59SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept {
84*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
85*0e6b6b59SJacob Faibussowitsch   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
86*0e6b6b59SJacob Faibussowitsch     PetscCall(destroy());
87*0e6b6b59SJacob Faibussowitsch   } else {
88*0e6b6b59SJacob Faibussowitsch     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
89*0e6b6b59SJacob Faibussowitsch 
90*0e6b6b59SJacob Faibussowitsch     if (stream_) {
91*0e6b6b59SJacob Faibussowitsch       flag_type flag;
92*0e6b6b59SJacob Faibussowitsch 
93*0e6b6b59SJacob Faibussowitsch       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
94*0e6b6b59SJacob Faibussowitsch       if ((flag != preferred) || (cupmStreamQuery(stream_) != cupmSuccess)) PetscCall(destroy());
95*0e6b6b59SJacob Faibussowitsch     }
96*0e6b6b59SJacob Faibussowitsch     PetscCall(create(preferred));
97*0e6b6b59SJacob Faibussowitsch   }
98*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
99*0e6b6b59SJacob Faibussowitsch }
100*0e6b6b59SJacob Faibussowitsch 
101*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
102*0e6b6b59SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept {
103*0e6b6b59SJacob Faibussowitsch   static id_type id = 0;
104*0e6b6b59SJacob Faibussowitsch   return id++;
105*0e6b6b59SJacob Faibussowitsch }
106*0e6b6b59SJacob Faibussowitsch 
107*0e6b6b59SJacob Faibussowitsch // CRTP implementations
108*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
109*0e6b6b59SJacob Faibussowitsch inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept {
110*0e6b6b59SJacob Faibussowitsch   return stream_;
111*0e6b6b59SJacob Faibussowitsch }
112*0e6b6b59SJacob Faibussowitsch 
113*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
114*0e6b6b59SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept {
115*0e6b6b59SJacob Faibussowitsch   return id_;
116*0e6b6b59SJacob Faibussowitsch }
117*0e6b6b59SJacob Faibussowitsch 
118*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
119*0e6b6b59SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept {
120*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
121*0e6b6b59SJacob Faibussowitsch   PetscCall(event.record(stream_));
122*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
123*0e6b6b59SJacob Faibussowitsch }
124*0e6b6b59SJacob Faibussowitsch 
125*0e6b6b59SJacob Faibussowitsch template <DeviceType T>
126*0e6b6b59SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept {
127*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
128*0e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
129*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
130*0e6b6b59SJacob Faibussowitsch }
131*0e6b6b59SJacob Faibussowitsch 
132*0e6b6b59SJacob Faibussowitsch } // namespace cupm
133*0e6b6b59SJacob Faibussowitsch 
134*0e6b6b59SJacob Faibussowitsch } // namespace device
135*0e6b6b59SJacob Faibussowitsch 
136*0e6b6b59SJacob Faibussowitsch } // namespace Petsc
137*0e6b6b59SJacob Faibussowitsch #endif // __cplusplus
138*0e6b6b59SJacob Faibussowitsch 
139*0e6b6b59SJacob Faibussowitsch #endif // PETSC_CUPMSTREAM_HPP
140