strides.hpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. //
  2. // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See
  5. // accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt)
  7. //
  8. // The authors gratefully acknowledge the support of
  9. // Fraunhofer IOSB, Ettlingen, Germany
  10. //
  11. /// \file strides.hpp Definition for the basic_strides template class
  12. #ifndef BOOST_UBLAS_TENSOR_STRIDES_HPP
  13. #define BOOST_UBLAS_TENSOR_STRIDES_HPP
  14. #include <vector>
  15. #include <limits>
  16. #include <numeric>
  17. #include <stdexcept>
  18. #include <initializer_list>
  19. #include <algorithm>
  20. #include <cassert>
  21. #include <boost/numeric/ublas/functional.hpp>
  22. namespace boost {
  23. namespace numeric {
  24. namespace ublas {
  25. using first_order = column_major;
  26. using last_order = row_major;
  27. template<class T>
  28. class basic_extents;
  29. /** @brief Template class for storing tensor strides for iteration with runtime variable size.
  30. *
  31. * Proxy template class of std::vector<int_type>.
  32. *
  33. */
  34. template<class __int_type, class __layout>
  35. class basic_strides
  36. {
  37. public:
  38. using base_type = std::vector<__int_type>;
  39. static_assert( std::numeric_limits<typename base_type::value_type>::is_integer,
  40. "Static error in boost::numeric::ublas::basic_strides: type must be of type integer.");
  41. static_assert(!std::numeric_limits<typename base_type::value_type>::is_signed,
  42. "Static error in boost::numeric::ublas::basic_strides: type must be of type unsigned integer.");
  43. static_assert(std::is_same<__layout,first_order>::value || std::is_same<__layout,last_order>::value,
  44. "Static error in boost::numeric::ublas::basic_strides: layout type must either first or last order");
  45. using layout_type = __layout;
  46. using value_type = typename base_type::value_type;
  47. using reference = typename base_type::reference;
  48. using const_reference = typename base_type::const_reference;
  49. using size_type = typename base_type::size_type;
  50. using const_pointer = typename base_type::const_pointer;
  51. using const_iterator = typename base_type::const_iterator;
  52. /** @brief Default constructs basic_strides
  53. *
  54. * @code auto ex = basic_strides<unsigned>{};
  55. */
  56. constexpr explicit basic_strides()
  57. : _base{}
  58. {
  59. }
  60. /** @brief Constructs basic_strides from basic_extents for the first- and last-order storage formats
  61. *
  62. * @code auto strides = basic_strides<unsigned>( basic_extents<std::size_t>{2,3,4} );
  63. *
  64. */
  65. template <class T>
  66. basic_strides(basic_extents<T> const& s)
  67. : _base(s.size(),1)
  68. {
  69. if(s.empty())
  70. return;
  71. if(!s.valid())
  72. throw std::runtime_error("Error in boost::numeric::ublas::basic_strides() : shape is not valid.");
  73. if(s.is_vector() || s.is_scalar())
  74. return;
  75. if(this->size() < 2)
  76. throw std::runtime_error("Error in boost::numeric::ublas::basic_strides() : size of strides must be greater or equal 2.");
  77. if constexpr (std::is_same<layout_type,first_order>::value){
  78. size_type k = 1ul, kend = this->size();
  79. for(; k < kend; ++k)
  80. _base[k] = _base[k-1] * s[k-1];
  81. }
  82. else {
  83. size_type k = this->size()-2, kend = 0ul;
  84. for(; k > kend; --k)
  85. _base[k] = _base[k+1] * s[k+1];
  86. _base[0] = _base[1] * s[1];
  87. }
  88. }
  89. basic_strides(basic_strides const& l)
  90. : _base(l._base)
  91. {}
  92. basic_strides(basic_strides && l )
  93. : _base(std::move(l._base))
  94. {}
  95. basic_strides(base_type const& l )
  96. : _base(l)
  97. {}
  98. basic_strides(base_type && l )
  99. : _base(std::move(l))
  100. {}
  101. ~basic_strides() = default;
  102. basic_strides& operator=(basic_strides other)
  103. {
  104. swap (*this, other);
  105. return *this;
  106. }
  107. friend void swap(basic_strides& lhs, basic_strides& rhs) {
  108. std::swap(lhs._base , rhs._base);
  109. }
  110. const_reference operator[] (size_type p) const{
  111. return _base[p];
  112. }
  113. const_pointer data() const{
  114. return _base.data();
  115. }
  116. const_reference at (size_type p) const{
  117. return _base.at(p);
  118. }
  119. bool empty() const{
  120. return _base.empty();
  121. }
  122. size_type size() const{
  123. return _base.size();
  124. }
  125. template<class other_layout>
  126. bool operator == (basic_strides<value_type, other_layout> const& b) const{
  127. return b.base() == this->base();
  128. }
  129. template<class other_layout>
  130. bool operator != (basic_strides<value_type, other_layout> const& b) const{
  131. return b.base() != this->base();
  132. }
  133. bool operator == (basic_strides const& b) const{
  134. return b._base == _base;
  135. }
  136. bool operator != (basic_strides const& b) const{
  137. return b._base != _base;
  138. }
  139. const_iterator begin() const{
  140. return _base.begin();
  141. }
  142. const_iterator end() const{
  143. return _base.end();
  144. }
  145. void clear() {
  146. this->_base.clear();
  147. }
  148. base_type const& base() const{
  149. return this->_base;
  150. }
  151. protected:
  152. base_type _base;
  153. };
  154. template<class layout_type>
  155. using strides = basic_strides<std::size_t, layout_type>;
  156. namespace detail {
  157. /** @brief Returns relative memory index with respect to a multi-index
  158. *
  159. * @code auto j = access(std::vector{3,4,5}, strides{shape{4,2,3},first_order}); @endcode
  160. *
  161. * @param[in] i multi-index of length p
  162. * @param[in] w stride vector of length p
  163. * @returns relative memory location depending on \c i and \c w
  164. */
  165. BOOST_UBLAS_INLINE
  166. template<class size_type, class layout_type>
  167. auto access(std::vector<size_type> const& i, basic_strides<size_type,layout_type> const& w)
  168. {
  169. const auto p = i.size();
  170. size_type sum = 0u;
  171. for(auto r = 0u; r < p; ++r)
  172. sum += i[r]*w[r];
  173. return sum;
  174. }
  175. /** @brief Returns relative memory index with respect to a multi-index
  176. *
  177. * @code auto j = access(0, strides{shape{4,2,3},first_order}, 2,3,4); @endcode
  178. *
  179. * @param[in] i first element of the partial multi-index
  180. * @param[in] is the following elements of the partial multi-index
  181. * @param[in] sum the current relative memory index
  182. * @returns relative memory location depending on \c i and \c w
  183. */
  184. BOOST_UBLAS_INLINE
  185. template<std::size_t r, class layout_type, class ... size_types>
  186. auto access(std::size_t sum, basic_strides<std::size_t, layout_type> const& w, std::size_t i, size_types ... is)
  187. {
  188. sum+=i*w[r];
  189. if constexpr (sizeof...(is) == 0)
  190. return sum;
  191. else
  192. return detail::access<r+1>(sum,w,std::forward<size_types>(is)...);
  193. }
  194. }
  195. }
  196. }
  197. }
  198. #endif