DecomposedFloat.h 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #pragma once
  2. #include <cstdint>
  3. #include <cstddef>
  4. #include <cstring>
  5. #include <base/extended_types.h>
  6. /// Allows to check the internals of IEEE-754 floating point number.
  7. template <typename T> struct FloatTraits;
  8. template <>
  9. struct FloatTraits<float>
  10. {
  11. using UInt = uint32_t;
  12. static constexpr size_t bits = 32;
  13. static constexpr size_t exponent_bits = 8;
  14. static constexpr size_t mantissa_bits = bits - exponent_bits - 1;
  15. };
  16. template <>
  17. struct FloatTraits<double>
  18. {
  19. using UInt = uint64_t;
  20. static constexpr size_t bits = 64;
  21. static constexpr size_t exponent_bits = 11;
  22. static constexpr size_t mantissa_bits = bits - exponent_bits - 1;
  23. };
  24. /// x = sign * (2 ^ normalized_exponent) * (1 + mantissa * 2 ^ -mantissa_bits)
  25. /// x = sign * (2 ^ normalized_exponent + mantissa * 2 ^ (normalized_exponent - mantissa_bits))
  26. template <typename T>
  27. struct DecomposedFloat
  28. {
  29. using Traits = FloatTraits<T>;
  30. explicit DecomposedFloat(T x)
  31. {
  32. memcpy(&x_uint, &x, sizeof(x));
  33. }
  34. typename Traits::UInt x_uint;
  35. bool isNegative() const
  36. {
  37. return x_uint >> (Traits::bits - 1);
  38. }
  39. /// Returns 0 for both +0. and -0.
  40. int sign() const
  41. {
  42. return (exponent() == 0 && mantissa() == 0)
  43. ? 0
  44. : (isNegative()
  45. ? -1
  46. : 1);
  47. }
  48. uint16_t exponent() const
  49. {
  50. return (x_uint >> (Traits::mantissa_bits)) & (((1ull << (Traits::exponent_bits + 1)) - 1) >> 1);
  51. }
  52. int16_t normalizedExponent() const
  53. {
  54. return int16_t(exponent()) - ((1ull << (Traits::exponent_bits - 1)) - 1);
  55. }
  56. uint64_t mantissa() const
  57. {
  58. return x_uint & ((1ull << Traits::mantissa_bits) - 1);
  59. }
  60. int64_t mantissaWithSign() const
  61. {
  62. return isNegative() ? -mantissa() : mantissa();
  63. }
  64. /// NOTE Probably floating point instructions can be better.
  65. bool isIntegerInRepresentableRange() const
  66. {
  67. return x_uint == 0
  68. || (normalizedExponent() >= 0 /// The number is not less than one
  69. /// The number is inside the range where every integer has exact representation in float
  70. && normalizedExponent() <= static_cast<int16_t>(Traits::mantissa_bits)
  71. /// After multiplying by 2^exp, the fractional part becomes zero, means the number is integer
  72. && ((mantissa() & ((1ULL << (Traits::mantissa_bits - normalizedExponent())) - 1)) == 0));
  73. }
  74. /// Compare float with integer of arbitrary width (both signed and unsigned are supported). Assuming two's complement arithmetic.
  75. /// This function is generic, big integers (128, 256 bit) are supported as well.
  76. /// Infinities are compared correctly. NaNs are treat similarly to infinities, so they can be less than all numbers.
  77. /// (note that we need total order)
  78. /// Returns -1, 0 or 1.
  79. template <typename Int>
  80. int compare(Int rhs) const
  81. {
  82. if (rhs == 0)
  83. return sign();
  84. /// Different signs
  85. if (isNegative() && rhs > 0)
  86. return -1;
  87. if (!isNegative() && rhs < 0)
  88. return 1;
  89. /// Fractional number with magnitude less than one
  90. if (normalizedExponent() < 0)
  91. {
  92. if (!isNegative())
  93. return rhs > 0 ? -1 : 1;
  94. else
  95. return rhs >= 0 ? -1 : 1;
  96. }
  97. /// The case of the most negative integer
  98. if constexpr (is_signed_v<Int>)
  99. {
  100. if (rhs == std::numeric_limits<Int>::lowest())
  101. {
  102. assert(isNegative());
  103. if (normalizedExponent() < static_cast<int16_t>(8 * sizeof(Int) - is_signed_v<Int>))
  104. return 1;
  105. if (normalizedExponent() > static_cast<int16_t>(8 * sizeof(Int) - is_signed_v<Int>))
  106. return -1;
  107. if (mantissa() == 0)
  108. return 0;
  109. else
  110. return -1;
  111. }
  112. }
  113. /// Too large number: abs(float) > abs(rhs). Also the case with infinities and NaN.
  114. if (normalizedExponent() >= static_cast<int16_t>(8 * sizeof(Int) - is_signed_v<Int>))
  115. return isNegative() ? -1 : 1;
  116. using UInt = std::conditional_t<(sizeof(Int) > sizeof(typename Traits::UInt)), make_unsigned_t<Int>, typename Traits::UInt>;
  117. UInt uint_rhs = rhs < 0 ? -rhs : rhs;
  118. /// Smaller octave: abs(rhs) < abs(float)
  119. /// FYI, TIL: octave is also called "binade", https://en.wikipedia.org/wiki/Binade
  120. if (uint_rhs < (static_cast<UInt>(1) << normalizedExponent()))
  121. return isNegative() ? -1 : 1;
  122. /// Larger octave: abs(rhs) > abs(float)
  123. if (normalizedExponent() + 1 < static_cast<int16_t>(8 * sizeof(Int) - is_signed_v<Int>)
  124. && uint_rhs >= (static_cast<UInt>(1) << (normalizedExponent() + 1)))
  125. return isNegative() ? 1 : -1;
  126. /// The same octave
  127. /// uint_rhs == 2 ^ normalizedExponent + mantissa * 2 ^ (normalizedExponent - mantissa_bits)
  128. bool large_and_always_integer = normalizedExponent() >= static_cast<int16_t>(Traits::mantissa_bits);
  129. UInt a = large_and_always_integer
  130. ? static_cast<UInt>(mantissa()) << (normalizedExponent() - Traits::mantissa_bits)
  131. : static_cast<UInt>(mantissa()) >> (Traits::mantissa_bits - normalizedExponent());
  132. UInt b = uint_rhs - (static_cast<UInt>(1) << normalizedExponent());
  133. if (a < b)
  134. return isNegative() ? 1 : -1;
  135. if (a > b)
  136. return isNegative() ? -1 : 1;
  137. /// Float has no fractional part means that the numbers are equal.
  138. if (large_and_always_integer || (mantissa() & ((1ULL << (Traits::mantissa_bits - normalizedExponent())) - 1)) == 0)
  139. return 0;
  140. else
  141. /// Float has fractional part means its abs value is larger.
  142. return isNegative() ? -1 : 1;
  143. }
  144. template <typename Int>
  145. bool equals(Int rhs) const
  146. {
  147. return compare(rhs) == 0;
  148. }
  149. template <typename Int>
  150. bool notEquals(Int rhs) const
  151. {
  152. return compare(rhs) != 0;
  153. }
  154. template <typename Int>
  155. bool less(Int rhs) const
  156. {
  157. return compare(rhs) < 0;
  158. }
  159. template <typename Int>
  160. bool greater(Int rhs) const
  161. {
  162. return compare(rhs) > 0;
  163. }
  164. template <typename Int>
  165. bool lessOrEquals(Int rhs) const
  166. {
  167. return compare(rhs) <= 0;
  168. }
  169. template <typename Int>
  170. bool greaterOrEquals(Int rhs) const
  171. {
  172. return compare(rhs) >= 0;
  173. }
  174. };
  175. using DecomposedFloat64 = DecomposedFloat<double>;
  176. using DecomposedFloat32 = DecomposedFloat<float>;