123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634 |
- #include "yql_expr_csee.h"
- #include "yql_expr_type_annotation.h"
- #include "yql_expr_optimize.h"
- #include <yql/essentials/utils/yql_panic.h>
- #include <yql/essentials/utils/log/log.h>
- #include <util/generic/hash_set.h>
- #include <util/system/env.h>
- #include <tuple>
- namespace NYql {
- namespace {
- static constexpr bool UseDeterminsticHash = false;
- struct TLambdaFrame {
- TLambdaFrame(const TExprNode* lambda, const TLambdaFrame* prev)
- : Lambda(lambda)
- , Prev(prev)
- {}
- TLambdaFrame() = default;
- const TExprNode* Lambda = nullptr;
- const TLambdaFrame* Prev = nullptr;
- };
- bool IsArgInScope(const TLambdaFrame& frame, const TExprNode& arg) {
- for (auto curr = &frame; curr; curr = curr->Prev) {
- if (const auto lambda = curr->Lambda) {
- YQL_ENSURE(lambda->IsLambda());
- for (ui32 i = 0U; i < lambda->Head().ChildrenSize(); ++i) {
- if (lambda->Head().Child(i) == &arg) {
- return true;
- }
- }
- }
- }
- return false;
- }
- ui16 GetDependencyLevel(const TExprNode& node) {
- if (const auto lambda = node.GetDependencyScope()->first) {
- return 1 + GetDependencyLevel(*lambda);
- }
- return 0;
- }
- enum class EDependencyScope : ui8 {
- None = 0,
- Inner = 1,
- Outer = 2,
- Mixed = Inner | Outer
- };
- EDependencyScope CheckDependencyScope(const TLambdaFrame& frame, const TExprNode& node) {
- if (!node.IsAtom()) {
- if (const auto scope = node.GetDependencyScope()) {
- const auto outerLambda = scope->first;
- const auto innerLambda = scope->second;
- if (bool innerFound = false; innerLambda || outerLambda) {
- for (auto curr = &frame; curr; curr = curr->Prev) {
- if (!innerFound && innerLambda) {
- if (curr->Lambda == innerLambda) {
- innerFound = true;
- } else {
- continue;
- }
- }
- if (curr->Lambda == outerLambda) {
- return curr->Lambda == &node ? EDependencyScope::None : EDependencyScope::Inner;
- }
- }
- return innerFound ? EDependencyScope::Mixed : EDependencyScope::Outer;
- }
- }
- }
- return EDependencyScope::None;
- }
- ui64 CalculateHash(ui16 depth, TExprNode& node, const TLambdaFrame& currFrame, const TColumnOrderStorage& coStore) {
- const auto dependency = CheckDependencyScope(currFrame, node);
- switch (dependency) {
- case EDependencyScope::None:
- if (const auto hash = node.GetHash()) {
- return hash;
- }
- break;
- case EDependencyScope::Inner:
- if (const auto hash = node.GetHashAbove()) {
- return hash;
- }
- break;
- case EDependencyScope::Outer:
- if (const auto hash = node.GetHashBelow()) {
- return hash;
- }
- break;
- case EDependencyScope::Mixed:
- break;
- }
- ui64 hash = node.GetTypeAnn()->GetHash();
- hash = CseeHash(ui32(node.Type()), hash);
- for (auto c: node.GetAllConstraints()) {
- hash = CseeHash(c->GetHash(), hash);
- }
- hash = AddColumnOrderHash(coStore.Lookup(node.UniqueId()), hash);
- switch (node.Type()) {
- case TExprNode::Atom: {
- if constexpr (UseDeterminsticHash) {
- hash = CseeHash(node.Content().data(), node.Content().size(), hash);
- } else {
- // can hash ptr due to intern
- const char* ptr = node.Content().data();
- hash = CseeHash(&ptr, sizeof(ptr), hash);
- }
- hash = CseeHash(node.GetFlagsToCompare(), hash);
- break;
- }
- case TExprNode::Callable:
- if constexpr (UseDeterminsticHash) {
- hash = CseeHash(node.Content().data(), node.Content().size(), hash);
- } else {
- // can hash ptr due to intern
- const char* ptr = node.Content().data();
- hash = CseeHash(&ptr, sizeof(ptr), hash);
- }
- [[fallthrough]];
- case TExprNode::List: {
- const auto size = node.ChildrenSize();
- hash = CseeHash(size, hash);
- if (node.UnorderedChildren()) {
- TSmallVec<ui64> hashes;
- hashes.reserve(size);
- for (ui32 i = 0U; i < node.ChildrenSize(); ++i) {
- hashes.emplace_back(CalculateHash(depth, *node.Child(i), currFrame, coStore));
- };
- std::sort(hashes.begin(), hashes.end());
- hash = std::accumulate(hashes.cbegin(), hashes.cend(), ~hash, [] (ui64 hash, ui64 childHash) {
- return CseeHash(childHash, hash);
- });
- } else {
- for (ui32 i = 0U; i < node.ChildrenSize(); ++i) {
- const auto childHash = CalculateHash(depth, *node.Child(i), currFrame, coStore);
- hash = CseeHash(childHash, hash);
- }
- }
- break;
- }
- case TExprNode::Lambda: {
- if (const ui32 size = node.ChildrenSize())
- hash = CseeHash(size, hash);
- const auto& args = node.Head();
- hash = CseeHash(args.ChildrenSize(), hash);
- for (ui32 i = 0; i < args.ChildrenSize(); ++i) {
- const auto& arg = *args.Child(i);
- hash = CseeHash(arg.GetTypeAnn()->GetHash(), hash);
- for (auto c: arg.GetAllConstraints()) {
- hash = CseeHash(c->GetHash(), hash);
- }
- }
- TLambdaFrame newFrame(&node, &currFrame);
- for (ui32 i = 1U; i < node.ChildrenSize(); ++i) {
- const auto lambdaHash = CalculateHash(depth + 1, *node.Child(i), newFrame, coStore);
- hash = CseeHash(lambdaHash, hash);
- }
- break;
- }
- case TExprNode::Argument:
- switch (dependency) {
- case EDependencyScope::Inner: {
- hash = CseeHash(GetDependencyLevel(node), hash);
- hash = CseeHash(node.GetArgIndex(), hash);
- break;
- }
- case EDependencyScope::Outer: {
- if constexpr (UseDeterminsticHash) {
- hash = CseeHash(node.UniqueId(), hash);
- } else {
- const auto ptr = &node;
- hash = CseeHash(&ptr, sizeof(ptr), hash);
- }
- break;
- }
- case EDependencyScope::None:
- case EDependencyScope::Mixed:
- Y_ABORT("Strange argument.");
- }
- break;
- case TExprNode::World:
- break;
- default:
- YQL_ENSURE(false, "Unexpected");
- }
- if (hash == 0) {
- hash = 1;
- }
- switch (dependency) {
- case EDependencyScope::None:
- node.SetHash(hash);
- break;
- case EDependencyScope::Inner:
- node.SetHashAbove(hash);
- break;
- case EDependencyScope::Outer:
- node.SetHashBelow(hash);
- break;
- case EDependencyScope::Mixed:
- break;
- }
- return hash;
- }
- using TEqualResults = THashMap<std::pair<const TExprNode*, const TExprNode*>, bool>;
- bool DoEqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame,
- TEqualResults& visited, const TColumnOrderStorage& coStore);
- bool EqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame,
- TEqualResults& visited, const TColumnOrderStorage& coStore)
- {
- if (&left == &right) {
- return true;
- }
- auto key = std::make_pair(&left, &right);
- if (auto it = visited.find(key); it != visited.end()) {
- return it->second;
- }
- bool res = DoEqualNodes(left, currLeftFrame, right, currRightFrame, visited, coStore);
- visited[key] = res;
- return res;
- }
- bool DoEqualNodes(const TExprNode& left, TLambdaFrame& currLeftFrame, const TExprNode& right, TLambdaFrame& currRightFrame,
- TEqualResults& visited, const TColumnOrderStorage& coStore)
- {
- if (left.Type() != right.Type()) {
- return false;
- }
- if (left.GetTypeAnn() != right.GetTypeAnn()) {
- return false;
- }
- if (left.GetAllConstraints() != right.GetAllConstraints()) {
- return false;
- }
- auto l = coStore.Lookup(left.UniqueId());
- auto r = coStore.Lookup(right.UniqueId());
- if (l && r && *l != *r) {
- return false;
- }
- switch (left.Type()) {
- case TExprNode::Atom:
- // compare pointers due to intern
- return left.Content().data() == right.Content().data() && left.GetFlagsToCompare() == right.GetFlagsToCompare();
- case TExprNode::Callable:
- // compare pointers due to intern
- if (left.Content().data() != right.Content().data()) {
- return false;
- }
- [[fallthrough]];
- case TExprNode::List:
- if (left.ChildrenSize() != right.ChildrenSize()) {
- return false;
- }
- if (left.UnorderedChildren() && right.UnorderedChildren()) {
- if (2U == left.ChildrenSize()) {
- return EqualNodes(left.Head(), currLeftFrame, right.Head(), currRightFrame, visited, coStore)
- && EqualNodes(left.Tail(), currLeftFrame, right.Tail(), currRightFrame, visited, coStore)
- || EqualNodes(left.Head(), currLeftFrame, right.Tail(), currRightFrame, visited, coStore)
- && EqualNodes(left.Tail(), currLeftFrame, right.Head(), currRightFrame, visited, coStore);
- } else {
- TSmallVec<const TExprNode*> lNodes, rNodes;
- lNodes.reserve(left.ChildrenSize());
- rNodes.reserve(right.ChildrenSize());
- left.ForEachChild([&lNodes](const TExprNode& child){ return lNodes.emplace_back(&child); });
- right.ForEachChild([&rNodes](const TExprNode& child){ return rNodes.emplace_back(&child); });
- const auto order = [](const TExprNode* l, const TExprNode* r) { return l->GetHashAbove() < r->GetHashAbove(); };
- std::sort(lNodes.begin(), lNodes.end(), order);
- std::sort(rNodes.begin(), rNodes.end(), order);
- for (ui32 i = 0; i < lNodes.size(); ++i) {
- if (!EqualNodes(*lNodes[i], currLeftFrame, *rNodes[i], currRightFrame, visited, coStore)) {
- return false;
- }
- }
- }
- } else {
- for (ui32 i = 0; i < left.ChildrenSize(); ++i) {
- if (!EqualNodes(*left.Child(i), currLeftFrame, *right.Child(i), currRightFrame, visited, coStore)) {
- return false;
- }
- }
- }
- return true;
- case TExprNode::Lambda: {
- if (left.IsComplete() != right.IsComplete()) {
- return false;
- }
- if (left.ChildrenSize() != right.ChildrenSize()) {
- return false;
- }
- const auto& leftArgs = left.Head();
- const auto& rightArgs = right.Head();
- if (leftArgs.ChildrenSize() != rightArgs.ChildrenSize()) {
- return false;
- }
- for (ui32 i = 0; i < leftArgs.ChildrenSize(); ++i) {
- const auto& leftArg = *leftArgs.Child(i);
- const auto& rightArg = *rightArgs.Child(i);
- if (leftArg.GetTypeAnn() != rightArg.GetTypeAnn()) {
- return false;
- }
- if (leftArg.GetAllConstraints() != rightArg.GetAllConstraints()) {
- return false;
- }
- }
- TLambdaFrame newLeftFrame(&left, &currLeftFrame);
- TLambdaFrame newRightFrame(&right, &currRightFrame);
- for (ui32 i = 1U; i < left.ChildrenSize(); ++i) {
- if (!EqualNodes(*left.Child(i), newLeftFrame, *right.Child(i), newRightFrame, visited, coStore))
- return false;
- }
- return true;
- }
- case TExprNode::Argument: {
- if (currLeftFrame.Lambda && currRightFrame.Lambda && IsArgInScope(currLeftFrame, left) && IsArgInScope(currRightFrame, right)) {
- const ui16 leftRelativeLevel = GetDependencyLevel(left);
- const ui16 rightRelativeLevel = GetDependencyLevel(right);
- return leftRelativeLevel == rightRelativeLevel && left.GetArgIndex() == right.GetArgIndex();
- } else {
- return &left == &right;
- }
- }
- case TExprNode::Arguments:
- break;
- case TExprNode::World:
- return true;
- }
- YQL_ENSURE(false, "Unexpected");
- return false;
- }
- using TCompareResults = THashMap<std::pair<const TExprNode*, const TExprNode*>, int>;
- int DoCompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited);
- int CompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited) {
- if (&left == &right) {
- return 0;
- }
- auto key = std::make_pair(&left, &right);
- if (auto it = visited.find(key); it != visited.end()) {
- return it->second;
- }
- int res = DoCompareNodes(left, right, visited);
- visited[key] = res;
- return res;
- }
- int DoCompareNodes(const TExprNode& left, const TExprNode& right, TCompareResults& visited) {
- if (left.Type() != right.Type()) {
- return (int)left.Type() - (int)right.Type();
- }
- switch (left.Type()) {
- case TExprNode::Atom:
- if (left.Content().size() != right.Content().size()) {
- return (int)left.Content().size() - (int)right.Content().size();
- }
- // compare pointers due to intern
- if (left.Content().data() != right.Content().data()) {
- if (const auto res = left.Content().compare(right.Content())) {
- return res;
- }
- }
- return (int)left.GetFlagsToCompare() - (int)right.GetFlagsToCompare();
- case TExprNode::Callable:
- if (left.Content().size() != right.Content().size()) {
- return (int)left.Content().size() - (int)right.Content().size();
- }
- // compare pointers due to intern
- if (left.Content().data() != right.Content().data()) {
- if (const auto res = left.Content().compare(right.Content())) {
- return res;
- }
- }
- [[fallthrough]];
- case TExprNode::List:
- if (left.ChildrenSize() != right.ChildrenSize()) {
- return (int)left.ChildrenSize() - (int)right.ChildrenSize();
- }
- for (ui32 i = 0; i < left.ChildrenSize(); ++i) {
- if (const auto res = CompareNodes(*left.Child(i), *right.Child(i), visited)) {
- return res;
- }
- }
- return 0;
- case TExprNode::Lambda: {
- if (left.ChildrenSize() != right.ChildrenSize()) {
- return (int)left.ChildrenSize() - (int)right.ChildrenSize();
- }
- const auto& leftArgs = left.Head();
- const auto& rightArgs = right.Head();
- if (leftArgs.ChildrenSize() != rightArgs.ChildrenSize()) {
- return (int)leftArgs.ChildrenSize() - (int)rightArgs.ChildrenSize();
- }
- for (ui32 i = 1U; i < left.ChildrenSize(); ++i) {
- if (const auto c = CompareNodes(*left.Child(i), *right.Child(i), visited))
- return c;
- }
- return 0;
- }
- case TExprNode::Argument:
- if (left.GetArgIndex() != right.GetArgIndex()) {
- return (int)left.GetArgIndex() - (int)right.GetArgIndex();
- }
- return (int)left.GetDependencyScope()->first->GetLambdaLevel() - (int)right.GetDependencyScope()->first->GetLambdaLevel();
- case TExprNode::Arguments:
- break;
- case TExprNode::World:
- return 0;
- }
- YQL_ENSURE(false, "Unexpected");
- return 0;
- }
- void CalculateCompletness(TExprNode& node, bool insideDependsOn, ui16 level, TNodeSet& closures,
- TNodeMap<TNodeSet>& visited, TNodeMap<TNodeSet>& visitedInsideDependsOn) {
- switch (node.Type()) {
- case TExprNode::Atom:
- node.SetDependencyScope(nullptr, nullptr);
- return;
- case TExprNode::Argument:
- closures.emplace(node.GetDependencyScope()->first);
- if (insideDependsOn) {
- node.SetUsedInDependsOn();
- }
- return;
- default: break;
- }
- const auto ins = (insideDependsOn ? visitedInsideDependsOn : visited).emplace(&node, TNodeSet{});
- if (!ins.second) {
- closures.insert(ins.first->second.cbegin(), ins.first->second.cend());
- return;
- }
- auto& internal = ins.first->second;
- if (TExprNode::Lambda == node.Type()) {
- node.SetLambdaLevel(level);
- node.Head().ForEachChild(std::bind(&TExprNode::SetDependencyScope, std::placeholders::_1, &node, &node));
- for (ui32 i = 1U; i < node.ChildrenSize(); ++i) {
- CalculateCompletness(*node.Child(i), insideDependsOn, level + 1, internal, visited, visitedInsideDependsOn);
- }
- internal.erase(&node);
- } else {
- insideDependsOn = insideDependsOn || node.IsCallable("DependsOn");
- node.ForEachChild(std::bind(&CalculateCompletness, std::placeholders::_1, insideDependsOn, level, std::ref(internal),
- std::ref(visited), std::ref(visitedInsideDependsOn)));
- }
- const TExprNode* outerLambda = nullptr;
- const TExprNode* innerLambda = nullptr;
- for (const auto lambda : internal) {
- if (!outerLambda || lambda->GetLambdaLevel() < outerLambda->GetLambdaLevel()) {
- outerLambda = lambda;
- }
- if (!innerLambda || lambda->GetLambdaLevel() > innerLambda->GetLambdaLevel()) {
- innerLambda = lambda;
- }
- }
- node.SetDependencyScope(outerLambda, innerLambda);
- closures.insert(internal.cbegin(), internal.cend());
- }
- ui64 CalcHash(TExprNode& node, const TColumnOrderStorage& coStore) {
- TLambdaFrame frame;
- return CalculateHash(0, node, frame, coStore);
- }
- bool EqualNodes(const TExprNode& left, const TExprNode& right, const TColumnOrderStorage& coStore) {
- TEqualResults visited;
- TLambdaFrame frame;
- return EqualNodes(left, frame, right, frame, visited, coStore);
- }
- TExprNode::TPtr VisitNode(TExprNode& node, TExprNode* currentLambda, ui16 level,
- std::unordered_multimap<ui64, TExprNode*>& uniqueNodes,
- std::unordered_multimap<ui64, TExprNode*>& incompleteNodes,
- TNodeMap<TExprNode*>& renames, const TColumnOrderStorage& coStore,
- const TNodeSet& reachable) {
- if (node.Type() == TExprNode::Argument) {
- return nullptr;
- }
- const auto find = renames.emplace(&node, nullptr);
- if (!find.second) {
- return find.first->second;
- }
- const auto hash = CalcHash(node, coStore);
- if (node.Type() == TExprNode::Lambda) {
- for (ui32 i = 1U; i < node.ChildrenSize(); ++i) {
- if (auto newNode = VisitNode(*node.Child(i), &node, level + 1U, uniqueNodes, incompleteNodes, renames, coStore, reachable)) {
- node.ChildRef(i) = std::move(newNode);
- }
- }
- } else {
- for (ui32 i = 0; i < node.ChildrenSize(); ++i) {
- if (auto newNode = VisitNode(*node.Child(i), currentLambda, level, uniqueNodes, incompleteNodes, renames, coStore, reachable)) {
- node.ChildRef(i) = std::move(newNode);
- }
- }
- }
- if (const auto kind = node.GetTypeAnn()->GetKind(); ETypeAnnotationKind::Flow != kind && ETypeAnnotationKind::Stream != kind || node.IsLambda()) {
- auto& nodesSet = node.IsComplete() ? uniqueNodes : incompleteNodes;
- const auto pair = nodesSet.equal_range(hash);
- auto iter = pair.first;
- while (pair.second != iter) {
- // search for duplicates
- if (iter->second->Dead()) {
- iter = nodesSet.erase(iter);
- continue;
- }
- if (!reachable.contains(iter->second)) {
- iter = nodesSet.erase(iter);
- continue;
- }
- if (!EqualNodes(node, *iter->second, coStore)) {
- #ifndef NDEBUG
- if (!GetEnv("YQL_ALLOW_CSEE_HASH_COLLISION")) {
- YQL_ENSURE(false, "Node -BEGIN-\n" << node.Dump() << "-END-" << " has same hash as -BEGIN-\n"
- << iter->second->Dump() << "-END-");
- }
- #endif
- ++iter;
- continue;
- }
- if (iter->second == &node)
- return nullptr;
- find.first->second = iter->second;
- if (node.Type() == TExprNode::Atom) {
- iter->second->NormalizeAtomFlags(node);
- }
- return iter->second;
- }
- nodesSet.emplace_hint(iter, hash, &node);
- }
- return nullptr;
- }
- }
- IGraphTransformer::TStatus UpdateCompletness(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext&) {
- YQL_PROFILE_SCOPE(DEBUG, "UpdateCompletness");
- output = input;
- // process closures
- TNodeSet closures;
- TNodeMap<TNodeSet> visited;
- TNodeMap<TNodeSet> visitedInsideDependsOn;
- CalculateCompletness(*input, false, 0, closures, visited, visitedInsideDependsOn);
- return IGraphTransformer::TStatus::Ok;
- }
- IGraphTransformer::TStatus EliminateCommonSubExpressions(const TExprNode::TPtr& input, TExprNode::TPtr& output,
- TExprContext& ctx, bool forSubGraph, const TColumnOrderStorage& coStore)
- {
- YQL_PROFILE_SCOPE(DEBUG, forSubGraph ? "EliminateCommonSubExpressionsForSubGraph" : "EliminateCommonSubExpressions");
- output = input;
- TNodeSet reachable;
- VisitExpr(*output, [&](const TExprNode& node) {
- reachable.emplace(&node);
- return true;
- });
- TNodeMap<TExprNode*> renames;
- //Cerr << "INPUT\n" << output->Dump() << "\n";
- std::unordered_multimap<ui64, TExprNode*> incompleteNodes;
- const auto newNode = VisitNode(*output, nullptr, 0, ctx.UniqueNodes, incompleteNodes, renames, coStore, reachable);
- YQL_ENSURE(forSubGraph || !newNode);
- //Cerr << "OUTPUT\n" << output->Dump() << "\n";
- return IGraphTransformer::TStatus::Ok;
- }
- int CompareNodes(const TExprNode& left, const TExprNode& right) {
- TCompareResults visited;
- return CompareNodes(left, right, visited);
- }
- }
|