9
3

test_einstein_notation.cpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. // Copyright (c) 2018-2019 Cem Bassoy
  2. //
  3. // Distributed under the Boost Software License, Version 1.0. (See
  4. // accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. // The authors gratefully acknowledge the support of
  8. // Fraunhofer and Google in producing this work
  9. // which started as a Google Summer of Code project.
  10. //
  11. // And we acknowledge the support from all contributors.
  12. #include <iostream>
  13. #include <algorithm>
  14. #include <boost/numeric/ublas/tensor.hpp>
  15. #include <boost/test/unit_test.hpp>
  16. #include "utility.hpp"
  17. BOOST_AUTO_TEST_SUITE ( test_einstein_notation, * boost::unit_test::depends_on("test_multi_index") )
  18. using test_types = zip<int,long,float,double,std::complex<float>>::with_t<boost::numeric::ublas::first_order, boost::numeric::ublas::last_order>;
  19. //using test_types = zip<int>::with_t<boost::numeric::ublas::first_order>;
  20. BOOST_AUTO_TEST_CASE_TEMPLATE( test_einstein_multiplication, value, test_types )
  21. {
  22. using namespace boost::numeric::ublas;
  23. using value_type = typename value::first_type;
  24. using layout_type = typename value::second_type;
  25. using tensor_type = tensor<value_type,layout_type>;
  26. using namespace boost::numeric::ublas::index;
  27. {
  28. auto A = tensor_type{5,3};
  29. auto B = tensor_type{3,4};
  30. // auto C = tensor_type{4,5,6};
  31. for(auto j = 0u; j < A.extents().at(1); ++j)
  32. for(auto i = 0u; i < A.extents().at(0); ++i)
  33. A.at( i,j ) = value_type(i+1);
  34. for(auto j = 0u; j < B.extents().at(1); ++j)
  35. for(auto i = 0u; i < B.extents().at(0); ++i)
  36. B.at( i,j ) = value_type(i+1);
  37. auto AB = A(_,_e) * B(_e,_);
  38. // std::cout << "A = " << A << std::endl;
  39. // std::cout << "B = " << B << std::endl;
  40. // std::cout << "AB = " << AB << std::endl;
  41. for(auto j = 0u; j < AB.extents().at(1); ++j)
  42. for(auto i = 0u; i < AB.extents().at(0); ++i)
  43. BOOST_CHECK_EQUAL( AB.at( i,j ) , value_type(A.at( i,0 ) * ( B.extents().at(0) * (B.extents().at(0)+1) / 2 )) );
  44. }
  45. {
  46. auto A = tensor_type{4,5,3};
  47. auto B = tensor_type{3,4,2};
  48. for(auto k = 0u; k < A.extents().at(2); ++k)
  49. for(auto j = 0u; j < A.extents().at(1); ++j)
  50. for(auto i = 0u; i < A.extents().at(0); ++i)
  51. A.at( i,j,k ) = value_type(i+1);
  52. for(auto k = 0u; k < B.extents().at(2); ++k)
  53. for(auto j = 0u; j < B.extents().at(1); ++j)
  54. for(auto i = 0u; i < B.extents().at(0); ++i)
  55. B.at( i,j,k ) = value_type(i+1);
  56. auto AB = A(_d,_,_f) * B(_f,_d,_);
  57. // std::cout << "A = " << A << std::endl;
  58. // std::cout << "B = " << B << std::endl;
  59. // std::cout << "AB = " << AB << std::endl;
  60. // n*(n+1)/2;
  61. auto const nf = ( B.extents().at(0) * (B.extents().at(0)+1) / 2 );
  62. auto const nd = ( A.extents().at(0) * (A.extents().at(0)+1) / 2 );
  63. for(auto j = 0u; j < AB.extents().at(1); ++j)
  64. for(auto i = 0u; i < AB.extents().at(0); ++i)
  65. BOOST_CHECK_EQUAL( AB.at( i,j ) , value_type(nf * nd) );
  66. }
  67. {
  68. auto A = tensor_type{4,3};
  69. auto B = tensor_type{3,4,2};
  70. for(auto j = 0u; j < A.extents().at(1); ++j)
  71. for(auto i = 0u; i < A.extents().at(0); ++i)
  72. A.at( i,j ) = value_type(i+1);
  73. for(auto k = 0u; k < B.extents().at(2); ++k)
  74. for(auto j = 0u; j < B.extents().at(1); ++j)
  75. for(auto i = 0u; i < B.extents().at(0); ++i)
  76. B.at( i,j,k ) = value_type(i+1);
  77. auto AB = A(_d,_f) * B(_f,_d,_);
  78. // n*(n+1)/2;
  79. auto const nf = ( B.extents().at(0) * (B.extents().at(0)+1) / 2 );
  80. auto const nd = ( A.extents().at(0) * (A.extents().at(0)+1) / 2 );
  81. for(auto i = 0u; i < AB.extents().at(0); ++i)
  82. BOOST_CHECK_EQUAL ( AB.at( i ) , value_type(nf * nd) );
  83. }
  84. }
  85. BOOST_AUTO_TEST_SUITE_END()