discrete_distribution.cc 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "y_absl/random/discrete_distribution.h"
  15. namespace y_absl {
  16. Y_ABSL_NAMESPACE_BEGIN
  17. namespace random_internal {
  18. // Initializes the distribution table for Walker's Aliasing algorithm, described
  19. // in Knuth, Vol 2. as well as in https://en.wikipedia.org/wiki/Alias_method
  20. std::vector<std::pair<double, size_t>> InitDiscreteDistribution(
  21. std::vector<double>* probabilities) {
  22. // The empty-case should already be handled by the constructor.
  23. assert(probabilities);
  24. assert(!probabilities->empty());
  25. // Step 1. Normalize the input probabilities to 1.0.
  26. double sum = std::accumulate(std::begin(*probabilities),
  27. std::end(*probabilities), 0.0);
  28. if (std::fabs(sum - 1.0) > 1e-6) {
  29. // Scale `probabilities` only when the sum is too far from 1.0. Scaling
  30. // unconditionally will alter the probabilities slightly.
  31. for (double& item : *probabilities) {
  32. item = item / sum;
  33. }
  34. }
  35. // Step 2. At this point `probabilities` is set to the conditional
  36. // probabilities of each element which sum to 1.0, to within reasonable error.
  37. // These values are used to construct the proportional probability tables for
  38. // the selection phases of Walker's Aliasing algorithm.
  39. //
  40. // To construct the table, pick an element which is under-full (i.e., an
  41. // element for which `(*probabilities)[i] < 1.0/n`), and pair it with an
  42. // element which is over-full (i.e., an element for which
  43. // `(*probabilities)[i] > 1.0/n`). The smaller value can always be retired.
  44. // The larger may still be greater than 1.0/n, or may now be less than 1.0/n,
  45. // and put back onto the appropriate collection.
  46. const size_t n = probabilities->size();
  47. std::vector<std::pair<double, size_t>> q;
  48. q.reserve(n);
  49. std::vector<size_t> over;
  50. std::vector<size_t> under;
  51. size_t idx = 0;
  52. for (const double item : *probabilities) {
  53. assert(item >= 0);
  54. const double v = item * n;
  55. q.emplace_back(v, 0);
  56. if (v < 1.0) {
  57. under.push_back(idx++);
  58. } else {
  59. over.push_back(idx++);
  60. }
  61. }
  62. while (!over.empty() && !under.empty()) {
  63. auto lo = under.back();
  64. under.pop_back();
  65. auto hi = over.back();
  66. over.pop_back();
  67. q[lo].second = hi;
  68. const double r = q[hi].first - (1.0 - q[lo].first);
  69. q[hi].first = r;
  70. if (r < 1.0) {
  71. under.push_back(hi);
  72. } else {
  73. over.push_back(hi);
  74. }
  75. }
  76. // Due to rounding errors, there may be un-paired elements in either
  77. // collection; these should all be values near 1.0. For these values, set `q`
  78. // to 1.0 and set the alternate to the identity.
  79. for (auto i : over) {
  80. q[i] = {1.0, i};
  81. }
  82. for (auto i : under) {
  83. q[i] = {1.0, i};
  84. }
  85. return q;
  86. }
  87. } // namespace random_internal
  88. Y_ABSL_NAMESPACE_END
  89. } // namespace y_absl