#include "mkql_node_visitor.h"
#include "mkql_node.h"

#include <util/generic/algorithm.h>

namespace NKikimr {
namespace NMiniKQL {

using namespace NDetail;

const ui64 IS_NODE_ENTERED = 1;
const ui64 IS_NODE_EXITED = 2;

void TThrowingNodeVisitor::Visit(TTypeType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TVoidType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TNullType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TEmptyListType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TEmptyDictType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TDataType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TPgType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TStructType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TListType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TOptionalType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TDictType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TCallableType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TAnyType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TTupleType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TResourceType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TVariantType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TVoid& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TNull& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TEmptyList& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TEmptyDict& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TDataLiteral& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TStructLiteral& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TListLiteral& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TOptionalLiteral& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TDictLiteral& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TCallable& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TAny& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TTupleLiteral& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TVariantLiteral& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TStreamType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TFlowType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TTaggedType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TBlockType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::Visit(TMultiType& node) {
    Y_UNUSED(node);
    ThrowUnexpectedNodeType();
}

void TThrowingNodeVisitor::ThrowUnexpectedNodeType() {
    THROW yexception() << "Unexpected node type";
}

void TEmptyNodeVisitor::Visit(TTypeType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TVoidType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TNullType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TEmptyListType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TEmptyDictType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TDataType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TPgType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TStructType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TListType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TOptionalType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TDictType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TCallableType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TAnyType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TTupleType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TResourceType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TVariantType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TVoid& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TNull& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TEmptyList& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TEmptyDict& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TDataLiteral& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TStructLiteral& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TListLiteral& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TOptionalLiteral& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TDictLiteral& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TCallable& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TAny& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TTupleLiteral& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TVariantLiteral& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TStreamType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TFlowType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TTaggedType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TBlockType& node) {
    Y_UNUSED(node);
}

void TEmptyNodeVisitor::Visit(TMultiType& node) {
    Y_UNUSED(node);
}

void TExploringNodeVisitor::Visit(TTypeType& node) {
    Y_DEBUG_ABORT_UNLESS(node.GetType() == &node);
}

void TExploringNodeVisitor::Visit(TVoidType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TNullType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TEmptyListType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TEmptyDictType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TDataType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TPgType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TStructType& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0, e = node.GetMembersCount(); i < e; ++i) {
        AddChildNode(&node, *node.GetMemberType(i));
    }
}

void TExploringNodeVisitor::Visit(TListType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetItemType());
}

void TExploringNodeVisitor::Visit(TOptionalType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetItemType());
}

void TExploringNodeVisitor::Visit(TDictType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetKeyType());
    AddChildNode(&node, *node.GetPayloadType());
}

void TExploringNodeVisitor::Visit(TCallableType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetReturnType());
    for (ui32 i = 0, e = node.GetArgumentsCount(); i < e; ++i) {
        AddChildNode(&node, *node.GetArgumentType(i));
    }

    if (node.GetPayload()) {
        AddChildNode(&node, *node.GetPayload());
    }
}

void TExploringNodeVisitor::Visit(TAnyType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TTupleType& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0, e = node.GetElementsCount(); i < e; ++i) {
        AddChildNode(&node, *node.GetElementType(i));
    }
}

void TExploringNodeVisitor::Visit(TResourceType& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TVariantType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetUnderlyingType());
}

void TExploringNodeVisitor::Visit(TVoid& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TNull& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TEmptyList& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TEmptyDict& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TDataLiteral& node) {
    AddChildNode(&node, *node.GetType());
}

void TExploringNodeVisitor::Visit(TStructLiteral& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0, e = node.GetValuesCount(); i < e; ++i) {
        AddChildNode(&node, *node.GetValue(i).GetNode());
    }
}

void TExploringNodeVisitor::Visit(TListLiteral& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0; i < node.GetItemsCount(); ++i) {
        AddChildNode(&node, *node.GetItems()[i].GetNode());
    }
}

void TExploringNodeVisitor::Visit(TOptionalLiteral& node) {
    AddChildNode(&node, *node.GetType());
    if (node.HasItem()) {
        AddChildNode(&node, *node.GetItem().GetNode());
    }
}

void TExploringNodeVisitor::Visit(TDictLiteral& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0, e = node.GetItemsCount(); i < e; ++i) {
        auto item = node.GetItem(i);
        AddChildNode(&node, *item.first.GetNode());
        AddChildNode(&node, *item.second.GetNode());
    }
}

void TExploringNodeVisitor::Visit(TCallable& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0, e = node.GetInputsCount(); i < e; ++i) {
        AddChildNode(&node, *node.GetInput(i).GetNode());
    }

    if (node.HasResult())
        AddChildNode(&node, *node.GetResult().GetNode());
}

void TExploringNodeVisitor::Visit(TAny& node) {
    AddChildNode(&node, *node.GetType());
    if (node.HasItem()) {
        AddChildNode(&node, *node.GetItem().GetNode());
    }
}

void TExploringNodeVisitor::Visit(TTupleLiteral& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0, e = node.GetValuesCount(); i < e; ++i) {
        AddChildNode(&node, *node.GetValue(i).GetNode());
    }
}

void TExploringNodeVisitor::Visit(TVariantLiteral& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetItem().GetNode());
}

void TExploringNodeVisitor::Visit(TStreamType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetItemType());
}

void TExploringNodeVisitor::Visit(TFlowType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetItemType());
}

void TExploringNodeVisitor::Visit(TTaggedType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetBaseType());
}

void TExploringNodeVisitor::Visit(TBlockType& node) {
    AddChildNode(&node, *node.GetType());
    AddChildNode(&node, *node.GetItemType());
}

void TExploringNodeVisitor::Visit(TMultiType& node) {
    AddChildNode(&node, *node.GetType());
    for (ui32 i = 0, e = node.GetElementsCount(); i < e; ++i) {
        AddChildNode(&node, *node.GetElementType(i));
    }
}

void TExploringNodeVisitor::AddChildNode(TNode* parent, TNode& child) {
    Stack->push_back(&child);

    if (BuildConsumersMap) {
        if (parent != nullptr) {
            ConsumersMap[&child].push_back(parent);
        } else {
            ConsumersMap[&child] = {};
        }
    }
}

void TExploringNodeVisitor::Clear() {
    NodeList.clear();
    Stack = nullptr;
    ConsumersMap.clear();
}

void TExploringNodeVisitor::Walk(TNode* root, const TTypeEnvironment& env, const std::vector<TNode*>& terminalNodes,
    bool buildConsumersMap, size_t nodesCountHint)
{
    BuildConsumersMap = buildConsumersMap;

    Clear();

    if (BuildConsumersMap && nodesCountHint) {
        ConsumersMap.reserve(nodesCountHint);
    }

    Stack = &env.GetNodeStack();
    Stack->clear();
    AddChildNode(nullptr, *root);
    while (!Stack->empty()) {
        auto node = Stack->back();

        if (node->GetCookie() == 0) {
            node->SetCookie(IS_NODE_ENTERED);

            if (Find(terminalNodes.begin(), terminalNodes.end(), node) == terminalNodes.end()) {
                node->Accept(*this);
            }
        } else {
            if (node->GetCookie() == IS_NODE_ENTERED) {
                NodeList.push_back(node);
                node->SetCookie(IS_NODE_EXITED);
            } else {
                Y_ABORT_UNLESS(node->GetCookie() <= IS_NODE_EXITED, "TNode graph should not be reused");
            }

            Stack->pop_back();
            continue;
        }
    }

    for (auto node : NodeList) {
        node->SetCookie(0);
    }

    Stack = nullptr;
}

const std::vector<TNode*>& TExploringNodeVisitor::GetNodes() {
    return NodeList;
}

const TExploringNodeVisitor::TNodesVec& TExploringNodeVisitor::GetConsumerNodes(TNode& node) {
    Y_ABORT_UNLESS(BuildConsumersMap);
    const auto consumers = ConsumersMap.find(&node);
    Y_ABORT_UNLESS(consumers != ConsumersMap.cend());
    return consumers->second;
}

template <bool InPlace>
TRuntimeNode SinglePassVisitCallablesImpl(TRuntimeNode root, TExploringNodeVisitor& explorer,
    const TCallableVisitFuncProvider& funcProvider, const TTypeEnvironment& env, bool& wereChanges)
{
    auto& nodes = explorer.GetNodes();

    wereChanges = false;

    for (TNode* exploredNode : nodes) {
        TNode* node;
        if (!InPlace) {
            node = exploredNode->CloneOnCallableWrite(env);
        } else {
            node = exploredNode;
            node->Freeze(env);
        }

        if (node->GetType()->IsCallable()) {
            auto& callable = static_cast<TCallable&>(*node);
            if (!callable.HasResult()) {
                const auto& callableType = callable.GetType();
                const auto& name = callableType->GetNameStr();
                const auto& func = funcProvider(name);
                if (func) {
                    TRuntimeNode result = func(callable, env);
                    result.Freeze();
                    if (result.GetNode() != node) {
                        if (InPlace) {
                            callable.SetResult(result, env);
                            wereChanges = true;
                        } else {
                            TNode* wrappedResult = TCallable::Create(result, callable.GetType(), env);
                            exploredNode->SetCookie((ui64)wrappedResult);
                        }

                        continue;
                    }
                }
            }
        }

        if (!InPlace && node != exploredNode) {
            exploredNode->SetCookie((ui64)node);
        }
    }

    if (!InPlace) {
        auto newRoot = (TNode*)root.GetNode()->GetCookie();
        if (newRoot) {
            root = TRuntimeNode(newRoot, root.IsImmediate());
            wereChanges = true;
        }
    }

    root.Freeze();
    if (!InPlace) {
        for (TNode* exploredNode : nodes) {
            exploredNode->SetCookie(0);
        }
    }

    return root;
}

TRuntimeNode SinglePassVisitCallables(TRuntimeNode root, TExploringNodeVisitor& explorer,
    const TCallableVisitFuncProvider& funcProvider, const TTypeEnvironment& env, bool inPlace, bool& wereChanges) {
    if (inPlace) {
        return SinglePassVisitCallablesImpl<true>(root, explorer, funcProvider, env, wereChanges);
    } else {
        return SinglePassVisitCallablesImpl<false>(root, explorer, funcProvider, env, wereChanges);
    }
}

}
}