transform_iterator.hpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ITERATOR_TRANSFORM_ITERATOR_HPP
  11. #define BOOST_COMPUTE_ITERATOR_TRANSFORM_ITERATOR_HPP
  12. #include <cstddef>
  13. #include <iterator>
  14. #include <boost/config.hpp>
  15. #include <boost/iterator/iterator_adaptor.hpp>
  16. #include <boost/compute/functional.hpp>
  17. #include <boost/compute/detail/meta_kernel.hpp>
  18. #include <boost/compute/detail/is_buffer_iterator.hpp>
  19. #include <boost/compute/detail/read_write_single_value.hpp>
  20. #include <boost/compute/iterator/detail/get_base_iterator_buffer.hpp>
  21. #include <boost/compute/type_traits/is_device_iterator.hpp>
  22. #include <boost/compute/type_traits/result_of.hpp>
  23. namespace boost {
  24. namespace compute {
  25. // forward declaration for transform_iterator
  26. template<class InputIterator, class UnaryFunction>
  27. class transform_iterator;
  28. namespace detail {
  29. // meta-function returning the value_type for a transform_iterator
  30. template<class InputIterator, class UnaryFunction>
  31. struct make_transform_iterator_value_type
  32. {
  33. typedef typename std::iterator_traits<InputIterator>::value_type value_type;
  34. typedef typename boost::compute::result_of<UnaryFunction(value_type)>::type type;
  35. };
  36. // helper class which defines the iterator_adaptor super-class
  37. // type for transform_iterator
  38. template<class InputIterator, class UnaryFunction>
  39. class transform_iterator_base
  40. {
  41. public:
  42. typedef ::boost::iterator_adaptor<
  43. ::boost::compute::transform_iterator<InputIterator, UnaryFunction>,
  44. InputIterator,
  45. typename make_transform_iterator_value_type<InputIterator, UnaryFunction>::type,
  46. typename std::iterator_traits<InputIterator>::iterator_category,
  47. typename make_transform_iterator_value_type<InputIterator, UnaryFunction>::type
  48. > type;
  49. };
  50. template<class InputIterator, class UnaryFunction, class IndexExpr>
  51. struct transform_iterator_index_expr
  52. {
  53. typedef typename
  54. make_transform_iterator_value_type<
  55. InputIterator,
  56. UnaryFunction
  57. >::type result_type;
  58. transform_iterator_index_expr(const InputIterator &input_iter,
  59. const UnaryFunction &transform_expr,
  60. const IndexExpr &index_expr)
  61. : m_input_iter(input_iter),
  62. m_transform_expr(transform_expr),
  63. m_index_expr(index_expr)
  64. {
  65. }
  66. const InputIterator m_input_iter;
  67. const UnaryFunction m_transform_expr;
  68. const IndexExpr m_index_expr;
  69. };
  70. template<class InputIterator, class UnaryFunction, class IndexExpr>
  71. inline meta_kernel& operator<<(meta_kernel &kernel,
  72. const transform_iterator_index_expr<InputIterator,
  73. UnaryFunction,
  74. IndexExpr> &expr)
  75. {
  76. return kernel << expr.m_transform_expr(expr.m_input_iter[expr.m_index_expr]);
  77. }
  78. } // end detail namespace
  79. /// \class transform_iterator
  80. /// \brief A transform iterator adaptor.
  81. ///
  82. /// The transform_iterator adaptor applies a unary function to each element
  83. /// produced from the underlying iterator when dereferenced.
  84. ///
  85. /// For example, to copy from an input range to an output range while taking
  86. /// the absolute value of each element:
  87. ///
  88. /// \snippet test/test_transform_iterator.cpp copy_abs
  89. ///
  90. /// \see buffer_iterator, make_transform_iterator()
  91. template<class InputIterator, class UnaryFunction>
  92. class transform_iterator :
  93. public detail::transform_iterator_base<InputIterator, UnaryFunction>::type
  94. {
  95. public:
  96. typedef typename
  97. detail::transform_iterator_base<InputIterator,
  98. UnaryFunction>::type super_type;
  99. typedef typename super_type::value_type value_type;
  100. typedef typename super_type::reference reference;
  101. typedef typename super_type::base_type base_type;
  102. typedef typename super_type::difference_type difference_type;
  103. typedef UnaryFunction unary_function;
  104. transform_iterator(InputIterator iterator, UnaryFunction transform)
  105. : super_type(iterator),
  106. m_transform(transform)
  107. {
  108. }
  109. transform_iterator(const transform_iterator<InputIterator,
  110. UnaryFunction> &other)
  111. : super_type(other.base()),
  112. m_transform(other.m_transform)
  113. {
  114. }
  115. transform_iterator<InputIterator, UnaryFunction>&
  116. operator=(const transform_iterator<InputIterator,
  117. UnaryFunction> &other)
  118. {
  119. if(this != &other){
  120. super_type::operator=(other);
  121. m_transform = other.m_transform;
  122. }
  123. return *this;
  124. }
  125. ~transform_iterator()
  126. {
  127. }
  128. size_t get_index() const
  129. {
  130. return super_type::base().get_index();
  131. }
  132. const buffer& get_buffer() const
  133. {
  134. return detail::get_base_iterator_buffer(*this);
  135. }
  136. template<class IndexExpression>
  137. detail::transform_iterator_index_expr<InputIterator, UnaryFunction, IndexExpression>
  138. operator[](const IndexExpression &expr) const
  139. {
  140. return detail::transform_iterator_index_expr<InputIterator,
  141. UnaryFunction,
  142. IndexExpression>(super_type::base(),
  143. m_transform,
  144. expr);
  145. }
  146. private:
  147. friend class ::boost::iterator_core_access;
  148. reference dereference() const
  149. {
  150. const context &context = super_type::base().get_buffer().get_context();
  151. command_queue queue(context, context.get_device());
  152. detail::meta_kernel k("read");
  153. size_t output_arg = k.add_arg<value_type *>(memory_object::global_memory, "output");
  154. k << "*output = " << m_transform(super_type::base()[k.lit(0)]) << ";";
  155. kernel kernel = k.compile(context);
  156. buffer output_buffer(context, sizeof(value_type));
  157. kernel.set_arg(output_arg, output_buffer);
  158. queue.enqueue_task(kernel);
  159. return detail::read_single_value<value_type>(output_buffer, queue);
  160. }
  161. private:
  162. UnaryFunction m_transform;
  163. };
  164. /// Returns a transform_iterator for \p iterator with \p transform.
  165. ///
  166. /// \param iterator the underlying iterator
  167. /// \param transform the unary transform function
  168. ///
  169. /// \return a \c transform_iterator for \p iterator with \p transform
  170. ///
  171. /// For example, to create an iterator which returns the square-root of each
  172. /// value in a \c vector<int>:
  173. /// \code
  174. /// auto sqrt_iterator = make_transform_iterator(vec.begin(), sqrt<int>());
  175. /// \endcode
  176. template<class InputIterator, class UnaryFunction>
  177. inline transform_iterator<InputIterator, UnaryFunction>
  178. make_transform_iterator(InputIterator iterator, UnaryFunction transform)
  179. {
  180. return transform_iterator<InputIterator,
  181. UnaryFunction>(iterator, transform);
  182. }
  183. /// \internal_ (is_device_iterator specialization for transform_iterator)
  184. template<class InputIterator, class UnaryFunction>
  185. struct is_device_iterator<
  186. transform_iterator<InputIterator, UnaryFunction> > : boost::true_type {};
  187. } // end compute namespace
  188. } // end boost namespace
  189. #endif // BOOST_COMPUTE_ITERATOR_TRANSFORM_ITERATOR_HPP