stat_udf_ut.cpp 19 KB


  1. #include <library/cpp/testing/unittest/registar.h>
  2. #include <yql/essentials/minikql/mkql_function_registry.h>
  3. #include <yql/essentials/minikql/mkql_program_builder.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  5. #include <yql/essentials/minikql/comp_nodes/mkql_factories.h>
  6. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
  7. #include <util/random/random.h>
  8. #include <util/system/sanitizers.h>
  9. #include <array>
  10. namespace NYql {
  11. using namespace NKikimr::NMiniKQL;
  12. namespace NUdf {
  13. extern NUdf::TUniquePtr<NUdf::IUdfModule> CreateStatModule();
  14. }
  15. Y_UNIT_TEST_SUITE(TUDFStatTest) {
  16. Y_UNIT_TEST(SimplePercentile) {
  17. auto mutableFunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry())->Clone();
  18. auto randomProvider = CreateDeterministicRandomProvider(1);
  19. auto timeProvider = CreateDeterministicTimeProvider(10000000);
  20. NUdf::TUniquePtr<NUdf::IUdfModule> module = NUdf::CreateStatModule();
  21. mutableFunctionRegistry->AddModule("", "Stat", std::move(module));
  22. TScopedAlloc alloc(__LOCATION__);
  23. TTypeEnvironment env(alloc);
  24. TProgramBuilder pgmBuilder(env, *mutableFunctionRegistry);
  25. auto udfTDigest_Create = pgmBuilder.Udf("Stat.TDigest_Create");
  26. auto udfTDigest_AddValue = pgmBuilder.Udf("Stat.TDigest_AddValue");
  27. auto udfTDigest_GetPercentile = pgmBuilder.Udf("Stat.TDigest_GetPercentile");
  28. TRuntimeNode pgmDigest;
  29. {
  30. auto param1 = pgmBuilder.NewDataLiteral<double>(0.0);
  31. TVector<TRuntimeNode> params = {param1};
  32. pgmDigest = pgmBuilder.Apply(udfTDigest_Create, params);
  33. }
  34. for (int n = 1; n < 10; n += 1) {
  35. auto param2 = pgmBuilder.NewDataLiteral((double)n);
  36. TVector<TRuntimeNode> params = {pgmDigest, param2};
  37. pgmDigest = pgmBuilder.Apply(udfTDigest_AddValue, params);
  38. }
  39. TRuntimeNode pgmReturn;
  40. {
  41. auto param2 = pgmBuilder.NewDataLiteral<double>(0.9);
  42. TVector<TRuntimeNode> params = {pgmDigest, param2};
  43. pgmReturn = pgmBuilder.Apply(udfTDigest_GetPercentile, params);
  44. }
  45. TExploringNodeVisitor explorer;
  46. explorer.Walk(pgmReturn.GetNode(), env);
  47. TComputationPatternOpts opts(alloc.Ref(), env, GetBuiltinFactory(), mutableFunctionRegistry.Get(),
  48. NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "", EGraphPerProcess::Multi);
  49. auto pattern = MakeComputationPattern(explorer, pgmReturn, {}, opts);
  50. auto graph = pattern->Clone(opts.ToComputationOptions(*randomProvider, *timeProvider));
  51. auto value = graph->GetValue();
  52. UNIT_ASSERT_DOUBLES_EQUAL(value.Get<double>(), 8.5, 0.001);
  53. }
  54. Y_UNIT_TEST(SimplePercentileSpecific) {
  55. auto mutableFunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry())->Clone();
  56. auto randomProvider = CreateDeterministicRandomProvider(1);
  57. auto timeProvider = CreateDeterministicTimeProvider(1);
  58. NUdf::TUniquePtr<NUdf::IUdfModule> module = NUdf::CreateStatModule();
  59. mutableFunctionRegistry->AddModule("", "Stat", std::move(module));
  60. TScopedAlloc alloc(__LOCATION__);
  61. TTypeEnvironment env(alloc);
  62. TProgramBuilder pgmBuilder(env, *mutableFunctionRegistry);
  63. auto udfTDigest_Create = pgmBuilder.Udf("Stat.TDigest_Create");
  64. auto udfTDigest_AddValue = pgmBuilder.Udf("Stat.TDigest_AddValue");
  65. auto udfTDigest_GetPercentile = pgmBuilder.Udf("Stat.TDigest_GetPercentile");
  66. TRuntimeNode pgmDigest;
  67. {
  68. auto param1 = pgmBuilder.NewDataLiteral<double>(75.0);
  69. TVector<TRuntimeNode> params = {param1};
  70. pgmDigest = pgmBuilder.Apply(udfTDigest_Create, params);
  71. }
  72. TVector<double> vals = {800, 20, 150};
  73. for (auto val : vals) {
  74. auto param2 = pgmBuilder.NewDataLiteral(val);
  75. TVector<TRuntimeNode> params = {pgmDigest, param2};
  76. pgmDigest = pgmBuilder.Apply(udfTDigest_AddValue, params);
  77. }
  78. TRuntimeNode pgmReturn;
  79. {
  80. auto param2 = pgmBuilder.NewDataLiteral<double>(0.5);
  81. TVector<TRuntimeNode> params = {pgmDigest, param2};
  82. pgmReturn = pgmBuilder.Apply(udfTDigest_GetPercentile, params);
  83. }
  84. TExploringNodeVisitor explorer;
  85. explorer.Walk(pgmReturn.GetNode(), env);
  86. TComputationPatternOpts opts(alloc.Ref(), env, GetBuiltinFactory(), mutableFunctionRegistry.Get(),
  87. NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "", EGraphPerProcess::Multi);
  88. auto pattern = MakeComputationPattern(explorer, pgmReturn, {}, opts);
  89. auto graph = pattern->Clone(opts.ToComputationOptions(*randomProvider, *timeProvider));
  90. auto value = graph->GetValue();
  91. Cerr << value.Get<double>() << Endl;
  92. //~ UNIT_ASSERT_DOUBLES_EQUAL(value.Get<double>(), 9.0, 0.001);
  93. }
  94. Y_UNIT_TEST(SerializedPercentile) {
  95. auto mutableFunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry())->Clone();
  96. auto randomProvider = CreateDeterministicRandomProvider(1);
  97. auto timeProvider = CreateDeterministicTimeProvider(1);
  98. NUdf::TUniquePtr<NUdf::IUdfModule> module = NUdf::CreateStatModule();
  99. mutableFunctionRegistry->AddModule("", "Stat", std::move(module));
  100. TScopedAlloc alloc(__LOCATION__);
  101. TTypeEnvironment env(alloc);
  102. TProgramBuilder pgmBuilder(env, *mutableFunctionRegistry);
  103. auto udfTDigest_Create = pgmBuilder.Udf("Stat.TDigest_Create");
  104. auto udfTDigest_AddValue = pgmBuilder.Udf("Stat.TDigest_AddValue");
  105. auto udfTDigest_GetPercentile = pgmBuilder.Udf("Stat.TDigest_GetPercentile");
  106. auto udfTDigest_Serialize = pgmBuilder.Udf("Stat.TDigest_Serialize");
  107. auto udfTDigest_Deserialize = pgmBuilder.Udf("Stat.TDigest_Deserialize");
  108. TRuntimeNode pgmDigest;
  109. {
  110. auto param1 = pgmBuilder.NewDataLiteral<double>(0.0);
  111. TVector<TRuntimeNode> params = {param1};
  112. pgmDigest = pgmBuilder.Apply(udfTDigest_Create, params);
  113. }
  114. for (int n = 1; n < 10; n += 1) {
  115. auto param2 = pgmBuilder.NewDataLiteral((double)n);
  116. TVector<TRuntimeNode> params = {pgmDigest, param2};
  117. pgmDigest = pgmBuilder.Apply(udfTDigest_AddValue, params);
  118. }
  119. TRuntimeNode pgmSerializedData;
  120. {
  121. TVector<TRuntimeNode> params = {pgmDigest};
  122. pgmSerializedData = pgmBuilder.Apply(udfTDigest_Serialize, params);
  123. }
  124. TRuntimeNode pgmDigest2;
  125. {
  126. TVector<TRuntimeNode> params = {pgmSerializedData};
  127. pgmDigest2 = pgmBuilder.Apply(udfTDigest_Deserialize, params);
  128. }
  129. TRuntimeNode pgmReturn;
  130. {
  131. auto param2 = pgmBuilder.NewDataLiteral<double>(0.9);
  132. TVector<TRuntimeNode> params = {pgmDigest2, param2};
  133. pgmReturn = pgmBuilder.Apply(udfTDigest_GetPercentile, params);
  134. }
  135. TExploringNodeVisitor explorer;
  136. explorer.Walk(pgmReturn.GetNode(), env);
  137. TComputationPatternOpts opts(alloc.Ref(), env, GetBuiltinFactory(), mutableFunctionRegistry.Get(),
  138. NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "", EGraphPerProcess::Multi);
  139. auto pattern = MakeComputationPattern(explorer, pgmReturn, {}, opts);
  140. auto graph = pattern->Clone(opts.ToComputationOptions(*randomProvider, *timeProvider));
  141. auto value = graph->GetValue();
  142. UNIT_ASSERT_DOUBLES_EQUAL(value.Get<double>(), 8.5, 0.001);
  143. }
  144. Y_UNIT_TEST(SerializedMergedPercentile) {
  145. auto mutableFunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry())->Clone();
  146. auto randomProvider = CreateDeterministicRandomProvider(1);
  147. auto timeProvider = CreateDeterministicTimeProvider(1);
  148. NUdf::TUniquePtr<NUdf::IUdfModule> module = NUdf::CreateStatModule();
  149. mutableFunctionRegistry->AddModule("", "Stat", std::move(module));
  150. TScopedAlloc alloc(__LOCATION__);
  151. TTypeEnvironment env(alloc);
  152. TProgramBuilder pgmBuilder(env, *mutableFunctionRegistry);
  153. auto udfTDigest_Create = pgmBuilder.Udf("Stat.TDigest_Create");
  154. auto udfTDigest_AddValue = pgmBuilder.Udf("Stat.TDigest_AddValue");
  155. auto udfTDigest_GetPercentile = pgmBuilder.Udf("Stat.TDigest_GetPercentile");
  156. auto udfTDigest_Serialize = pgmBuilder.Udf("Stat.TDigest_Serialize");
  157. auto udfTDigest_Deserialize = pgmBuilder.Udf("Stat.TDigest_Deserialize");
  158. auto udfTDigest_Merge = pgmBuilder.Udf("Stat.TDigest_Merge");
  159. TVector<TRuntimeNode> pgmSerializedDataVector;
  160. for (int i = 0; i < 100; i += 10) {
  161. TRuntimeNode pgmDigest;
  162. {
  163. auto param1 = pgmBuilder.NewDataLiteral(double(i) / 10);
  164. TVector<TRuntimeNode> params = {param1};
  165. pgmDigest = pgmBuilder.Apply(udfTDigest_Create, params);
  166. }
  167. for (int n = i + 1; n < i + 10; n += 1) {
  168. auto param2 = pgmBuilder.NewDataLiteral(double(n) / 10);
  169. TVector<TRuntimeNode> params = {pgmDigest, param2};
  170. pgmDigest = pgmBuilder.Apply(udfTDigest_AddValue, params);
  171. }
  172. TRuntimeNode pgmSerializedData;
  173. {
  174. TVector<TRuntimeNode> params = {pgmDigest};
  175. pgmSerializedData = pgmBuilder.Apply(udfTDigest_Serialize, params);
  176. }
  177. pgmSerializedDataVector.push_back(pgmSerializedData);
  178. }
  179. TRuntimeNode pgmDigest;
  180. for (size_t i = 0; i < pgmSerializedDataVector.size(); ++i) {
  181. TRuntimeNode pgmDigest2;
  182. {
  183. TVector<TRuntimeNode> params = {pgmSerializedDataVector[i]};
  184. pgmDigest2 = pgmBuilder.Apply(udfTDigest_Deserialize, params);
  185. }
  186. if (!pgmDigest) {
  187. pgmDigest = pgmDigest2;
  188. } else {
  189. TVector<TRuntimeNode> params = {pgmDigest, pgmDigest2};
  190. pgmDigest = pgmBuilder.Apply(udfTDigest_Merge, params);
  191. }
  192. }
  193. TRuntimeNode pgmReturn;
  194. {
  195. auto param2 = pgmBuilder.NewDataLiteral<double>(0.9);
  196. TVector<TRuntimeNode> params = {pgmDigest, param2};
  197. pgmReturn = pgmBuilder.Apply(udfTDigest_GetPercentile, params);
  198. }
  199. TExploringNodeVisitor explorer;
  200. explorer.Walk(pgmReturn.GetNode(), env);
  201. TComputationPatternOpts opts(alloc.Ref(), env, GetBuiltinFactory(), mutableFunctionRegistry.Get(),
  202. NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "", EGraphPerProcess::Multi);
  203. auto pattern = MakeComputationPattern(explorer, pgmReturn, {}, opts);
  204. auto graph = pattern->Clone(opts.ToComputationOptions(*randomProvider, *timeProvider));
  205. auto value = graph->GetValue();
  206. UNIT_ASSERT_DOUBLES_EQUAL(value.Get<double>(), 8.95, 0.001);
  207. }
  208. static double GetParetoRandomNumber(double a) {
  209. return 1 / pow(RandomNumber<double>(), double(1) / a);
  210. }
  211. Y_UNIT_TEST(BigPercentile) {
  212. auto mutableFunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry())->Clone();
  213. auto randomProvider = CreateDeterministicRandomProvider(1);
  214. auto timeProvider = CreateDeterministicTimeProvider(1);
  215. NUdf::TUniquePtr<NUdf::IUdfModule> module = NUdf::CreateStatModule();
  216. mutableFunctionRegistry->AddModule("", "Stat", std::move(module));
  217. TScopedAlloc alloc(__LOCATION__);
  218. TTypeEnvironment env(alloc);
  219. TProgramBuilder pgmBuilder(env, *mutableFunctionRegistry);
  220. auto udfTDigest_Create = pgmBuilder.Udf("Stat.TDigest_Create");
  221. auto udfTDigest_AddValue = pgmBuilder.Udf("Stat.TDigest_AddValue");
  222. auto udfTDigest_GetPercentile = pgmBuilder.Udf("Stat.TDigest_GetPercentile");
  223. const size_t NUMBERS = 100000;
  224. const double PERCENTILE = 0.99;
  225. const double THRESHOLD = 0.0004; // at q=0.99 threshold is 4*delta*0.0099
  226. TVector<double> randomNumbers1;
  227. TVector<TRuntimeNode> randomNumbers2;
  228. randomNumbers1.reserve(NUMBERS);
  229. randomNumbers2.reserve(NUMBERS);
  230. for (size_t n = 0; n < NUMBERS; ++n) {
  231. double randomNumber = GetParetoRandomNumber(10);
  232. randomNumbers1.push_back(randomNumber);
  233. randomNumbers2.push_back(pgmBuilder.NewDataLiteral(randomNumber));
  234. }
  235. TRuntimeNode bigList = pgmBuilder.AsList(randomNumbers2);
  236. auto pgmDigest =
  237. pgmBuilder.Fold1(bigList,
  238. [&](TRuntimeNode item) {
  239. std::array<TRuntimeNode, 1> args;
  240. args[0] = item;
  241. return pgmBuilder.Apply(udfTDigest_Create, args);
  242. },
  243. [&](TRuntimeNode item, TRuntimeNode state) {
  244. std::array<TRuntimeNode, 2> args;
  245. args[0] = state;
  246. args[1] = item;
  247. return pgmBuilder.Apply(udfTDigest_AddValue, args);
  248. });
  249. TRuntimeNode pgmReturn =
  250. pgmBuilder.Map(pgmDigest, [&](TRuntimeNode item) {
  251. auto param2 = pgmBuilder.NewDataLiteral(PERCENTILE);
  252. std::array<TRuntimeNode, 2> args;
  253. args[0] = item;
  254. args[1] = param2;
  255. return pgmBuilder.Apply(udfTDigest_GetPercentile, args);
  256. });
  257. TExploringNodeVisitor explorer;
  258. explorer.Walk(pgmReturn.GetNode(), env);
  259. TComputationPatternOpts opts(alloc.Ref(), env, GetBuiltinFactory(), mutableFunctionRegistry.Get(),
  260. NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "", EGraphPerProcess::Multi);
  261. auto pattern = MakeComputationPattern(explorer, pgmReturn, {}, opts);
  262. auto graph = pattern->Clone(opts.ToComputationOptions(*randomProvider, *timeProvider));
  263. auto value = graph->GetValue();
  264. UNIT_ASSERT(value);
  265. double digestValue = value.Get<double>();
  266. std::sort(randomNumbers1.begin(), randomNumbers1.end());
  267. // This gives us a 1-based index of the last value <= digestValue
  268. auto index = std::upper_bound(randomNumbers1.begin(), randomNumbers1.end(), digestValue) - randomNumbers1.begin();
  269. // See https://en.wikipedia.org/wiki/Percentile#First_Variant.2C
  270. double p = (index - 0.5) / double(randomNumbers1.size());
  271. UNIT_ASSERT_DOUBLES_EQUAL(p, PERCENTILE, THRESHOLD);
  272. }
  273. Y_UNIT_TEST(CentroidPrecision) {
  274. auto mutableFunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry())->Clone();
  275. auto randomProvider = CreateDeterministicRandomProvider(1);
  276. auto timeProvider = CreateDeterministicTimeProvider(1);
  277. NUdf::TUniquePtr<NUdf::IUdfModule> module = NUdf::CreateStatModule();
  278. mutableFunctionRegistry->AddModule("", "Stat", std::move(module));
  279. TScopedAlloc alloc(__LOCATION__);
  280. TTypeEnvironment env(alloc);
  281. TProgramBuilder pgmBuilder(env, *mutableFunctionRegistry);
  282. auto udfTDigest_Create = pgmBuilder.Udf("Stat.TDigest_Create");
  283. auto udfTDigest_AddValue = pgmBuilder.Udf("Stat.TDigest_AddValue");
  284. auto udfTDigest_GetPercentile = pgmBuilder.Udf("Stat.TDigest_GetPercentile");
  285. const size_t NUMBERS = 100000;
  286. const double PERCENTILE = 0.25;
  287. const double minValue = 1.0;
  288. const double maxValue = 100.0;
  289. const double majorityValue = 50.0;
  290. TVector<TRuntimeNode> numbers;
  291. numbers.reserve(NUMBERS);
  292. for (size_t n = 0; n < NUMBERS - 2; ++n) {
  293. numbers.push_back(pgmBuilder.NewDataLiteral(majorityValue));
  294. }
  295. numbers.push_back(pgmBuilder.NewDataLiteral(minValue));
  296. numbers.push_back(pgmBuilder.NewDataLiteral(maxValue));
  297. TRuntimeNode bigList = pgmBuilder.AsList(numbers);
  298. auto pgmDigest =
  299. pgmBuilder.Fold1(bigList,
  300. [&](TRuntimeNode item) {
  301. std::array<TRuntimeNode, 1> args;
  302. args[0] = item;
  303. return pgmBuilder.Apply(udfTDigest_Create, args);
  304. },
  305. [&](TRuntimeNode item, TRuntimeNode state) {
  306. std::array<TRuntimeNode, 2> args;
  307. args[0] = state;
  308. args[1] = item;
  309. return pgmBuilder.Apply(udfTDigest_AddValue, args);
  310. });
  311. TRuntimeNode pgmReturn =
  312. pgmBuilder.Map(pgmDigest, [&](TRuntimeNode item) {
  313. auto param2 = pgmBuilder.NewDataLiteral(PERCENTILE);
  314. std::array<TRuntimeNode, 2> args;
  315. args[0] = item;
  316. args[1] = param2;
  317. return pgmBuilder.Apply(udfTDigest_GetPercentile, args);
  318. });
  319. TExploringNodeVisitor explorer;
  320. explorer.Walk(pgmReturn.GetNode(), env);
  321. TComputationPatternOpts opts(alloc.Ref(), env, GetBuiltinFactory(), mutableFunctionRegistry.Get(),
  322. NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "", EGraphPerProcess::Multi);
  323. auto pattern = MakeComputationPattern(explorer, pgmReturn, {}, opts);
  324. auto graph = pattern->Clone(opts.ToComputationOptions(*randomProvider, *timeProvider));
  325. auto value = graph->GetValue();
  326. UNIT_ASSERT(value);
  327. double digestValue = value.Get<double>();
  328. UNIT_ASSERT_EQUAL(digestValue, majorityValue);
  329. }
  330. }
  331. }