device_ptr.hpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_DEVICE_PTR_HPP
  11. #define BOOST_COMPUTE_DEVICE_PTR_HPP
  12. #include <boost/type_traits.hpp>
  13. #include <boost/static_assert.hpp>
  14. #include <boost/compute/buffer.hpp>
  15. #include <boost/compute/config.hpp>
  16. #include <boost/compute/detail/is_buffer_iterator.hpp>
  17. #include <boost/compute/detail/read_write_single_value.hpp>
  18. #include <boost/compute/type_traits/is_device_iterator.hpp>
  19. namespace boost {
  20. namespace compute {
  21. namespace detail {
  22. template<class T, class IndexExpr>
  23. struct device_ptr_index_expr
  24. {
  25. typedef T result_type;
  26. device_ptr_index_expr(const buffer &buffer,
  27. uint_ index,
  28. const IndexExpr &expr)
  29. : m_buffer(buffer),
  30. m_index(index),
  31. m_expr(expr)
  32. {
  33. }
  34. operator T() const
  35. {
  36. BOOST_STATIC_ASSERT_MSG(boost::is_integral<IndexExpr>::value,
  37. "Index expression must be integral");
  38. BOOST_ASSERT(m_buffer.get());
  39. const context &context = m_buffer.get_context();
  40. const device &device = context.get_device();
  41. command_queue queue(context, device);
  42. return detail::read_single_value<T>(m_buffer, m_expr, queue);
  43. }
  44. const buffer &m_buffer;
  45. uint_ m_index;
  46. IndexExpr m_expr;
  47. };
  48. template<class T>
  49. class device_ptr
  50. {
  51. public:
  52. typedef T value_type;
  53. typedef std::size_t size_type;
  54. typedef std::ptrdiff_t difference_type;
  55. typedef std::random_access_iterator_tag iterator_category;
  56. typedef T* pointer;
  57. typedef T& reference;
  58. device_ptr()
  59. : m_index(0)
  60. {
  61. }
  62. device_ptr(const buffer &buffer, size_t index = 0)
  63. : m_buffer(buffer.get(), false),
  64. m_index(index)
  65. {
  66. }
  67. device_ptr(const device_ptr<T> &other)
  68. : m_buffer(other.m_buffer.get(), false),
  69. m_index(other.m_index)
  70. {
  71. }
  72. device_ptr<T>& operator=(const device_ptr<T> &other)
  73. {
  74. if(this != &other){
  75. m_buffer.get() = other.m_buffer.get();
  76. m_index = other.m_index;
  77. }
  78. return *this;
  79. }
  80. #ifndef BOOST_COMPUTE_NO_RVALUE_REFERENCES
  81. device_ptr(device_ptr<T>&& other) BOOST_NOEXCEPT
  82. : m_buffer(other.m_buffer.get(), false),
  83. m_index(other.m_index)
  84. {
  85. other.m_buffer.get() = 0;
  86. }
  87. device_ptr<T>& operator=(device_ptr<T>&& other) BOOST_NOEXCEPT
  88. {
  89. m_buffer.get() = other.m_buffer.get();
  90. m_index = other.m_index;
  91. other.m_buffer.get() = 0;
  92. return *this;
  93. }
  94. #endif // BOOST_COMPUTE_NO_RVALUE_REFERENCES
  95. ~device_ptr()
  96. {
  97. // set buffer to null so that its reference count will
  98. // not be decremented when its destructor is called
  99. m_buffer.get() = 0;
  100. }
  101. size_type get_index() const
  102. {
  103. return m_index;
  104. }
  105. const buffer& get_buffer() const
  106. {
  107. return m_buffer;
  108. }
  109. template<class OT>
  110. device_ptr<OT> cast() const
  111. {
  112. return device_ptr<OT>(m_buffer, m_index);
  113. }
  114. device_ptr<T> operator+(difference_type n) const
  115. {
  116. return device_ptr<T>(m_buffer, m_index + n);
  117. }
  118. device_ptr<T> operator+(const device_ptr<T> &other) const
  119. {
  120. return device_ptr<T>(m_buffer, m_index + other.m_index);
  121. }
  122. device_ptr<T>& operator+=(difference_type n)
  123. {
  124. m_index += static_cast<size_t>(n);
  125. return *this;
  126. }
  127. difference_type operator-(const device_ptr<T> &other) const
  128. {
  129. return static_cast<difference_type>(m_index - other.m_index);
  130. }
  131. device_ptr<T>& operator-=(difference_type n)
  132. {
  133. m_index -= n;
  134. return *this;
  135. }
  136. bool operator==(const device_ptr<T> &other) const
  137. {
  138. return m_buffer.get() == other.m_buffer.get() &&
  139. m_index == other.m_index;
  140. }
  141. bool operator!=(const device_ptr<T> &other) const
  142. {
  143. return !(*this == other);
  144. }
  145. template<class Expr>
  146. detail::device_ptr_index_expr<T, Expr>
  147. operator[](const Expr &expr) const
  148. {
  149. BOOST_ASSERT(m_buffer.get());
  150. return detail::device_ptr_index_expr<T, Expr>(m_buffer,
  151. uint_(m_index),
  152. expr);
  153. }
  154. private:
  155. const buffer m_buffer;
  156. size_t m_index;
  157. };
  158. // is_buffer_iterator specialization for device_ptr
  159. template<class Iterator>
  160. struct is_buffer_iterator<
  161. Iterator,
  162. typename boost::enable_if<
  163. boost::is_same<
  164. device_ptr<typename Iterator::value_type>,
  165. typename boost::remove_const<Iterator>::type
  166. >
  167. >::type
  168. > : public boost::true_type {};
  169. } // end detail namespace
  170. // is_device_iterator specialization for device_ptr
  171. template<class T>
  172. struct is_device_iterator<detail::device_ptr<T> > : boost::true_type {};
  173. } // end compute namespace
  174. } // end boost namespace
  175. #endif // BOOST_COMPUTE_DEVICE_PTR_HPP