cbo_hints.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. #include "cbo_optimizer_new.h"
  2. #include <util/string/join.h>
  3. #include <util/string/printf.h>
  4. #include <library/cpp/iterator/zip.h>
  5. using namespace NYql;
  6. TString ToLower(TString s) {
  7. for (char& c: s) {
  8. c = std::tolower(c);
  9. }
  10. return s;
  11. }
  12. class TOptimizerHintsParser {
  13. public:
  14. TOptimizerHintsParser(const TString& text)
  15. : Pos(-1)
  16. , Size(static_cast<i32>(text.size()) - 1)
  17. , Text(text)
  18. {}
  19. TOptimizerHints Parse() {
  20. Start();
  21. return Hints;
  22. }
  23. private:
  24. void Start() {
  25. while (Pos < Size) {
  26. auto hintType = Keyword({"JoinOrder", "Leading", "JoinType", "Rows"});
  27. if (hintType == "JoinOrder" || hintType == "Leading") {
  28. JoinOrder(hintType == "Leading");
  29. } else if (hintType == "JoinType") {
  30. JoinType();
  31. } else if (hintType == "Rows"){
  32. Rows();
  33. } else {
  34. ParseError(Sprintf("Undefined hints type: %s", hintType.c_str()), Pos - hintType.size());
  35. }
  36. SkipWhiteSpaces();
  37. }
  38. }
  39. TVector<TString> CollectLabels() {
  40. TVector<TString> labels;
  41. while (auto maybeTerm = MaybeLabel()) {
  42. labels.push_back(maybeTerm.value());
  43. }
  44. return labels;
  45. }
  46. void JoinType() {
  47. i32 beginPos = Pos + 1;
  48. Keyword({"("});
  49. i32 labelsBeginPos = Pos + 1;
  50. TVector<TString> labels = CollectLabels();
  51. if (labels.size() <= 1) {
  52. ParseError(Sprintf("Bad labels for JoinType hint: %s, example of the format: JoinType(t1 t2 Shuffle)", JoinSeq(", ", labels).c_str()), labelsBeginPos);
  53. }
  54. TString reqJoinAlgoStr = std::move(labels.back());
  55. labels.pop_back();
  56. Keyword({")"});
  57. TVector<EJoinAlgoType> joinAlgos = {EJoinAlgoType::GraceJoin, EJoinAlgoType::LookupJoin, EJoinAlgoType::MapJoin};
  58. TVector<TString> joinAlgosStr = {"shuffle", "lookup", "broadcast"};
  59. for (const auto& [JoinType, joinAlgoStr]: Zip(joinAlgos, joinAlgosStr)) {
  60. if (ToLower(reqJoinAlgoStr) == joinAlgoStr) {
  61. Hints.JoinAlgoHints->PushBack(std::move(labels), JoinType, "JoinType" + Text.substr(beginPos, Pos - beginPos + 1));
  62. return;
  63. }
  64. }
  65. ParseError(Sprintf("Unknown JoinType: '%s', supported algos: [%s]", reqJoinAlgoStr.c_str(), JoinSeq(", ", joinAlgosStr).c_str()), Pos - reqJoinAlgoStr.size());
  66. Y_UNREACHABLE();
  67. }
  68. void JoinOrder(bool leading /* is keyword "Leading" or "JoinOrder" */) {
  69. i32 beginPos = Pos + 1;
  70. Keyword({"("});
  71. auto joinOrderHintTree = JoinOrderLabels();
  72. Keyword({")"});
  73. Hints.JoinOrderHints->PushBack(
  74. std::move(joinOrderHintTree),
  75. leading? "Leading" : "JoinOrder" + Text.substr(beginPos, Pos - beginPos + 1)
  76. );
  77. }
  78. std::shared_ptr<TJoinOrderHints::ITreeNode> JoinOrderLabels() {
  79. auto lhs = JoinOrderLabel();
  80. auto rhs = JoinOrderLabel();
  81. return std::make_shared<TJoinOrderHints::TJoinNode>(std::move(lhs), std::move(rhs));
  82. }
  83. std::shared_ptr<TJoinOrderHints::ITreeNode> JoinOrderLabel() {
  84. if (auto maybeLabel = MaybeLabel()) {
  85. return std::make_shared<TJoinOrderHints::TRelationNode>(std::move(maybeLabel.value()));
  86. } else if (auto maybeBracket = MaybeKeyword({"("})) {
  87. auto join = JoinOrderLabels();
  88. Keyword({")"});
  89. return join;
  90. }
  91. ParseError(Sprintf("JoinOrder args must be either a relation, either a join, example of the format: JoinOrder(t1 (t2 t3))"), Pos);
  92. Y_UNREACHABLE();
  93. }
  94. void Rows() {
  95. i32 beginPos = Pos + 1;
  96. Keyword({"("});
  97. TVector<TString> labels = CollectLabels();
  98. auto signStr = Keyword({"+", "-", "/", "*", "#"});
  99. char sign = signStr[0];
  100. auto value = Number();
  101. Keyword({")"});
  102. TCardinalityHints::ECardOperation op;
  103. switch (sign) {
  104. case '+': { op = TCardinalityHints::ECardOperation::Add; break; }
  105. case '-': { op = TCardinalityHints::ECardOperation::Subtract; break; }
  106. case '/': { op = TCardinalityHints::ECardOperation::Divide; break; }
  107. case '*': { op = TCardinalityHints::ECardOperation::Multiply; break; }
  108. case '#': { op = TCardinalityHints::ECardOperation::Replace; break; }
  109. default: {ParseError(Sprintf("Unknown operation: '%c'", sign), Pos - 1); Y_UNREACHABLE();}
  110. }
  111. Hints.CardinalityHints->PushBack(std::move(labels), op, value, "Rows" + Text.substr(beginPos, Pos - beginPos + 1));
  112. }
  113. private:
  114. // Expressions
  115. void ParseError(const TString& err, i32 pos) {
  116. auto [line, linePos] = GetLineAndLinePosFromTextPos(pos);
  117. Y_ENSURE(false, Sprintf("Optimizer hints parser error at [line:%d, pos:%d], msg: %s", line, linePos, err.c_str()));
  118. }
  119. TString Label() {
  120. return Term(LabelAllowedSymbols());
  121. }
  122. std::optional<TString> MaybeLabel() {
  123. try {
  124. return Label();
  125. } catch (...) {
  126. return std::nullopt;
  127. }
  128. }
  129. TString Term(const std::bitset<256>& allowedSym = {}) {
  130. SkipWhiteSpaces();
  131. Y_ENSURE(Pos < Size, "Expected <string>, but got end of the string.");
  132. TString term;
  133. while (Pos < Size) {
  134. try {
  135. term.push_back(Char(allowedSym));
  136. } catch (...) {
  137. break;
  138. }
  139. }
  140. if (term.empty()) {
  141. ParseError("Expected a term!", Pos);
  142. }
  143. return term;
  144. }
  145. char Char(unsigned char c) {
  146. std::bitset<256> allowed;
  147. allowed[c] = 1;
  148. return Char(allowed);
  149. }
  150. char Char(unsigned char intervalBegin, unsigned char intervalEnd) {
  151. std::bitset<256> allowed;
  152. for (size_t i = intervalBegin; i <= intervalEnd; ++i) {
  153. allowed[i] = 1;
  154. }
  155. return Char(allowed);
  156. }
  157. char Char(const std::bitset<256>& allowedSymbols = {}) {
  158. Y_ENSURE(Pos < Size, Sprintf("Expected [%s], but got end of the string.", ""));
  159. char nextSym = Text[Pos + 1];
  160. if (allowedSymbols.count() == 0) {
  161. ++Pos;
  162. return nextSym;
  163. }
  164. for (size_t i = 0; i < allowedSymbols.size(); ++i) {
  165. if (allowedSymbols[i] && tolower(i) == tolower(nextSym)) {
  166. ++Pos;
  167. return nextSym;
  168. }
  169. }
  170. ParseError(Sprintf("Expected [%s], but got [%c]", "", nextSym), Pos);
  171. Y_UNREACHABLE();
  172. }
  173. std::optional<TString> MaybeKeyword(const TVector<TString>& keywords) {
  174. try {
  175. return Keyword(keywords);
  176. } catch(...) {
  177. return std::nullopt;
  178. }
  179. }
  180. TString Keyword(const TVector<TString>& keywords) {
  181. SkipWhiteSpaces();
  182. Y_ENSURE(Pos < Size, Sprintf("Expected [%s], but got end of the string.", JoinSeq(", ", keywords).c_str()));
  183. for (const auto& keyword: keywords) {
  184. size_t lowInclude = Pos + 1;
  185. size_t highExclude = lowInclude + keyword.size();
  186. if (Text.substr(lowInclude, highExclude - lowInclude).equal(keyword)) {
  187. Pos += keyword.size();
  188. return keyword;
  189. }
  190. }
  191. ParseError(Sprintf("Expected [%s], but got [%c]", JoinSeq(", ", keywords).c_str(), Text[Pos + 1]), Pos);
  192. Y_UNREACHABLE();
  193. }
  194. double Number() {
  195. SkipWhiteSpaces();
  196. Y_ENSURE(Pos < Size, Sprintf("Expected number, but got end of the string."));
  197. TString number;
  198. if (auto maybeSign = MaybeKeyword({"+", "-"})) {
  199. number.push_back(maybeSign.value()[0]);
  200. }
  201. auto term = Term(Digits() | Chars(".-e")); // for double like 1.0 / 1e9
  202. try {
  203. return std::stod(term);
  204. } catch (...) {
  205. ParseError(Sprintf("Expected a number, got [%s]", term.c_str()), Pos - term.size());
  206. }
  207. Y_UNREACHABLE();
  208. }
  209. private:
  210. // Helpers
  211. constexpr std::bitset<256> Chars(const TString& s) {
  212. std::bitset<256> res;
  213. for (char c: s) {
  214. res[c] = 1;
  215. }
  216. return res;
  217. }
  218. constexpr std::bitset<256> Letters() {
  219. std::bitset<256> res;
  220. for (unsigned char i = 'a'; i <= 'z'; ++i) {
  221. res[i] = 1;
  222. }
  223. for (unsigned char i = 'A'; i <= 'Z'; ++i) {
  224. res[i] = 1;
  225. }
  226. return res;
  227. }
  228. constexpr std::bitset<256> Digits() {
  229. std::bitset<256> res;
  230. for (unsigned char i = '0'; i <= '9'; ++i) {
  231. res[i] = 1;
  232. }
  233. return res;
  234. }
  235. constexpr std::bitset<256> LabelAllowedSymbols() {
  236. auto labelSymbols = Digits() | Letters();
  237. labelSymbols['_'] = 1;
  238. return labelSymbols;
  239. }
  240. void SkipWhiteSpaces() {
  241. for (; Pos < Size && isspace(Text[Pos + 1]); ++Pos) {
  242. }
  243. }
  244. std::pair<i32, i32> GetLineAndLinePosFromTextPos(i32 pos) {
  245. i32 Line = 0;
  246. i32 LinePos = 0;
  247. for (i32 i = 0; i <= pos && i < static_cast<i32>(Text.size()); ++i) {
  248. if (Text[i] == '\n') {
  249. LinePos = 0;
  250. ++Line;
  251. } else {
  252. ++LinePos;
  253. }
  254. }
  255. return {Line, LinePos};
  256. }
  257. private:
  258. i32 Pos;
  259. const i32 Size;
  260. const TString& Text;
  261. private:
  262. TOptimizerHints Hints;
  263. };
  264. TOptimizerHints TOptimizerHints::Parse(const TString& text) {
  265. return TOptimizerHintsParser(text).Parse();
  266. }