discrete_distribution.hpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. /* boost random/discrete_distribution.hpp header file
  2. *
  3. * Copyright Steven Watanabe 2009-2011
  4. * Distributed under the Boost Software License, Version 1.0. (See
  5. * accompanying file LICENSE_1_0.txt or copy at
  6. * http://www.boost.org/LICENSE_1_0.txt)
  7. *
  8. * See http://www.boost.org for most recent version including documentation.
  9. *
  10. * $Id$
  11. */
  12. #ifndef BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED
  13. #define BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED
  14. #include <vector>
  15. #include <limits>
  16. #include <numeric>
  17. #include <utility>
  18. #include <iterator>
  19. #include <boost/assert.hpp>
  20. #include <boost/random/uniform_01.hpp>
  21. #include <boost/random/uniform_int_distribution.hpp>
  22. #include <boost/random/detail/config.hpp>
  23. #include <boost/random/detail/operators.hpp>
  24. #include <boost/random/detail/vector_io.hpp>
  25. #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST
  26. #include <initializer_list>
  27. #endif
  28. #include <boost/range/begin.hpp>
  29. #include <boost/range/end.hpp>
  30. #include <boost/random/detail/disable_warnings.hpp>
  31. namespace boost {
  32. namespace random {
  33. namespace detail {
  34. template<class IntType, class WeightType>
  35. struct integer_alias_table {
  36. WeightType get_weight(IntType bin) const {
  37. WeightType result = _average;
  38. if(bin < _excess) ++result;
  39. return result;
  40. }
  41. template<class Iter>
  42. WeightType init_average(Iter begin, Iter end) {
  43. WeightType weight_average = 0;
  44. IntType excess = 0;
  45. IntType n = 0;
  46. // weight_average * n + excess == current partial sum
  47. // This is a bit messy, but it's guaranteed not to overflow
  48. for(Iter iter = begin; iter != end; ++iter) {
  49. ++n;
  50. if(*iter < weight_average) {
  51. WeightType diff = weight_average - *iter;
  52. weight_average -= diff / n;
  53. if(diff % n > excess) {
  54. --weight_average;
  55. excess += n - diff % n;
  56. } else {
  57. excess -= diff % n;
  58. }
  59. } else {
  60. WeightType diff = *iter - weight_average;
  61. weight_average += diff / n;
  62. if(diff % n < n - excess) {
  63. excess += diff % n;
  64. } else {
  65. ++weight_average;
  66. excess -= n - diff % n;
  67. }
  68. }
  69. }
  70. _alias_table.resize(static_cast<std::size_t>(n));
  71. _average = weight_average;
  72. _excess = excess;
  73. return weight_average;
  74. }
  75. void init_empty()
  76. {
  77. _alias_table.clear();
  78. _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
  79. static_cast<IntType>(0)));
  80. _average = static_cast<WeightType>(1);
  81. _excess = static_cast<IntType>(0);
  82. }
  83. bool operator==(const integer_alias_table& other) const
  84. {
  85. return _alias_table == other._alias_table &&
  86. _average == other._average && _excess == other._excess;
  87. }
  88. static WeightType normalize(WeightType val, WeightType /* average */)
  89. {
  90. return val;
  91. }
  92. static void normalize(std::vector<WeightType>&) {}
  93. template<class URNG>
  94. WeightType test(URNG &urng) const
  95. {
  96. return uniform_int_distribution<WeightType>(0, _average)(urng);
  97. }
  98. bool accept(IntType result, WeightType val) const
  99. {
  100. return result < _excess || val < _average;
  101. }
  102. static WeightType try_get_sum(const std::vector<WeightType>& weights)
  103. {
  104. WeightType result = static_cast<WeightType>(0);
  105. for(typename std::vector<WeightType>::const_iterator
  106. iter = weights.begin(), end = weights.end();
  107. iter != end; ++iter)
  108. {
  109. if((std::numeric_limits<WeightType>::max)() - result > *iter) {
  110. return static_cast<WeightType>(0);
  111. }
  112. result += *iter;
  113. }
  114. return result;
  115. }
  116. template<class URNG>
  117. static WeightType generate_in_range(URNG &urng, WeightType max)
  118. {
  119. return uniform_int_distribution<WeightType>(
  120. static_cast<WeightType>(0), max-1)(urng);
  121. }
  122. typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
  123. alias_table_t _alias_table;
  124. WeightType _average;
  125. IntType _excess;
  126. };
  127. template<class IntType, class WeightType>
  128. struct real_alias_table {
  129. WeightType get_weight(IntType) const
  130. {
  131. return WeightType(1.0);
  132. }
  133. template<class Iter>
  134. WeightType init_average(Iter first, Iter last)
  135. {
  136. std::size_t size = std::distance(first, last);
  137. WeightType weight_sum =
  138. std::accumulate(first, last, static_cast<WeightType>(0));
  139. _alias_table.resize(size);
  140. return weight_sum / size;
  141. }
  142. void init_empty()
  143. {
  144. _alias_table.clear();
  145. _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
  146. static_cast<IntType>(0)));
  147. }
  148. bool operator==(const real_alias_table& other) const
  149. {
  150. return _alias_table == other._alias_table;
  151. }
  152. static WeightType normalize(WeightType val, WeightType average)
  153. {
  154. return val / average;
  155. }
  156. static void normalize(std::vector<WeightType>& weights)
  157. {
  158. WeightType sum =
  159. std::accumulate(weights.begin(), weights.end(),
  160. static_cast<WeightType>(0));
  161. for(typename std::vector<WeightType>::iterator
  162. iter = weights.begin(),
  163. end = weights.end();
  164. iter != end; ++iter)
  165. {
  166. *iter /= sum;
  167. }
  168. }
  169. template<class URNG>
  170. WeightType test(URNG &urng) const
  171. {
  172. return uniform_01<WeightType>()(urng);
  173. }
  174. bool accept(IntType, WeightType) const
  175. {
  176. return true;
  177. }
  178. static WeightType try_get_sum(const std::vector<WeightType>& /* weights */)
  179. {
  180. return static_cast<WeightType>(1);
  181. }
  182. template<class URNG>
  183. static WeightType generate_in_range(URNG &urng, WeightType)
  184. {
  185. return uniform_01<WeightType>()(urng);
  186. }
  187. typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
  188. alias_table_t _alias_table;
  189. };
  190. template<bool IsIntegral>
  191. struct select_alias_table;
  192. template<>
  193. struct select_alias_table<true> {
  194. template<class IntType, class WeightType>
  195. struct apply {
  196. typedef integer_alias_table<IntType, WeightType> type;
  197. };
  198. };
  199. template<>
  200. struct select_alias_table<false> {
  201. template<class IntType, class WeightType>
  202. struct apply {
  203. typedef real_alias_table<IntType, WeightType> type;
  204. };
  205. };
  206. }
  207. /**
  208. * The class @c discrete_distribution models a \random_distribution.
  209. * It produces integers in the range [0, n) with the probability
  210. * of producing each value is specified by the parameters of the
  211. * distribution.
  212. */
  213. template<class IntType = int, class WeightType = double>
  214. class discrete_distribution {
  215. public:
  216. typedef WeightType input_type;
  217. typedef IntType result_type;
  218. class param_type {
  219. public:
  220. typedef discrete_distribution distribution_type;
  221. /**
  222. * Constructs a @c param_type object, representing a distribution
  223. * with \f$p(0) = 1\f$ and \f$p(k|k>0) = 0\f$.
  224. */
  225. param_type() : _probabilities(1, static_cast<WeightType>(1)) {}
  226. /**
  227. * If @c first == @c last, equivalent to the default constructor.
  228. * Otherwise, the values of the range represent weights for the
  229. * possible values of the distribution.
  230. */
  231. template<class Iter>
  232. param_type(Iter first, Iter last) : _probabilities(first, last)
  233. {
  234. normalize();
  235. }
  236. #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST
  237. /**
  238. * If wl.size() == 0, equivalent to the default constructor.
  239. * Otherwise, the values of the @c initializer_list represent
  240. * weights for the possible values of the distribution.
  241. */
  242. param_type(const std::initializer_list<WeightType>& wl)
  243. : _probabilities(wl)
  244. {
  245. normalize();
  246. }
  247. #endif
  248. /**
  249. * If the range is empty, equivalent to the default constructor.
  250. * Otherwise, the elements of the range represent
  251. * weights for the possible values of the distribution.
  252. */
  253. template<class Range>
  254. explicit param_type(const Range& range)
  255. : _probabilities(boost::begin(range), boost::end(range))
  256. {
  257. normalize();
  258. }
  259. /**
  260. * If nw is zero, equivalent to the default constructor.
  261. * Otherwise, the range of the distribution is [0, nw),
  262. * and the weights are found by calling fw with values
  263. * evenly distributed between \f$\mbox{xmin} + \delta/2\f$ and
  264. * \f$\mbox{xmax} - \delta/2\f$, where
  265. * \f$\delta = (\mbox{xmax} - \mbox{xmin})/\mbox{nw}\f$.
  266. */
  267. template<class Func>
  268. param_type(std::size_t nw, double xmin, double xmax, Func fw)
  269. {
  270. std::size_t n = (nw == 0) ? 1 : nw;
  271. double delta = (xmax - xmin) / n;
  272. BOOST_ASSERT(delta > 0);
  273. for(std::size_t k = 0; k < n; ++k) {
  274. _probabilities.push_back(fw(xmin + k*delta + delta/2));
  275. }
  276. normalize();
  277. }
  278. /**
  279. * Returns a vector containing the probabilities of each possible
  280. * value of the distribution.
  281. */
  282. std::vector<WeightType> probabilities() const
  283. {
  284. return _probabilities;
  285. }
  286. /** Writes the parameters to a @c std::ostream. */
  287. BOOST_RANDOM_DETAIL_OSTREAM_OPERATOR(os, param_type, parm)
  288. {
  289. detail::print_vector(os, parm._probabilities);
  290. return os;
  291. }
  292. /** Reads the parameters from a @c std::istream. */
  293. BOOST_RANDOM_DETAIL_ISTREAM_OPERATOR(is, param_type, parm)
  294. {
  295. std::vector<WeightType> temp;
  296. detail::read_vector(is, temp);
  297. if(is) {
  298. parm._probabilities.swap(temp);
  299. }
  300. return is;
  301. }
  302. /** Returns true if the two sets of parameters are the same. */
  303. BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(param_type, lhs, rhs)
  304. {
  305. return lhs._probabilities == rhs._probabilities;
  306. }
  307. /** Returns true if the two sets of parameters are different. */
  308. BOOST_RANDOM_DETAIL_INEQUALITY_OPERATOR(param_type)
  309. private:
  310. /// @cond show_private
  311. friend class discrete_distribution;
  312. explicit param_type(const discrete_distribution& dist)
  313. : _probabilities(dist.probabilities())
  314. {}
  315. void normalize()
  316. {
  317. impl_type::normalize(_probabilities);
  318. }
  319. std::vector<WeightType> _probabilities;
  320. /// @endcond
  321. };
  322. /**
  323. * Creates a new @c discrete_distribution object that has
  324. * \f$p(0) = 1\f$ and \f$p(i|i>0) = 0\f$.
  325. */
  326. discrete_distribution()
  327. {
  328. _impl.init_empty();
  329. }
  330. /**
  331. * Constructs a discrete_distribution from an iterator range.
  332. * If @c first == @c last, equivalent to the default constructor.
  333. * Otherwise, the values of the range represent weights for the
  334. * possible values of the distribution.
  335. */
  336. template<class Iter>
  337. discrete_distribution(Iter first, Iter last)
  338. {
  339. init(first, last);
  340. }
  341. #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST
  342. /**
  343. * Constructs a @c discrete_distribution from a @c std::initializer_list.
  344. * If the @c initializer_list is empty, equivalent to the default
  345. * constructor. Otherwise, the values of the @c initializer_list
  346. * represent weights for the possible values of the distribution.
  347. * For example, given the distribution
  348. *
  349. * @code
  350. * discrete_distribution<> dist{1, 4, 5};
  351. * @endcode
  352. *
  353. * The probability of a 0 is 1/10, the probability of a 1 is 2/5,
  354. * the probability of a 2 is 1/2, and no other values are possible.
  355. */
  356. discrete_distribution(std::initializer_list<WeightType> wl)
  357. {
  358. init(wl.begin(), wl.end());
  359. }
  360. #endif
  361. /**
  362. * Constructs a discrete_distribution from a Boost.Range range.
  363. * If the range is empty, equivalent to the default constructor.
  364. * Otherwise, the values of the range represent weights for the
  365. * possible values of the distribution.
  366. */
  367. template<class Range>
  368. explicit discrete_distribution(const Range& range)
  369. {
  370. init(boost::begin(range), boost::end(range));
  371. }
  372. /**
  373. * Constructs a discrete_distribution that approximates a function.
  374. * If nw is zero, equivalent to the default constructor.
  375. * Otherwise, the range of the distribution is [0, nw),
  376. * and the weights are found by calling fw with values
  377. * evenly distributed between \f$\mbox{xmin} + \delta/2\f$ and
  378. * \f$\mbox{xmax} - \delta/2\f$, where
  379. * \f$\delta = (\mbox{xmax} - \mbox{xmin})/\mbox{nw}\f$.
  380. */
  381. template<class Func>
  382. discrete_distribution(std::size_t nw, double xmin, double xmax, Func fw)
  383. {
  384. std::size_t n = (nw == 0) ? 1 : nw;
  385. double delta = (xmax - xmin) / n;
  386. BOOST_ASSERT(delta > 0);
  387. std::vector<WeightType> weights;
  388. for(std::size_t k = 0; k < n; ++k) {
  389. weights.push_back(fw(xmin + k*delta + delta/2));
  390. }
  391. init(weights.begin(), weights.end());
  392. }
  393. /**
  394. * Constructs a discrete_distribution from its parameters.
  395. */
  396. explicit discrete_distribution(const param_type& parm)
  397. {
  398. param(parm);
  399. }
  400. /**
  401. * Returns a value distributed according to the parameters of the
  402. * discrete_distribution.
  403. */
  404. template<class URNG>
  405. IntType operator()(URNG& urng) const
  406. {
  407. BOOST_ASSERT(!_impl._alias_table.empty());
  408. IntType result;
  409. WeightType test;
  410. do {
  411. result = uniform_int_distribution<IntType>((min)(), (max)())(urng);
  412. test = _impl.test(urng);
  413. } while(!_impl.accept(result, test));
  414. if(test < _impl._alias_table[static_cast<std::size_t>(result)].first) {
  415. return result;
  416. } else {
  417. return(_impl._alias_table[static_cast<std::size_t>(result)].second);
  418. }
  419. }
  420. /**
  421. * Returns a value distributed according to the parameters
  422. * specified by param.
  423. */
  424. template<class URNG>
  425. IntType operator()(URNG& urng, const param_type& parm) const
  426. {
  427. if(WeightType limit = impl_type::try_get_sum(parm._probabilities)) {
  428. WeightType val = impl_type::generate_in_range(urng, limit);
  429. WeightType sum = 0;
  430. std::size_t result = 0;
  431. for(typename std::vector<WeightType>::const_iterator
  432. iter = parm._probabilities.begin(),
  433. end = parm._probabilities.end();
  434. iter != end; ++iter, ++result)
  435. {
  436. sum += *iter;
  437. if(sum > val) {
  438. return result;
  439. }
  440. }
  441. // This shouldn't be reachable, but round-off error
  442. // can prevent any match from being found when val is
  443. // very close to 1.
  444. return static_cast<IntType>(parm._probabilities.size() - 1);
  445. } else {
  446. // WeightType is integral and sum(parm._probabilities)
  447. // would overflow. Just use the easy solution.
  448. return discrete_distribution(parm)(urng);
  449. }
  450. }
  451. /** Returns the smallest value that the distribution can produce. */
  452. result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; }
  453. /** Returns the largest value that the distribution can produce. */
  454. result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const
  455. { return static_cast<result_type>(_impl._alias_table.size() - 1); }
  456. /**
  457. * Returns a vector containing the probabilities of each
  458. * value of the distribution. For example, given
  459. *
  460. * @code
  461. * discrete_distribution<> dist = { 1, 4, 5 };
  462. * std::vector<double> p = dist.param();
  463. * @endcode
  464. *
  465. * the vector, p will contain {0.1, 0.4, 0.5}.
  466. *
  467. * If @c WeightType is integral, then the weights
  468. * will be returned unchanged.
  469. */
  470. std::vector<WeightType> probabilities() const
  471. {
  472. std::vector<WeightType> result(_impl._alias_table.size(), static_cast<WeightType>(0));
  473. std::size_t i = 0;
  474. for(typename impl_type::alias_table_t::const_iterator
  475. iter = _impl._alias_table.begin(),
  476. end = _impl._alias_table.end();
  477. iter != end; ++iter, ++i)
  478. {
  479. WeightType val = iter->first;
  480. result[i] += val;
  481. result[static_cast<std::size_t>(iter->second)] += _impl.get_weight(i) - val;
  482. }
  483. impl_type::normalize(result);
  484. return(result);
  485. }
  486. /** Returns the parameters of the distribution. */
  487. param_type param() const
  488. {
  489. return param_type(*this);
  490. }
  491. /** Sets the parameters of the distribution. */
  492. void param(const param_type& parm)
  493. {
  494. init(parm._probabilities.begin(), parm._probabilities.end());
  495. }
  496. /**
  497. * Effects: Subsequent uses of the distribution do not depend
  498. * on values produced by any engine prior to invoking reset.
  499. */
  500. void reset() {}
  501. /** Writes a distribution to a @c std::ostream. */
  502. BOOST_RANDOM_DETAIL_OSTREAM_OPERATOR(os, discrete_distribution, dd)
  503. {
  504. os << dd.param();
  505. return os;
  506. }
  507. /** Reads a distribution from a @c std::istream */
  508. BOOST_RANDOM_DETAIL_ISTREAM_OPERATOR(is, discrete_distribution, dd)
  509. {
  510. param_type parm;
  511. if(is >> parm) {
  512. dd.param(parm);
  513. }
  514. return is;
  515. }
  516. /**
  517. * Returns true if the two distributions will return the
  518. * same sequence of values, when passed equal generators.
  519. */
  520. BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(discrete_distribution, lhs, rhs)
  521. {
  522. return lhs._impl == rhs._impl;
  523. }
  524. /**
  525. * Returns true if the two distributions may return different
  526. * sequences of values, when passed equal generators.
  527. */
  528. BOOST_RANDOM_DETAIL_INEQUALITY_OPERATOR(discrete_distribution)
  529. private:
  530. /// @cond show_private
  531. template<class Iter>
  532. void init(Iter first, Iter last, std::input_iterator_tag)
  533. {
  534. std::vector<WeightType> temp(first, last);
  535. init(temp.begin(), temp.end());
  536. }
  537. template<class Iter>
  538. void init(Iter first, Iter last, std::forward_iterator_tag)
  539. {
  540. size_t input_size = std::distance(first, last);
  541. std::vector<std::pair<WeightType, IntType> > below_average;
  542. std::vector<std::pair<WeightType, IntType> > above_average;
  543. below_average.reserve(input_size);
  544. above_average.reserve(input_size);
  545. WeightType weight_average = _impl.init_average(first, last);
  546. WeightType normalized_average = _impl.get_weight(0);
  547. std::size_t i = 0;
  548. for(; first != last; ++first, ++i) {
  549. WeightType val = impl_type::normalize(*first, weight_average);
  550. std::pair<WeightType, IntType> elem(val, static_cast<IntType>(i));
  551. if(val < normalized_average) {
  552. below_average.push_back(elem);
  553. } else {
  554. above_average.push_back(elem);
  555. }
  556. }
  557. typename impl_type::alias_table_t::iterator
  558. b_iter = below_average.begin(),
  559. b_end = below_average.end(),
  560. a_iter = above_average.begin(),
  561. a_end = above_average.end()
  562. ;
  563. while(b_iter != b_end && a_iter != a_end) {
  564. _impl._alias_table[static_cast<std::size_t>(b_iter->second)] =
  565. std::make_pair(b_iter->first, a_iter->second);
  566. a_iter->first -= (_impl.get_weight(b_iter->second) - b_iter->first);
  567. if(a_iter->first < normalized_average) {
  568. *b_iter = *a_iter++;
  569. } else {
  570. ++b_iter;
  571. }
  572. }
  573. for(; b_iter != b_end; ++b_iter) {
  574. _impl._alias_table[static_cast<std::size_t>(b_iter->second)].first =
  575. _impl.get_weight(b_iter->second);
  576. }
  577. for(; a_iter != a_end; ++a_iter) {
  578. _impl._alias_table[static_cast<std::size_t>(a_iter->second)].first =
  579. _impl.get_weight(a_iter->second);
  580. }
  581. }
  582. template<class Iter>
  583. void init(Iter first, Iter last)
  584. {
  585. if(first == last) {
  586. _impl.init_empty();
  587. } else {
  588. typename std::iterator_traits<Iter>::iterator_category category;
  589. init(first, last, category);
  590. }
  591. }
  592. typedef typename detail::select_alias_table<
  593. (::boost::is_integral<WeightType>::value)
  594. >::template apply<IntType, WeightType>::type impl_type;
  595. impl_type _impl;
  596. /// @endcond
  597. };
  598. }
  599. }
  600. #include <boost/random/detail/enable_warnings.hpp>
  601. #endif