// // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com // // Distributed under the Boost Software License, Version 1.0. (See // accompanying file LICENSE_1_0.txt or copy at // http://www.boost.org/LICENSE_1_0.txt) // // The authors gratefully acknowledge the support of // Fraunhofer IOSB, Ettlingen, Germany // #ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_ #define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_ #include #include namespace boost::numeric::ublas { template class tensor; template class basic_extents; } namespace boost::numeric::ublas::detail { template struct tensor_expression; template struct binary_tensor_expression; template struct unary_tensor_expression; } namespace boost::numeric::ublas::detail { template struct has_tensor_types { static constexpr bool value = false; }; template struct has_tensor_types { static constexpr bool value = true; }; template struct has_tensor_types> { static constexpr bool value = std::is_same::value || has_tensor_types::value; }; template struct has_tensor_types> { static constexpr bool value = std::is_same::value || std::is_same::value || has_tensor_types::value || has_tensor_types::value; }; template struct has_tensor_types> { static constexpr bool value = std::is_same::value || has_tensor_types::value; }; } // namespace boost::numeric::ublas::detail namespace boost::numeric::ublas::detail { /** @brief Retrieves extents of the tensor * */ template auto retrieve_extents(tensor const& t) { return t.extents(); } /** @brief Retrieves extents of the tensor expression * * @note tensor expression must be a binary tree with at least one tensor type * * @returns extents of the child expression if it is a tensor or extents of one child of its child. */ template auto retrieve_extents(tensor_expression const& expr) { static_assert(detail::has_tensor_types>::value, "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors."); auto const& cast_expr = static_cast(expr); if constexpr ( std::is_same::value ) return cast_expr.extents(); else return retrieve_extents(cast_expr); } /** @brief Retrieves extents of the binary tensor expression * * @note tensor expression must be a binary tree with at least one tensor type * * @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child. */ template auto retrieve_extents(binary_tensor_expression const& expr) { static_assert(detail::has_tensor_types>::value, "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors."); if constexpr ( std::is_same::value ) return expr.el.extents(); if constexpr ( std::is_same::value ) return expr.er.extents(); else if constexpr ( detail::has_tensor_types::value ) return retrieve_extents(expr.el); else if constexpr ( detail::has_tensor_types::value ) return retrieve_extents(expr.er); } /** @brief Retrieves extents of the binary tensor expression * * @note tensor expression must be a binary tree with at least one tensor type * * @returns extents of the child expression if it is a tensor or extents of a child of its child. */ template auto retrieve_extents(unary_tensor_expression const& expr) { static_assert(detail::has_tensor_types>::value, "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors."); if constexpr ( std::is_same::value ) return expr.e.extents(); else if constexpr ( detail::has_tensor_types::value ) return retrieve_extents(expr.e); } } // namespace boost::numeric::ublas::detail /////////////// namespace boost::numeric::ublas::detail { template auto all_extents_equal(tensor const& t, basic_extents const& extents) { return extents == t.extents(); } template auto all_extents_equal(tensor_expression const& expr, basic_extents const& extents) { static_assert(detail::has_tensor_types>::value, "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors."); auto const& cast_expr = static_cast(expr); if constexpr ( std::is_same::value ) if( extents != cast_expr.extents() ) return false; if constexpr ( detail::has_tensor_types::value ) if ( !all_extents_equal(cast_expr, extents)) return false; return true; } template auto all_extents_equal(binary_tensor_expression const& expr, basic_extents const& extents) { static_assert(detail::has_tensor_types>::value, "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors."); if constexpr ( std::is_same::value ) if(extents != expr.el.extents()) return false; if constexpr ( std::is_same::value ) if(extents != expr.er.extents()) return false; if constexpr ( detail::has_tensor_types::value ) if(!all_extents_equal(expr.el, extents)) return false; if constexpr ( detail::has_tensor_types::value ) if(!all_extents_equal(expr.er, extents)) return false; return true; } template auto all_extents_equal(unary_tensor_expression const& expr, basic_extents const& extents) { static_assert(detail::has_tensor_types>::value, "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors."); if constexpr ( std::is_same::value ) if(extents != expr.e.extents()) return false; if constexpr ( detail::has_tensor_types::value ) if(!all_extents_equal(expr.e, extents)) return false; return true; } } // namespace boost::numeric::ublas::detail namespace boost::numeric::ublas::detail { /** @brief Evaluates expression for a tensor * * Assigns the results of the expression to the tensor. * * \note Checks if shape of the tensor matches those of all tensors within the expression. */ template void eval(tensor_type& lhs, tensor_expression const& expr) { if constexpr (detail::has_tensor_types >::value ) if(!detail::all_extents_equal(expr, lhs.extents() )) throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes."); #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) lhs(i) = expr()(i); } /** @brief Evaluates expression for a tensor * * Applies a unary function to the results of the expressions before the assignment. * Usually applied needed for unary operators such as A += C; * * \note Checks if shape of the tensor matches those of all tensors within the expression. */ template void eval(tensor_type& lhs, tensor_expression const& expr, unary_fn const fn) { if constexpr (detail::has_tensor_types< tensor_type, tensor_expression >::value ) if(!detail::all_extents_equal( expr, lhs.extents() )) throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes."); #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) fn(lhs(i), expr()(i)); } /** @brief Evaluates expression for a tensor * * Applies a unary function to the results of the expressions before the assignment. * Usually applied needed for unary operators such as A += C; * * \note Checks if shape of the tensor matches those of all tensors within the expression. */ template void eval(tensor_type& lhs, unary_fn const fn) { #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) fn(lhs(i)); } } #endif