test_splitter.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # coding: utf-8
  2. import collections
  3. def flatten_tests(test_classes):
  4. """
  5. >>> test_classes = {x: [x] for x in range(5)}
  6. >>> flatten_tests(test_classes)
  7. [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
  8. >>> test_classes = {x: [x + 1, x + 2] for x in range(2)}
  9. >>> flatten_tests(test_classes)
  10. [(0, 1), (0, 2), (1, 2), (1, 3)]
  11. """
  12. tests = []
  13. for class_name, test_names in test_classes.items():
  14. tests += [(class_name, test_name) for test_name in test_names]
  15. return tests
  16. def get_sequential_chunk(tests, modulo, modulo_index, is_sorted=False):
  17. """
  18. >>> get_sequential_chunk(range(10), 4, 0)
  19. [0, 1, 2]
  20. >>> get_sequential_chunk(range(10), 4, 1)
  21. [3, 4, 5]
  22. >>> get_sequential_chunk(range(10), 4, 2)
  23. [6, 7]
  24. >>> get_sequential_chunk(range(10), 4, 3)
  25. [8, 9]
  26. >>> get_sequential_chunk(range(10), 4, 4)
  27. []
  28. >>> get_sequential_chunk(range(10), 4, 5)
  29. []
  30. """
  31. if not is_sorted:
  32. tests = sorted(tests)
  33. chunk_size = len(tests) // modulo
  34. not_used = len(tests) % modulo
  35. shift = chunk_size + (modulo_index < not_used)
  36. start = chunk_size * modulo_index + min(modulo_index, not_used)
  37. end = start + shift
  38. return [] if end > len(tests) else tests[start:end]
  39. def get_shuffled_chunk(tests, modulo, modulo_index, is_sorted=False):
  40. """
  41. >>> get_shuffled_chunk(range(10), 4, 0)
  42. [0, 4, 8]
  43. >>> get_shuffled_chunk(range(10), 4, 1)
  44. [1, 5, 9]
  45. >>> get_shuffled_chunk(range(10), 4, 2)
  46. [2, 6]
  47. >>> get_shuffled_chunk(range(10), 4, 3)
  48. [3, 7]
  49. >>> get_shuffled_chunk(range(10), 4, 4)
  50. []
  51. >>> get_shuffled_chunk(range(10), 4, 5)
  52. []
  53. """
  54. if not is_sorted:
  55. tests = sorted(tests)
  56. result_tests = []
  57. for i, test in enumerate(tests):
  58. if i % modulo == modulo_index:
  59. result_tests.append(test)
  60. return result_tests
  61. def get_splitted_tests(test_entities, modulo, modulo_index, partition_mode, is_sorted=False):
  62. if partition_mode == 'SEQUENTIAL':
  63. return get_sequential_chunk(test_entities, modulo, modulo_index, is_sorted)
  64. elif partition_mode == 'MODULO':
  65. return get_shuffled_chunk(test_entities, modulo, modulo_index, is_sorted)
  66. else:
  67. raise ValueError("detected unknown partition mode: {}".format(partition_mode))
  68. def filter_tests_by_modulo(test_classes, modulo, modulo_index, split_by_tests, partition_mode="SEQUENTIAL"):
  69. """
  70. >>> test_classes = {x: [x] for x in range(20)}
  71. >>> filter_tests_by_modulo(test_classes, 4, 0, False)
  72. {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]}
  73. >>> filter_tests_by_modulo(test_classes, 4, 1, False)
  74. {8: [8], 9: [9], 5: [5], 6: [6], 7: [7]}
  75. >>> filter_tests_by_modulo(test_classes, 4, 2, False)
  76. {10: [10], 11: [11], 12: [12], 13: [13], 14: [14]}
  77. >>> dict(filter_tests_by_modulo(test_classes, 4, 0, True))
  78. {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]}
  79. >>> dict(filter_tests_by_modulo(test_classes, 4, 1, True))
  80. {8: [8], 9: [9], 5: [5], 6: [6], 7: [7]}
  81. """
  82. if split_by_tests:
  83. tests = get_splitted_tests(flatten_tests(test_classes), modulo, modulo_index, partition_mode)
  84. test_classes = collections.defaultdict(list)
  85. for class_name, test_name in tests:
  86. test_classes[class_name].append(test_name)
  87. return test_classes
  88. else:
  89. target_classes = get_splitted_tests(test_classes, modulo, modulo_index, partition_mode)
  90. return {class_name: test_classes[class_name] for class_name in target_classes}