Actual source code: cupmstream.hpp
  1: #pragma once
  3: #include <petsc/private/cupminterface.hpp>
  5: #include "../segmentedmempool.hpp"
  6: #include "cupmevent.hpp"
  8: namespace Petsc
  9: {
 11: namespace device
 12: {
 14: namespace cupm
 15: {
 17: // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
 18: // identify separate cupm streams. This is so that the memory pool can accelerate allocation
 19: // calls as it can just pass back a pointer to memory that was used on the same
 20: // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
 21: // Address of the objects does not suffice since cupmStreams are very likely internally reused.
 23: template <DeviceType T>
 24: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
 25:   using crtp_base_type = StreamBase<CUPMStream<T>>;
 26:   friend crtp_base_type;
 28: public:
 29:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
 31:   using stream_type = cupmStream_t;
 32:   using id_type     = typename crtp_base_type::id_type;
 33:   using event_type  = CUPMEvent<T>;
 34:   using flag_type   = unsigned int;
 36:   CUPMStream() noexcept = default;
 38:   PetscErrorCode destroy() noexcept;
 39:   PetscErrorCode create(flag_type) noexcept;
 40:   PetscErrorCode change_type(PetscStreamType) noexcept;
 42: private:
 43:   stream_type stream_{};
 44:   id_type     id_ = new_id_();
 46:   PETSC_NODISCARD static id_type new_id_() noexcept;
 48:   // CRTP implementations
 49:   PETSC_NODISCARD const stream_type &get_stream_() const noexcept;
 50:   PETSC_NODISCARD id_type            get_id_() const noexcept;
 51:   PetscErrorCode                     record_event_(event_type &) const noexcept;
 52:   PetscErrorCode                     wait_for_(event_type &) const noexcept;
 53: };
 55: template <DeviceType T>
 56: inline PetscErrorCode CUPMStream<T>::destroy() noexcept
 57: {
 58:   PetscFunctionBegin;
 59:   if (stream_) {
 60:     PetscCallCUPM(cupmStreamDestroy(stream_));
 61:     stream_ = cupmStream_t{};
 62:     id_     = 0;
 63:   }
 64:   PetscFunctionReturn(PETSC_SUCCESS);
 65: }
 67: template <DeviceType T>
 68: inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
 69: {
 70:   PetscFunctionBegin;
 71:   if (stream_) {
 72:     if (PetscDefined(USE_DEBUG)) {
 73:       flag_type current_flags;
 75:       PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags));
 76:       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
 77:     }
 78:     PetscFunctionReturn(PETSC_SUCCESS);
 79:   }
 80:   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
 81:   id_ = new_id_();
 82:   PetscFunctionReturn(PETSC_SUCCESS);
 83: }
 85: template <DeviceType T>
 86: inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
 87: {
 88:   PetscFunctionBegin;
 89:   if (newtype == PETSC_STREAM_DEFAULT || newtype == PETSC_STREAM_DEFAULT_WITH_BARRIER) {
 90:     PetscCall(destroy());
 91:   } else { // change to a nonblokcing stream
 92:     const flag_type preferred = cupmStreamNonBlocking;
 94:     if (stream_) {
 95:       flag_type flag;
 97:       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
 98:       if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
 99:       PetscCall(destroy());
100:     }
101:     PetscCall(create(preferred));
102:   }
103:   PetscFunctionReturn(PETSC_SUCCESS);
104: }
106: template <DeviceType T>
107: inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
108: {
109:   static id_type id = 0;
110:   return id++;
111: }
113: // CRTP implementations
114: template <DeviceType T>
115: inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept
116: {
117:   return stream_;
118: }
120: template <DeviceType T>
121: inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
122: {
123:   return id_;
124: }
126: template <DeviceType T>
127: inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
128: {
129:   PetscFunctionBegin;
130:   PetscCall(event.record(stream_));
131:   PetscFunctionReturn(PETSC_SUCCESS);
132: }
134: template <DeviceType T>
135: inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
136: {
137:   PetscFunctionBegin;
138:   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
139:   PetscFunctionReturn(PETSC_SUCCESS);
140: }
142: } // namespace cupm
144: } // namespace device
146: } // namespace Petsc