operators_comparison.hpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. //
  2. // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See
  5. // accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt)
  7. //
  8. // The authors gratefully acknowledge the support of
  9. // Fraunhofer IOSB, Ettlingen, Germany
  10. //
  11. #ifndef BOOST_UBLAS_TENSOR_OPERATORS_COMPARISON_HPP
  12. #define BOOST_UBLAS_TENSOR_OPERATORS_COMPARISON_HPP
  13. #include <boost/numeric/ublas/tensor/expression.hpp>
  14. #include <boost/numeric/ublas/tensor/expression_evaluation.hpp>
  15. #include <type_traits>
  16. #include <functional>
  17. namespace boost::numeric::ublas {
  18. template<class element_type, class storage_format, class storage_type>
  19. class tensor;
  20. }
  21. namespace boost::numeric::ublas::detail {
  22. template<class T, class F, class A, class BinaryPred>
  23. bool compare(tensor<T,F,A> const& lhs, tensor<T,F,A> const& rhs, BinaryPred pred)
  24. {
  25. if(lhs.extents() != rhs.extents()){
  26. if constexpr(!std::is_same<BinaryPred,std::equal_to<>>::value && !std::is_same<BinaryPred,std::not_equal_to<>>::value)
  27. throw std::runtime_error("Error in boost::numeric::ublas::detail::compare: cannot compare tensors with different shapes.");
  28. else
  29. return false;
  30. }
  31. if constexpr(std::is_same<BinaryPred,std::greater<>>::value || std::is_same<BinaryPred,std::less<>>::value)
  32. if(lhs.empty())
  33. return false;
  34. for(auto i = 0u; i < lhs.size(); ++i)
  35. if(!pred(lhs(i), rhs(i)))
  36. return false;
  37. return true;
  38. }
  39. template<class T, class F, class A, class UnaryPred>
  40. bool compare(tensor<T,F,A> const& rhs, UnaryPred pred)
  41. {
  42. for(auto i = 0u; i < rhs.size(); ++i)
  43. if(!pred(rhs(i)))
  44. return false;
  45. return true;
  46. }
  47. template<class T, class L, class R, class BinaryPred>
  48. bool compare(tensor_expression<T,L> const& lhs, tensor_expression<T,R> const& rhs, BinaryPred pred)
  49. {
  50. constexpr bool lhs_is_tensor = std::is_same<T,L>::value;
  51. constexpr bool rhs_is_tensor = std::is_same<T,R>::value;
  52. if constexpr (lhs_is_tensor && rhs_is_tensor)
  53. return compare(static_cast<T const&>( lhs ), static_cast<T const&>( rhs ), pred);
  54. else if constexpr (lhs_is_tensor && !rhs_is_tensor)
  55. return compare(static_cast<T const&>( lhs ), T( rhs ), pred);
  56. else if constexpr (!lhs_is_tensor && rhs_is_tensor)
  57. return compare(T( lhs ), static_cast<T const&>( rhs ), pred);
  58. else
  59. return compare(T( lhs ), T( rhs ), pred);
  60. }
  61. template<class T, class D, class UnaryPred>
  62. bool compare(tensor_expression<T,D> const& expr, UnaryPred pred)
  63. {
  64. if constexpr (std::is_same<T,D>::value)
  65. return compare(static_cast<T const&>( expr ), pred);
  66. else
  67. return compare(T( expr ), pred);
  68. }
  69. }
  70. template<class T, class L, class R>
  71. bool operator==( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
  72. boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
  73. return boost::numeric::ublas::detail::compare( lhs, rhs, std::equal_to<>{} );
  74. }
  75. template<class T, class L, class R>
  76. auto operator!=(boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
  77. boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
  78. return boost::numeric::ublas::detail::compare( lhs, rhs, std::not_equal_to<>{} );
  79. }
  80. template<class T, class L, class R>
  81. auto operator< ( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
  82. boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
  83. return boost::numeric::ublas::detail::compare( lhs, rhs, std::less<>{} );
  84. }
  85. template<class T, class L, class R>
  86. auto operator<=( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
  87. boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
  88. return boost::numeric::ublas::detail::compare( lhs, rhs, std::less_equal<>{} );
  89. }
  90. template<class T, class L, class R>
  91. auto operator> ( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
  92. boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
  93. return boost::numeric::ublas::detail::compare( lhs, rhs, std::greater<>{} );
  94. }
  95. template<class T, class L, class R>
  96. auto operator>=( boost::numeric::ublas::detail::tensor_expression<T,L> const& lhs,
  97. boost::numeric::ublas::detail::tensor_expression<T,R> const& rhs) {
  98. return boost::numeric::ublas::detail::compare( lhs, rhs, std::greater_equal<>{} );
  99. }
  100. template<class T, class D>
  101. bool operator==( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
  102. return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs == r; } );
  103. }
  104. template<class T, class D>
  105. auto operator!=( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
  106. return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs != r; } );
  107. }
  108. template<class T, class D>
  109. auto operator< ( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
  110. return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs < r; } );
  111. }
  112. template<class T, class D>
  113. auto operator<=( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
  114. return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs <= r; } );
  115. }
  116. template<class T, class D>
  117. auto operator> ( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
  118. return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs > r; } );
  119. }
  120. template<class T, class D>
  121. auto operator>=( typename T::const_reference lhs, boost::numeric::ublas::detail::tensor_expression<T,D> const& rhs) {
  122. return boost::numeric::ublas::detail::compare( rhs, [lhs](auto const& r){ return lhs >= r; } );
  123. }
  124. template<class T, class D>
  125. bool operator==( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
  126. return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l == rhs; } );
  127. }
  128. template<class T, class D>
  129. auto operator!=( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
  130. return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l != rhs; } );
  131. }
  132. template<class T, class D>
  133. auto operator< ( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
  134. return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l < rhs; } );
  135. }
  136. template<class T, class D>
  137. auto operator<=( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
  138. return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l <= rhs; } );
  139. }
  140. template<class T, class D>
  141. auto operator> ( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
  142. return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l > rhs; } );
  143. }
  144. template<class T, class D>
  145. auto operator>=( boost::numeric::ublas::detail::tensor_expression<T,D> const& lhs, typename T::const_reference rhs) {
  146. return boost::numeric::ublas::detail::compare( lhs, [rhs](auto const& l){ return l >= rhs; } );
  147. }
  148. #endif