extract_used_columns.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #include "extract_used_columns.h"
  2. #include <yql/essentials/public/purecalc/common/inspect_input.h>
  3. #include <yql/essentials/core/yql_expr_optimize.h>
  4. #include <yql/essentials/core/expr_nodes/yql_expr_nodes.h>
  5. using namespace NYql;
  6. using namespace NYql::NPureCalc;
  7. namespace {
  8. class TUsedColumnsExtractor : public TSyncTransformerBase {
  9. private:
  10. TVector<THashSet<TString>>* const Destination_;
  11. const TVector<THashSet<TString>>& AllColumns_;
  12. TString NodeName_;
  13. bool CalculatedUsedFields_ = false;
  14. public:
  15. TUsedColumnsExtractor(
  16. TVector<THashSet<TString>>* destination,
  17. const TVector<THashSet<TString>>& allColumns,
  18. TString nodeName
  19. )
  20. : Destination_(destination)
  21. , AllColumns_(allColumns)
  22. , NodeName_(std::move(nodeName))
  23. {
  24. }
  25. TUsedColumnsExtractor(TVector<THashSet<TString>>*, TVector<THashSet<TString>>&&, TString) = delete;
  26. public:
  27. TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final {
  28. output = input;
  29. if (CalculatedUsedFields_) {
  30. return IGraphTransformer::TStatus::Ok;
  31. }
  32. bool hasError = false;
  33. *Destination_ = AllColumns_;
  34. VisitExpr(input, [&](const TExprNode::TPtr& inputExpr) {
  35. NNodes::TExprBase node(inputExpr);
  36. if (auto maybeExtract = node.Maybe<NNodes::TCoExtractMembers>()) {
  37. auto extract = maybeExtract.Cast();
  38. const auto& arg = extract.Input().Ref();
  39. if (arg.IsCallable(NodeName_)) {
  40. ui32 inputIndex;
  41. if (!TryFetchInputIndexFromSelf(arg, ctx, AllColumns_.size(), inputIndex)) {
  42. hasError = true;
  43. return false;
  44. }
  45. YQL_ENSURE(inputIndex < AllColumns_.size());
  46. auto& destinationColumnsSet = (*Destination_)[inputIndex];
  47. const auto& allColumnsSet = AllColumns_[inputIndex];
  48. destinationColumnsSet.clear();
  49. for (const auto& columnAtom : extract.Members()) {
  50. TString name = TString(columnAtom.Value());
  51. YQL_ENSURE(allColumnsSet.contains(name), "unexpected column in the input struct");
  52. destinationColumnsSet.insert(name);
  53. }
  54. }
  55. }
  56. return true;
  57. });
  58. if (hasError) {
  59. return IGraphTransformer::TStatus::Error;
  60. }
  61. CalculatedUsedFields_ = true;
  62. return IGraphTransformer::TStatus::Ok;
  63. }
  64. void Rewind() final {
  65. CalculatedUsedFields_ = false;
  66. }
  67. };
  68. }
  69. TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeUsedColumnsExtractor(
  70. TVector<THashSet<TString>>* destination,
  71. const TVector<THashSet<TString>>& allColumns,
  72. const TString& nodeName
  73. ) {
  74. return new TUsedColumnsExtractor(destination, allColumns, nodeName);
  75. }