normal_ut.cpp 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #include "normal.h"
  2. #include "fast.h"
  3. #include <library/cpp/testing/unittest/registar.h>
  4. #include <util/generic/vector.h>
  5. #include <functional>
  6. Y_UNIT_TEST_SUITE(TestNormalDistribution) {
  7. Y_UNIT_TEST(TestDefined) {
  8. volatile auto x = NormalRandom<float>(0, 1) + NormalRandom<double>(0, 1) + NormalRandom<long double>(0, 1);
  9. (void)x;
  10. }
  11. template <class T>
  12. static void TestMD(std::function<T()> f, T m, T d) {
  13. TVector<T> v;
  14. v.reserve(20000);
  15. for (size_t i = 0; i < 20000; ++i) {
  16. v.push_back(f());
  17. }
  18. long double mm = 0;
  19. long double vv = 0;
  20. for (auto x : v) {
  21. mm += x;
  22. }
  23. mm /= v.size();
  24. for (auto x : v) {
  25. vv += (mm - x) * (mm - x);
  26. }
  27. vv /= v.size();
  28. long double dd = std::sqrt(vv);
  29. UNIT_ASSERT_DOUBLES_EQUAL(m, mm, (m + 1) * 0.05);
  30. UNIT_ASSERT_DOUBLES_EQUAL(d, dd, (d + 1) * 0.05);
  31. }
  32. Y_UNIT_TEST(Test1) {
  33. TestMD<float>(&StdNormalRandom<float>, 0, 1);
  34. TestMD<double>(&StdNormalRandom<double>, 0, 1);
  35. TestMD<long double>(&StdNormalRandom<long double>, 0, 1);
  36. }
  37. template <class T>
  38. std::function<T()> GenFunc1(T m, T d) {
  39. return [m, d]() {
  40. return NormalRandom<T>(m, d);
  41. };
  42. }
  43. template <class T>
  44. std::function<T()> GenFunc2(T m, T d) {
  45. TFastRng<ui64> rng(17);
  46. return [rng, m, d]() mutable {
  47. return NormalDistribution<T>(rng, m, d);
  48. };
  49. }
  50. Y_UNIT_TEST(Test2) {
  51. TestMD<float>(GenFunc1<float>(2, 3), 2, 3);
  52. TestMD<double>(GenFunc1<double>(3, 4), 3, 4);
  53. TestMD<long double>(GenFunc1<long double>(4, 5), 4, 5);
  54. }
  55. Y_UNIT_TEST(Test3) {
  56. TestMD<float>(GenFunc2<float>(20, 30), 20, 30);
  57. TestMD<double>(GenFunc2<double>(30, 40), 30, 40);
  58. TestMD<long double>(GenFunc2<long double>(40, 50), 40, 50);
  59. }
  60. }