// // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com // // Distributed under the Boost Software License, Version 1.0. (See // accompanying file LICENSE_1_0.txt or copy at // http://www.boost.org/LICENSE_1_0.txt) // // The authors gratefully acknowledge the support of // Fraunhofer IOSB, Ettlingen, Germany // #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP #include #include #include #include #include #include #include namespace boost { namespace numeric { namespace ublas { /** @brief Template class for storing tensor extents with runtime variable size. * * Proxy template class of std::vector. * */ template class basic_extents { static_assert( std::numeric_limits::value_type>::is_integer, "Static error in basic_layout: type must be of type integer."); static_assert(!std::numeric_limits::value_type>::is_signed, "Static error in basic_layout: type must be of type unsigned integer."); public: using base_type = std::vector; using value_type = typename base_type::value_type; using const_reference = typename base_type::const_reference; using reference = typename base_type::reference; using size_type = typename base_type::size_type; using const_pointer = typename base_type::const_pointer; using const_iterator = typename base_type::const_iterator; /** @brief Default constructs basic_extents * * @code auto ex = basic_extents{}; */ constexpr explicit basic_extents() : _base{} { } /** @brief Copy constructs basic_extents from a one-dimensional container * * @code auto ex = basic_extents( std::vector(3u,3u) ); * * @note checks if size > 1 and all elements > 0 * * @param b one-dimensional std::vector container */ explicit basic_extents(base_type const& b) : _base(b) { if (!this->valid()){ throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements."); } } /** @brief Move constructs basic_extents from a one-dimensional container * * @code auto ex = basic_extents( std::vector(3u,3u) ); * * @note checks if size > 1 and all elements > 0 * * @param b one-dimensional container of type std::vector */ explicit basic_extents(base_type && b) : _base(std::move(b)) { if (!this->valid()){ throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements."); } } /** @brief Constructs basic_extents from an initializer list * * @code auto ex = basic_extents{3,2,4}; * * @note checks if size > 1 and all elements > 0 * * @param l one-dimensional list of type std::initializer */ basic_extents(std::initializer_list l) : basic_extents( base_type(std::move(l)) ) { } /** @brief Constructs basic_extents from a range specified by two iterators * * @code auto ex = basic_extents(a.begin(), a.end()); * * @note checks if size > 1 and all elements > 0 * * @param first iterator pointing to the first element * @param last iterator pointing to the next position after the last element */ basic_extents(const_iterator first, const_iterator last) : basic_extents ( base_type( first,last ) ) { } /** @brief Copy constructs basic_extents */ basic_extents(basic_extents const& l ) : _base(l._base) { } /** @brief Move constructs basic_extents */ basic_extents(basic_extents && l ) noexcept : _base(std::move(l._base)) { } ~basic_extents() = default; basic_extents& operator=(basic_extents other) noexcept { swap (*this, other); return *this; } friend void swap(basic_extents& lhs, basic_extents& rhs) { std::swap(lhs._base , rhs._base ); } /** @brief Returns true if this has a scalar shape * * @returns true if (1,1,[1,...,1]) */ bool is_scalar() const { return _base.size() != 0 && std::all_of(_base.begin(), _base.end(), [](const_reference a){ return a == 1;}); } /** @brief Returns true if this has a vector shape * * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1 */ bool is_vector() const { if(_base.size() == 0){ return false; } if(_base.size() == 1){ return _base.at(0) > 1; } auto greater_one = [](const_reference a){ return a > 1;}; auto equal_one = [](const_reference a){ return a == 1;}; return std::any_of(_base.begin(), _base.begin()+2, greater_one) && std::any_of(_base.begin(), _base.begin()+2, equal_one ) && std::all_of(_base.begin()+2, _base.end(), equal_one); } /** @brief Returns true if this has a matrix shape * * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1 */ bool is_matrix() const { if(_base.size() < 2){ return false; } auto greater_one = [](const_reference a){ return a > 1;}; auto equal_one = [](const_reference a){ return a == 1;}; return std::all_of(_base.begin(), _base.begin()+2, greater_one) && std::all_of(_base.begin()+2, _base.end(), equal_one ); } /** @brief Returns true if this is has a tensor shape * * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix() */ bool is_tensor() const { if(_base.size() < 3){ return false; } auto greater_one = [](const_reference a){ return a > 1;}; return std::any_of(_base.begin()+2, _base.end(), greater_one); } const_pointer data() const { return this->_base.data(); } const_reference operator[] (size_type p) const { return this->_base[p]; } const_reference at (size_type p) const { return this->_base.at(p); } reference operator[] (size_type p) { return this->_base[p]; } reference at (size_type p) { return this->_base.at(p); } bool empty() const { return this->_base.empty(); } size_type size() const { return this->_base.size(); } /** @brief Returns true if size > 1 and all elements > 0 */ bool valid() const { return this->size() > 1 && std::none_of(_base.begin(), _base.end(), [](const_reference a){ return a == value_type(0); }); } /** @brief Returns the number of elements a tensor holds with this */ size_type product() const { if(_base.empty()){ return 0; } return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>()); } /** @brief Eliminates singleton dimensions when size > 2 * * squeeze { 1,1} -> { 1,1} * squeeze { 2,1} -> { 2,1} * squeeze { 1,2} -> { 1,2} * * squeeze {1,2,3} -> { 2,3} * squeeze {2,1,3} -> { 2,3} * squeeze {1,3,1} -> { 3,1} * */ basic_extents squeeze() const { if(this->size() <= 2){ return *this; } auto new_extent = basic_extents{}; auto insert_iter = std::back_insert_iterator(new_extent._base); std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1}); return new_extent; } void clear() { this->_base.clear(); } bool operator == (basic_extents const& b) const { return _base == b._base; } bool operator != (basic_extents const& b) const { return !( _base == b._base ); } const_iterator begin() const { return _base.begin(); } const_iterator end() const { return _base.end(); } base_type const& base() const { return _base; } private: base_type _base; }; using shape = basic_extents; } // namespace ublas } // namespace numeric } // namespace boost #endif