test_expression_evaluation.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. // Copyright (c) 2018-2019 Cem Bassoy
  2. //
  3. // Distributed under the Boost Software License, Version 1.0. (See
  4. // accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. // The authors gratefully acknowledge the support of
  8. // Fraunhofer and Google in producing this work
  9. // which started as a Google Summer of Code project.
  10. //
  11. #include <boost/numeric/ublas/tensor/expression_evaluation.hpp>
  12. #include <boost/numeric/ublas/tensor/expression.hpp>
  13. #include <boost/numeric/ublas/tensor/tensor.hpp>
  14. #include <boost/test/unit_test.hpp>
  15. #include "utility.hpp"
  16. #include <functional>
  17. using test_types = zip<int,long,float,double,std::complex<float>>::with_t<boost::numeric::ublas::first_order, boost::numeric::ublas::last_order>;
  18. struct fixture
  19. {
  20. using extents_type = boost::numeric::ublas::shape;
  21. fixture()
  22. : extents{
  23. extents_type{}, // 0
  24. extents_type{1,1}, // 1
  25. extents_type{1,2}, // 2
  26. extents_type{2,1}, // 3
  27. extents_type{2,3}, // 4
  28. extents_type{2,3,1}, // 5
  29. extents_type{1,2,3}, // 6
  30. extents_type{1,1,2,3}, // 7
  31. extents_type{1,2,3,1,1}, // 8
  32. extents_type{4,2,3}, // 9
  33. extents_type{4,2,1,3}, // 10
  34. extents_type{4,2,1,3,1}, // 11
  35. extents_type{1,4,2,1,3,1}} // 12
  36. {
  37. }
  38. std::vector<extents_type> extents;
  39. };
  40. BOOST_FIXTURE_TEST_CASE_TEMPLATE( test_tensor_expression_retrieve_extents, value, test_types, fixture)
  41. {
  42. using namespace boost::numeric;
  43. using value_type = typename value::first_type;
  44. using layout_type = typename value::second_type;
  45. using tensor_type = ublas::tensor<value_type, layout_type>;
  46. auto uplus1 = std::bind( std::plus<value_type>{}, std::placeholders::_1, value_type(1) );
  47. auto uplus2 = std::bind( std::plus<value_type>{}, value_type(2), std::placeholders::_2 );
  48. auto bplus = std::plus <value_type>{};
  49. auto bminus = std::minus<value_type>{};
  50. for(auto const& e : extents) {
  51. auto t = tensor_type(e);
  52. auto v = value_type{};
  53. for(auto& tt: t){ tt = v; v+=value_type{1}; }
  54. BOOST_CHECK( ublas::detail::retrieve_extents( t ) == e );
  55. // uexpr1 = t+1
  56. // uexpr2 = 2+t
  57. auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_type>( t, uplus1 );
  58. auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_type>( t, uplus2 );
  59. BOOST_CHECK( ublas::detail::retrieve_extents( uexpr1 ) == e );
  60. BOOST_CHECK( ublas::detail::retrieve_extents( uexpr2 ) == e );
  61. // bexpr_uexpr = (t+1) + (2+t)
  62. auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( uexpr1, uexpr2, bplus );
  63. BOOST_CHECK( ublas::detail::retrieve_extents( bexpr_uexpr ) == e );
  64. // bexpr_bexpr_uexpr = ((t+1) + (2+t)) - t
  65. auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( bexpr_uexpr, t, bminus );
  66. BOOST_CHECK( ublas::detail::retrieve_extents( bexpr_bexpr_uexpr ) == e );
  67. }
  68. for(auto i = 0u; i < extents.size()-1; ++i)
  69. {
  70. auto v = value_type{};
  71. auto t1 = tensor_type(extents[i]);
  72. for(auto& tt: t1){ tt = v; v+=value_type{1}; }
  73. auto t2 = tensor_type(extents[i+1]);
  74. for(auto& tt: t2){ tt = v; v+=value_type{2}; }
  75. BOOST_CHECK( ublas::detail::retrieve_extents( t1 ) != ublas::detail::retrieve_extents( t2 ) );
  76. // uexpr1 = t1+1
  77. // uexpr2 = 2+t2
  78. auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_type>( t1, uplus1 );
  79. auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_type>( t2, uplus2 );
  80. BOOST_CHECK( ublas::detail::retrieve_extents( t1 ) == ublas::detail::retrieve_extents( uexpr1 ) );
  81. BOOST_CHECK( ublas::detail::retrieve_extents( t2 ) == ublas::detail::retrieve_extents( uexpr2 ) );
  82. BOOST_CHECK( ublas::detail::retrieve_extents( uexpr1 ) != ublas::detail::retrieve_extents( uexpr2 ) );
  83. // bexpr_uexpr = (t1+1) + (2+t2)
  84. auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( uexpr1, uexpr2, bplus );
  85. BOOST_CHECK( ublas::detail::retrieve_extents( bexpr_uexpr ) == ublas::detail::retrieve_extents(t1) );
  86. // bexpr_bexpr_uexpr = ((t1+1) + (2+t2)) - t2
  87. auto bexpr_bexpr_uexpr1 = ublas::detail::make_binary_tensor_expression<tensor_type>( bexpr_uexpr, t2, bminus );
  88. BOOST_CHECK( ublas::detail::retrieve_extents( bexpr_bexpr_uexpr1 ) == ublas::detail::retrieve_extents(t2) );
  89. // bexpr_bexpr_uexpr = t2 - ((t1+1) + (2+t2))
  90. auto bexpr_bexpr_uexpr2 = ublas::detail::make_binary_tensor_expression<tensor_type>( t2, bexpr_uexpr, bminus );
  91. BOOST_CHECK( ublas::detail::retrieve_extents( bexpr_bexpr_uexpr2 ) == ublas::detail::retrieve_extents(t2) );
  92. }
  93. }
  94. BOOST_FIXTURE_TEST_CASE_TEMPLATE( test_tensor_expression_all_extents_equal, value, test_types, fixture)
  95. {
  96. using namespace boost::numeric;
  97. using value_type = typename value::first_type;
  98. using layout_type = typename value::second_type;
  99. using tensor_type = ublas::tensor<value_type, layout_type>;
  100. auto uplus1 = std::bind( std::plus<value_type>{}, std::placeholders::_1, value_type(1) );
  101. auto uplus2 = std::bind( std::plus<value_type>{}, value_type(2), std::placeholders::_2 );
  102. auto bplus = std::plus <value_type>{};
  103. auto bminus = std::minus<value_type>{};
  104. for(auto const& e : extents) {
  105. auto t = tensor_type(e);
  106. auto v = value_type{};
  107. for(auto& tt: t){ tt = v; v+=value_type{1}; }
  108. BOOST_CHECK( ublas::detail::all_extents_equal( t , e ) );
  109. // uexpr1 = t+1
  110. // uexpr2 = 2+t
  111. auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_type>( t, uplus1 );
  112. auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_type>( t, uplus2 );
  113. BOOST_CHECK( ublas::detail::all_extents_equal( uexpr1, e ) );
  114. BOOST_CHECK( ublas::detail::all_extents_equal( uexpr2, e ) );
  115. // bexpr_uexpr = (t+1) + (2+t)
  116. auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( uexpr1, uexpr2, bplus );
  117. BOOST_CHECK( ublas::detail::all_extents_equal( bexpr_uexpr, e ) );
  118. // bexpr_bexpr_uexpr = ((t+1) + (2+t)) - t
  119. auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( bexpr_uexpr, t, bminus );
  120. BOOST_CHECK( ublas::detail::all_extents_equal( bexpr_bexpr_uexpr , e ) );
  121. }
  122. for(auto i = 0u; i < extents.size()-1; ++i)
  123. {
  124. auto v = value_type{};
  125. auto t1 = tensor_type(extents[i]);
  126. for(auto& tt: t1){ tt = v; v+=value_type{1}; }
  127. auto t2 = tensor_type(extents[i+1]);
  128. for(auto& tt: t2){ tt = v; v+=value_type{2}; }
  129. BOOST_CHECK( ublas::detail::all_extents_equal( t1, ublas::detail::retrieve_extents(t1) ) );
  130. BOOST_CHECK( ublas::detail::all_extents_equal( t2, ublas::detail::retrieve_extents(t2) ) );
  131. // uexpr1 = t1+1
  132. // uexpr2 = 2+t2
  133. auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_type>( t1, uplus1 );
  134. auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_type>( t2, uplus2 );
  135. BOOST_CHECK( ublas::detail::all_extents_equal( uexpr1, ublas::detail::retrieve_extents(uexpr1) ) );
  136. BOOST_CHECK( ublas::detail::all_extents_equal( uexpr2, ublas::detail::retrieve_extents(uexpr2) ) );
  137. // bexpr_uexpr = (t1+1) + (2+t2)
  138. auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( uexpr1, uexpr2, bplus );
  139. BOOST_CHECK( ! ublas::detail::all_extents_equal( bexpr_uexpr, ublas::detail::retrieve_extents( bexpr_uexpr ) ) );
  140. // bexpr_bexpr_uexpr = ((t1+1) + (2+t2)) - t2
  141. auto bexpr_bexpr_uexpr1 = ublas::detail::make_binary_tensor_expression<tensor_type>( bexpr_uexpr, t2, bminus );
  142. BOOST_CHECK( ! ublas::detail::all_extents_equal( bexpr_bexpr_uexpr1, ublas::detail::retrieve_extents( bexpr_bexpr_uexpr1 ) ) );
  143. // bexpr_bexpr_uexpr = t2 - ((t1+1) + (2+t2))
  144. auto bexpr_bexpr_uexpr2 = ublas::detail::make_binary_tensor_expression<tensor_type>( t2, bexpr_uexpr, bminus );
  145. BOOST_CHECK( ! ublas::detail::all_extents_equal( bexpr_bexpr_uexpr2, ublas::detail::retrieve_extents( bexpr_bexpr_uexpr2 ) ) );
  146. // bexpr_uexpr2 = (t1+1) + t2
  147. auto bexpr_uexpr2 = ublas::detail::make_binary_tensor_expression<tensor_type>( uexpr1, t2, bplus );
  148. BOOST_CHECK( ! ublas::detail::all_extents_equal( bexpr_uexpr2, ublas::detail::retrieve_extents( bexpr_uexpr2 ) ) );
  149. // bexpr_uexpr2 = ((t1+1) + t2) + t1
  150. auto bexpr_bexpr_uexpr3 = ublas::detail::make_binary_tensor_expression<tensor_type>( bexpr_uexpr2, t1, bplus );
  151. BOOST_CHECK( ! ublas::detail::all_extents_equal( bexpr_bexpr_uexpr3, ublas::detail::retrieve_extents( bexpr_bexpr_uexpr3 ) ) );
  152. // bexpr_uexpr2 = t1 + (((t1+1) + t2) + t1)
  153. auto bexpr_bexpr_uexpr4 = ublas::detail::make_binary_tensor_expression<tensor_type>( t1, bexpr_bexpr_uexpr3, bplus );
  154. BOOST_CHECK( ! ublas::detail::all_extents_equal( bexpr_bexpr_uexpr4, ublas::detail::retrieve_extents( bexpr_bexpr_uexpr4 ) ) );
  155. }
  156. }