123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- import yatest
- from yqlrun import YQLRun
- import cyson
- from multiprocessing.pool import ThreadPool
- import time
- import base64
- import binascii
- from yql_utils import get_param
- def run_one(item):
- line, input, output, should_fail = item
- start_time = time.time()
- try:
- support_udfs = False
- if "LIKE" in input:
- support_udfs = True
- yqlrun_res = YQLRun(prov='yt',
- use_sql2yql=False,
- cfg_dir='yql/essentials/cfg/udf_test',
- support_udfs=support_udfs).yql_exec(
- program="--!syntax_pg\n" + input,
- run_sql=True,
- check_error=True
- )
- dom = cyson.loads(yqlrun_res.results)
- elapsed_time = time.time() - start_time
- return (line, input, output, dom, None, elapsed_time, should_fail)
- except Exception as e:
- elapsed_time = time.time() - start_time
- return (line, input, output, None, e, elapsed_time, should_fail)
- def convert_cell(cell, output):
- if cell is None:
- value = 'NULL'
- elif isinstance(cell, bytes):
- if output.startswith("\\x"):
- value = "\\x" + binascii.hexlify(cell).decode("utf-8")
- else:
- value = cell.decode("utf-8")
- else:
- value = "\\x" + binascii.hexlify(base64.b64decode(cell[0])).decode("utf-8")
- if output.startswith("~"):
- value = ''
- output = ''
- return (value, output)
- def convert_value(data, output):
- if len(data) == 1:
- return convert_cell(data[0], output)
- lst = [convert_cell(x[0], x[1]) for x in zip(data, output.split(","))]
- return (",".join(x[0] for x in lst), ",".join(x[1] for x in lst))
- def test_doc():
- skip_before = get_param("skip_before")
- stop_at = get_param("stop_at")
- if skip_before is not None:
- print("WILL SKIP TESTS BEFORE: ", skip_before)
- if stop_at is not None:
- print("WILL STOP AT: ", stop_at)
- doc_src = yatest.common.source_path("contrib/ydb/docs/ru/core/postgresql/_includes/functions.md")
- with open(doc_src) as f:
- doc_data = f.readlines()
- in_code = False
- queue = []
- total = 0
- skipped = 0
- skipped_exception = 0
- skipped_mismatch_res = 0
- skipped_same_res = 0
- set_of = None
- original_line = None
- original_input = None
- multiline = None
- skip_in_progress = skip_before is not None
- should_fail = None
- for raw_line in doc_data:
- line = raw_line.strip()
- if stop_at is not None and line.startswith("## " + stop_at):
- break
- if skip_in_progress:
- if line.startswith("## " + skip_before):
- skip_in_progress = False
- continue
- if set_of is not None:
- if line.startswith("]"):
- queue.append((original_line, original_input, set_of, should_fail))
- set_of = None
- original_line = None
- original_input = None
- else:
- set_of.append(line)
- continue
- if multiline is not None:
- if line.endswith('"""'):
- multiline.append(line[0:line.index('"""')])
- queue.append((original_line, original_input, "".join(multiline), should_fail))
- multiline = None
- original_line = None
- original_input = None
- else:
- multiline.append(raw_line)
- continue
- if line.startswith("```sql"):
- in_code = True
- continue
- elif in_code and line.startswith("```"):
- in_code = False
- continue
- if not in_code:
- continue
- if "→" not in line:
- continue
- total += 1
- line = line.replace("~→ ", "→ ~")
- input, output = [x.strip() for x in line.split("→")]
- should_fail = False
- if input.startswith("#") and not input.startswith("# "):
- should_fail = True
- skipped += 1
- input = input[1:]
- if not input.startswith("SELECT"):
- input = "SELECT " + input
- if "/*" in output:
- output = output[:output.index("/*")].strip()
- if output.startswith('"""'):
- multiline = [output[output.index('"""') + 3:] + "\n"]
- original_line = line
- original_input = input
- continue
- elif output.startswith("'") and output.endswith("'"):
- output = output[1:-1]
- elif output.endswith("["):
- set_of = []
- original_line = line
- original_input = input
- continue
- queue.append((line, input, output, should_fail))
- with ThreadPool(16) as pool:
- for res in pool.map(run_one, queue):
- line, input, output, dom, e, elapsed_time, should_fail = res
- print("TEST: " + line)
- print("INPUT: ", input)
- print("OUTPUT: ", output)
- print("ELAPSED: ", elapsed_time)
- if e is not None:
- if not should_fail:
- raise e
- else:
- skipped_exception += 1
- print("SKIPPED, EXCEPTION")
- else:
- data = dom[0][b"Write"][0][b"Data"]
- print("DATA: ", data)
- if isinstance(output, list):
- pairs = [convert_value(x[0], x[1]) for x in zip(data, output)]
- value = [x[0] for x in pairs]
- output = [x[1] for x in pairs]
- else:
- value, output = convert_value(data[0], output)
- print("VALUE: ", value)
- try:
- assert value == output, f"Expected '{output}' but got '{value}', test: {line}"
- except Exception as err:
- if should_fail:
- e = err
- skipped_mismatch_res += 1
- print("SKIPPED, MISMATCH RESULT")
- else:
- raise
- if should_fail and e is None:
- print("SKIPPED, SAME RESULT")
- skipped_same_res += 1
- print("TOTAL TESTS: ", total)
- print("SKIPPED TESTS: ", skipped)
- print("SKIPPED TESTS WITH EXCEPTION: ", skipped_exception)
- print("SKIPPED TESTS WITH MISMATCH RESULT: ", skipped_mismatch_res)
- print("SKIPPED TESTS WITH SAME RESULT: ", skipped_same_res)
|