extents.hpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. //
  2. // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See
  5. // accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt)
  7. //
  8. // The authors gratefully acknowledge the support of
  9. // Fraunhofer IOSB, Ettlingen, Germany
  10. //
  11. #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
  12. #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
  13. #include <algorithm>
  14. #include <initializer_list>
  15. #include <limits>
  16. #include <numeric>
  17. #include <stdexcept>
  18. #include <vector>
  19. #include <cassert>
  20. namespace boost {
  21. namespace numeric {
  22. namespace ublas {
  23. /** @brief Template class for storing tensor extents with runtime variable size.
  24. *
  25. * Proxy template class of std::vector<int_type>.
  26. *
  27. */
  28. template<class int_type>
  29. class basic_extents
  30. {
  31. static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer.");
  32. static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed, "Static error in basic_layout: type must be of type unsigned integer.");
  33. public:
  34. using base_type = std::vector<int_type>;
  35. using value_type = typename base_type::value_type;
  36. using const_reference = typename base_type::const_reference;
  37. using reference = typename base_type::reference;
  38. using size_type = typename base_type::size_type;
  39. using const_pointer = typename base_type::const_pointer;
  40. using const_iterator = typename base_type::const_iterator;
  41. /** @brief Default constructs basic_extents
  42. *
  43. * @code auto ex = basic_extents<unsigned>{};
  44. */
  45. constexpr explicit basic_extents()
  46. : _base{}
  47. {
  48. }
  49. /** @brief Copy constructs basic_extents from a one-dimensional container
  50. *
  51. * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) );
  52. *
  53. * @note checks if size > 1 and all elements > 0
  54. *
  55. * @param b one-dimensional std::vector<int_type> container
  56. */
  57. explicit basic_extents(base_type const& b)
  58. : _base(b)
  59. {
  60. if (!this->valid()){
  61. throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
  62. }
  63. }
  64. /** @brief Move constructs basic_extents from a one-dimensional container
  65. *
  66. * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) );
  67. *
  68. * @note checks if size > 1 and all elements > 0
  69. *
  70. * @param b one-dimensional container of type std::vector<int_type>
  71. */
  72. explicit basic_extents(base_type && b)
  73. : _base(std::move(b))
  74. {
  75. if (!this->valid()){
  76. throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
  77. }
  78. }
  79. /** @brief Constructs basic_extents from an initializer list
  80. *
  81. * @code auto ex = basic_extents<unsigned>{3,2,4};
  82. *
  83. * @note checks if size > 1 and all elements > 0
  84. *
  85. * @param l one-dimensional list of type std::initializer<int_type>
  86. */
  87. basic_extents(std::initializer_list<value_type> l)
  88. : basic_extents( base_type(std::move(l)) )
  89. {
  90. }
  91. /** @brief Constructs basic_extents from a range specified by two iterators
  92. *
  93. * @code auto ex = basic_extents<unsigned>(a.begin(), a.end());
  94. *
  95. * @note checks if size > 1 and all elements > 0
  96. *
  97. * @param first iterator pointing to the first element
  98. * @param last iterator pointing to the next position after the last element
  99. */
  100. basic_extents(const_iterator first, const_iterator last)
  101. : basic_extents ( base_type( first,last ) )
  102. {
  103. }
  104. /** @brief Copy constructs basic_extents */
  105. basic_extents(basic_extents const& l )
  106. : _base(l._base)
  107. {
  108. }
  109. /** @brief Move constructs basic_extents */
  110. basic_extents(basic_extents && l ) noexcept
  111. : _base(std::move(l._base))
  112. {
  113. }
  114. ~basic_extents() = default;
  115. basic_extents& operator=(basic_extents other) noexcept
  116. {
  117. swap (*this, other);
  118. return *this;
  119. }
  120. friend void swap(basic_extents& lhs, basic_extents& rhs) {
  121. std::swap(lhs._base , rhs._base );
  122. }
  123. /** @brief Returns true if this has a scalar shape
  124. *
  125. * @returns true if (1,1,[1,...,1])
  126. */
  127. bool is_scalar() const
  128. {
  129. return
  130. _base.size() != 0 &&
  131. std::all_of(_base.begin(), _base.end(),
  132. [](const_reference a){ return a == 1;});
  133. }
  134. /** @brief Returns true if this has a vector shape
  135. *
  136. * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1
  137. */
  138. bool is_vector() const
  139. {
  140. if(_base.size() == 0){
  141. return false;
  142. }
  143. if(_base.size() == 1){
  144. return _base.at(0) > 1;
  145. }
  146. auto greater_one = [](const_reference a){ return a > 1;};
  147. auto equal_one = [](const_reference a){ return a == 1;};
  148. return
  149. std::any_of(_base.begin(), _base.begin()+2, greater_one) &&
  150. std::any_of(_base.begin(), _base.begin()+2, equal_one ) &&
  151. std::all_of(_base.begin()+2, _base.end(), equal_one);
  152. }
  153. /** @brief Returns true if this has a matrix shape
  154. *
  155. * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1
  156. */
  157. bool is_matrix() const
  158. {
  159. if(_base.size() < 2){
  160. return false;
  161. }
  162. auto greater_one = [](const_reference a){ return a > 1;};
  163. auto equal_one = [](const_reference a){ return a == 1;};
  164. return
  165. std::all_of(_base.begin(), _base.begin()+2, greater_one) &&
  166. std::all_of(_base.begin()+2, _base.end(), equal_one );
  167. }
  168. /** @brief Returns true if this is has a tensor shape
  169. *
  170. * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix()
  171. */
  172. bool is_tensor() const
  173. {
  174. if(_base.size() < 3){
  175. return false;
  176. }
  177. auto greater_one = [](const_reference a){ return a > 1;};
  178. return std::any_of(_base.begin()+2, _base.end(), greater_one);
  179. }
  180. const_pointer data() const
  181. {
  182. return this->_base.data();
  183. }
  184. const_reference operator[] (size_type p) const
  185. {
  186. return this->_base[p];
  187. }
  188. const_reference at (size_type p) const
  189. {
  190. return this->_base.at(p);
  191. }
  192. reference operator[] (size_type p)
  193. {
  194. return this->_base[p];
  195. }
  196. reference at (size_type p)
  197. {
  198. return this->_base.at(p);
  199. }
  200. bool empty() const
  201. {
  202. return this->_base.empty();
  203. }
  204. size_type size() const
  205. {
  206. return this->_base.size();
  207. }
  208. /** @brief Returns true if size > 1 and all elements > 0 */
  209. bool valid() const
  210. {
  211. return
  212. this->size() > 1 &&
  213. std::none_of(_base.begin(), _base.end(),
  214. [](const_reference a){ return a == value_type(0); });
  215. }
  216. /** @brief Returns the number of elements a tensor holds with this */
  217. size_type product() const
  218. {
  219. if(_base.empty()){
  220. return 0;
  221. }
  222. return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>());
  223. }
  224. /** @brief Eliminates singleton dimensions when size > 2
  225. *
  226. * squeeze { 1,1} -> { 1,1}
  227. * squeeze { 2,1} -> { 2,1}
  228. * squeeze { 1,2} -> { 1,2}
  229. *
  230. * squeeze {1,2,3} -> { 2,3}
  231. * squeeze {2,1,3} -> { 2,3}
  232. * squeeze {1,3,1} -> { 3,1}
  233. *
  234. */
  235. basic_extents squeeze() const
  236. {
  237. if(this->size() <= 2){
  238. return *this;
  239. }
  240. auto new_extent = basic_extents{};
  241. auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base);
  242. std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1});
  243. return new_extent;
  244. }
  245. void clear()
  246. {
  247. this->_base.clear();
  248. }
  249. bool operator == (basic_extents const& b) const
  250. {
  251. return _base == b._base;
  252. }
  253. bool operator != (basic_extents const& b) const
  254. {
  255. return !( _base == b._base );
  256. }
  257. const_iterator
  258. begin() const
  259. {
  260. return _base.begin();
  261. }
  262. const_iterator
  263. end() const
  264. {
  265. return _base.end();
  266. }
  267. base_type const& base() const { return _base; }
  268. private:
  269. base_type _base;
  270. };
  271. using shape = basic_extents<std::size_t>;
  272. } // namespace ublas
  273. } // namespace numeric
  274. } // namespace boost
  275. #endif