erfinv.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #include <cmath>
  2. #include <array>
  3. #include <numeric>
  4. #include "erfinv.h"
  5. template <size_t N>
  6. static double polEval(double x, const std::array<double, N>& coef) {
  7. static_assert(N > 0, "Array coef[] should not be empty.");
  8. return std::accumulate(coef.crbegin() + 1, coef.crend(), coef[N - 1],
  9. [x] (auto init, auto cur) {
  10. return std::move(init) * x + cur;
  11. });
  12. }
  13. namespace NMathUdf {
  14. // https://www.jstor.org/stable/2347330
  15. double ErfInv(double x) {
  16. static constexpr std::array<double, 8> a = {
  17. 1.1975323115670912564578e0,
  18. 4.7072688112383978012285e1,
  19. 6.9706266534389598238465e2,
  20. 4.8548868893843886794648e3,
  21. 1.6235862515167575384252e4,
  22. 2.3782041382114385731252e4,
  23. 1.1819493347062294404278e4,
  24. 8.8709406962545514830200e2,
  25. };
  26. static constexpr std::array<double, 8> b = {
  27. 1.,
  28. 4.2313330701600911252e1,
  29. 6.8718700749205790830e2,
  30. 5.3941960214247511077e3,
  31. 2.1213794301586595867e4,
  32. 3.9307895800092710610e4,
  33. 2.8729085735721942674e4,
  34. 5.2264952788528545610e3,
  35. };
  36. static constexpr std::array<double, 8> c = {
  37. 1.42343711074968357734e0,
  38. 4.63033784615654529590e0,
  39. 5.76949722146069140550e0,
  40. 3.64784832476320460504e0,
  41. 1.27045825245236838258e0,
  42. 2.41780725177450611770e-1,
  43. 2.27238449892691845833e-2,
  44. 7.74545014278341407640e-4,
  45. };
  46. static constexpr std::array<double, 8> d = {
  47. 1.4142135623730950488016887e0,
  48. 2.9036514445419946173133295e0,
  49. 2.3707661626024532365971225e0,
  50. 9.7547832001787427186894837e-1,
  51. 2.0945065210512749128288442e-1,
  52. 2.1494160384252876777097297e-2,
  53. 7.7441459065157709165577218e-4,
  54. 1.4859850019840355905497876e-9,
  55. };
  56. static constexpr std::array<double, 8> e = {
  57. 6.65790464350110377720e0,
  58. 5.46378491116411436990e0,
  59. 1.78482653991729133580e0,
  60. 2.96560571828504891230e-1,
  61. 2.65321895265761230930e-2,
  62. 1.24266094738807843860e-3,
  63. 2.71155556874348757815e-5,
  64. 2.01033439929228813265e-7,
  65. };
  66. static constexpr std::array<double, 8> f = {
  67. 1.414213562373095048801689e0,
  68. 8.482908416595164588112026e-1,
  69. 1.936480946950659106176712e-1,
  70. 2.103693768272068968719679e-2,
  71. 1.112800997078859844711555e-3,
  72. 2.611088405080593625138020e-5,
  73. 2.010321207683943062279931e-7,
  74. 2.891024605872965461538222e-15,
  75. };
  76. if (isnan(x) || x <= -1. || x >= 1.) {
  77. if (x == 1.) {
  78. return std::numeric_limits<double>::infinity();
  79. }
  80. if (x == -1.) {
  81. return -std::numeric_limits<double>::infinity();
  82. }
  83. return std::numeric_limits<double>::quiet_NaN();
  84. }
  85. double sign = (x > 0) - (x < 0);
  86. x = abs(x);
  87. if (x < 1e-7) {
  88. return sign * x / M_2_SQRTPI;
  89. }
  90. double ans;
  91. if (x <= 0.85) {
  92. double r = 0.180625 - 0.25 * x * x;
  93. ans = x * polEval(r, a) / polEval(r, b);
  94. } else {
  95. double r = std::sqrt(M_LN2 - log(1. - x)) - 1.6;
  96. if (r <= 3.4) {
  97. ans = polEval(r, c) / polEval(r, d);
  98. } else {
  99. r -= 3.4;
  100. ans = polEval(r, e) / polEval(r, f);
  101. }
  102. }
  103. return ans * sign;
  104. }
  105. }