fast_uniform_bits.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_
  15. #define ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_
  16. #include <cstddef>
  17. #include <cstdint>
  18. #include <limits>
  19. #include <type_traits>
  20. #include "absl/base/config.h"
  21. #include "absl/meta/type_traits.h"
  22. #include "absl/random/internal/traits.h"
  23. namespace absl {
  24. ABSL_NAMESPACE_BEGIN
  25. namespace random_internal {
  26. // Returns true if the input value is zero or a power of two. Useful for
  27. // determining if the range of output values in a URBG
  28. template <typename UIntType>
  29. constexpr bool IsPowerOfTwoOrZero(UIntType n) {
  30. return (n == 0) || ((n & (n - 1)) == 0);
  31. }
  32. // Computes the length of the range of values producible by the URBG, or returns
  33. // zero if that would encompass the entire range of representable values in
  34. // URBG::result_type.
  35. template <typename URBG>
  36. constexpr typename URBG::result_type RangeSize() {
  37. using result_type = typename URBG::result_type;
  38. static_assert((URBG::max)() != (URBG::min)(), "URBG range cannot be 0.");
  39. return ((URBG::max)() == (std::numeric_limits<result_type>::max)() &&
  40. (URBG::min)() == std::numeric_limits<result_type>::lowest())
  41. ? result_type{0}
  42. : ((URBG::max)() - (URBG::min)() + result_type{1});
  43. }
  44. // Computes the floor of the log. (i.e., std::floor(std::log2(N));
  45. template <typename UIntType>
  46. constexpr UIntType IntegerLog2(UIntType n) {
  47. return (n <= 1) ? 0 : 1 + IntegerLog2(n >> 1);
  48. }
  49. // Returns the number of bits of randomness returned through
  50. // `PowerOfTwoVariate(urbg)`.
  51. template <typename URBG>
  52. constexpr size_t NumBits() {
  53. return static_cast<size_t>(
  54. RangeSize<URBG>() == 0
  55. ? std::numeric_limits<typename URBG::result_type>::digits
  56. : IntegerLog2(RangeSize<URBG>()));
  57. }
  58. // Given a shift value `n`, constructs a mask with exactly the low `n` bits set.
  59. // If `n == 0`, all bits are set.
  60. template <typename UIntType>
  61. constexpr UIntType MaskFromShift(size_t n) {
  62. return ((n % std::numeric_limits<UIntType>::digits) == 0)
  63. ? ~UIntType{0}
  64. : (UIntType{1} << n) - UIntType{1};
  65. }
  66. // Tags used to dispatch FastUniformBits::generate to the simple or more complex
  67. // entropy extraction algorithm.
  68. struct SimplifiedLoopTag {};
  69. struct RejectionLoopTag {};
  70. // FastUniformBits implements a fast path to acquire uniform independent bits
  71. // from a type which conforms to the [rand.req.urbg] concept.
  72. // Parameterized by:
  73. // `UIntType`: the result (output) type
  74. //
  75. // The std::independent_bits_engine [rand.adapt.ibits] adaptor can be
  76. // instantiated from an existing generator through a copy or a move. It does
  77. // not, however, facilitate the production of pseudorandom bits from an un-owned
  78. // generator that will outlive the std::independent_bits_engine instance.
  79. template <typename UIntType = uint64_t>
  80. class FastUniformBits {
  81. public:
  82. using result_type = UIntType;
  83. static constexpr result_type(min)() { return 0; }
  84. static constexpr result_type(max)() {
  85. return (std::numeric_limits<result_type>::max)();
  86. }
  87. template <typename URBG>
  88. result_type operator()(URBG& g); // NOLINT(runtime/references)
  89. private:
  90. static_assert(IsUnsigned<UIntType>::value,
  91. "Class-template FastUniformBits<> must be parameterized using "
  92. "an unsigned type.");
  93. // Generate() generates a random value, dispatched on whether
  94. // the underlying URBG must use rejection sampling to generate a value,
  95. // or whether a simplified loop will suffice.
  96. template <typename URBG>
  97. result_type Generate(URBG& g, // NOLINT(runtime/references)
  98. SimplifiedLoopTag);
  99. template <typename URBG>
  100. result_type Generate(URBG& g, // NOLINT(runtime/references)
  101. RejectionLoopTag);
  102. };
  103. template <typename UIntType>
  104. template <typename URBG>
  105. typename FastUniformBits<UIntType>::result_type
  106. FastUniformBits<UIntType>::operator()(URBG& g) { // NOLINT(runtime/references)
  107. // kRangeMask is the mask used when sampling variates from the URBG when the
  108. // width of the URBG range is not a power of 2.
  109. // Y = (2 ^ kRange) - 1
  110. static_assert((URBG::max)() > (URBG::min)(),
  111. "URBG::max and URBG::min may not be equal.");
  112. using tag = absl::conditional_t<IsPowerOfTwoOrZero(RangeSize<URBG>()),
  113. SimplifiedLoopTag, RejectionLoopTag>;
  114. return Generate(g, tag{});
  115. }
  116. template <typename UIntType>
  117. template <typename URBG>
  118. typename FastUniformBits<UIntType>::result_type
  119. FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
  120. SimplifiedLoopTag) {
  121. // The simplified version of FastUniformBits works only on URBGs that have
  122. // a range that is a power of 2. In this case we simply loop and shift without
  123. // attempting to balance the bits across calls.
  124. static_assert(IsPowerOfTwoOrZero(RangeSize<URBG>()),
  125. "incorrect Generate tag for URBG instance");
  126. static constexpr size_t kResultBits =
  127. std::numeric_limits<result_type>::digits;
  128. static constexpr size_t kUrbgBits = NumBits<URBG>();
  129. static constexpr size_t kIters =
  130. (kResultBits / kUrbgBits) + (kResultBits % kUrbgBits != 0);
  131. static constexpr size_t kShift = (kIters == 1) ? 0 : kUrbgBits;
  132. static constexpr auto kMin = (URBG::min)();
  133. result_type r = static_cast<result_type>(g() - kMin);
  134. for (size_t n = 1; n < kIters; ++n) {
  135. r = static_cast<result_type>(r << kShift) +
  136. static_cast<result_type>(g() - kMin);
  137. }
  138. return r;
  139. }
  140. template <typename UIntType>
  141. template <typename URBG>
  142. typename FastUniformBits<UIntType>::result_type
  143. FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references)
  144. RejectionLoopTag) {
  145. static_assert(!IsPowerOfTwoOrZero(RangeSize<URBG>()),
  146. "incorrect Generate tag for URBG instance");
  147. using urbg_result_type = typename URBG::result_type;
  148. // See [rand.adapt.ibits] for more details on the constants calculated below.
  149. //
  150. // It is preferable to use roughly the same number of bits from each generator
  151. // call, however this is only possible when the number of bits provided by the
  152. // URBG is a divisor of the number of bits in `result_type`. In all other
  153. // cases, the number of bits used cannot always be the same, but it can be
  154. // guaranteed to be off by at most 1. Thus we run two loops, one with a
  155. // smaller bit-width size (`kSmallWidth`) and one with a larger width size
  156. // (satisfying `kLargeWidth == kSmallWidth + 1`). The loops are run
  157. // `kSmallIters` and `kLargeIters` times respectively such
  158. // that
  159. //
  160. // `kResultBits == kSmallIters * kSmallBits
  161. // + kLargeIters * kLargeBits`
  162. //
  163. // where `kResultBits` is the total number of bits in `result_type`.
  164. //
  165. static constexpr size_t kResultBits =
  166. std::numeric_limits<result_type>::digits; // w
  167. static constexpr urbg_result_type kUrbgRange = RangeSize<URBG>(); // R
  168. static constexpr size_t kUrbgBits = NumBits<URBG>(); // m
  169. // compute the initial estimate of the bits used.
  170. // [rand.adapt.ibits] 2 (c)
  171. static constexpr size_t kA = // ceil(w/m)
  172. (kResultBits / kUrbgBits) + ((kResultBits % kUrbgBits) != 0); // n'
  173. static constexpr size_t kABits = kResultBits / kA; // w0'
  174. static constexpr urbg_result_type kARejection =
  175. ((kUrbgRange >> kABits) << kABits); // y0'
  176. // refine the selection to reduce the rejection frequency.
  177. static constexpr size_t kTotalIters =
  178. ((kUrbgRange - kARejection) <= (kARejection / kA)) ? kA : (kA + 1); // n
  179. // [rand.adapt.ibits] 2 (b)
  180. static constexpr size_t kSmallIters =
  181. kTotalIters - (kResultBits % kTotalIters); // n0
  182. static constexpr size_t kSmallBits = kResultBits / kTotalIters; // w0
  183. static constexpr urbg_result_type kSmallRejection =
  184. ((kUrbgRange >> kSmallBits) << kSmallBits); // y0
  185. static constexpr size_t kLargeBits = kSmallBits + 1; // w0+1
  186. static constexpr urbg_result_type kLargeRejection =
  187. ((kUrbgRange >> kLargeBits) << kLargeBits); // y1
  188. //
  189. // Because `kLargeBits == kSmallBits + 1`, it follows that
  190. //
  191. // `kResultBits == kSmallIters * kSmallBits + kLargeIters`
  192. //
  193. // and therefore
  194. //
  195. // `kLargeIters == kTotalWidth % kSmallWidth`
  196. //
  197. // Intuitively, each iteration with the large width accounts for one unit
  198. // of the remainder when `kTotalWidth` is divided by `kSmallWidth`. As
  199. // mentioned above, if the URBG width is a divisor of `kTotalWidth`, then
  200. // there would be no need for any large iterations (i.e., one loop would
  201. // suffice), and indeed, in this case, `kLargeIters` would be zero.
  202. static_assert(kResultBits == kSmallIters * kSmallBits +
  203. (kTotalIters - kSmallIters) * kLargeBits,
  204. "Error in looping constant calculations.");
  205. // The small shift is essentially small bits, but due to the potential
  206. // of generating a smaller result_type from a larger urbg type, the actual
  207. // shift might be 0.
  208. static constexpr size_t kSmallShift = kSmallBits % kResultBits;
  209. static constexpr auto kSmallMask =
  210. MaskFromShift<urbg_result_type>(kSmallShift);
  211. static constexpr size_t kLargeShift = kLargeBits % kResultBits;
  212. static constexpr auto kLargeMask =
  213. MaskFromShift<urbg_result_type>(kLargeShift);
  214. static constexpr auto kMin = (URBG::min)();
  215. result_type s = 0;
  216. for (size_t n = 0; n < kSmallIters; ++n) {
  217. urbg_result_type v;
  218. do {
  219. v = g() - kMin;
  220. } while (v >= kSmallRejection);
  221. s = (s << kSmallShift) + static_cast<result_type>(v & kSmallMask);
  222. }
  223. for (size_t n = kSmallIters; n < kTotalIters; ++n) {
  224. urbg_result_type v;
  225. do {
  226. v = g() - kMin;
  227. } while (v >= kLargeRejection);
  228. s = (s << kLargeShift) + static_cast<result_type>(v & kLargeMask);
  229. }
  230. return s;
  231. }
  232. } // namespace random_internal
  233. ABSL_NAMESPACE_END
  234. } // namespace absl
  235. #endif // ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_