__init__.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import sys
  2. import logging
  3. from pathlib import Path
  4. import subprocess
  5. from .differ import Differ
  6. LOGGER = logging.getLogger(__name__)
  7. def setup_logger():
  8. options = dict(
  9. level=logging.DEBUG,
  10. format='%(levelname)s: %(message)s',
  11. datefmt='%Y-%m-%d %H:%M:%S',
  12. stream=sys.stderr
  13. )
  14. logging.basicConfig(**options)
  15. setup_logger()
  16. def find_sql_tests(path):
  17. tests = []
  18. for sql_file in Path(path).glob('*.sql'):
  19. if not sql_file.is_file():
  20. LOGGER.warning("'%s' is not a file", sql_file.absolute())
  21. continue
  22. out_files = list(get_out_files(sql_file))
  23. if not out_files:
  24. LOGGER.warning("No .out files found for '%s'", sql_file.absolute())
  25. continue
  26. tests.append((sql_file.stem, (sql_file, out_files)))
  27. return tests
  28. def load_init_scripts_for_testcase(testcase_name, init_scripts_cfg, init_scripts_dir):
  29. with open(init_scripts_cfg, 'r') as cfg:
  30. for lineno, line in enumerate(cfg, 1):
  31. cfgline = line.strip().split(':')
  32. if len(cfgline) != 2:
  33. LOGGER.info("Bad line %d in init scripts configuration '%s'", lineno, init_scripts_cfg)
  34. continue
  35. if cfgline[0].strip() == testcase_name:
  36. break
  37. else:
  38. return []
  39. avail_scripts = frozenset(s.stem for s in init_scripts_dir.glob("*.sql"))
  40. scripts = [(init_scripts_dir / s).with_suffix(".sql") for s in cfgline[1].split() if s in avail_scripts]
  41. if scripts:
  42. LOGGER.debug("Init scripts: %s", ", ".join(s.stem for s in scripts))
  43. return scripts
  44. def run_sql_test(sql, out, tmp_path, runner, udfs, init_scripts_cfg, init_scripts_dir):
  45. args = [runner, "--datadir", tmp_path]
  46. for udf in udfs:
  47. args.append("--udf")
  48. args.append(udf)
  49. LOGGER.debug("Loading init scripts for '%s' from '%s'", sql.stem, init_scripts_cfg)
  50. init_scripts = load_init_scripts_for_testcase(sql.stem, init_scripts_cfg, Path(init_scripts_dir))
  51. if init_scripts:
  52. LOGGER.debug("Executing init scripts for '%s'", sql.stem)
  53. for script in init_scripts:
  54. LOGGER.debug("Executing init script '%s'", script.name)
  55. with open(script, 'rb') as f:
  56. pi = subprocess.run(args,
  57. stdin=f, stdout=subprocess.PIPE, stderr=sys.stderr, check=True)
  58. LOGGER.debug("Running %s '%s' -> [%s]", runner, sql, ', '.join("'{}'".format(a) for a in out))
  59. with open(sql, 'rb') as f:
  60. pi = subprocess.run(args,
  61. stdin=f, stdout=subprocess.PIPE, stderr=sys.stderr, check=True)
  62. min_diff = sys.maxsize
  63. best_match = out[0]
  64. best_diff = ''
  65. for out_file in out:
  66. with open(out_file, 'rb') as f:
  67. out_data = f.read()
  68. last_diff = Differ.diff(pi.stdout, out_data)
  69. diff_len = len(last_diff)
  70. if diff_len == 0:
  71. return
  72. if diff_len < min_diff:
  73. min_diff = diff_len
  74. best_match = out_file
  75. best_diff = last_diff
  76. LOGGER.info("No exact match for '%s'. Best match is '%s'", sql, best_match)
  77. for line in best_diff:
  78. LOGGER.debug(line)
  79. # We need assert to fail the test properly
  80. assert min_diff == 0, \
  81. f"pgrun output does not match out-file for {sql}. Diff:\n" + ''.join(d.decode('utf8') for d in best_diff)[:1024]
  82. def get_out_files(sql_file):
  83. base_name = sql_file.stem
  84. out_file = sql_file.with_suffix('.out')
  85. if out_file.is_file():
  86. yield out_file
  87. for i in range(1, 10):
  88. nth_out_file = out_file.with_stem('{}_{}'.format(base_name, i))
  89. if not nth_out_file.is_file():
  90. break
  91. yield nth_out_file