test_utils.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import json
  2. import os
  3. import six
  4. import re
  5. import yatest.common
  6. import zlib
  7. from yql_utils import get_param as yql_get_param
  8. from google.protobuf import text_format
  9. import yql.essentials.providers.common.proto.gateways_config_pb2 as gateways_config_pb2
  10. try:
  11. SQLRUN_PATH = yatest.common.binary_path('yql/essentials/tools/sql2yql/sql2yql')
  12. except BaseException:
  13. SQLRUN_PATH = None
  14. try:
  15. YQLRUN_PATH = yatest.common.binary_path('yql/tools/yqlrun/yqlrun')
  16. except BaseException:
  17. YQLRUN_PATH = None
  18. def _make_hash(x):
  19. if six.PY2:
  20. return hash(x)
  21. return zlib.crc32(repr(x).encode("utf-8"))
  22. def get_sql_flags():
  23. gateway_config = gateways_config_pb2.TGatewaysConfig()
  24. with open(yatest.common.source_path('yql/essentials/cfg/tests/gateways.conf')) as f:
  25. text_format.Merge(f.read(), gateway_config)
  26. if yql_get_param('SQL_FLAGS'):
  27. flags = yql_get_param('SQL_FLAGS').split(',')
  28. gateway_config.SqlCore.TranslationFlags.extend(flags)
  29. return gateway_config.SqlCore.TranslationFlags
  30. try:
  31. SQL_FLAGS = get_sql_flags()
  32. except BaseException:
  33. SQL_FLAGS = None
  34. def recursive_glob(root, begin_template=None, end_template=None):
  35. for parent, dirs, files in os.walk(root):
  36. for filename in files:
  37. if begin_template is not None and not filename.startswith(begin_template):
  38. continue
  39. if end_template is not None and not filename.endswith(end_template):
  40. continue
  41. path = os.path.join(parent, filename)
  42. yield os.path.relpath(path, root)
  43. def pytest_generate_tests_by_template(template, metafunc, data_path):
  44. assert data_path is not None
  45. argvalues = []
  46. suites = [name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))]
  47. for suite in suites:
  48. for case in sorted([sql_query_path[:-len(template)]
  49. for sql_query_path in recursive_glob(os.path.join(data_path, suite), end_template=template)]):
  50. argvalues.append((suite, case))
  51. metafunc.parametrize(['suite', 'case'], argvalues)
  52. def pytest_generate_tests_for_run(metafunc, template='.sql', suites=None, currentPart=0, partsCount=1, data_path=None, mode_expander=None):
  53. assert data_path is not None
  54. argvalues = []
  55. if not suites:
  56. suites = sorted([name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))])
  57. for suite in suites:
  58. suite_dir = os.path.join(data_path, suite)
  59. # .sql's
  60. for case in sorted([sql_query_path[:-len(template)]
  61. for sql_query_path in recursive_glob(suite_dir, end_template=template)]):
  62. case_program = case + template
  63. with open(os.path.join(suite_dir, case_program)) as f:
  64. if 'do not execute' in f.read():
  65. continue
  66. # .cfg's
  67. configs = [
  68. cfg_file.replace(case + '-', '').replace('.cfg', '')
  69. for cfg_file in recursive_glob(suite_dir, begin_template=case + '-', end_template='.cfg')
  70. ]
  71. if os.path.exists(suite_dir + '/' + case + '.cfg'):
  72. configs.append('')
  73. to_append = []
  74. for cfg in sorted(configs):
  75. if _make_hash((suite, case, cfg)) % partsCount == currentPart:
  76. to_append.append((suite, case, cfg))
  77. if not configs and _make_hash((suite, case, 'default.txt')) % partsCount == currentPart:
  78. to_append.append((suite, case, 'default.txt'))
  79. if mode_expander is None:
  80. argvalues += to_append
  81. else:
  82. argvalues += mode_expander(to_append)
  83. metafunc.parametrize(
  84. ['suite', 'case', 'cfg'] + (['what'] if mode_expander is not None else []),
  85. argvalues,
  86. )
  87. def pytest_generate_tests_for_part(metafunc, currentPart, partsCount, data_path=None, template='.sql', mode_expander=None):
  88. return pytest_generate_tests_for_run(metafunc, currentPart=currentPart, partsCount=partsCount,
  89. data_path=data_path, template=template, mode_expander=mode_expander)
  90. def get_cfg_file(cfg, case):
  91. if cfg:
  92. return (case + '-' + cfg + '.cfg') if cfg != 'default.txt' else 'default.cfg'
  93. else:
  94. return case + '.cfg'
  95. def validate_cfg(result):
  96. for r in result:
  97. assert r[0] in (
  98. "in",
  99. "in2",
  100. "out",
  101. "udf",
  102. "providers",
  103. "res",
  104. "mount",
  105. "canonize_peephole",
  106. "canonize_lineage",
  107. "peephole_use_blocks",
  108. "with_final_result_issues",
  109. "xfail",
  110. "pragma",
  111. "canonize_yt",
  112. "file",
  113. "http_file",
  114. "yt_file",
  115. "os",
  116. "param",
  117. ), "Unknown command in .cfg: %s" % (r[0])
  118. def get_config(suite, case, cfg, data_path):
  119. assert data_path is not None
  120. result = []
  121. try:
  122. default_cfg = get_cfg_file('default.txt', case)
  123. inherit = ['canonize_peephole', 'canonize_lineage', 'peephole_use_blocks']
  124. with open(os.path.join(data_path, suite, default_cfg)) as cfg_file_content:
  125. result = [line.strip().split() for line in cfg_file_content.readlines() if line.strip() and line.strip().split()[0]]
  126. validate_cfg(result)
  127. result = [r for r in result if r[0] in inherit]
  128. except IOError:
  129. pass
  130. cfg_file = get_cfg_file(cfg, case)
  131. with open(os.path.join(data_path, suite, cfg_file)) as cfg_file_content:
  132. result = [line.strip().split() for line in cfg_file_content.readlines() if line.strip()] + result
  133. validate_cfg(result)
  134. return result
  135. def load_json_file_strip_comments(path):
  136. with open(path) as file:
  137. return '\n'.join([line for line in file.readlines() if not line.startswith('#')])
  138. def get_parameters_files(suite, config, data_path):
  139. assert data_path is not None
  140. result = []
  141. for line in config:
  142. if len(line) != 3 or not line[0] == "param":
  143. continue
  144. result.append((line[1], os.path.join(data_path, suite, line[2])))
  145. return result
  146. def get_parameters_json(suite, config, data_path):
  147. assert data_path is not None
  148. parameters_files = get_parameters_files(suite, config, data_path)
  149. data = {}
  150. for p in parameters_files:
  151. value_json = json.loads(load_json_file_strip_comments(p[1]))
  152. data[p[0]] = {'Data': value_json}
  153. return data
  154. def output_dir(name):
  155. output_dir = yatest.common.output_path(name)
  156. if not os.path.isdir(output_dir):
  157. os.mkdir(output_dir)
  158. return output_dir
  159. def run_sql_on_mr(name, query, kikimr):
  160. out_dir = output_dir(name)
  161. opt_file = os.path.join(out_dir, 'opt.yql')
  162. results_file = os.path.join(out_dir, 'results.yson')
  163. try:
  164. kikimr(
  165. 'yql-exec -d 1 -P %s --sql --run --optimize -i /dev/stdin --oexpr %s --oresults %s' % (
  166. kikimr.yql_pool_id,
  167. opt_file,
  168. results_file
  169. ),
  170. stdin=query
  171. )
  172. except yatest.common.ExecutionError as e:
  173. runyqljob_result = e.execution_result
  174. assert 0, 'yql-exec finished with error: \n\n%s \n\non program: \n\n%s' % (
  175. runyqljob_result.std_err,
  176. query
  177. )
  178. return opt_file, results_file
  179. def normalize_table(csv, fields_order=None):
  180. '''
  181. :param csv: table content
  182. :param fields_order: normal order of fields (default: 'key', 'subkey', 'value')
  183. :return: normalized table content
  184. '''
  185. if not csv.strip():
  186. return ''
  187. headers = csv.splitlines()[0].strip().split(';')
  188. if fields_order is None:
  189. if len(set(headers)) < len(headers):
  190. # we have duplicates in case of joining tables, let's just cut headers and return as is
  191. return '\n'.join(csv.splitlines()[1:])
  192. fields_order = headers
  193. normalized = ''
  194. if any(field not in headers for field in fields_order):
  195. fields_order = sorted(headers)
  196. translator = {
  197. field: headers.index(field) for field in fields_order
  198. }
  199. def normalize_cell(s):
  200. if s == 't':
  201. return 'true'
  202. if s == 'f':
  203. return 'false'
  204. if '.' in s:
  205. try:
  206. f = float(s)
  207. return str(str(int(f)) if f.is_integer() else f)
  208. except ValueError:
  209. return s
  210. else:
  211. return s
  212. for line in csv.splitlines()[1:]:
  213. line = line.strip().split(';')
  214. normalized_cells = [normalize_cell(line[translator[field]]) for field in fields_order]
  215. normalized += '\n' + ';'.join(normalized_cells)
  216. return normalized.strip()
  217. def replace_vars(sql_query, var_tag):
  218. """
  219. Sql can contain comment like /* yt_local_var: VAR_NAME=VAR_VALUE */
  220. it will replace VAR_NAME with VAR_VALUE within sql query
  221. """
  222. vars = re.findall(r"\/\* {}: (.*)=(.*) \*\/".format(var_tag), sql_query)
  223. for var_name, var_value in vars:
  224. sql_query = re.sub(re.escape(var_name.strip()), var_value.strip(), sql_query)
  225. return sql_query