test_expression.cpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // Copyright (c) 2018-2019 Cem Bassoy
  2. //
  3. // Distributed under the Boost Software License, Version 1.0. (See
  4. // accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. // The authors gratefully acknowledge the support of
  8. // Fraunhofer and Google in producing this work
  9. // which started as a Google Summer of Code project.
  10. //
  11. #include <boost/numeric/ublas/tensor/expression.hpp>
  12. #include <boost/numeric/ublas/tensor/tensor.hpp>
  13. #include <boost/test/unit_test.hpp>
  14. #include "utility.hpp"
  15. #include <functional>
  16. #include <complex>
  17. using test_types = zip<int,long,float,double,std::complex<float>>::with_t<boost::numeric::ublas::first_order, boost::numeric::ublas::last_order>;
  18. struct fixture
  19. {
  20. using extents_type = boost::numeric::ublas::shape;
  21. fixture()
  22. : extents {
  23. extents_type{}, // 0
  24. extents_type{1,1}, // 1
  25. extents_type{1,2}, // 2
  26. extents_type{2,1}, // 3
  27. extents_type{2,3}, // 4
  28. extents_type{2,3,1}, // 5
  29. extents_type{1,2,3}, // 6
  30. extents_type{1,1,2,3}, // 7
  31. extents_type{1,2,3,1,1}, // 8
  32. extents_type{4,2,3}, // 9
  33. extents_type{4,2,1,3}, // 10
  34. extents_type{4,2,1,3,1}, // 11
  35. extents_type{1,4,2,1,3,1} } // 12
  36. {
  37. }
  38. std::vector<extents_type> extents;
  39. };
  40. BOOST_FIXTURE_TEST_CASE_TEMPLATE( test_tensor_expression_access, value, test_types, fixture)
  41. {
  42. using namespace boost::numeric;
  43. using value_type = typename value::first_type;
  44. using layout_type = typename value::second_type;
  45. using tensor_type = ublas::tensor<value_type, layout_type>;
  46. using tensor_expression_type = typename tensor_type::super_type;
  47. for(auto const& e : extents) {
  48. auto v = value_type{};
  49. auto t = tensor_type(e);
  50. for(auto& tt: t){ tt = v; v+=value_type{1}; }
  51. const auto& tensor_expression_const = static_cast<tensor_expression_type const&>( t );
  52. for(auto i = 0ul; i < t.size(); ++i)
  53. BOOST_CHECK_EQUAL( tensor_expression_const()(i), t(i) );
  54. }
  55. }
  56. BOOST_FIXTURE_TEST_CASE_TEMPLATE( test_tensor_unary_expression, value, test_types, fixture)
  57. {
  58. using namespace boost::numeric;
  59. using value_type = typename value::first_type;
  60. using layout_type = typename value::second_type;
  61. using tensor_type = ublas::tensor<value_type, layout_type>;
  62. auto uplus1 = std::bind( std::plus<value_type>{}, std::placeholders::_1, value_type(1) );
  63. for(auto const& e : extents) {
  64. auto t = tensor_type(e);
  65. auto v = value_type{};
  66. for(auto& tt: t) { tt = v; v+=value_type{1}; }
  67. const auto uexpr = ublas::detail::make_unary_tensor_expression<tensor_type>( t, uplus1 );
  68. for(auto i = 0ul; i < t.size(); ++i)
  69. BOOST_CHECK_EQUAL( uexpr(i), uplus1(t(i)) );
  70. auto uexpr_uexpr = ublas::detail::make_unary_tensor_expression<tensor_type>( uexpr, uplus1 );
  71. for(auto i = 0ul; i < t.size(); ++i)
  72. BOOST_CHECK_EQUAL( uexpr_uexpr(i), uplus1(uplus1(t(i))) );
  73. const auto & uexpr_e = uexpr.e;
  74. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_e) >, tensor_type > ) );
  75. const auto & uexpr_uexpr_e_e = uexpr_uexpr.e.e;
  76. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_uexpr_e_e) >, tensor_type > ) );
  77. }
  78. }
  79. BOOST_FIXTURE_TEST_CASE_TEMPLATE( test_tensor_binary_expression, value, test_types, fixture)
  80. {
  81. using namespace boost::numeric;
  82. using value_type = typename value::first_type;
  83. using layout_type = typename value::second_type;
  84. using tensor_type = ublas::tensor<value_type, layout_type>;
  85. auto uplus1 = std::bind( std::plus<value_type>{}, std::placeholders::_1, value_type(1) );
  86. auto uplus2 = std::bind( std::plus<value_type>{}, std::placeholders::_1, value_type(2) );
  87. auto bplus = std::plus <value_type>{};
  88. auto bminus = std::minus<value_type>{};
  89. for(auto const& e : extents) {
  90. auto t = tensor_type(e);
  91. auto v = value_type{};
  92. for(auto& tt: t){ tt = v; v+=value_type{1}; }
  93. auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_type>( t, uplus1 );
  94. auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_type>( t, uplus2 );
  95. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.e) >, tensor_type > ) );
  96. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.e) >, tensor_type > ) );
  97. for(auto i = 0ul; i < t.size(); ++i)
  98. BOOST_CHECK_EQUAL( uexpr1(i), uplus1(t(i)) );
  99. for(auto i = 0ul; i < t.size(); ++i)
  100. BOOST_CHECK_EQUAL( uexpr2(i), uplus2(t(i)) );
  101. auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( uexpr1, uexpr2, bplus );
  102. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.el.e) >, tensor_type > ) );
  103. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.er.e) >, tensor_type > ) );
  104. for(auto i = 0ul; i < t.size(); ++i)
  105. BOOST_CHECK_EQUAL( bexpr_uexpr(i), bplus(uexpr1(i),uexpr2(i)) );
  106. auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_type>( bexpr_uexpr, t, bminus );
  107. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.el.e) >, tensor_type > ) );
  108. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.er.e) >, tensor_type > ) );
  109. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_type > ) );
  110. BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_type > ) );
  111. for(auto i = 0ul; i < t.size(); ++i)
  112. BOOST_CHECK_EQUAL( bexpr_bexpr_uexpr(i), bminus(bexpr_uexpr(i),t(i)) );
  113. }
  114. }