waitfor.hpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. // Copyright Oliver Kowalke 2017.
  2. // Distributed under the Boost Software License, Version 1.0.
  3. // (See accompanying file LICENSE_1_0.txt or copy at
  4. // http://www.boost.org/LICENSE_1_0.txt)
  5. #ifndef BOOST_FIBERS_CUDA_WAITFOR_H
  6. #define BOOST_FIBERS_CUDA_WAITFOR_H
  7. #include <initializer_list>
  8. #include <mutex>
  9. #include <iostream>
  10. #include <set>
  11. #include <tuple>
  12. #include <vector>
  13. #include <boost/assert.hpp>
  14. #include <boost/config.hpp>
  15. #include <cuda.h>
  16. #include <boost/fiber/detail/config.hpp>
  17. #include <boost/fiber/detail/is_all_same.hpp>
  18. #include <boost/fiber/condition_variable.hpp>
  19. #include <boost/fiber/mutex.hpp>
  20. #ifdef BOOST_HAS_ABI_HEADERS
  21. # include BOOST_ABI_PREFIX
  22. #endif
  23. namespace boost {
  24. namespace fibers {
  25. namespace cuda {
  26. namespace detail {
  27. template< typename Rendezvous >
  28. static void trampoline( cudaStream_t st, cudaError_t status, void * vp) {
  29. Rendezvous * data = static_cast< Rendezvous * >( vp);
  30. data->notify( st, status);
  31. }
  32. class single_stream_rendezvous {
  33. public:
  34. single_stream_rendezvous( cudaStream_t st) {
  35. unsigned int flags = 0;
  36. cudaError_t status = ::cudaStreamAddCallback( st, trampoline< single_stream_rendezvous >, this, flags);
  37. if ( cudaSuccess != status) {
  38. st_ = st;
  39. status_ = status;
  40. done_ = true;
  41. }
  42. }
  43. void notify( cudaStream_t st, cudaError_t status) noexcept {
  44. std::unique_lock< mutex > lk{ mtx_ };
  45. st_ = st;
  46. status_ = status;
  47. done_ = true;
  48. lk.unlock();
  49. cv_.notify_one();
  50. }
  51. std::tuple< cudaStream_t, cudaError_t > wait() {
  52. std::unique_lock< mutex > lk{ mtx_ };
  53. cv_.wait( lk, [this]{ return done_; });
  54. return std::make_tuple( st_, status_);
  55. }
  56. private:
  57. mutex mtx_{};
  58. condition_variable cv_{};
  59. cudaStream_t st_{};
  60. cudaError_t status_{ cudaErrorUnknown };
  61. bool done_{ false };
  62. };
  63. class many_streams_rendezvous {
  64. public:
  65. many_streams_rendezvous( std::initializer_list< cudaStream_t > l) :
  66. stx_{ l } {
  67. results_.reserve( stx_.size() );
  68. for ( cudaStream_t st : stx_) {
  69. unsigned int flags = 0;
  70. cudaError_t status = ::cudaStreamAddCallback( st, trampoline< many_streams_rendezvous >, this, flags);
  71. if ( cudaSuccess != status) {
  72. std::unique_lock< mutex > lk{ mtx_ };
  73. stx_.erase( st);
  74. results_.push_back( std::make_tuple( st, status) );
  75. }
  76. }
  77. }
  78. void notify( cudaStream_t st, cudaError_t status) noexcept {
  79. std::unique_lock< mutex > lk{ mtx_ };
  80. stx_.erase( st);
  81. results_.push_back( std::make_tuple( st, status) );
  82. if ( stx_.empty() ) {
  83. lk.unlock();
  84. cv_.notify_one();
  85. }
  86. }
  87. std::vector< std::tuple< cudaStream_t, cudaError_t > > wait() {
  88. std::unique_lock< mutex > lk{ mtx_ };
  89. cv_.wait( lk, [this]{ return stx_.empty(); });
  90. return results_;
  91. }
  92. private:
  93. mutex mtx_{};
  94. condition_variable cv_{};
  95. std::set< cudaStream_t > stx_;
  96. std::vector< std::tuple< cudaStream_t, cudaError_t > > results_;
  97. };
  98. }
  99. void waitfor_all();
  100. inline
  101. std::tuple< cudaStream_t, cudaError_t > waitfor_all( cudaStream_t st) {
  102. detail::single_stream_rendezvous rendezvous( st);
  103. return rendezvous.wait();
  104. }
  105. template< typename ... STP >
  106. std::vector< std::tuple< cudaStream_t, cudaError_t > > waitfor_all( cudaStream_t st0, STP ... stx) {
  107. static_assert( boost::fibers::detail::is_all_same< cudaStream_t, STP ...>::value, "all arguments must be of type `CUstream*`.");
  108. detail::many_streams_rendezvous rendezvous{ st0, stx ... };
  109. return rendezvous.wait();
  110. }
  111. }}}
  112. #ifdef BOOST_HAS_ABI_HEADERS
  113. # include BOOST_ABI_SUFFIX
  114. #endif
  115. #endif // BOOST_FIBERS_CUDA_WAITFOR_H