test_doc.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import yatest
  2. from yqlrun import YQLRun
  3. import cyson
  4. from multiprocessing.pool import ThreadPool
  5. import time
  6. import base64
  7. import binascii
  8. from yql_utils import get_param
  9. def run_one(item):
  10. line, input, output, should_fail = item
  11. start_time = time.time()
  12. try:
  13. support_udfs = False
  14. if "LIKE" in input:
  15. support_udfs = True
  16. yqlrun_res = YQLRun(prov='yt',
  17. use_sql2yql=False,
  18. cfg_dir='yql/essentials/cfg/udf_test',
  19. support_udfs=support_udfs).yql_exec(
  20. program="--!syntax_pg\n" + input,
  21. run_sql=True,
  22. check_error=True
  23. )
  24. dom = cyson.loads(yqlrun_res.results)
  25. elapsed_time = time.time() - start_time
  26. return (line, input, output, dom, None, elapsed_time, should_fail)
  27. except Exception as e:
  28. elapsed_time = time.time() - start_time
  29. return (line, input, output, None, e, elapsed_time, should_fail)
  30. def convert_cell(cell, output):
  31. if cell is None:
  32. value = 'NULL'
  33. elif isinstance(cell, bytes):
  34. if output.startswith("\\x"):
  35. value = "\\x" + binascii.hexlify(cell).decode("utf-8")
  36. else:
  37. value = cell.decode("utf-8")
  38. else:
  39. value = "\\x" + binascii.hexlify(base64.b64decode(cell[0])).decode("utf-8")
  40. if output.startswith("~"):
  41. value = ''
  42. output = ''
  43. return (value, output)
  44. def convert_value(data, output):
  45. if len(data) == 1:
  46. return convert_cell(data[0], output)
  47. lst = [convert_cell(x[0], x[1]) for x in zip(data, output.split(","))]
  48. return (",".join(x[0] for x in lst), ",".join(x[1] for x in lst))
  49. def test_doc():
  50. skip_before = get_param("skip_before")
  51. stop_at = get_param("stop_at")
  52. if skip_before is not None:
  53. print("WILL SKIP TESTS BEFORE: ", skip_before)
  54. if stop_at is not None:
  55. print("WILL STOP AT: ", stop_at)
  56. doc_src = yatest.common.source_path("contrib/ydb/docs/ru/core/postgresql/_includes/functions.md")
  57. with open(doc_src) as f:
  58. doc_data = f.readlines()
  59. in_code = False
  60. queue = []
  61. total = 0
  62. skipped = 0
  63. skipped_exception = 0
  64. skipped_mismatch_res = 0
  65. skipped_same_res = 0
  66. set_of = None
  67. original_line = None
  68. original_input = None
  69. multiline = None
  70. skip_in_progress = skip_before is not None
  71. should_fail = None
  72. for raw_line in doc_data:
  73. line = raw_line.strip()
  74. if stop_at is not None and line.startswith("## " + stop_at):
  75. break
  76. if skip_in_progress:
  77. if line.startswith("## " + skip_before):
  78. skip_in_progress = False
  79. continue
  80. if set_of is not None:
  81. if line.startswith("]"):
  82. queue.append((original_line, original_input, set_of, should_fail))
  83. set_of = None
  84. original_line = None
  85. original_input = None
  86. else:
  87. set_of.append(line)
  88. continue
  89. if multiline is not None:
  90. if line.endswith('"""'):
  91. multiline.append(line[0:line.index('"""')])
  92. queue.append((original_line, original_input, "".join(multiline), should_fail))
  93. multiline = None
  94. original_line = None
  95. original_input = None
  96. else:
  97. multiline.append(raw_line)
  98. continue
  99. if line.startswith("```sql"):
  100. in_code = True
  101. continue
  102. elif in_code and line.startswith("```"):
  103. in_code = False
  104. continue
  105. if not in_code:
  106. continue
  107. if "→" not in line:
  108. continue
  109. total += 1
  110. line = line.replace("~→ ", "→ ~")
  111. input, output = [x.strip() for x in line.split("→")]
  112. should_fail = False
  113. if input.startswith("#") and not input.startswith("# "):
  114. should_fail = True
  115. skipped += 1
  116. input = input[1:]
  117. if not input.startswith("SELECT"):
  118. input = "SELECT " + input
  119. if "/*" in output:
  120. output = output[:output.index("/*")].strip()
  121. if output.startswith('"""'):
  122. multiline = [output[output.index('"""') + 3:] + "\n"]
  123. original_line = line
  124. original_input = input
  125. continue
  126. elif output.startswith("'") and output.endswith("'"):
  127. output = output[1:-1]
  128. elif output.endswith("["):
  129. set_of = []
  130. original_line = line
  131. original_input = input
  132. continue
  133. queue.append((line, input, output, should_fail))
  134. with ThreadPool(16) as pool:
  135. for res in pool.map(run_one, queue):
  136. line, input, output, dom, e, elapsed_time, should_fail = res
  137. print("TEST: " + line)
  138. print("INPUT: ", input)
  139. print("OUTPUT: ", output)
  140. print("ELAPSED: ", elapsed_time)
  141. if e is not None:
  142. if not should_fail:
  143. raise e
  144. else:
  145. skipped_exception += 1
  146. print("SKIPPED, EXCEPTION")
  147. else:
  148. data = dom[0][b"Write"][0][b"Data"]
  149. print("DATA: ", data)
  150. if isinstance(output, list):
  151. pairs = [convert_value(x[0], x[1]) for x in zip(data, output)]
  152. value = [x[0] for x in pairs]
  153. output = [x[1] for x in pairs]
  154. else:
  155. value, output = convert_value(data[0], output)
  156. print("VALUE: ", value)
  157. try:
  158. assert value == output, f"Expected '{output}' but got '{value}', test: {line}"
  159. except Exception as err:
  160. if should_fail:
  161. e = err
  162. skipped_mismatch_res += 1
  163. print("SKIPPED, MISMATCH RESULT")
  164. else:
  165. raise
  166. if should_fail and e is None:
  167. print("SKIPPED, SAME RESULT")
  168. skipped_same_res += 1
  169. print("TOTAL TESTS: ", total)
  170. print("SKIPPED TESTS: ", skipped)
  171. print("SKIPPED TESTS WITH EXCEPTION: ", skipped_exception)
  172. print("SKIPPED TESTS WITH MISMATCH RESULT: ", skipped_mismatch_res)
  173. print("SKIPPED TESTS WITH SAME RESULT: ", skipped_same_res)