split_unittest.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import argparse
  2. import os
  3. import tempfile
  4. import shlex
  5. import subprocess
  6. def parse_args():
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument("--split-factor", type=int, default=0)
  9. parser.add_argument("--shard", type=int, default=0)
  10. parser.add_argument("--fork-mode", type=str, default="SEQUENTIAL")
  11. parser.add_argument("command", nargs=argparse.REMAINDER)
  12. return parser.parse_args()
  13. def get_sequential_chunk(tests, modulo, modulo_index):
  14. chunk_size = len(tests) // modulo
  15. not_used = len(tests) % modulo
  16. shift = chunk_size + (modulo_index < not_used)
  17. start = chunk_size * modulo_index + min(modulo_index, not_used)
  18. end = start + shift
  19. return [] if end > len(tests) else tests[start:end]
  20. def get_shuffled_chunk(tests, modulo, modulo_index):
  21. result_tests = []
  22. for i, test in enumerate(tests):
  23. if i % modulo == modulo_index:
  24. result_tests.append(test)
  25. return result_tests
  26. def list_tests(binary):
  27. # can't use NamedTemporaryFile or mkstemp because of child process access issues on Windows
  28. # https://stackoverflow.com/questions/66744497/python-tempfile-namedtemporaryfile-cant-use-generated-tempfile
  29. with tempfile.TemporaryDirectory() as tmp_dir:
  30. list_file = os.path.join(tmp_dir, 'list')
  31. cmd = [binary, "--list-verbose", "--list-path", list_file]
  32. subprocess.check_call(cmd)
  33. with open(list_file) as afile:
  34. lines = afile.read().strip().split("\n")
  35. lines = [x.strip() for x in lines]
  36. return [x for x in lines if x]
  37. def get_shard_tests(args):
  38. test_names = list_tests(args.command[0])
  39. test_names = sorted(test_names)
  40. if args.fork_mode == "MODULO":
  41. return get_shuffled_chunk(test_names, args.split_factor, args.shard)
  42. elif args.fork_mode == "SEQUENTIAL":
  43. return get_sequential_chunk(test_names, args.split_factor, args.shard)
  44. else:
  45. raise ValueError("detected unknown partition mode: {}".format(args.fork_mode))
  46. def get_shard_cmd_args(args):
  47. return ["+{}".format(x) for x in get_shard_tests(args)]
  48. def main():
  49. args = parse_args()
  50. if args.split_factor:
  51. shard_cmd = get_shard_cmd_args(args)
  52. if shard_cmd:
  53. cmd = args.command + shard_cmd
  54. else:
  55. print("No tests for {} shard".format(args.shard))
  56. return 0
  57. else:
  58. cmd = args.command
  59. rc = subprocess.call(cmd)
  60. if rc:
  61. print("Some tests failed. To reproduce run: {}".format(shlex.join(cmd)))
  62. return rc
  63. if __name__ == "__main__":
  64. exit(main())