123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- import json
- import os
- import six
- import re
- import yatest.common
- import zlib
- from yql_utils import get_param as yql_get_param
- from google.protobuf import text_format
- import yql.essentials.providers.common.proto.gateways_config_pb2 as gateways_config_pb2
- try:
- SQLRUN_PATH = yatest.common.binary_path('yql/essentials/tools/sql2yql/sql2yql')
- except BaseException:
- SQLRUN_PATH = None
- try:
- YQLRUN_PATH = yatest.common.binary_path('yql/tools/yqlrun/yqlrun')
- except BaseException:
- YQLRUN_PATH = None
- def _make_hash(x):
- if six.PY2:
- return hash(x)
- return zlib.crc32(repr(x).encode("utf-8"))
- def get_sql_flags():
- gateway_config = gateways_config_pb2.TGatewaysConfig()
- with open(yatest.common.source_path('yql/essentials/cfg/tests/gateways.conf')) as f:
- text_format.Merge(f.read(), gateway_config)
- if yql_get_param('SQL_FLAGS'):
- flags = yql_get_param('SQL_FLAGS').split(',')
- gateway_config.SqlCore.TranslationFlags.extend(flags)
- return gateway_config.SqlCore.TranslationFlags
- try:
- SQL_FLAGS = get_sql_flags()
- except BaseException:
- SQL_FLAGS = None
- def recursive_glob(root, begin_template=None, end_template=None):
- for parent, dirs, files in os.walk(root):
- for filename in files:
- if begin_template is not None and not filename.startswith(begin_template):
- continue
- if end_template is not None and not filename.endswith(end_template):
- continue
- path = os.path.join(parent, filename)
- yield os.path.relpath(path, root)
- def pytest_generate_tests_by_template(template, metafunc, data_path):
- assert data_path is not None
- argvalues = []
- suites = [name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))]
- for suite in suites:
- for case in sorted([sql_query_path[:-len(template)]
- for sql_query_path in recursive_glob(os.path.join(data_path, suite), end_template=template)]):
- argvalues.append((suite, case))
- metafunc.parametrize(['suite', 'case'], argvalues)
- def pytest_generate_tests_for_run(metafunc, template='.sql', suites=None, currentPart=0, partsCount=1, data_path=None, mode_expander=None):
- assert data_path is not None
- argvalues = []
- if not suites:
- suites = sorted([name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))])
- for suite in suites:
- suite_dir = os.path.join(data_path, suite)
- # .sql's
- for case in sorted([sql_query_path[:-len(template)]
- for sql_query_path in recursive_glob(suite_dir, end_template=template)]):
- case_program = case + template
- with open(os.path.join(suite_dir, case_program)) as f:
- if 'do not execute' in f.read():
- continue
- # .cfg's
- configs = [
- cfg_file.replace(case + '-', '').replace('.cfg', '')
- for cfg_file in recursive_glob(suite_dir, begin_template=case + '-', end_template='.cfg')
- ]
- if os.path.exists(suite_dir + '/' + case + '.cfg'):
- configs.append('')
- to_append = []
- for cfg in sorted(configs):
- if _make_hash((suite, case, cfg)) % partsCount == currentPart:
- to_append.append((suite, case, cfg))
- if not configs and _make_hash((suite, case, 'default.txt')) % partsCount == currentPart:
- to_append.append((suite, case, 'default.txt'))
- if mode_expander is None:
- argvalues += to_append
- else:
- argvalues += mode_expander(to_append)
- metafunc.parametrize(
- ['suite', 'case', 'cfg'] + (['what'] if mode_expander is not None else []),
- argvalues,
- )
- def pytest_generate_tests_for_part(metafunc, currentPart, partsCount, data_path=None, template='.sql', mode_expander=None):
- return pytest_generate_tests_for_run(metafunc, currentPart=currentPart, partsCount=partsCount,
- data_path=data_path, template=template, mode_expander=mode_expander)
- def get_cfg_file(cfg, case):
- if cfg:
- return (case + '-' + cfg + '.cfg') if cfg != 'default.txt' else 'default.cfg'
- else:
- return case + '.cfg'
- def validate_cfg(result):
- for r in result:
- assert r[0] in (
- "in",
- "in2",
- "out",
- "udf",
- "providers",
- "res",
- "mount",
- "canonize_peephole",
- "canonize_lineage",
- "peephole_use_blocks",
- "with_final_result_issues",
- "xfail",
- "pragma",
- "canonize_yt",
- "file",
- "http_file",
- "yt_file",
- "os",
- "param",
- ), "Unknown command in .cfg: %s" % (r[0])
- def get_config(suite, case, cfg, data_path):
- assert data_path is not None
- result = []
- try:
- default_cfg = get_cfg_file('default.txt', case)
- inherit = ['canonize_peephole', 'canonize_lineage', 'peephole_use_blocks']
- with open(os.path.join(data_path, suite, default_cfg)) as cfg_file_content:
- result = [line.strip().split() for line in cfg_file_content.readlines() if line.strip() and line.strip().split()[0]]
- validate_cfg(result)
- result = [r for r in result if r[0] in inherit]
- except IOError:
- pass
- cfg_file = get_cfg_file(cfg, case)
- with open(os.path.join(data_path, suite, cfg_file)) as cfg_file_content:
- result = [line.strip().split() for line in cfg_file_content.readlines() if line.strip()] + result
- validate_cfg(result)
- return result
- def load_json_file_strip_comments(path):
- with open(path) as file:
- return '\n'.join([line for line in file.readlines() if not line.startswith('#')])
- def get_parameters_files(suite, config, data_path):
- assert data_path is not None
- result = []
- for line in config:
- if len(line) != 3 or not line[0] == "param":
- continue
- result.append((line[1], os.path.join(data_path, suite, line[2])))
- return result
- def get_parameters_json(suite, config, data_path):
- assert data_path is not None
- parameters_files = get_parameters_files(suite, config, data_path)
- data = {}
- for p in parameters_files:
- value_json = json.loads(load_json_file_strip_comments(p[1]))
- data[p[0]] = {'Data': value_json}
- return data
- def output_dir(name):
- output_dir = yatest.common.output_path(name)
- if not os.path.isdir(output_dir):
- os.mkdir(output_dir)
- return output_dir
- def run_sql_on_mr(name, query, kikimr):
- out_dir = output_dir(name)
- opt_file = os.path.join(out_dir, 'opt.yql')
- results_file = os.path.join(out_dir, 'results.yson')
- try:
- kikimr(
- 'yql-exec -d 1 -P %s --sql --run --optimize -i /dev/stdin --oexpr %s --oresults %s' % (
- kikimr.yql_pool_id,
- opt_file,
- results_file
- ),
- stdin=query
- )
- except yatest.common.ExecutionError as e:
- runyqljob_result = e.execution_result
- assert 0, 'yql-exec finished with error: \n\n%s \n\non program: \n\n%s' % (
- runyqljob_result.std_err,
- query
- )
- return opt_file, results_file
- def normalize_table(csv, fields_order=None):
- '''
- :param csv: table content
- :param fields_order: normal order of fields (default: 'key', 'subkey', 'value')
- :return: normalized table content
- '''
- if not csv.strip():
- return ''
- headers = csv.splitlines()[0].strip().split(';')
- if fields_order is None:
- if len(set(headers)) < len(headers):
- # we have duplicates in case of joining tables, let's just cut headers and return as is
- return '\n'.join(csv.splitlines()[1:])
- fields_order = headers
- normalized = ''
- if any(field not in headers for field in fields_order):
- fields_order = sorted(headers)
- translator = {
- field: headers.index(field) for field in fields_order
- }
- def normalize_cell(s):
- if s == 't':
- return 'true'
- if s == 'f':
- return 'false'
- if '.' in s:
- try:
- f = float(s)
- return str(str(int(f)) if f.is_integer() else f)
- except ValueError:
- return s
- else:
- return s
- for line in csv.splitlines()[1:]:
- line = line.strip().split(';')
- normalized_cells = [normalize_cell(line[translator[field]]) for field in fields_order]
- normalized += '\n' + ';'.join(normalized_cells)
- return normalized.strip()
- def replace_vars(sql_query, var_tag):
- """
- Sql can contain comment like /* yt_local_var: VAR_NAME=VAR_VALUE */
- it will replace VAR_NAME with VAR_VALUE within sql query
- """
- vars = re.findall(r"\/\* {}: (.*)=(.*) \*\/".format(var_tag), sql_query)
- for var_name, var_value in vars:
- sql_query = re.sub(re.escape(var_name.strip()), var_value.strip(), sql_query)
- return sql_query
|