split_unittest.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import argparse
  2. import tempfile
  3. import shlex
  4. import subprocess
  5. def parse_args():
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument("--split-factor", type=int, default=0)
  8. parser.add_argument("--shard", type=int, default=0)
  9. parser.add_argument("--fork-mode", type=str, default="SEQUENTIAL")
  10. parser.add_argument("command", nargs=argparse.REMAINDER)
  11. return parser.parse_args()
  12. def get_sequential_chunk(tests, modulo, modulo_index):
  13. chunk_size = len(tests) // modulo
  14. not_used = len(tests) % modulo
  15. shift = chunk_size + (modulo_index < not_used)
  16. start = chunk_size * modulo_index + min(modulo_index, not_used)
  17. end = start + shift
  18. return [] if end > len(tests) else tests[start:end]
  19. def get_shuffled_chunk(tests, modulo, modulo_index):
  20. result_tests = []
  21. for i, test in enumerate(tests):
  22. if i % modulo == modulo_index:
  23. result_tests.append(test)
  24. return result_tests
  25. def list_tests(binary):
  26. with tempfile.NamedTemporaryFile() as tmpfile:
  27. cmd = [binary, "--list-verbose", "--list-path", tmpfile.name]
  28. subprocess.check_call(cmd)
  29. with open(tmpfile.name) as afile:
  30. lines = afile.read().strip().split("\n")
  31. lines = [x.strip() for x in lines]
  32. return [x for x in lines if x]
  33. def get_shard_tests(args):
  34. test_names = list_tests(args.command[0])
  35. test_names = sorted(test_names)
  36. if args.fork_mode == "MODULO":
  37. return get_shuffled_chunk(test_names, args.split_factor, args.shard)
  38. elif args.fork_mode == "SEQUENTIAL":
  39. return get_sequential_chunk(test_names, args.split_factor, args.shard)
  40. else:
  41. raise ValueError("detected unknown partition mode: {}".format(args.fork_mode))
  42. def get_shard_cmd_args(args):
  43. return ["+{}".format(x) for x in get_shard_tests(args)]
  44. def main():
  45. args = parse_args()
  46. if args.split_factor:
  47. shard_cmd = get_shard_cmd_args(args)
  48. if shard_cmd:
  49. cmd = args.command + shard_cmd
  50. else:
  51. print("No tests for {} shard".format(args.shard))
  52. return 0
  53. else:
  54. cmd = args.command
  55. rc = subprocess.call(cmd)
  56. if rc:
  57. print("Some tests failed. To reproduce run: {}".format(shlex.join(cmd)))
  58. return rc
  59. if __name__ == "__main__":
  60. exit(main())