expression_evaluation.hpp 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. //
  2. // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See
  5. // accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt)
  7. //
  8. // The authors gratefully acknowledge the support of
  9. // Fraunhofer IOSB, Ettlingen, Germany
  10. //
  11. #ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
  12. #define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
  13. #include <type_traits>
  14. #include <stdexcept>
  15. namespace boost::numeric::ublas {
  16. template<class element_type, class storage_format, class storage_type>
  17. class tensor;
  18. template<class size_type>
  19. class basic_extents;
  20. }
  21. namespace boost::numeric::ublas::detail {
  22. template<class T, class D>
  23. struct tensor_expression;
  24. template<class T, class EL, class ER, class OP>
  25. struct binary_tensor_expression;
  26. template<class T, class E, class OP>
  27. struct unary_tensor_expression;
  28. }
  29. namespace boost::numeric::ublas::detail {
  30. template<class T, class E>
  31. struct has_tensor_types
  32. { static constexpr bool value = false; };
  33. template<class T>
  34. struct has_tensor_types<T,T>
  35. { static constexpr bool value = true; };
  36. template<class T, class D>
  37. struct has_tensor_types<T, tensor_expression<T,D>>
  38. { static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
  39. template<class T, class EL, class ER, class OP>
  40. struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
  41. { static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value; };
  42. template<class T, class E, class OP>
  43. struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
  44. { static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
  45. } // namespace boost::numeric::ublas::detail
  46. namespace boost::numeric::ublas::detail {
  47. /** @brief Retrieves extents of the tensor
  48. *
  49. */
  50. template<class T, class F, class A>
  51. auto retrieve_extents(tensor<T,F,A> const& t)
  52. {
  53. return t.extents();
  54. }
  55. /** @brief Retrieves extents of the tensor expression
  56. *
  57. * @note tensor expression must be a binary tree with at least one tensor type
  58. *
  59. * @returns extents of the child expression if it is a tensor or extents of one child of its child.
  60. */
  61. template<class T, class D>
  62. auto retrieve_extents(tensor_expression<T,D> const& expr)
  63. {
  64. static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
  65. "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
  66. auto const& cast_expr = static_cast<D const&>(expr);
  67. if constexpr ( std::is_same<T,D>::value )
  68. return cast_expr.extents();
  69. else
  70. return retrieve_extents(cast_expr);
  71. }
  72. /** @brief Retrieves extents of the binary tensor expression
  73. *
  74. * @note tensor expression must be a binary tree with at least one tensor type
  75. *
  76. * @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
  77. */
  78. template<class T, class EL, class ER, class OP>
  79. auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
  80. {
  81. static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
  82. "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
  83. if constexpr ( std::is_same<T,EL>::value )
  84. return expr.el.extents();
  85. if constexpr ( std::is_same<T,ER>::value )
  86. return expr.er.extents();
  87. else if constexpr ( detail::has_tensor_types<T,EL>::value )
  88. return retrieve_extents(expr.el);
  89. else if constexpr ( detail::has_tensor_types<T,ER>::value )
  90. return retrieve_extents(expr.er);
  91. }
  92. /** @brief Retrieves extents of the binary tensor expression
  93. *
  94. * @note tensor expression must be a binary tree with at least one tensor type
  95. *
  96. * @returns extents of the child expression if it is a tensor or extents of a child of its child.
  97. */
  98. template<class T, class E, class OP>
  99. auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
  100. {
  101. static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
  102. "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
  103. if constexpr ( std::is_same<T,E>::value )
  104. return expr.e.extents();
  105. else if constexpr ( detail::has_tensor_types<T,E>::value )
  106. return retrieve_extents(expr.e);
  107. }
  108. } // namespace boost::numeric::ublas::detail
  109. ///////////////
  110. namespace boost::numeric::ublas::detail {
  111. template<class T, class F, class A, class S>
  112. auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
  113. {
  114. return extents == t.extents();
  115. }
  116. template<class T, class D, class S>
  117. auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
  118. {
  119. static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
  120. "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
  121. auto const& cast_expr = static_cast<D const&>(expr);
  122. if constexpr ( std::is_same<T,D>::value )
  123. if( extents != cast_expr.extents() )
  124. return false;
  125. if constexpr ( detail::has_tensor_types<T,D>::value )
  126. if ( !all_extents_equal(cast_expr, extents))
  127. return false;
  128. return true;
  129. }
  130. template<class T, class EL, class ER, class OP, class S>
  131. auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
  132. {
  133. static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
  134. "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
  135. if constexpr ( std::is_same<T,EL>::value )
  136. if(extents != expr.el.extents())
  137. return false;
  138. if constexpr ( std::is_same<T,ER>::value )
  139. if(extents != expr.er.extents())
  140. return false;
  141. if constexpr ( detail::has_tensor_types<T,EL>::value )
  142. if(!all_extents_equal(expr.el, extents))
  143. return false;
  144. if constexpr ( detail::has_tensor_types<T,ER>::value )
  145. if(!all_extents_equal(expr.er, extents))
  146. return false;
  147. return true;
  148. }
  149. template<class T, class E, class OP, class S>
  150. auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
  151. {
  152. static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
  153. "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
  154. if constexpr ( std::is_same<T,E>::value )
  155. if(extents != expr.e.extents())
  156. return false;
  157. if constexpr ( detail::has_tensor_types<T,E>::value )
  158. if(!all_extents_equal(expr.e, extents))
  159. return false;
  160. return true;
  161. }
  162. } // namespace boost::numeric::ublas::detail
  163. namespace boost::numeric::ublas::detail {
  164. /** @brief Evaluates expression for a tensor
  165. *
  166. * Assigns the results of the expression to the tensor.
  167. *
  168. * \note Checks if shape of the tensor matches those of all tensors within the expression.
  169. */
  170. template<class tensor_type, class derived_type>
  171. void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
  172. {
  173. if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
  174. if(!detail::all_extents_equal(expr, lhs.extents() ))
  175. throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
  176. #pragma omp parallel for
  177. for(auto i = 0u; i < lhs.size(); ++i)
  178. lhs(i) = expr()(i);
  179. }
  180. /** @brief Evaluates expression for a tensor
  181. *
  182. * Applies a unary function to the results of the expressions before the assignment.
  183. * Usually applied needed for unary operators such as A += C;
  184. *
  185. * \note Checks if shape of the tensor matches those of all tensors within the expression.
  186. */
  187. template<class tensor_type, class derived_type, class unary_fn>
  188. void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
  189. {
  190. if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
  191. if(!detail::all_extents_equal( expr, lhs.extents() ))
  192. throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
  193. #pragma omp parallel for
  194. for(auto i = 0u; i < lhs.size(); ++i)
  195. fn(lhs(i), expr()(i));
  196. }
  197. /** @brief Evaluates expression for a tensor
  198. *
  199. * Applies a unary function to the results of the expressions before the assignment.
  200. * Usually applied needed for unary operators such as A += C;
  201. *
  202. * \note Checks if shape of the tensor matches those of all tensors within the expression.
  203. */
  204. template<class tensor_type, class unary_fn>
  205. void eval(tensor_type& lhs, unary_fn const fn)
  206. {
  207. #pragma omp parallel for
  208. for(auto i = 0u; i < lhs.size(); ++i)
  209. fn(lhs(i));
  210. }
  211. }
  212. #endif