123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 |
- #include "cbo_optimizer_new.h"
- #include <util/string/join.h>
- #include <util/string/printf.h>
- #include <library/cpp/iterator/zip.h>
- using namespace NYql;
- TString ToLower(TString s) {
- for (char& c: s) {
- c = std::tolower(c);
- }
- return s;
- }
- class TOptimizerHintsParser {
- public:
- TOptimizerHintsParser(const TString& text)
- : Pos(-1)
- , Size(static_cast<i32>(text.size()) - 1)
- , Text(text)
- {}
- TOptimizerHints Parse() {
- Start();
- return Hints;
- }
- private:
- void Start() {
- while (Pos < Size) {
- auto hintType = Keyword({"JoinOrder", "Leading", "JoinType", "Rows"});
- if (hintType == "JoinOrder" || hintType == "Leading") {
- JoinOrder(hintType == "Leading");
- } else if (hintType == "JoinType") {
- JoinType();
- } else if (hintType == "Rows"){
- Rows();
- } else {
- ParseError(Sprintf("Undefined hints type: %s", hintType.c_str()), Pos - hintType.size());
- }
- SkipWhiteSpaces();
- }
- }
- TVector<TString> CollectLabels() {
- TVector<TString> labels;
- while (auto maybeTerm = MaybeLabel()) {
- labels.push_back(maybeTerm.value());
- }
- return labels;
- }
- void JoinType() {
- i32 beginPos = Pos + 1;
- Keyword({"("});
- i32 labelsBeginPos = Pos + 1;
- TVector<TString> labels = CollectLabels();
- if (labels.size() <= 1) {
- ParseError(Sprintf("Bad labels for JoinType hint: %s, example of the format: JoinType(t1 t2 Shuffle)", JoinSeq(", ", labels).c_str()), labelsBeginPos);
- }
- TString reqJoinAlgoStr = std::move(labels.back());
- labels.pop_back();
- Keyword({")"});
- TVector<EJoinAlgoType> joinAlgos = {EJoinAlgoType::GraceJoin, EJoinAlgoType::LookupJoin, EJoinAlgoType::MapJoin};
- TVector<TString> joinAlgosStr = {"shuffle", "lookup", "broadcast"};
- for (const auto& [JoinType, joinAlgoStr]: Zip(joinAlgos, joinAlgosStr)) {
- if (ToLower(reqJoinAlgoStr) == joinAlgoStr) {
- Hints.JoinAlgoHints->PushBack(std::move(labels), JoinType, "JoinType" + Text.substr(beginPos, Pos - beginPos + 1));
- return;
- }
- }
- ParseError(Sprintf("Unknown JoinType: '%s', supported algos: [%s]", reqJoinAlgoStr.c_str(), JoinSeq(", ", joinAlgosStr).c_str()), Pos - reqJoinAlgoStr.size());
- Y_UNREACHABLE();
- }
- void JoinOrder(bool leading /* is keyword "Leading" or "JoinOrder" */) {
- i32 beginPos = Pos + 1;
- Keyword({"("});
- auto joinOrderHintTree = JoinOrderLabels();
- Keyword({")"});
- Hints.JoinOrderHints->PushBack(
- std::move(joinOrderHintTree),
- leading? "Leading" : "JoinOrder" + Text.substr(beginPos, Pos - beginPos + 1)
- );
- }
- std::shared_ptr<TJoinOrderHints::ITreeNode> JoinOrderLabels() {
- auto lhs = JoinOrderLabel();
- auto rhs = JoinOrderLabel();
- return std::make_shared<TJoinOrderHints::TJoinNode>(std::move(lhs), std::move(rhs));
- }
- std::shared_ptr<TJoinOrderHints::ITreeNode> JoinOrderLabel() {
- if (auto maybeLabel = MaybeLabel()) {
- return std::make_shared<TJoinOrderHints::TRelationNode>(std::move(maybeLabel.value()));
- } else if (auto maybeBracket = MaybeKeyword({"("})) {
- auto join = JoinOrderLabels();
- Keyword({")"});
- return join;
- }
- ParseError(Sprintf("JoinOrder args must be either a relation, either a join, example of the format: JoinOrder(t1 (t2 t3))"), Pos);
- Y_UNREACHABLE();
- }
- void Rows() {
- i32 beginPos = Pos + 1;
- Keyword({"("});
- TVector<TString> labels = CollectLabels();
- auto signStr = Keyword({"+", "-", "/", "*", "#"});
- char sign = signStr[0];
- auto value = Number();
- Keyword({")"});
- TCardinalityHints::ECardOperation op;
- switch (sign) {
- case '+': { op = TCardinalityHints::ECardOperation::Add; break; }
- case '-': { op = TCardinalityHints::ECardOperation::Subtract; break; }
- case '/': { op = TCardinalityHints::ECardOperation::Divide; break; }
- case '*': { op = TCardinalityHints::ECardOperation::Multiply; break; }
- case '#': { op = TCardinalityHints::ECardOperation::Replace; break; }
- default: {ParseError(Sprintf("Unknown operation: '%c'", sign), Pos - 1); Y_UNREACHABLE();}
- }
- Hints.CardinalityHints->PushBack(std::move(labels), op, value, "Rows" + Text.substr(beginPos, Pos - beginPos + 1));
- }
- private:
- // Expressions
- void ParseError(const TString& err, i32 pos) {
- auto [line, linePos] = GetLineAndLinePosFromTextPos(pos);
- Y_ENSURE(false, Sprintf("Optimizer hints parser error at [line:%d, pos:%d], msg: %s", line, linePos, err.c_str()));
- }
- TString Label() {
- return Term(LabelAllowedSymbols());
- }
- std::optional<TString> MaybeLabel() {
- try {
- return Label();
- } catch (...) {
- return std::nullopt;
- }
- }
- TString Term(const std::bitset<256>& allowedSym = {}) {
- SkipWhiteSpaces();
- Y_ENSURE(Pos < Size, "Expected <string>, but got end of the string.");
- TString term;
- while (Pos < Size) {
- try {
- term.push_back(Char(allowedSym));
- } catch (...) {
- break;
- }
- }
- if (term.empty()) {
- ParseError("Expected a term!", Pos);
- }
- return term;
- }
- char Char(unsigned char c) {
- std::bitset<256> allowed;
- allowed[c] = 1;
- return Char(allowed);
- }
- char Char(unsigned char intervalBegin, unsigned char intervalEnd) {
- std::bitset<256> allowed;
- for (size_t i = intervalBegin; i <= intervalEnd; ++i) {
- allowed[i] = 1;
- }
- return Char(allowed);
- }
- char Char(const std::bitset<256>& allowedSymbols = {}) {
- Y_ENSURE(Pos < Size, Sprintf("Expected [%s], but got end of the string.", ""));
- char nextSym = Text[Pos + 1];
- if (allowedSymbols.count() == 0) {
- ++Pos;
- return nextSym;
- }
- for (size_t i = 0; i < allowedSymbols.size(); ++i) {
- if (allowedSymbols[i] && tolower(i) == tolower(nextSym)) {
- ++Pos;
- return nextSym;
- }
- }
- ParseError(Sprintf("Expected [%s], but got [%c]", "", nextSym), Pos);
- Y_UNREACHABLE();
- }
- std::optional<TString> MaybeKeyword(const TVector<TString>& keywords) {
- try {
- return Keyword(keywords);
- } catch(...) {
- return std::nullopt;
- }
- }
- TString Keyword(const TVector<TString>& keywords) {
- SkipWhiteSpaces();
- Y_ENSURE(Pos < Size, Sprintf("Expected [%s], but got end of the string.", JoinSeq(", ", keywords).c_str()));
- for (const auto& keyword: keywords) {
- size_t lowInclude = Pos + 1;
- size_t highExclude = lowInclude + keyword.size();
- if (Text.substr(lowInclude, highExclude - lowInclude).equal(keyword)) {
- Pos += keyword.size();
- return keyword;
- }
- }
- ParseError(Sprintf("Expected [%s], but got [%c]", JoinSeq(", ", keywords).c_str(), Text[Pos + 1]), Pos);
- Y_UNREACHABLE();
- }
- double Number() {
- SkipWhiteSpaces();
- Y_ENSURE(Pos < Size, Sprintf("Expected number, but got end of the string."));
- TString number;
- if (auto maybeSign = MaybeKeyword({"+", "-"})) {
- number.push_back(maybeSign.value()[0]);
- }
- auto term = Term(Digits() | Chars(".-e")); // for double like 1.0 / 1e9
- try {
- return std::stod(term);
- } catch (...) {
- ParseError(Sprintf("Expected a number, got [%s]", term.c_str()), Pos - term.size());
- }
- Y_UNREACHABLE();
- }
- private:
- // Helpers
- constexpr std::bitset<256> Chars(const TString& s) {
- std::bitset<256> res;
- for (char c: s) {
- res[c] = 1;
- }
- return res;
- }
- constexpr std::bitset<256> Letters() {
- std::bitset<256> res;
- for (unsigned char i = 'a'; i <= 'z'; ++i) {
- res[i] = 1;
- }
- for (unsigned char i = 'A'; i <= 'Z'; ++i) {
- res[i] = 1;
- }
- return res;
- }
- constexpr std::bitset<256> Digits() {
- std::bitset<256> res;
- for (unsigned char i = '0'; i <= '9'; ++i) {
- res[i] = 1;
- }
- return res;
- }
- constexpr std::bitset<256> LabelAllowedSymbols() {
- auto labelSymbols = Digits() | Letters();
- labelSymbols['_'] = 1;
- return labelSymbols;
- }
- void SkipWhiteSpaces() {
- for (; Pos < Size && isspace(Text[Pos + 1]); ++Pos) {
- }
- }
- std::pair<i32, i32> GetLineAndLinePosFromTextPos(i32 pos) {
- i32 Line = 0;
- i32 LinePos = 0;
- for (i32 i = 0; i <= pos && i < static_cast<i32>(Text.size()); ++i) {
- if (Text[i] == '\n') {
- LinePos = 0;
- ++Line;
- } else {
- ++LinePos;
- }
- }
- return {Line, LinePos};
- }
- private:
- i32 Pos;
- const i32 Size;
- const TString& Text;
- private:
- TOptimizerHints Hints;
- };
- TOptimizerHints TOptimizerHints::Parse(const TString& text) {
- return TOptimizerHintsParser(text).Parse();
- }
|