proto_ast_antlr4.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. #pragma once
  2. #include <yql/essentials/parser/proto_ast/common.h>
  3. #ifdef ERROR
  4. #undef ERROR
  5. #endif
  6. #include <contrib/libs/antlr4_cpp_runtime/src/antlr4-runtime.h>
  7. namespace antlr4 {
  8. class ANTLR4CPP_PUBLIC YqlErrorListener : public BaseErrorListener {
  9. NProtoAST::IErrorCollector* errors;
  10. bool* error;
  11. public:
  12. YqlErrorListener(NProtoAST::IErrorCollector* errors, bool* error);
  13. virtual void syntaxError(Recognizer *recognizer, Token * offendingSymbol, size_t line, size_t charPositionInLine,
  14. const std::string &msg, std::exception_ptr e) override;
  15. };
  16. }
  17. namespace NProtoAST {
  18. template <>
  19. inline void InvalidToken<antlr4::Token>(IOutputStream& err, const antlr4::Token* token) {
  20. if (token) {
  21. if (token->getInputStream()) {
  22. err << " '" << token->getText() << "'";
  23. } else {
  24. err << ABSENCE;
  25. }
  26. }
  27. }
  28. template <typename TParser, typename TLexer>
  29. class TProtoASTBuilder4 {
  30. public:
  31. TProtoASTBuilder4(TStringBuf data, const TString& queryName = "query", google::protobuf::Arena* arena = nullptr)
  32. : QueryName(queryName)
  33. , InputStream(data)
  34. , Lexer(&InputStream)
  35. , TokenStream(&Lexer)
  36. , Parser(&TokenStream, arena)
  37. {
  38. }
  39. google::protobuf::Message* BuildAST(IErrorCollector& errors) {
  40. // TODO: find a better way to break on lexer errors
  41. typename antlr4::YqlErrorListener listener(&errors, &Parser.error);
  42. Parser.removeErrorListeners();
  43. Parser.addErrorListener(&listener);
  44. try {
  45. auto result = Parser.Parse(&errors);
  46. Parser.removeErrorListener(&listener);
  47. Parser.error = false;
  48. return result;
  49. } catch (const TTooManyErrors&) {
  50. Parser.removeErrorListener(&listener);
  51. Parser.error = false;
  52. return nullptr;
  53. } catch (...) {
  54. errors.Error(0, 0, CurrentExceptionMessage());
  55. Parser.removeErrorListener(&listener);
  56. Parser.error = false;
  57. return nullptr;
  58. }
  59. }
  60. private:
  61. TString QueryName;
  62. antlr4::ANTLRInputStream InputStream;
  63. TLexer Lexer;
  64. antlr4::CommonTokenStream TokenStream;
  65. TParser Parser;
  66. };
  67. template <typename TLexer>
  68. class TLexerTokensCollector4 {
  69. public:
  70. TLexerTokensCollector4(TStringBuf data, const TString& queryName = "query")
  71. : QueryName(queryName)
  72. , InputStream(std::string(data))
  73. , Lexer(&InputStream)
  74. {
  75. }
  76. void CollectTokens(IErrorCollector& errors, const NSQLTranslation::ILexer::TTokenCallback& onNextToken) {
  77. try {
  78. bool error = false;
  79. typename antlr4::YqlErrorListener listener(&errors, &error);
  80. Lexer.removeErrorListeners();
  81. Lexer.addErrorListener(&listener);
  82. for (;;) {
  83. auto token = Lexer.nextToken();
  84. auto type = token->getType();
  85. const bool isEOF = type == TLexer::EOF;
  86. NSQLTranslation::TParsedToken last;
  87. last.Name = GetTokenName(type);
  88. last.Content = token->getText();
  89. last.Line = token->getLine();
  90. last.LinePos = token->getCharPositionInLine();
  91. onNextToken(std::move(last));
  92. if (isEOF) {
  93. break;
  94. }
  95. }
  96. } catch (const TTooManyErrors&) {
  97. } catch (...) {
  98. errors.Error(0, 0, CurrentExceptionMessage());
  99. }
  100. }
  101. private:
  102. TString GetTokenName(size_t type) const {
  103. auto res = Lexer.getVocabulary().getSymbolicName(type);
  104. if (res != ""){
  105. return TString(res);
  106. }
  107. return TString(INVALID_TOKEN_NAME);
  108. }
  109. TString QueryName;
  110. antlr4::ANTLRInputStream InputStream;
  111. TLexer Lexer;
  112. };
  113. } // namespace NProtoAST