binomial_distribution.hpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. /* boost random/binomial_distribution.hpp header file
  2. *
  3. * Copyright Steven Watanabe 2010
  4. * Distributed under the Boost Software License, Version 1.0. (See
  5. * accompanying file LICENSE_1_0.txt or copy at
  6. * http://www.boost.org/LICENSE_1_0.txt)
  7. *
  8. * See http://www.boost.org for most recent version including documentation.
  9. *
  10. * $Id$
  11. */
  12. #ifndef BOOST_RANDOM_BINOMIAL_DISTRIBUTION_HPP_INCLUDED
  13. #define BOOST_RANDOM_BINOMIAL_DISTRIBUTION_HPP_INCLUDED
  14. #include <boost/config/no_tr1/cmath.hpp>
  15. #include <cstdlib>
  16. #include <iosfwd>
  17. #include <boost/random/detail/config.hpp>
  18. #include <boost/random/uniform_01.hpp>
  19. #include <boost/random/detail/disable_warnings.hpp>
  20. namespace boost {
  21. namespace random {
  22. namespace detail {
  23. template<class RealType>
  24. struct binomial_table {
  25. static const RealType table[10];
  26. };
  27. template<class RealType>
  28. const RealType binomial_table<RealType>::table[10] = {
  29. 0.08106146679532726,
  30. 0.04134069595540929,
  31. 0.02767792568499834,
  32. 0.02079067210376509,
  33. 0.01664469118982119,
  34. 0.01387612882307075,
  35. 0.01189670994589177,
  36. 0.01041126526197209,
  37. 0.009255462182712733,
  38. 0.008330563433362871
  39. };
  40. }
  41. /**
  42. * The binomial distribution is an integer valued distribution with
  43. * two parameters, @c t and @c p. The values of the distribution
  44. * are within the range [0,t].
  45. *
  46. * The distribution function is
  47. * \f$\displaystyle P(k) = {t \choose k}p^k(1-p)^{t-k}\f$.
  48. *
  49. * The algorithm used is the BTRD algorithm described in
  50. *
  51. * @blockquote
  52. * "The generation of binomial random variates", Wolfgang Hormann,
  53. * Journal of Statistical Computation and Simulation, Volume 46,
  54. * Issue 1 & 2 April 1993 , pages 101 - 110
  55. * @endblockquote
  56. */
  57. template<class IntType = int, class RealType = double>
  58. class binomial_distribution {
  59. public:
  60. typedef IntType result_type;
  61. typedef RealType input_type;
  62. class param_type {
  63. public:
  64. typedef binomial_distribution distribution_type;
  65. /**
  66. * Construct a param_type object. @c t and @c p
  67. * are the parameters of the distribution.
  68. *
  69. * Requires: t >=0 && 0 <= p <= 1
  70. */
  71. explicit param_type(IntType t_arg = 1, RealType p_arg = RealType (0.5))
  72. : _t(t_arg), _p(p_arg)
  73. {}
  74. /** Returns the @c t parameter of the distribution. */
  75. IntType t() const { return _t; }
  76. /** Returns the @c p parameter of the distribution. */
  77. RealType p() const { return _p; }
  78. #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS
  79. /** Writes the parameters of the distribution to a @c std::ostream. */
  80. template<class CharT, class Traits>
  81. friend std::basic_ostream<CharT,Traits>&
  82. operator<<(std::basic_ostream<CharT,Traits>& os,
  83. const param_type& parm)
  84. {
  85. os << parm._p << " " << parm._t;
  86. return os;
  87. }
  88. /** Reads the parameters of the distribution from a @c std::istream. */
  89. template<class CharT, class Traits>
  90. friend std::basic_istream<CharT,Traits>&
  91. operator>>(std::basic_istream<CharT,Traits>& is, param_type& parm)
  92. {
  93. is >> parm._p >> std::ws >> parm._t;
  94. return is;
  95. }
  96. #endif
  97. /** Returns true if the parameters have the same values. */
  98. friend bool operator==(const param_type& lhs, const param_type& rhs)
  99. {
  100. return lhs._t == rhs._t && lhs._p == rhs._p;
  101. }
  102. /** Returns true if the parameters have different values. */
  103. friend bool operator!=(const param_type& lhs, const param_type& rhs)
  104. {
  105. return !(lhs == rhs);
  106. }
  107. private:
  108. IntType _t;
  109. RealType _p;
  110. };
  111. /**
  112. * Construct a @c binomial_distribution object. @c t and @c p
  113. * are the parameters of the distribution.
  114. *
  115. * Requires: t >=0 && 0 <= p <= 1
  116. */
  117. explicit binomial_distribution(IntType t_arg = 1,
  118. RealType p_arg = RealType(0.5))
  119. : _t(t_arg), _p(p_arg)
  120. {
  121. init();
  122. }
  123. /**
  124. * Construct an @c binomial_distribution object from the
  125. * parameters.
  126. */
  127. explicit binomial_distribution(const param_type& parm)
  128. : _t(parm.t()), _p(parm.p())
  129. {
  130. init();
  131. }
  132. /**
  133. * Returns a random variate distributed according to the
  134. * binomial distribution.
  135. */
  136. template<class URNG>
  137. IntType operator()(URNG& urng) const
  138. {
  139. if(use_inversion()) {
  140. if(0.5 < _p) {
  141. return _t - invert(_t, 1-_p, urng);
  142. } else {
  143. return invert(_t, _p, urng);
  144. }
  145. } else if(0.5 < _p) {
  146. return _t - generate(urng);
  147. } else {
  148. return generate(urng);
  149. }
  150. }
  151. /**
  152. * Returns a random variate distributed according to the
  153. * binomial distribution with parameters specified by @c param.
  154. */
  155. template<class URNG>
  156. IntType operator()(URNG& urng, const param_type& parm) const
  157. {
  158. return binomial_distribution(parm)(urng);
  159. }
  160. /** Returns the @c t parameter of the distribution. */
  161. IntType t() const { return _t; }
  162. /** Returns the @c p parameter of the distribution. */
  163. RealType p() const { return _p; }
  164. /** Returns the smallest value that the distribution can produce. */
  165. IntType min BOOST_PREVENT_MACRO_SUBSTITUTION() const { return 0; }
  166. /** Returns the largest value that the distribution can produce. */
  167. IntType max BOOST_PREVENT_MACRO_SUBSTITUTION() const { return _t; }
  168. /** Returns the parameters of the distribution. */
  169. param_type param() const { return param_type(_t, _p); }
  170. /** Sets parameters of the distribution. */
  171. void param(const param_type& parm)
  172. {
  173. _t = parm.t();
  174. _p = parm.p();
  175. init();
  176. }
  177. /**
  178. * Effects: Subsequent uses of the distribution do not depend
  179. * on values produced by any engine prior to invoking reset.
  180. */
  181. void reset() { }
  182. #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS
  183. /** Writes the parameters of the distribution to a @c std::ostream. */
  184. template<class CharT, class Traits>
  185. friend std::basic_ostream<CharT,Traits>&
  186. operator<<(std::basic_ostream<CharT,Traits>& os,
  187. const binomial_distribution& bd)
  188. {
  189. os << bd.param();
  190. return os;
  191. }
  192. /** Reads the parameters of the distribution from a @c std::istream. */
  193. template<class CharT, class Traits>
  194. friend std::basic_istream<CharT,Traits>&
  195. operator>>(std::basic_istream<CharT,Traits>& is, binomial_distribution& bd)
  196. {
  197. bd.read(is);
  198. return is;
  199. }
  200. #endif
  201. /** Returns true if the two distributions will produce the same
  202. sequence of values, given equal generators. */
  203. friend bool operator==(const binomial_distribution& lhs,
  204. const binomial_distribution& rhs)
  205. {
  206. return lhs._t == rhs._t && lhs._p == rhs._p;
  207. }
  208. /** Returns true if the two distributions could produce different
  209. sequences of values, given equal generators. */
  210. friend bool operator!=(const binomial_distribution& lhs,
  211. const binomial_distribution& rhs)
  212. {
  213. return !(lhs == rhs);
  214. }
  215. private:
  216. /// @cond show_private
  217. template<class CharT, class Traits>
  218. void read(std::basic_istream<CharT, Traits>& is) {
  219. param_type parm;
  220. if(is >> parm) {
  221. param(parm);
  222. }
  223. }
  224. bool use_inversion() const
  225. {
  226. // BTRD is safe when np >= 10
  227. return m < 11;
  228. }
  229. // computes the correction factor for the Stirling approximation
  230. // for log(k!)
  231. static RealType fc(IntType k)
  232. {
  233. if(k < 10) return detail::binomial_table<RealType>::table[k];
  234. else {
  235. RealType ikp1 = RealType(1) / (k + 1);
  236. return (RealType(1)/12
  237. - (RealType(1)/360
  238. - (RealType(1)/1260)*(ikp1*ikp1))*(ikp1*ikp1))*ikp1;
  239. }
  240. }
  241. void init()
  242. {
  243. using std::sqrt;
  244. using std::pow;
  245. RealType p = (0.5 < _p)? (1 - _p) : _p;
  246. IntType t = _t;
  247. m = static_cast<IntType>((t+1)*p);
  248. if(use_inversion()) {
  249. _u.q_n = pow((1 - p), static_cast<RealType>(t));
  250. } else {
  251. _u.btrd.r = p/(1-p);
  252. _u.btrd.nr = (t+1)*_u.btrd.r;
  253. _u.btrd.npq = t*p*(1-p);
  254. RealType sqrt_npq = sqrt(_u.btrd.npq);
  255. _u.btrd.b = 1.15 + 2.53 * sqrt_npq;
  256. _u.btrd.a = -0.0873 + 0.0248*_u.btrd.b + 0.01*p;
  257. _u.btrd.c = t*p + 0.5;
  258. _u.btrd.alpha = (2.83 + 5.1/_u.btrd.b) * sqrt_npq;
  259. _u.btrd.v_r = 0.92 - 4.2/_u.btrd.b;
  260. _u.btrd.u_rv_r = 0.86*_u.btrd.v_r;
  261. }
  262. }
  263. template<class URNG>
  264. result_type generate(URNG& urng) const
  265. {
  266. using std::floor;
  267. using std::abs;
  268. using std::log;
  269. while(true) {
  270. RealType u;
  271. RealType v = uniform_01<RealType>()(urng);
  272. if(v <= _u.btrd.u_rv_r) {
  273. u = v/_u.btrd.v_r - 0.43;
  274. return static_cast<IntType>(floor(
  275. (2*_u.btrd.a/(0.5 - abs(u)) + _u.btrd.b)*u + _u.btrd.c));
  276. }
  277. if(v >= _u.btrd.v_r) {
  278. u = uniform_01<RealType>()(urng) - 0.5;
  279. } else {
  280. u = v/_u.btrd.v_r - 0.93;
  281. u = ((u < 0)? -0.5 : 0.5) - u;
  282. v = uniform_01<RealType>()(urng) * _u.btrd.v_r;
  283. }
  284. RealType us = 0.5 - abs(u);
  285. IntType k = static_cast<IntType>(floor((2*_u.btrd.a/us + _u.btrd.b)*u + _u.btrd.c));
  286. if(k < 0 || k > _t) continue;
  287. v = v*_u.btrd.alpha/(_u.btrd.a/(us*us) + _u.btrd.b);
  288. RealType km = abs(k - m);
  289. if(km <= 15) {
  290. RealType f = 1;
  291. if(m < k) {
  292. IntType i = m;
  293. do {
  294. ++i;
  295. f = f*(_u.btrd.nr/i - _u.btrd.r);
  296. } while(i != k);
  297. } else if(m > k) {
  298. IntType i = k;
  299. do {
  300. ++i;
  301. v = v*(_u.btrd.nr/i - _u.btrd.r);
  302. } while(i != m);
  303. }
  304. if(v <= f) return k;
  305. else continue;
  306. } else {
  307. // final acceptance/rejection
  308. v = log(v);
  309. RealType rho =
  310. (km/_u.btrd.npq)*(((km/3. + 0.625)*km + 1./6)/_u.btrd.npq + 0.5);
  311. RealType t = -km*km/(2*_u.btrd.npq);
  312. if(v < t - rho) return k;
  313. if(v > t + rho) continue;
  314. IntType nm = _t - m + 1;
  315. RealType h = (m + 0.5)*log((m + 1)/(_u.btrd.r*nm))
  316. + fc(m) + fc(_t - m);
  317. IntType nk = _t - k + 1;
  318. if(v <= h + (_t+1)*log(static_cast<RealType>(nm)/nk)
  319. + (k + 0.5)*log(nk*_u.btrd.r/(k+1))
  320. - fc(k)
  321. - fc(_t - k))
  322. {
  323. return k;
  324. } else {
  325. continue;
  326. }
  327. }
  328. }
  329. }
  330. template<class URNG>
  331. IntType invert(IntType t, RealType p, URNG& urng) const
  332. {
  333. RealType q = 1 - p;
  334. RealType s = p / q;
  335. RealType a = (t + 1) * s;
  336. RealType r = _u.q_n;
  337. RealType u = uniform_01<RealType>()(urng);
  338. IntType x = 0;
  339. while(u > r) {
  340. u = u - r;
  341. ++x;
  342. RealType r1 = ((a/x) - s) * r;
  343. // If r gets too small then the round-off error
  344. // becomes a problem. At this point, p(i) is
  345. // decreasing exponentially, so if we just call
  346. // it 0, it's close enough. Note that the
  347. // minimum value of q_n is about 1e-7, so we
  348. // may need to be a little careful to make sure that
  349. // we don't terminate the first time through the loop
  350. // for float. (Hence the test that r is decreasing)
  351. if(r1 < std::numeric_limits<RealType>::epsilon() && r1 < r) {
  352. break;
  353. }
  354. r = r1;
  355. }
  356. return x;
  357. }
  358. // parameters
  359. IntType _t;
  360. RealType _p;
  361. // common data
  362. IntType m;
  363. union {
  364. // for btrd
  365. struct {
  366. RealType r;
  367. RealType nr;
  368. RealType npq;
  369. RealType b;
  370. RealType a;
  371. RealType c;
  372. RealType alpha;
  373. RealType v_r;
  374. RealType u_rv_r;
  375. } btrd;
  376. // for inversion
  377. RealType q_n;
  378. } _u;
  379. /// @endcond
  380. };
  381. }
  382. // backwards compatibility
  383. using random::binomial_distribution;
  384. }
  385. #include <boost/random/detail/enable_warnings.hpp>
  386. #endif