yql_optimize.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #include "yql_optimize.h"
  2. #include <yql/essentials/utils/log/log.h>
  3. #include <yql/essentials/utils/yql_panic.h>
  4. #include <util/generic/hash_set.h>
  5. #include <util/generic/yexception.h>
  6. namespace NYql {
  7. using namespace NNodes;
  8. class TOptimizeTransformerBase::TIgnoreOptimizationContext: public IOptimizationContext {
  9. public:
  10. TIgnoreOptimizationContext(TOptimizeTransformerBase::TGetParents getParents)
  11. : GetParents_(std::move(getParents))
  12. {
  13. }
  14. virtual ~TIgnoreOptimizationContext() = default;
  15. void RemapNode(const TExprNode& src, const TExprNode::TPtr&) final {
  16. const TParentsMap* parentsMap = GetParents_();
  17. auto parentsIt = parentsMap->find(&src);
  18. YQL_ENSURE(parentsIt != parentsMap->cend());
  19. YQL_ENSURE(parentsIt->second.size() == 1, "Bad usage of local optimizer. Try to switch to global mode");
  20. }
  21. private:
  22. TOptimizeTransformerBase::TGetParents GetParents_;
  23. };
  24. class TOptimizeTransformerBase::TRemapOptimizationContext: public IOptimizationContext {
  25. public:
  26. TRemapOptimizationContext(TNodeOnNodeOwnedMap& remaps)
  27. : Remaps_(remaps)
  28. {
  29. }
  30. virtual ~TRemapOptimizationContext() = default;
  31. void RemapNode(const TExprNode& fromNode, const TExprNode::TPtr& toNode) final {
  32. YQL_ENSURE(Remaps_.emplace(&fromNode, toNode).second, "Duplicate remap of the same node");
  33. }
  34. void SetError() {
  35. HasError_ = true;
  36. }
  37. bool CanContinue() const {
  38. return Remaps_.empty() && !HasError_;
  39. }
  40. bool HasError() const {
  41. return HasError_;
  42. }
  43. private:
  44. TNodeOnNodeOwnedMap& Remaps_;
  45. bool HasError_ = false;
  46. };
  47. TOptimizeTransformerBase::TOptimizeTransformerBase(TTypeAnnotationContext* types, NLog::EComponent logComponent, const TSet<TString>& disabledOpts)
  48. : Types(types)
  49. , LogComponent(logComponent)
  50. , DisabledOpts(disabledOpts)
  51. {
  52. }
  53. IGraphTransformer::TStatus TOptimizeTransformerBase::DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) {
  54. TOptimizeExprSettings settings(Types);
  55. IGraphTransformer::TStatus status = IGraphTransformer::TStatus::Ok;
  56. output = input;
  57. for (auto& step: Steps) {
  58. TParentsMap parentsMap;
  59. bool parentsMapInit = false;
  60. TGetParents getParents = [&input, &parentsMap, &parentsMapInit] () {
  61. if (!parentsMapInit) {
  62. GatherParents(*input, parentsMap);
  63. parentsMapInit = true;
  64. }
  65. return &parentsMap;
  66. };
  67. if (step.Global) {
  68. TNodeOnNodeOwnedMap remaps;
  69. auto optCtx = TRemapOptimizationContext{remaps};
  70. VisitExpr(output,
  71. [&optCtx, &step](const TExprNode::TPtr& node) {
  72. return optCtx.CanContinue() && !node->StartsExecution() && !step.ProcessedNodes.contains(node->UniqueId());
  73. },
  74. [this, &step, &getParents, &ctx, &optCtx](const TExprNode::TPtr& node) -> bool {
  75. if (optCtx.CanContinue() && !node->StartsExecution() && !step.ProcessedNodes.contains(node->UniqueId())) {
  76. for (auto& opt: step.Optimizers) {
  77. if (opt.Filter(node.Get())) {
  78. try {
  79. auto ret = opt.Handler(NNodes::TExprBase(node), ctx, optCtx, getParents);
  80. if (!ret) {
  81. YQL_CVLOG(NLog::ELevel::ERROR, LogComponent) << "Error applying " << opt.OptName;
  82. optCtx.SetError();
  83. } else if (auto retNode = ret.Cast(); retNode.Ptr() != node) {
  84. YQL_CVLOG(NLog::ELevel::INFO, LogComponent) << opt.OptName;
  85. optCtx.RemapNode(*node, retNode.Ptr());
  86. }
  87. } catch (...) {
  88. YQL_CVLOG(NLog::ELevel::ERROR, LogComponent) << "Error applying " << opt.OptName << ": " << CurrentExceptionMessage();
  89. throw;
  90. }
  91. }
  92. if (!optCtx.CanContinue()) {
  93. break;
  94. }
  95. }
  96. if (optCtx.CanContinue()) {
  97. step.ProcessedNodes.insert(node->UniqueId());
  98. }
  99. }
  100. return true;
  101. }
  102. );
  103. if (optCtx.HasError()) {
  104. status = IGraphTransformer::TStatus::Error;
  105. } else if (!remaps.empty()) {
  106. settings.ProcessedNodes = nullptr;
  107. status = RemapExpr(output, output, remaps, ctx, settings);
  108. }
  109. } else {
  110. settings.ProcessedNodes = &step.ProcessedNodes;
  111. status = OptimizeExpr(output, output, [this, &step, &getParents](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr {
  112. TIgnoreOptimizationContext ignoreOptCtx(getParents);
  113. for (auto& opt: step.Optimizers) {
  114. if (opt.Filter(node.Get())) {
  115. try {
  116. auto ret = opt.Handler(NNodes::TExprBase(node), ctx, ignoreOptCtx, getParents);
  117. if (!ret) {
  118. YQL_CVLOG(NLog::ELevel::ERROR, LogComponent) << "Error applying " << opt.OptName;
  119. return {};
  120. }
  121. auto retNode = ret.Cast();
  122. if (retNode.Ptr() != node) {
  123. YQL_CVLOG(NLog::ELevel::INFO, LogComponent) << opt.OptName;
  124. return retNode.Ptr();
  125. }
  126. } catch (...) {
  127. YQL_CVLOG(NLog::ELevel::ERROR, LogComponent) << "Error applying " << opt.OptName << ": " << CurrentExceptionMessage();
  128. throw;
  129. }
  130. }
  131. }
  132. return node;
  133. }, ctx, settings);
  134. }
  135. if (status.Level != IGraphTransformer::TStatus::Ok) {
  136. return status;
  137. }
  138. }
  139. return status;
  140. }
  141. void TOptimizeTransformerBase::Rewind() {
  142. for (auto& step: Steps) {
  143. step.ProcessedNodes.clear();
  144. }
  145. }
  146. TOptimizeTransformerBase::TFilter TOptimizeTransformerBase::Any() {
  147. return [] (const TExprNode* node) {
  148. Y_UNUSED(node);
  149. return true;
  150. };
  151. }
  152. TOptimizeTransformerBase::TFilter TOptimizeTransformerBase::Names(std::initializer_list<TStringBuf> names) {
  153. return [filter = THashSet<TStringBuf>(names)] (const TExprNode* node) {
  154. return node->IsCallable(filter);
  155. };
  156. }
  157. TOptimizeTransformerBase::TFilter TOptimizeTransformerBase::Or(std::initializer_list<TOptimizeTransformerBase::TFilter> filters) {
  158. return [orFilters = TVector<TFilter>(filters)] (const TExprNode* node) {
  159. for (auto& f: orFilters) {
  160. if (f(node)) {
  161. return true;
  162. }
  163. }
  164. return false;
  165. };
  166. }
  167. void TOptimizeTransformerBase::AddHandler(size_t step, TFilter filter, TStringBuf optName, THandler handler) {
  168. if (DisabledOpts.contains(optName)) {
  169. return;
  170. }
  171. if (step >= Steps.size()) {
  172. Steps.resize(step + 1);
  173. }
  174. TOptInfo opt;
  175. opt.OptName = optName;
  176. opt.Filter = filter;
  177. opt.Handler = handler;
  178. Steps[step].Optimizers.push_back(std::move(opt));
  179. }
  180. void TOptimizeTransformerBase::SetGlobal(size_t step) {
  181. if (step >= Steps.size()) {
  182. Steps.resize(step + 1);
  183. }
  184. Steps[step].Global = true;
  185. }
  186. }