__main__.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. import os
  2. import os.path
  3. import sys
  4. import logging
  5. import subprocess
  6. from multiprocessing import Pool
  7. from pathlib import Path
  8. import tempfile
  9. import shutil
  10. import re
  11. import csv
  12. import click
  13. import patch
  14. from collections import Counter
  15. from library.python.svn_version import svn_version
  16. from yql.essentials.tests.postgresql.common import get_out_files, Differ
  17. PROGRAM_NAME = "pg-make-test"
  18. RUNNER = "../pgrun/pgrun"
  19. SPLITTER = "../pgrun/pgrun split-statements"
  20. INIT_SCRIPTS_CFG = "testinits.cfg"
  21. INIT_SCRIPTS_DIR = "initscripts"
  22. REPORT_FILE = "pg_tests.csv"
  23. LOGGER = None
  24. def get_logger(name, logfile, is_debug):
  25. logger = logging.getLogger(name)
  26. logger.setLevel(logging.DEBUG if is_debug else logging.INFO)
  27. if logfile is not None:
  28. logger.addHandler(logging.FileHandler(logfile, encoding="utf-8"))
  29. return logger
  30. def setup_logging(logfile, is_debug):
  31. global LOGGER
  32. LOGGER = get_logger(__file__, logfile, is_debug)
  33. class Configuration:
  34. def __init__(
  35. self, srcdir, dstdir, udfs, patchdir, skip_tests, runner, splitter, report_path, parallel, logfile, is_debug
  36. ):
  37. self.srcdir = srcdir
  38. self.dstdir = dstdir
  39. self.udfs = udfs
  40. self.patchdir = patchdir
  41. self.skip_tests = skip_tests
  42. self.runner = runner
  43. self.splitter = splitter
  44. self.report_path = report_path
  45. self.parallel = parallel
  46. self.logfile = logfile
  47. self.is_debug = is_debug
  48. def save_strings(fname, lst):
  49. with open(fname, 'wb') as f:
  50. for line in lst:
  51. f.write(line)
  52. def argwhere1(predicate, collection, default):
  53. """Returns index of the first element in collection, which satisfies the predicate."""
  54. try:
  55. pos, _ = next(enumerate(item for item in collection if predicate(item)))
  56. except StopIteration:
  57. return default
  58. else:
  59. return pos
  60. class TestCaseBuilder:
  61. def __init__(self, config):
  62. self.config = config
  63. def build(self, args):
  64. sqlfile, init_scripts = args
  65. is_split_logging = self.config.logfile is not None and self.config.parallel
  66. if is_split_logging:
  67. logger = get_logger(
  68. sqlfile.stem,
  69. f"{self.config.logfile.parent}/{sqlfile.stem}-{self.config.logfile.name}",
  70. self.config.is_debug,
  71. )
  72. else:
  73. logger = LOGGER
  74. splitted_stmts = list(self.split_sql_file(sqlfile))
  75. stmts_count = len(splitted_stmts)
  76. if init_scripts:
  77. logging.debug("Init scripts: %s", init_scripts)
  78. ressqlfile = self.config.dstdir / sqlfile.name
  79. resoutfile = ressqlfile.with_suffix('.out')
  80. reserrfile_base = resoutfile.with_suffix('.err')
  81. max_stmts_run = 0
  82. ressql = None
  83. resout = None
  84. for outfile_idx, outfile in enumerate(get_out_files(sqlfile)):
  85. test_name = Path(sqlfile).name
  86. LOGGER.info("Processing (%d) %s -> %s", os.getpid(), test_name, Path(outfile).name)
  87. if is_split_logging:
  88. logger.info("Processing (%d) %s -> %s", os.getpid(), test_name, Path(outfile).name)
  89. with open(outfile, 'rb') as fout:
  90. outdata = fout.readlines()
  91. only_out_stmts = Counter()
  92. only_pgrun_stmts = Counter()
  93. statements = list(self.split_out_file(splitted_stmts, outdata, logger))
  94. logger.debug("Matching sql statements to .out file lines")
  95. for (s_sql, s_out) in statements:
  96. stmt = '\n'.join(str(sql_line) for sql_line in s_sql)
  97. only_out_stmts[stmt] += 1
  98. logger.debug(
  99. "<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n%s\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n%s\n============================",
  100. stmt,
  101. '\n'.join(str(out_line) for out_line in s_out),
  102. )
  103. with tempfile.TemporaryDirectory() as tempdir:
  104. test_out_name = Path(tempdir) / "test.out"
  105. test_err_name = test_out_name.with_suffix(".err")
  106. runner_args = self.config.runner + ["--datadir", tempdir]
  107. for udf in self.config.udfs:
  108. runner_args.append("--udf")
  109. runner_args.append(udf)
  110. if init_scripts:
  111. init_out_name = Path(tempdir) / "init.out"
  112. init_err_name = init_out_name.with_suffix(".err")
  113. for init_script in init_scripts:
  114. logger.debug("Running init script %s '%s'", self.config.runner, init_script)
  115. with open(init_script, 'rb') as f, open(init_out_name, 'wb') as fout, open(init_err_name, 'wb') as ferr:
  116. pi = subprocess.run(runner_args, stdin=f, stdout=fout, stderr=ferr)
  117. if pi.returncode != 0:
  118. logger.warning("%s returned error code %d", self.config.runner, pi.returncode)
  119. logger.debug("Running test %s '%s' -> [%s]", self.config.runner, sqlfile, outfile)
  120. with open(sqlfile, 'rb') as f, open(test_out_name, 'wb') as fout, open(test_err_name, 'wb') as ferr:
  121. pi = subprocess.run(runner_args, stdin=f, stdout=fout, stderr=ferr)
  122. if pi.returncode != 0:
  123. logger.warning("%s returned error code %d", self.config.runner, pi.returncode)
  124. with open(test_out_name, 'rb') as fresult:
  125. out = fresult.readlines()
  126. logger.debug("Run result:\n%s", str(b'\n'.join(out)))
  127. real_statements = list(self.split_out_file(splitted_stmts, out, logger))
  128. logger.debug("Matching sql statements to pgrun's output")
  129. for (s_sql, s_out) in real_statements:
  130. stmt = '\n'.join(str(sql_line) for sql_line in s_sql)
  131. logger.debug(
  132. "<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n%s\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n%s\n============================",
  133. stmt,
  134. '\n'.join(str(out_line) for out_line in s_out),
  135. )
  136. if 0 < only_out_stmts[stmt]:
  137. only_out_stmts[stmt] -= 1
  138. if 0 == only_out_stmts[stmt]:
  139. del only_out_stmts[stmt]
  140. else:
  141. only_pgrun_stmts[stmt] += 1
  142. reserrfile = reserrfile_base if outfile_idx == 0 else reserrfile_base.with_suffix(reserrfile_base.suffix + ".{0}".format(outfile_idx))
  143. shutil.move(test_err_name, reserrfile)
  144. if only_pgrun_stmts:
  145. logger.info("Statements in pgrun output, but not in out file:\n%s",
  146. "\n--------------------------------\n".join(stmt for stmt in only_pgrun_stmts))
  147. if only_out_stmts:
  148. logger.info("Statements in out file, but not in pgrun output:\n%s",
  149. "\n--------------------------------\n".join(stmt for stmt in only_out_stmts))
  150. stmts_run = 0
  151. stmts = []
  152. outs = []
  153. assert len(statements) == len(real_statements), f"Incorrect statements split in {test_name}. Statements in out-file: {len(statements)}, statements in pgrun output: {len(real_statements)}"
  154. for ((l_sql, out), (r_sql, res)) in zip(statements, real_statements):
  155. if l_sql != r_sql:
  156. logger.warning("out SQL <> pgrun SQL:\n <: %s\n >: %s", l_sql, r_sql)
  157. break
  158. if len(Differ.diff(b''.join(out), b''.join(res))) == 0:
  159. stmts.extend(l_sql)
  160. outs.extend(out)
  161. stmts_run += 1
  162. else:
  163. logger.warning("out result differs from pgrun result:\n <<: %s\n >>: %s", out, res)
  164. if max_stmts_run < stmts_run:
  165. max_stmts_run = stmts_run
  166. ressql = stmts
  167. resout = outs
  168. if ressql is not None and resout is not None:
  169. LOGGER.info('Case built: %s', sqlfile.name)
  170. if is_split_logging:
  171. logger.info('Case built: %s', sqlfile.name)
  172. save_strings(ressqlfile, ressql)
  173. save_strings(resoutfile, resout)
  174. else:
  175. LOGGER.warning('Case is empty: %s', sqlfile.name)
  176. if is_split_logging:
  177. logger.warning('Case is empty: %s', sqlfile.name)
  178. ressqlfile.unlink(missing_ok=True)
  179. resoutfile.unlink(missing_ok=True)
  180. return Path(sqlfile).stem, stmts_count, stmts_run, round(stmts_run * 100 / stmts_count, 2)
  181. def split_sql_file(self, sqlfile):
  182. with open(sqlfile, "rb") as f:
  183. pi = subprocess.run(self.config.splitter, stdin=f, stdout=subprocess.PIPE, stderr=sys.stderr, check=True)
  184. lines = iter(pi.stdout.splitlines(keepends=True))
  185. delimiter = next(lines)
  186. cur_stmt = []
  187. for line in lines:
  188. if line == delimiter:
  189. yield cur_stmt
  190. cur_stmt = []
  191. continue
  192. cur_stmt.append(line)
  193. if cur_stmt:
  194. yield cur_stmt
  195. reCopyFromStdin = re.compile(b"COPY[^;]+FROM std(?:in|out)", re.I)
  196. def split_out_file(self, stmts, outdata, logger):
  197. """Matches SQL & its output in outdata with individual SQL statements in stmts.
  198. Args:
  199. stmts ([[str]]): Iterator of SQL statements.
  200. outdata ([str]): Contents of out-file.
  201. Yields:
  202. [([str], [str])]: Sequence of matching parts in sql & out files.
  203. """
  204. cur_stmt_out = []
  205. out_iter = enumerate(outdata)
  206. echo_none = False
  207. in_copy_from = False
  208. no_more_stmts_expected = False
  209. try:
  210. line_no, out_line = next(out_iter)
  211. except StopIteration:
  212. no_more_stmts_expected = True
  213. in_line_no = 0
  214. for i, stmt in enumerate(stmts):
  215. if no_more_stmts_expected:
  216. yield stmt, cur_stmt_out
  217. cur_stmt_out = []
  218. continue
  219. try:
  220. for stmt_line in stmt:
  221. in_line_no += 1
  222. if echo_none:
  223. if stmt_line.startswith(b"\\set ECHO ") and not stmt_line.rstrip().endswith(b"none"):
  224. echo_none = False
  225. continue
  226. if stmt_line.startswith(b"\\set ECHO none"):
  227. echo_none = True
  228. # We skip data lines of copy ... from stdin, since they aren't in the out-file
  229. if in_copy_from:
  230. if stmt_line.startswith(b"\\."):
  231. in_copy_from = False
  232. continue
  233. if self.reCopyFromStdin.match(stmt_line):
  234. in_copy_from = True
  235. logger.debug("Line %d: %s -> %d: %s", in_line_no, stmt_line, line_no, out_line)
  236. if stmt_line != out_line:
  237. raise Exception(f"Mismatch at {line_no}: '{stmt_line}' != '{out_line}'")
  238. cur_stmt_out.append(out_line)
  239. line_no, out_line = next(out_iter)
  240. assert not in_copy_from, f"Missing copy from stdout table end marker \\. at line {in_line_no}"
  241. if echo_none:
  242. continue
  243. try:
  244. next_stmt = stmts[i + 1]
  245. except IndexError:
  246. cur_stmt_out.append(out_line)
  247. cur_stmt_out.extend(l for _, l in out_iter)
  248. logger.debug("Last out:\n%s", str(b'\n'.join(cur_stmt_out)))
  249. yield stmt, cur_stmt_out
  250. return
  251. while True:
  252. while out_line != next_stmt[0]:
  253. logger.debug("Out: %s -> %s", next_stmt[0], out_line)
  254. cur_stmt_out.append(out_line)
  255. line_no, out_line = next(out_iter)
  256. last_pos = argwhere1(lambda s: self.reCopyFromStdin.match(s), next_stmt, default=len(next_stmt))
  257. maybe_next_stmt = outdata[line_no : line_no + last_pos]
  258. logger.debug("Left: %s\nRight: %s", next_stmt, maybe_next_stmt)
  259. if next_stmt[:last_pos] == maybe_next_stmt:
  260. break
  261. cur_stmt_out.append(out_line)
  262. line_no, out_line = next(out_iter)
  263. yield stmt, cur_stmt_out
  264. cur_stmt_out = []
  265. except StopIteration:
  266. no_more_stmts_expected = True
  267. yield stmt, cur_stmt_out
  268. cur_stmt_out = []
  269. def load_patches(patchdir):
  270. for p in patchdir.glob("*.patch"):
  271. ps = patch.fromfile(p)
  272. if ps is not False:
  273. yield p.stem, ps
  274. reInitScriptsCfgLine = re.compile(r"^([\w.]+):\s*([\w.]+(?:\s+[\w.]+)*)$")
  275. def load_init_scripts(initscriptscfg, initscriptsdir, tests_set):
  276. init_scripts_map = dict()
  277. if not initscriptscfg.is_file():
  278. LOGGER.warning("Init scripts config file is not found: %s", initscriptscfg)
  279. return init_scripts_map
  280. if not initscriptsdir.is_dir():
  281. LOGGER.warning("Init scripts directory is not found: %s", initscriptsdir)
  282. return init_scripts_map
  283. scripts = frozenset(s.stem for s in initscriptsdir.glob("*.sql"))
  284. with open(initscriptscfg, 'r') as cfg:
  285. for lineno, line in enumerate(cfg, 1):
  286. line = line.strip()
  287. if not line:
  288. continue
  289. m = reInitScriptsCfgLine.match(line)
  290. if m is None:
  291. LOGGER.warning("Bad line %d in init scripts config %s", lineno, initscriptscfg)
  292. continue
  293. test_name = m[1]
  294. if test_name not in tests_set:
  295. LOGGER.debug("Skipping init scripts for unknown test case %s", test_name)
  296. continue
  297. deps = [(initscriptsdir / s).with_suffix(".sql") for s in m[2].split() if s in scripts]
  298. if not deps:
  299. LOGGER.debug("No init scripts are listed for test case %s", test_name)
  300. continue
  301. init_scripts_map[test_name] = deps
  302. return init_scripts_map
  303. def patch_cases(cases, patches, patchdir):
  304. for i, sql_full_name in enumerate(cases):
  305. sql_name = sql_full_name.name
  306. p = patches.get(sql_name, None)
  307. if p is None:
  308. continue
  309. patched_sql_full_name = patchdir / sql_name
  310. shutil.copyfile(sql_full_name, patched_sql_full_name)
  311. success = p.apply(root=patchdir)
  312. if not success:
  313. LOGGER.warning(
  314. "Failed to patch %s testcase. Original version %s will be used", patched_sql_full_name, sql_full_name
  315. )
  316. continue
  317. out_full_name = sql_full_name.with_suffix('.out')
  318. out_name = out_full_name.name
  319. patched_out_full_name = patchdir / out_name
  320. # .out file should be in the same dir as .sql file, so copy it anyway
  321. shutil.copyfile(out_full_name, patched_out_full_name)
  322. p = patches.get(out_name, None)
  323. if p is None:
  324. LOGGER.warning(
  325. "Out-file patch for %s testcase is not found. Original version %s will be used",
  326. patched_sql_full_name,
  327. sql_full_name,
  328. )
  329. continue
  330. success = p.apply(root=patchdir)
  331. if not success:
  332. LOGGER.warning(
  333. "Failed to patch out-file for %s testcase. Original version %s will be used",
  334. patched_sql_full_name,
  335. sql_full_name,
  336. )
  337. continue
  338. cases[i] = patched_sql_full_name
  339. LOGGER.info("Patched %s -> %s", sql_full_name, cases[i])
  340. @click.command()
  341. @click.argument("cases", type=str, nargs=-1)
  342. @click.option(
  343. "--srcdir",
  344. "-i",
  345. help="Directory with SQL suits to process",
  346. required=True,
  347. multiple=False,
  348. type=click.Path(exists=True, file_okay=False, resolve_path=True, path_type=Path),
  349. )
  350. @click.option(
  351. "--dstdir",
  352. "-o",
  353. help="Output directory",
  354. required=True,
  355. multiple=False,
  356. type=click.Path(exists=True, file_okay=False, resolve_path=True, writable=True, path_type=Path),
  357. )
  358. @click.option(
  359. "--patchdir",
  360. "-p",
  361. help="Directory with patches for SQL suits",
  362. required=False,
  363. multiple=False,
  364. type=click.Path(exists=True, file_okay=False, resolve_path=True, path_type=Path),
  365. )
  366. @click.option(
  367. "--udf",
  368. "-u",
  369. help="Load shared library with UDF by given path",
  370. required=False,
  371. multiple=True,
  372. type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
  373. )
  374. @click.option(
  375. "--initscriptscfg",
  376. help="Config file for tests' init scripts",
  377. default=INIT_SCRIPTS_CFG,
  378. required=False,
  379. multiple=False,
  380. type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=Path)
  381. )
  382. @click.option(
  383. "--initscriptsdir",
  384. help="Directory with tests' init scripts",
  385. default=INIT_SCRIPTS_DIR,
  386. required=False,
  387. multiple=False,
  388. type=click.Path(exists=True, file_okay=False, resolve_path=True, path_type=Path)
  389. )
  390. @click.option("--skip", "-s", help="Comma-separated list of testsuits to skip", multiple=False, type=click.STRING)
  391. @click.option("--runner", help="Test runner", default=RUNNER, required=False, multiple=False, type=click.STRING)
  392. @click.option(
  393. "--splitter", help="SQL statements splitter", default=SPLITTER, required=False, multiple=False, type=click.STRING
  394. )
  395. @click.option(
  396. "--report",
  397. "-r",
  398. help="Report file name",
  399. default=REPORT_FILE,
  400. required=False,
  401. multiple=False,
  402. type=click.Path(dir_okay=False, resolve_path=True, writable=True, path_type=Path),
  403. )
  404. @click.option("--parallel/--no-parallel", help="Tests build mode", default=True, required=False)
  405. @click.option(
  406. "--logfile",
  407. "-l",
  408. help="Log file",
  409. default=None,
  410. required=False,
  411. multiple=False,
  412. type=click.Path(dir_okay=False, resolve_path=True, writable=True, path_type=Path),
  413. )
  414. @click.option("--debug/--no-debug", help="Logs verbosity", default=False, required=False)
  415. @click.version_option(version=svn_version(), prog_name=PROGRAM_NAME)
  416. def cli(cases, srcdir, dstdir, patchdir, udf, initscriptscfg, initscriptsdir, skip, runner, splitter, report, parallel, logfile, debug):
  417. setup_logging(logfile, debug)
  418. if udf:
  419. LOGGER.debug("UDFs: %s", udf)
  420. if skip is not None:
  421. skip_tests = frozenset(
  422. test_name if not (test_name := s.strip()).endswith(".sql") else test_name[:-4] for s in skip.split(",")
  423. )
  424. else:
  425. skip_tests = frozenset()
  426. config = Configuration(
  427. srcdir, dstdir, udf, patchdir, skip_tests, runner.split(), splitter.split(), report, parallel, logfile, debug
  428. )
  429. if not cases:
  430. cases = [c for c in config.srcdir.glob("*.sql") if c.stem not in skip_tests]
  431. else:
  432. cases = [Path(c) if os.path.isabs(c) else config.srcdir / c for c in cases]
  433. init_scripts = load_init_scripts(initscriptscfg, initscriptsdir, frozenset(c.stem for c in cases))
  434. LOGGER.debug("Init scripts: %s", init_scripts)
  435. if config.patchdir is not None:
  436. patches = dict(load_patches(config.patchdir))
  437. LOGGER.info("Patches: %s", ", ".join(p for p in patches))
  438. else:
  439. patches = {}
  440. with tempfile.TemporaryDirectory() as tempdir:
  441. patch_cases(cases, patches, Path(tempdir))
  442. LOGGER.info("Test cases: %s", ", ".join(c.as_posix() for c in cases))
  443. builder = TestCaseBuilder(config)
  444. if config.parallel:
  445. with Pool() as pool:
  446. results = list(pool.imap_unordered(builder.build, [(test_case, init_scripts.get(test_case.stem) or []) for test_case in cases]))
  447. else:
  448. results = [builder.build(c) for c in cases]
  449. with open(config.report_path, "w", newline='') as f:
  450. writer = csv.writer(f, dialect="excel")
  451. writer.writerow(["testcase", "statements", "successful", "ratio"])
  452. writer.writerows(sorted(results))
  453. if __name__ == "__main__":
  454. try:
  455. cli()
  456. finally:
  457. logging.shutdown()
  458. # vim:tw=78:sw=4:et:ai:si