split_unittest.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import argparse
  2. import tempfile
  3. import shlex
  4. import subprocess
  5. def parse_args():
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument("--split-factor", type=int, default=0)
  8. parser.add_argument("--shard", type=int, default=0)
  9. parser.add_argument("command", nargs=argparse.REMAINDER)
  10. return parser.parse_args()
  11. def list_tests(binary):
  12. with tempfile.NamedTemporaryFile() as tmpfile:
  13. cmd = [binary, "--list-verbose", "--list-path", tmpfile.name]
  14. subprocess.check_call(cmd)
  15. with open(tmpfile.name) as afile:
  16. lines = afile.read().strip().split("\n")
  17. lines = [x.strip() for x in lines]
  18. return [x for x in lines if x]
  19. def get_shard_tests(args):
  20. test_names = list_tests(args.command[0])
  21. test_names = sorted(test_names)
  22. chunk_size = len(test_names) // args.split_factor
  23. not_used = len(test_names) % args.split_factor
  24. shift = chunk_size + (args.shard < not_used)
  25. start = chunk_size * args.shard + min(args.shard, not_used)
  26. end = start + shift
  27. return [] if end > len(test_names) else test_names[start:end]
  28. def get_shard_cmd_args(args):
  29. return ["+{}".format(x) for x in get_shard_tests(args)]
  30. def main():
  31. args = parse_args()
  32. if args.split_factor:
  33. shard_cmd = get_shard_cmd_args(args)
  34. if shard_cmd:
  35. cmd = args.command + shard_cmd
  36. else:
  37. print("No tests for {} shard".format(args.shard))
  38. return 0
  39. else:
  40. cmd = args.command
  41. rc = subprocess.call(cmd)
  42. if rc:
  43. print("Some tests failed. To reproduce run: {}".format(shlex.join(cmd)))
  44. return rc
  45. if __name__ == "__main__":
  46. exit(main())