registry.cpp 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #include "registry.h"
  2. #include <yql/essentials/minikql/mkql_node_cast.h>
  3. #include <yql/essentials/minikql/mkql_node_serialization.h>
  4. #include <memory>
  5. namespace NYql {
  6. namespace {
  7. class TLoader : std::enable_shared_from_this<TLoader> {
  8. public:
  9. TLoader()
  10. : Alloc_(__LOCATION__)
  11. , Env_(Alloc_)
  12. {
  13. Alloc_.Release();
  14. }
  15. void Init(const TString& serialized,
  16. const NKikimr::NMiniKQL::IFunctionRegistry& functionRegistry,
  17. const NKikimr::NMiniKQL::TComputationNodeFactory& nodeFactory) {
  18. TGuard<NKikimr::NMiniKQL::TScopedAlloc> allocGuard(Alloc_);
  19. Pgm_ = NKikimr::NMiniKQL::DeserializeRuntimeNode(serialized, Env_);
  20. auto pgmTop = AS_CALLABLE("BlockAsTuple", Pgm_);
  21. MKQL_ENSURE(pgmTop->GetInputsCount() == 2, "Expected tuple of 2 items");
  22. auto argsNode = pgmTop->GetInput(0);
  23. MKQL_ENSURE(!argsNode.IsImmediate() && argsNode.GetNode()->GetType()->IsCallable(), "Expected callable");
  24. auto argsCallable = static_cast<NKikimr::NMiniKQL::TCallable*>(argsNode.GetNode());
  25. Explorer_.Walk(Pgm_.GetNode(), Env_);
  26. NKikimr::NMiniKQL::TComputationPatternOpts opts(Alloc_.Ref(), Env_, nodeFactory,
  27. &functionRegistry, NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, "OFF", NKikimr::NMiniKQL::EGraphPerProcess::Multi);
  28. std::vector<NKikimr::NMiniKQL::TNode*> entryPoints;
  29. if (argsCallable->GetType()->GetName() == "BlockAsTuple") {
  30. for (ui32 i = 0; i < argsCallable->GetInputsCount(); ++i) {
  31. entryPoints.emplace_back(argsCallable->GetInput(i).GetNode());
  32. }
  33. }
  34. Alloc_.Ref().UseRefLocking = true;
  35. Pattern_ = NKikimr::NMiniKQL::MakeComputationPattern(Explorer_, Pgm_, entryPoints, opts);
  36. RandomProvider_ = CreateDefaultRandomProvider();
  37. TimeProvider_ = CreateDefaultTimeProvider();
  38. Graph_ = Pattern_->Clone(opts.ToComputationOptions(*RandomProvider_, *TimeProvider_));
  39. NKikimr::NMiniKQL::TBindTerminator terminator(Graph_->GetTerminator());
  40. Topology_ = Graph_->GetKernelsTopology();
  41. MKQL_ENSURE(Topology_->Items.size() >= 3, "Expected at least 3 kernels");
  42. }
  43. ~TLoader() {
  44. Alloc_.Acquire();
  45. }
  46. ui32 GetKernelsCount() const {
  47. return Topology_->Items.size() - 3;
  48. }
  49. const arrow::compute::ScalarKernel* GetKernel(ui32 index) const {
  50. MKQL_ENSURE(index < Topology_->Items.size() - 3, "Bad kernel index");
  51. return &Topology_->Items[index].Node->GetArrowKernel();
  52. }
  53. private:
  54. NKikimr::NMiniKQL::TScopedAlloc Alloc_;
  55. NKikimr::NMiniKQL::TTypeEnvironment Env_;
  56. NKikimr::NMiniKQL::TRuntimeNode Pgm_;
  57. NKikimr::NMiniKQL::TExploringNodeVisitor Explorer_;
  58. NKikimr::NMiniKQL::IComputationPattern::TPtr Pattern_;
  59. TIntrusivePtr<IRandomProvider> RandomProvider_;
  60. TIntrusivePtr<ITimeProvider> TimeProvider_;
  61. THolder<NKikimr::NMiniKQL::IComputationGraph> Graph_;
  62. const NKikimr::NMiniKQL::TArrowKernelsTopology* Topology_;
  63. };
  64. }
  65. std::vector<std::shared_ptr<const arrow::compute::ScalarKernel>> LoadKernels(const TString& serialized,
  66. const NKikimr::NMiniKQL::IFunctionRegistry& functionRegistry,
  67. const NKikimr::NMiniKQL::TComputationNodeFactory& nodeFactory) {
  68. auto loader = std::make_shared<TLoader>();
  69. loader->Init(serialized, functionRegistry, nodeFactory);
  70. std::vector<std::shared_ptr<const arrow::compute::ScalarKernel>> ret(loader->GetKernelsCount());
  71. auto deleter = [loader](const arrow::compute::ScalarKernel*) {};
  72. for (ui32 i = 0; i < ret.size(); ++i) {
  73. ret[i] = std::shared_ptr<const arrow::compute::ScalarKernel>(loader->GetKernel(ret.size() - 1 - i), deleter);
  74. }
  75. return ret;
  76. }
  77. }