mkql_random.cpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #include "mkql_random.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_program_builder.h>
  6. #include <yql/essentials/minikql/mkql_string_util.h>
  7. #include <util/random/mersenne.h>
  8. namespace NKikimr {
  9. namespace NMiniKQL {
  10. namespace {
  11. class TRandomMTResource : public TComputationValue<TRandomMTResource> {
  12. public:
  13. TRandomMTResource(TMemoryUsageInfo* memInfo, ui64 seed)
  14. : TComputationValue(memInfo)
  15. , Gen(seed)
  16. {
  17. }
  18. private:
  19. NUdf::TStringRef GetResourceTag() const override {
  20. return NUdf::TStringRef(RandomMTResource);
  21. }
  22. void* GetResource() override {
  23. return &Gen;
  24. }
  25. TMersenne<ui64> Gen;
  26. };
  27. class TNewMTRandWrapper : public TMutableComputationNode<TNewMTRandWrapper> {
  28. typedef TMutableComputationNode<TNewMTRandWrapper> TBaseComputation;
  29. public:
  30. TNewMTRandWrapper(TComputationMutables& mutables, IComputationNode* seed)
  31. : TBaseComputation(mutables)
  32. , Seed(seed)
  33. {
  34. }
  35. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  36. const ui64 seedValue = Seed->GetValue(compCtx).Get<ui64>();
  37. return compCtx.HolderFactory.Create<TRandomMTResource>(seedValue);
  38. }
  39. private:
  40. void RegisterDependencies() const final {
  41. DependsOn(Seed);
  42. }
  43. IComputationNode* const Seed;
  44. };
  45. class TNextMTRandWrapper : public TMutableComputationNode<TNextMTRandWrapper> {
  46. typedef TMutableComputationNode<TNextMTRandWrapper> TBaseComputation;
  47. public:
  48. TNextMTRandWrapper(TComputationMutables& mutables, IComputationNode* rand)
  49. : TBaseComputation(mutables)
  50. , Rand(rand)
  51. , ResPair(mutables)
  52. {
  53. }
  54. NUdf::TUnboxedValue DoCalculate(TComputationContext& compCtx) const {
  55. auto rand = Rand->GetValue(compCtx);
  56. Y_DEBUG_ABORT_UNLESS(rand.GetResourceTag() == NUdf::TStringRef(RandomMTResource));
  57. NUdf::TUnboxedValue *items = nullptr;
  58. const auto tuple = ResPair.NewArray(compCtx, 2, items);
  59. items[0] = NUdf::TUnboxedValuePod(static_cast<TMersenne<ui64>*>(rand.GetResource())->GenRand());
  60. items[1] = std::move(rand);
  61. return tuple;
  62. }
  63. private:
  64. void RegisterDependencies() const final {
  65. DependsOn(Rand);
  66. }
  67. IComputationNode* const Rand;
  68. const TContainerCacheOnContext ResPair;
  69. };
  70. template <ERandom Rnd>
  71. class TRandomWrapper : public TMutableComputationNode<TRandomWrapper<Rnd>> {
  72. typedef TMutableComputationNode<TRandomWrapper<Rnd>> TBaseComputation;
  73. public:
  74. TRandomWrapper(TComputationMutables& mutables, TComputationNodePtrVector&& dependentNodes)
  75. : TBaseComputation(mutables)
  76. , DependentNodes(dependentNodes)
  77. {
  78. }
  79. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  80. switch (Rnd) {
  81. case ERandom::Double:
  82. return NUdf::TUnboxedValuePod(ctx.RandomProvider.GenRandReal2());
  83. case ERandom::Number:
  84. return NUdf::TUnboxedValuePod(ctx.RandomProvider.GenRand64());
  85. case ERandom::Uuid: {
  86. auto uuid = ctx.RandomProvider.GenUuid4();
  87. return MakeString(NUdf::TStringRef((const char*)&uuid, sizeof(uuid)));
  88. }
  89. }
  90. Y_ABORT("Unexpected");
  91. }
  92. private:
  93. void RegisterDependencies() const final {
  94. std::for_each(DependentNodes.cbegin(), DependentNodes.cend(), std::bind(&TRandomWrapper::DependsOn, this, std::placeholders::_1));
  95. }
  96. const TComputationNodePtrVector DependentNodes;
  97. };
  98. }
  99. IComputationNode* WrapNewMTRand(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  100. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  101. TDataType* dataType = AS_TYPE(TDataType, callable.GetInput(0));
  102. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<ui64>::Id,
  103. "Expected ui64");
  104. auto data = LocateNode(ctx.NodeLocator, callable, 0);
  105. return new TNewMTRandWrapper(ctx.Mutables, data);
  106. }
  107. IComputationNode* WrapNextMTRand(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  108. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  109. AS_TYPE(TResourceType, callable.GetInput(0));
  110. auto rand = LocateNode(ctx.NodeLocator, callable, 0);
  111. return new TNextMTRandWrapper(ctx.Mutables, rand);
  112. }
  113. template <ERandom Rnd>
  114. IComputationNode* WrapRandom(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  115. TComputationNodePtrVector dependentNodes(callable.GetInputsCount());
  116. for (ui32 i = 0; i < callable.GetInputsCount(); ++i) {
  117. dependentNodes[i] = LocateNode(ctx.NodeLocator, callable, i);
  118. }
  119. return new TRandomWrapper<Rnd>(ctx.Mutables, std::move(dependentNodes));
  120. }
  121. template
  122. IComputationNode* WrapRandom<ERandom::Double>(TCallable& callable, const TComputationNodeFactoryContext& ctx);
  123. template
  124. IComputationNode* WrapRandom<ERandom::Number>(TCallable& callable, const TComputationNodeFactoryContext& ctx);
  125. template
  126. IComputationNode* WrapRandom<ERandom::Uuid>(TCallable& callable, const TComputationNodeFactoryContext& ctx);
  127. }
  128. }