test_recipes.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170
  1. from decimal import Decimal
  2. from doctest import DocTestSuite
  3. from fractions import Fraction
  4. from functools import reduce
  5. from itertools import combinations, count, permutations
  6. from operator import mul
  7. from math import factorial
  8. from sys import version_info
  9. from unittest import TestCase, skipIf
  10. import more_itertools as mi
  11. def load_tests(loader, tests, ignore):
  12. # Add the doctests
  13. tests.addTests(DocTestSuite('more_itertools.recipes'))
  14. return tests
  15. class TakeTests(TestCase):
  16. """Tests for ``take()``"""
  17. def test_simple_take(self):
  18. """Test basic usage"""
  19. t = mi.take(5, range(10))
  20. self.assertEqual(t, [0, 1, 2, 3, 4])
  21. def test_null_take(self):
  22. """Check the null case"""
  23. t = mi.take(0, range(10))
  24. self.assertEqual(t, [])
  25. def test_negative_take(self):
  26. """Make sure taking negative items results in a ValueError"""
  27. self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
  28. def test_take_too_much(self):
  29. """Taking more than an iterator has remaining should return what the
  30. iterator has remaining.
  31. """
  32. t = mi.take(10, range(5))
  33. self.assertEqual(t, [0, 1, 2, 3, 4])
  34. class TabulateTests(TestCase):
  35. """Tests for ``tabulate()``"""
  36. def test_simple_tabulate(self):
  37. """Test the happy path"""
  38. t = mi.tabulate(lambda x: x)
  39. f = tuple([next(t) for _ in range(3)])
  40. self.assertEqual(f, (0, 1, 2))
  41. def test_count(self):
  42. """Ensure tabulate accepts specific count"""
  43. t = mi.tabulate(lambda x: 2 * x, -1)
  44. f = (next(t), next(t), next(t))
  45. self.assertEqual(f, (-2, 0, 2))
  46. class TailTests(TestCase):
  47. """Tests for ``tail()``"""
  48. def test_iterator_greater(self):
  49. """Length of iterator is greater than requested tail"""
  50. self.assertEqual(list(mi.tail(3, iter('ABCDEFG'))), list('EFG'))
  51. def test_iterator_equal(self):
  52. """Length of iterator is equal to the requested tail"""
  53. self.assertEqual(list(mi.tail(7, iter('ABCDEFG'))), list('ABCDEFG'))
  54. def test_iterator_less(self):
  55. """Length of iterator is less than requested tail"""
  56. self.assertEqual(list(mi.tail(8, iter('ABCDEFG'))), list('ABCDEFG'))
  57. def test_sized_greater(self):
  58. """Length of sized iterable is greater than requested tail"""
  59. self.assertEqual(list(mi.tail(3, 'ABCDEFG')), list('EFG'))
  60. def test_sized_equal(self):
  61. """Length of sized iterable is less than requested tail"""
  62. self.assertEqual(list(mi.tail(7, 'ABCDEFG')), list('ABCDEFG'))
  63. def test_sized_less(self):
  64. """Length of sized iterable is less than requested tail"""
  65. self.assertEqual(list(mi.tail(8, 'ABCDEFG')), list('ABCDEFG'))
  66. class ConsumeTests(TestCase):
  67. """Tests for ``consume()``"""
  68. def test_sanity(self):
  69. """Test basic functionality"""
  70. r = (x for x in range(10))
  71. mi.consume(r, 3)
  72. self.assertEqual(3, next(r))
  73. def test_null_consume(self):
  74. """Check the null case"""
  75. r = (x for x in range(10))
  76. mi.consume(r, 0)
  77. self.assertEqual(0, next(r))
  78. def test_negative_consume(self):
  79. """Check that negative consumption throws an error"""
  80. r = (x for x in range(10))
  81. self.assertRaises(ValueError, lambda: mi.consume(r, -1))
  82. def test_total_consume(self):
  83. """Check that iterator is totally consumed by default"""
  84. r = (x for x in range(10))
  85. mi.consume(r)
  86. self.assertRaises(StopIteration, lambda: next(r))
  87. class NthTests(TestCase):
  88. """Tests for ``nth()``"""
  89. def test_basic(self):
  90. """Make sure the nth item is returned"""
  91. l = range(10)
  92. for i, v in enumerate(l):
  93. self.assertEqual(mi.nth(l, i), v)
  94. def test_default(self):
  95. """Ensure a default value is returned when nth item not found"""
  96. l = range(3)
  97. self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
  98. def test_negative_item_raises(self):
  99. """Ensure asking for a negative item raises an exception"""
  100. self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
  101. class AllEqualTests(TestCase):
  102. """Tests for ``all_equal()``"""
  103. def test_true(self):
  104. """Everything is equal"""
  105. self.assertTrue(mi.all_equal('aaaaaa'))
  106. self.assertTrue(mi.all_equal([0, 0, 0, 0]))
  107. def test_false(self):
  108. """Not everything is equal"""
  109. self.assertFalse(mi.all_equal('aaaaab'))
  110. self.assertFalse(mi.all_equal([0, 0, 0, 1]))
  111. def test_tricky(self):
  112. """Not everything is identical, but everything is equal"""
  113. items = [1, complex(1, 0), 1.0]
  114. self.assertTrue(mi.all_equal(items))
  115. def test_empty(self):
  116. """Return True if the iterable is empty"""
  117. self.assertTrue(mi.all_equal(''))
  118. self.assertTrue(mi.all_equal([]))
  119. def test_one(self):
  120. """Return True if the iterable is singular"""
  121. self.assertTrue(mi.all_equal('0'))
  122. self.assertTrue(mi.all_equal([0]))
  123. class QuantifyTests(TestCase):
  124. """Tests for ``quantify()``"""
  125. def test_happy_path(self):
  126. """Make sure True count is returned"""
  127. q = [True, False, True]
  128. self.assertEqual(mi.quantify(q), 2)
  129. def test_custom_predicate(self):
  130. """Ensure non-default predicates return as expected"""
  131. q = range(10)
  132. self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
  133. class PadnoneTests(TestCase):
  134. def test_basic(self):
  135. iterable = range(2)
  136. for func in (mi.pad_none, mi.padnone):
  137. with self.subTest(func=func):
  138. p = func(iterable)
  139. self.assertEqual(
  140. [0, 1, None, None], [next(p) for _ in range(4)]
  141. )
  142. class NcyclesTests(TestCase):
  143. """Tests for ``nyclces()``"""
  144. def test_happy_path(self):
  145. """cycle a sequence three times"""
  146. r = ["a", "b", "c"]
  147. n = mi.ncycles(r, 3)
  148. self.assertEqual(
  149. ["a", "b", "c", "a", "b", "c", "a", "b", "c"], list(n)
  150. )
  151. def test_null_case(self):
  152. """asking for 0 cycles should return an empty iterator"""
  153. n = mi.ncycles(range(100), 0)
  154. self.assertRaises(StopIteration, lambda: next(n))
  155. def test_pathological_case(self):
  156. """asking for negative cycles should return an empty iterator"""
  157. n = mi.ncycles(range(100), -10)
  158. self.assertRaises(StopIteration, lambda: next(n))
  159. class DotproductTests(TestCase):
  160. """Tests for ``dotproduct()``'"""
  161. def test_happy_path(self):
  162. """simple dotproduct example"""
  163. self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
  164. class FlattenTests(TestCase):
  165. """Tests for ``flatten()``"""
  166. def test_basic_usage(self):
  167. """ensure list of lists is flattened one level"""
  168. f = [[0, 1, 2], [3, 4, 5]]
  169. self.assertEqual(list(range(6)), list(mi.flatten(f)))
  170. def test_single_level(self):
  171. """ensure list of lists is flattened only one level"""
  172. f = [[0, [1, 2]], [[3, 4], 5]]
  173. self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
  174. class RepeatfuncTests(TestCase):
  175. """Tests for ``repeatfunc()``"""
  176. def test_simple_repeat(self):
  177. """test simple repeated functions"""
  178. r = mi.repeatfunc(lambda: 5)
  179. self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
  180. def test_finite_repeat(self):
  181. """ensure limited repeat when times is provided"""
  182. r = mi.repeatfunc(lambda: 5, times=5)
  183. self.assertEqual([5, 5, 5, 5, 5], list(r))
  184. def test_added_arguments(self):
  185. """ensure arguments are applied to the function"""
  186. r = mi.repeatfunc(lambda x: x, 2, 3)
  187. self.assertEqual([3, 3], list(r))
  188. def test_null_times(self):
  189. """repeat 0 should return an empty iterator"""
  190. r = mi.repeatfunc(range, 0, 3)
  191. self.assertRaises(StopIteration, lambda: next(r))
  192. class PairwiseTests(TestCase):
  193. """Tests for ``pairwise()``"""
  194. def test_base_case(self):
  195. """ensure an iterable will return pairwise"""
  196. p = mi.pairwise([1, 2, 3])
  197. self.assertEqual([(1, 2), (2, 3)], list(p))
  198. def test_short_case(self):
  199. """ensure an empty iterator if there's not enough values to pair"""
  200. p = mi.pairwise("a")
  201. self.assertRaises(StopIteration, lambda: next(p))
  202. class GrouperTests(TestCase):
  203. def test_basic(self):
  204. seq = 'ABCDEF'
  205. for n, expected in [
  206. (3, [('A', 'B', 'C'), ('D', 'E', 'F')]),
  207. (4, [('A', 'B', 'C', 'D'), ('E', 'F', None, None)]),
  208. (5, [('A', 'B', 'C', 'D', 'E'), ('F', None, None, None, None)]),
  209. (6, [('A', 'B', 'C', 'D', 'E', 'F')]),
  210. (7, [('A', 'B', 'C', 'D', 'E', 'F', None)]),
  211. ]:
  212. with self.subTest(n=n):
  213. actual = list(mi.grouper(iter(seq), n))
  214. self.assertEqual(actual, expected)
  215. def test_fill(self):
  216. seq = 'ABCDEF'
  217. fillvalue = 'x'
  218. for n, expected in [
  219. (1, ['A', 'B', 'C', 'D', 'E', 'F']),
  220. (2, ['AB', 'CD', 'EF']),
  221. (3, ['ABC', 'DEF']),
  222. (4, ['ABCD', 'EFxx']),
  223. (5, ['ABCDE', 'Fxxxx']),
  224. (6, ['ABCDEF']),
  225. (7, ['ABCDEFx']),
  226. ]:
  227. with self.subTest(n=n):
  228. it = mi.grouper(
  229. iter(seq), n, incomplete='fill', fillvalue=fillvalue
  230. )
  231. actual = [''.join(x) for x in it]
  232. self.assertEqual(actual, expected)
  233. def test_ignore(self):
  234. seq = 'ABCDEF'
  235. for n, expected in [
  236. (1, ['A', 'B', 'C', 'D', 'E', 'F']),
  237. (2, ['AB', 'CD', 'EF']),
  238. (3, ['ABC', 'DEF']),
  239. (4, ['ABCD']),
  240. (5, ['ABCDE']),
  241. (6, ['ABCDEF']),
  242. (7, []),
  243. ]:
  244. with self.subTest(n=n):
  245. it = mi.grouper(iter(seq), n, incomplete='ignore')
  246. actual = [''.join(x) for x in it]
  247. self.assertEqual(actual, expected)
  248. def test_strict(self):
  249. seq = 'ABCDEF'
  250. for n, expected in [
  251. (1, ['A', 'B', 'C', 'D', 'E', 'F']),
  252. (2, ['AB', 'CD', 'EF']),
  253. (3, ['ABC', 'DEF']),
  254. (6, ['ABCDEF']),
  255. ]:
  256. with self.subTest(n=n):
  257. it = mi.grouper(iter(seq), n, incomplete='strict')
  258. actual = [''.join(x) for x in it]
  259. self.assertEqual(actual, expected)
  260. def test_strict_fails(self):
  261. seq = 'ABCDEF'
  262. for n in [4, 5, 7]:
  263. with self.subTest(n=n):
  264. with self.assertRaises(ValueError):
  265. list(mi.grouper(iter(seq), n, incomplete='strict'))
  266. def test_invalid_incomplete(self):
  267. with self.assertRaises(ValueError):
  268. list(mi.grouper('ABCD', 3, incomplete='bogus'))
  269. class RoundrobinTests(TestCase):
  270. """Tests for ``roundrobin()``"""
  271. def test_even_groups(self):
  272. """Ensure ordered output from evenly populated iterables"""
  273. self.assertEqual(
  274. list(mi.roundrobin('ABC', [1, 2, 3], range(3))),
  275. ['A', 1, 0, 'B', 2, 1, 'C', 3, 2],
  276. )
  277. def test_uneven_groups(self):
  278. """Ensure ordered output from unevenly populated iterables"""
  279. self.assertEqual(
  280. list(mi.roundrobin('ABCD', [1, 2], range(0))),
  281. ['A', 1, 'B', 2, 'C', 'D'],
  282. )
  283. class PartitionTests(TestCase):
  284. """Tests for ``partition()``"""
  285. def test_bool(self):
  286. lesser, greater = mi.partition(lambda x: x > 5, range(10))
  287. self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5])
  288. self.assertEqual(list(greater), [6, 7, 8, 9])
  289. def test_arbitrary(self):
  290. divisibles, remainders = mi.partition(lambda x: x % 3, range(10))
  291. self.assertEqual(list(divisibles), [0, 3, 6, 9])
  292. self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
  293. def test_pred_is_none(self):
  294. falses, trues = mi.partition(None, range(3))
  295. self.assertEqual(list(falses), [0])
  296. self.assertEqual(list(trues), [1, 2])
  297. class PowersetTests(TestCase):
  298. """Tests for ``powerset()``"""
  299. def test_combinatorics(self):
  300. """Ensure a proper enumeration"""
  301. p = mi.powerset([1, 2, 3])
  302. self.assertEqual(
  303. list(p), [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
  304. )
  305. class UniqueEverseenTests(TestCase):
  306. """Tests for ``unique_everseen()``"""
  307. def test_everseen(self):
  308. """ensure duplicate elements are ignored"""
  309. u = mi.unique_everseen('AAAABBBBCCDAABBB')
  310. self.assertEqual(['A', 'B', 'C', 'D'], list(u))
  311. def test_custom_key(self):
  312. """ensure the custom key comparison works"""
  313. u = mi.unique_everseen('aAbACCc', key=str.lower)
  314. self.assertEqual(list('abC'), list(u))
  315. def test_unhashable(self):
  316. """ensure things work for unhashable items"""
  317. iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
  318. u = mi.unique_everseen(iterable)
  319. self.assertEqual(list(u), ['a', [1, 2, 3]])
  320. def test_unhashable_key(self):
  321. """ensure things work for unhashable items with a custom key"""
  322. iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
  323. u = mi.unique_everseen(iterable, key=lambda x: x)
  324. self.assertEqual(list(u), ['a', [1, 2, 3]])
  325. class UniqueJustseenTests(TestCase):
  326. """Tests for ``unique_justseen()``"""
  327. def test_justseen(self):
  328. """ensure only last item is remembered"""
  329. u = mi.unique_justseen('AAAABBBCCDABB')
  330. self.assertEqual(list('ABCDAB'), list(u))
  331. def test_custom_key(self):
  332. """ensure the custom key comparison works"""
  333. u = mi.unique_justseen('AABCcAD', str.lower)
  334. self.assertEqual(list('ABCAD'), list(u))
  335. class IterExceptTests(TestCase):
  336. """Tests for ``iter_except()``"""
  337. def test_exact_exception(self):
  338. """ensure the exact specified exception is caught"""
  339. l = [1, 2, 3]
  340. i = mi.iter_except(l.pop, IndexError)
  341. self.assertEqual(list(i), [3, 2, 1])
  342. def test_generic_exception(self):
  343. """ensure the generic exception can be caught"""
  344. l = [1, 2]
  345. i = mi.iter_except(l.pop, Exception)
  346. self.assertEqual(list(i), [2, 1])
  347. def test_uncaught_exception_is_raised(self):
  348. """ensure a non-specified exception is raised"""
  349. l = [1, 2, 3]
  350. i = mi.iter_except(l.pop, KeyError)
  351. self.assertRaises(IndexError, lambda: list(i))
  352. def test_first(self):
  353. """ensure first is run before the function"""
  354. l = [1, 2, 3]
  355. f = lambda: 25
  356. i = mi.iter_except(l.pop, IndexError, f)
  357. self.assertEqual(list(i), [25, 3, 2, 1])
  358. def test_multiple(self):
  359. """ensure can catch multiple exceptions"""
  360. class Fiz(Exception):
  361. pass
  362. class Buzz(Exception):
  363. pass
  364. i = 0
  365. def fizbuzz():
  366. nonlocal i
  367. i += 1
  368. if i % 3 == 0:
  369. raise Fiz
  370. if i % 5 == 0:
  371. raise Buzz
  372. return i
  373. expected = ([1, 2], [4], [], [7, 8], [])
  374. for x in expected:
  375. self.assertEqual(list(mi.iter_except(fizbuzz, (Fiz, Buzz))), x)
  376. class FirstTrueTests(TestCase):
  377. """Tests for ``first_true()``"""
  378. def test_something_true(self):
  379. """Test with no keywords"""
  380. self.assertEqual(mi.first_true(range(10)), 1)
  381. def test_nothing_true(self):
  382. """Test default return value."""
  383. self.assertIsNone(mi.first_true([0, 0, 0]))
  384. def test_default(self):
  385. """Test with a default keyword"""
  386. self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
  387. def test_pred(self):
  388. """Test with a custom predicate"""
  389. self.assertEqual(
  390. mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6
  391. )
  392. class RandomProductTests(TestCase):
  393. """Tests for ``random_product()``
  394. Since random.choice() has different results with the same seed across
  395. python versions 2.x and 3.x, these tests use highly probably events to
  396. create predictable outcomes across platforms.
  397. """
  398. def test_simple_lists(self):
  399. """Ensure that one item is chosen from each list in each pair.
  400. Also ensure that each item from each list eventually appears in
  401. the chosen combinations.
  402. Odds are roughly 1 in 7.1 * 10e16 that one item from either list will
  403. not be chosen after 100 samplings of one item from each list. Just to
  404. be safe, better use a known random seed, too.
  405. """
  406. nums = [1, 2, 3]
  407. lets = ['a', 'b', 'c']
  408. n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)])
  409. n, m = set(n), set(m)
  410. self.assertEqual(n, set(nums))
  411. self.assertEqual(m, set(lets))
  412. self.assertEqual(len(n), len(nums))
  413. self.assertEqual(len(m), len(lets))
  414. def test_list_with_repeat(self):
  415. """ensure multiple items are chosen, and that they appear to be chosen
  416. from one list then the next, in proper order.
  417. """
  418. nums = [1, 2, 3]
  419. lets = ['a', 'b', 'c']
  420. r = list(mi.random_product(nums, lets, repeat=100))
  421. self.assertEqual(2 * 100, len(r))
  422. n, m = set(r[::2]), set(r[1::2])
  423. self.assertEqual(n, set(nums))
  424. self.assertEqual(m, set(lets))
  425. self.assertEqual(len(n), len(nums))
  426. self.assertEqual(len(m), len(lets))
  427. class RandomPermutationTests(TestCase):
  428. """Tests for ``random_permutation()``"""
  429. def test_full_permutation(self):
  430. """ensure every item from the iterable is returned in a new ordering
  431. 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so
  432. we fix a seed value just to be sure.
  433. """
  434. i = range(15)
  435. r = mi.random_permutation(i)
  436. self.assertEqual(set(i), set(r))
  437. if i == r:
  438. raise AssertionError("Values were not permuted")
  439. def test_partial_permutation(self):
  440. """ensure all returned items are from the iterable, that the returned
  441. permutation is of the desired length, and that all items eventually
  442. get returned.
  443. Sampling 100 permutations of length 5 from a set of 15 leaves a
  444. (2/3)^100 chance that an item will not be chosen. Multiplied by 15
  445. items, there is a 1 in 2.6e16 chance that at least 1 item will not
  446. show up in the resulting output. Using a random seed will fix that.
  447. """
  448. items = range(15)
  449. item_set = set(items)
  450. all_items = set()
  451. for _ in range(100):
  452. permutation = mi.random_permutation(items, 5)
  453. self.assertEqual(len(permutation), 5)
  454. permutation_set = set(permutation)
  455. self.assertLessEqual(permutation_set, item_set)
  456. all_items |= permutation_set
  457. self.assertEqual(all_items, item_set)
  458. class RandomCombinationTests(TestCase):
  459. """Tests for ``random_combination()``"""
  460. def test_pseudorandomness(self):
  461. """ensure different subsets of the iterable get returned over many
  462. samplings of random combinations"""
  463. items = range(15)
  464. all_items = set()
  465. for _ in range(50):
  466. combination = mi.random_combination(items, 5)
  467. all_items |= set(combination)
  468. self.assertEqual(all_items, set(items))
  469. def test_no_replacement(self):
  470. """ensure that elements are sampled without replacement"""
  471. items = range(15)
  472. for _ in range(50):
  473. combination = mi.random_combination(items, len(items))
  474. self.assertEqual(len(combination), len(set(combination)))
  475. self.assertRaises(
  476. ValueError, lambda: mi.random_combination(items, len(items) + 1)
  477. )
  478. class RandomCombinationWithReplacementTests(TestCase):
  479. """Tests for ``random_combination_with_replacement()``"""
  480. def test_replacement(self):
  481. """ensure that elements are sampled with replacement"""
  482. items = range(5)
  483. combo = mi.random_combination_with_replacement(items, len(items) * 2)
  484. self.assertEqual(2 * len(items), len(combo))
  485. if len(set(combo)) == len(combo):
  486. raise AssertionError("Combination contained no duplicates")
  487. def test_pseudorandomness(self):
  488. """ensure different subsets of the iterable get returned over many
  489. samplings of random combinations"""
  490. items = range(15)
  491. all_items = set()
  492. for _ in range(50):
  493. combination = mi.random_combination_with_replacement(items, 5)
  494. all_items |= set(combination)
  495. self.assertEqual(all_items, set(items))
  496. class NthCombinationTests(TestCase):
  497. def test_basic(self):
  498. iterable = 'abcdefg'
  499. r = 4
  500. for index, expected in enumerate(combinations(iterable, r)):
  501. actual = mi.nth_combination(iterable, r, index)
  502. self.assertEqual(actual, expected)
  503. def test_long(self):
  504. actual = mi.nth_combination(range(180), 4, 2000000)
  505. expected = (2, 12, 35, 126)
  506. self.assertEqual(actual, expected)
  507. def test_invalid_r(self):
  508. for r in (-1, 3):
  509. with self.assertRaises(ValueError):
  510. mi.nth_combination([], r, 0)
  511. def test_invalid_index(self):
  512. with self.assertRaises(IndexError):
  513. mi.nth_combination('abcdefg', 3, -36)
  514. class NthPermutationTests(TestCase):
  515. def test_r_less_than_n(self):
  516. iterable = 'abcde'
  517. r = 4
  518. for index, expected in enumerate(permutations(iterable, r)):
  519. actual = mi.nth_permutation(iterable, r, index)
  520. self.assertEqual(actual, expected)
  521. def test_r_equal_to_n(self):
  522. iterable = 'abcde'
  523. for index, expected in enumerate(permutations(iterable)):
  524. actual = mi.nth_permutation(iterable, None, index)
  525. self.assertEqual(actual, expected)
  526. def test_long(self):
  527. iterable = tuple(range(180))
  528. r = 4
  529. index = 1000000
  530. actual = mi.nth_permutation(iterable, r, index)
  531. expected = mi.nth(permutations(iterable, r), index)
  532. self.assertEqual(actual, expected)
  533. def test_null(self):
  534. actual = mi.nth_permutation([], 0, 0)
  535. expected = tuple()
  536. self.assertEqual(actual, expected)
  537. def test_negative_index(self):
  538. iterable = 'abcde'
  539. r = 4
  540. n = factorial(len(iterable)) // factorial(len(iterable) - r)
  541. for index, expected in enumerate(permutations(iterable, r)):
  542. actual = mi.nth_permutation(iterable, r, index - n)
  543. self.assertEqual(actual, expected)
  544. def test_invalid_index(self):
  545. iterable = 'abcde'
  546. r = 4
  547. n = factorial(len(iterable)) // factorial(len(iterable) - r)
  548. for index in [-1 - n, n + 1]:
  549. with self.assertRaises(IndexError):
  550. mi.nth_combination(iterable, r, index)
  551. def test_invalid_r(self):
  552. iterable = 'abcde'
  553. r = 4
  554. n = factorial(len(iterable)) // factorial(len(iterable) - r)
  555. for r in [-1, n + 1]:
  556. with self.assertRaises(ValueError):
  557. mi.nth_combination(iterable, r, 0)
  558. class PrependTests(TestCase):
  559. def test_basic(self):
  560. value = 'a'
  561. iterator = iter('bcdefg')
  562. actual = list(mi.prepend(value, iterator))
  563. expected = list('abcdefg')
  564. self.assertEqual(actual, expected)
  565. def test_multiple(self):
  566. value = 'ab'
  567. iterator = iter('cdefg')
  568. actual = tuple(mi.prepend(value, iterator))
  569. expected = ('ab',) + tuple('cdefg')
  570. self.assertEqual(actual, expected)
  571. class Convolvetests(TestCase):
  572. def test_moving_average(self):
  573. signal = iter([10, 20, 30, 40, 50])
  574. kernel = [0.5, 0.5]
  575. actual = list(mi.convolve(signal, kernel))
  576. expected = [
  577. (10 + 0) / 2,
  578. (20 + 10) / 2,
  579. (30 + 20) / 2,
  580. (40 + 30) / 2,
  581. (50 + 40) / 2,
  582. (0 + 50) / 2,
  583. ]
  584. self.assertEqual(actual, expected)
  585. def test_derivative(self):
  586. signal = iter([10, 20, 30, 40, 50])
  587. kernel = [1, -1]
  588. actual = list(mi.convolve(signal, kernel))
  589. expected = [10 - 0, 20 - 10, 30 - 20, 40 - 30, 50 - 40, 0 - 50]
  590. self.assertEqual(actual, expected)
  591. def test_infinite_signal(self):
  592. signal = count()
  593. kernel = [1, -1]
  594. actual = mi.take(5, mi.convolve(signal, kernel))
  595. expected = [0, 1, 1, 1, 1]
  596. self.assertEqual(actual, expected)
  597. class BeforeAndAfterTests(TestCase):
  598. def test_empty(self):
  599. before, after = mi.before_and_after(bool, [])
  600. self.assertEqual(list(before), [])
  601. self.assertEqual(list(after), [])
  602. def test_never_true(self):
  603. before, after = mi.before_and_after(bool, [0, False, None, ''])
  604. self.assertEqual(list(before), [])
  605. self.assertEqual(list(after), [0, False, None, ''])
  606. def test_never_false(self):
  607. before, after = mi.before_and_after(bool, [1, True, Ellipsis, ' '])
  608. self.assertEqual(list(before), [1, True, Ellipsis, ' '])
  609. self.assertEqual(list(after), [])
  610. def test_some_true(self):
  611. before, after = mi.before_and_after(bool, [1, True, 0, False])
  612. self.assertEqual(list(before), [1, True])
  613. self.assertEqual(list(after), [0, False])
  614. @staticmethod
  615. def _group_events(events):
  616. events = iter(events)
  617. while True:
  618. try:
  619. operation = next(events)
  620. except StopIteration:
  621. break
  622. assert operation in ["SUM", "MULTIPLY"]
  623. # Here, the remainder `events` is passed into `before_and_after`
  624. # again, which would be problematic if the remainder is a
  625. # generator function (as in Python 3.10 itertools recipes), since
  626. # that creates recursion. `itertools.chain` solves this problem.
  627. numbers, events = mi.before_and_after(
  628. lambda e: isinstance(e, int), events
  629. )
  630. yield (operation, numbers)
  631. def test_nested_remainder(self):
  632. events = ["SUM", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 1000
  633. events += ["MULTIPLY", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 1000
  634. for operation, numbers in self._group_events(events):
  635. if operation == "SUM":
  636. res = sum(numbers)
  637. self.assertEqual(res, 55)
  638. elif operation == "MULTIPLY":
  639. res = reduce(lambda a, b: a * b, numbers)
  640. self.assertEqual(res, 3628800)
  641. class TriplewiseTests(TestCase):
  642. def test_basic(self):
  643. for iterable, expected in [
  644. ([0], []),
  645. ([0, 1], []),
  646. ([0, 1, 2], [(0, 1, 2)]),
  647. ([0, 1, 2, 3], [(0, 1, 2), (1, 2, 3)]),
  648. ([0, 1, 2, 3, 4], [(0, 1, 2), (1, 2, 3), (2, 3, 4)]),
  649. ]:
  650. with self.subTest(expected=expected):
  651. actual = list(mi.triplewise(iterable))
  652. self.assertEqual(actual, expected)
  653. class SlidingWindowTests(TestCase):
  654. def test_basic(self):
  655. for iterable, n, expected in [
  656. ([], 1, []),
  657. ([0], 1, [(0,)]),
  658. ([0, 1], 1, [(0,), (1,)]),
  659. ([0, 1, 2], 2, [(0, 1), (1, 2)]),
  660. ([0, 1, 2], 3, [(0, 1, 2)]),
  661. ([0, 1, 2], 4, []),
  662. ([0, 1, 2, 3], 4, [(0, 1, 2, 3)]),
  663. ([0, 1, 2, 3, 4], 4, [(0, 1, 2, 3), (1, 2, 3, 4)]),
  664. ]:
  665. with self.subTest(expected=expected):
  666. actual = list(mi.sliding_window(iterable, n))
  667. self.assertEqual(actual, expected)
  668. class SubslicesTests(TestCase):
  669. def test_basic(self):
  670. for iterable, expected in [
  671. ([], []),
  672. ([1], [[1]]),
  673. ([1, 2], [[1], [1, 2], [2]]),
  674. (iter([1, 2]), [[1], [1, 2], [2]]),
  675. ([2, 1], [[2], [2, 1], [1]]),
  676. (
  677. 'ABCD',
  678. [
  679. ['A'],
  680. ['A', 'B'],
  681. ['A', 'B', 'C'],
  682. ['A', 'B', 'C', 'D'],
  683. ['B'],
  684. ['B', 'C'],
  685. ['B', 'C', 'D'],
  686. ['C'],
  687. ['C', 'D'],
  688. ['D'],
  689. ],
  690. ),
  691. ]:
  692. with self.subTest(expected=expected):
  693. actual = list(mi.subslices(iterable))
  694. self.assertEqual(actual, expected)
  695. class PolynomialFromRootsTests(TestCase):
  696. def test_basic(self):
  697. for roots, expected in [
  698. ((2, 1, -1), [1, -2, -1, 2]),
  699. ((2, 3), [1, -5, 6]),
  700. ((1, 2, 3), [1, -6, 11, -6]),
  701. ((2, 4, 1), [1, -7, 14, -8]),
  702. ]:
  703. with self.subTest(roots=roots):
  704. actual = mi.polynomial_from_roots(roots)
  705. self.assertEqual(actual, expected)
  706. class PolynomialEvalTests(TestCase):
  707. def test_basic(self):
  708. for coefficients, x, expected in [
  709. ([1, -4, -17, 60], 2, 18),
  710. ([1, -4, -17, 60], 2.5, 8.125),
  711. ([1, -4, -17, 60], Fraction(2, 3), Fraction(1274, 27)),
  712. ([1, -4, -17, 60], Decimal('1.75'), Decimal('23.359375')),
  713. ([], 2, 0),
  714. ([], 2.5, 0.0),
  715. ([], Fraction(2, 3), Fraction(0, 1)),
  716. ([], Decimal('1.75'), Decimal('0.00')),
  717. ([11], 7, 11),
  718. ([11, 2], 7, 79),
  719. ]:
  720. with self.subTest(x=x):
  721. actual = mi.polynomial_eval(coefficients, x)
  722. self.assertEqual(actual, expected)
  723. self.assertEqual(type(actual), type(x))
  724. class IterIndexTests(TestCase):
  725. def test_basic(self):
  726. iterable = 'AABCADEAF'
  727. for wrapper in (list, iter):
  728. with self.subTest(wrapper=wrapper):
  729. actual = list(mi.iter_index(wrapper(iterable), 'A'))
  730. expected = [0, 1, 4, 7]
  731. self.assertEqual(actual, expected)
  732. def test_start(self):
  733. for wrapper in (list, iter):
  734. with self.subTest(wrapper=wrapper):
  735. iterable = 'AABCADEAF'
  736. i = -1
  737. actual = []
  738. while True:
  739. try:
  740. i = next(
  741. mi.iter_index(wrapper(iterable), 'A', start=i + 1)
  742. )
  743. except StopIteration:
  744. break
  745. else:
  746. actual.append(i)
  747. expected = [0, 1, 4, 7]
  748. self.assertEqual(actual, expected)
  749. def test_stop(self):
  750. actual = list(mi.iter_index('AABCADEAF', 'A', stop=7))
  751. expected = [0, 1, 4]
  752. self.assertEqual(actual, expected)
  753. class SieveTests(TestCase):
  754. def test_basic(self):
  755. self.assertEqual(
  756. list(mi.sieve(67)),
  757. [
  758. 2,
  759. 3,
  760. 5,
  761. 7,
  762. 11,
  763. 13,
  764. 17,
  765. 19,
  766. 23,
  767. 29,
  768. 31,
  769. 37,
  770. 41,
  771. 43,
  772. 47,
  773. 53,
  774. 59,
  775. 61,
  776. ],
  777. )
  778. self.assertEqual(list(mi.sieve(68))[-1], 67)
  779. def test_prime_counts(self):
  780. for n, expected in (
  781. (100, 25),
  782. (1_000, 168),
  783. (10_000, 1229),
  784. (100_000, 9592),
  785. (1_000_000, 78498),
  786. ):
  787. with self.subTest(n=n):
  788. self.assertEqual(mi.ilen(mi.sieve(n)), expected)
  789. def test_small_numbers(self):
  790. with self.assertRaises(ValueError):
  791. list(mi.sieve(-1))
  792. for n in (0, 1, 2):
  793. with self.subTest(n=n):
  794. self.assertEqual(list(mi.sieve(n)), [])
  795. class BatchedTests(TestCase):
  796. def test_basic(self):
  797. iterable = range(1, 5 + 1)
  798. for n, expected in (
  799. (1, [(1,), (2,), (3,), (4,), (5,)]),
  800. (2, [(1, 2), (3, 4), (5,)]),
  801. (3, [(1, 2, 3), (4, 5)]),
  802. (4, [(1, 2, 3, 4), (5,)]),
  803. (5, [(1, 2, 3, 4, 5)]),
  804. (6, [(1, 2, 3, 4, 5)]),
  805. ):
  806. with self.subTest(n=n):
  807. actual = list(mi.batched(iterable, n))
  808. self.assertEqual(actual, expected)
  809. def test_strict(self):
  810. with self.assertRaises(ValueError):
  811. list(mi.batched('ABCDEFG', 3, strict=True))
  812. self.assertEqual(
  813. list(mi.batched('ABCDEF', 3, strict=True)),
  814. [('A', 'B', 'C'), ('D', 'E', 'F')],
  815. )
  816. class TransposeTests(TestCase):
  817. def test_empty(self):
  818. it = []
  819. actual = list(mi.transpose(it))
  820. expected = []
  821. self.assertEqual(actual, expected)
  822. def test_basic(self):
  823. it = [(10, 11, 12), (20, 21, 22), (30, 31, 32)]
  824. actual = list(mi.transpose(it))
  825. expected = [(10, 20, 30), (11, 21, 31), (12, 22, 32)]
  826. self.assertEqual(actual, expected)
  827. @skipIf(version_info[:2] < (3, 10), 'strict=True missing on 3.9')
  828. def test_incompatible_error(self):
  829. it = [(10, 11, 12, 13), (20, 21, 22), (30, 31, 32)]
  830. with self.assertRaises(ValueError):
  831. list(mi.transpose(it))
  832. @skipIf(version_info[:2] >= (3, 9), 'strict=True missing on 3.9')
  833. def test_incompatible_allow(self):
  834. it = [(10, 11, 12, 13), (20, 21, 22), (30, 31, 32)]
  835. actual = list(mi.transpose(it))
  836. expected = [(10, 20, 30), (11, 21, 31), (12, 22, 32)]
  837. self.assertEqual(actual, expected)
  838. class ReshapeTests(TestCase):
  839. def test_empty(self):
  840. actual = list(mi.reshape([], 3))
  841. self.assertEqual(actual, [])
  842. def test_zero(self):
  843. matrix = [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
  844. with self.assertRaises(ValueError):
  845. list(mi.reshape(matrix, 0))
  846. def test_basic(self):
  847. matrix = [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
  848. for cols, expected in (
  849. (
  850. 1,
  851. [
  852. (0,),
  853. (1,),
  854. (2,),
  855. (3,),
  856. (4,),
  857. (5,),
  858. (6,),
  859. (7,),
  860. (8,),
  861. (9,),
  862. (10,),
  863. (11,),
  864. ],
  865. ),
  866. (2, [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11)]),
  867. (3, [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)]),
  868. (4, [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]),
  869. (6, [(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)]),
  870. (12, [(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)]),
  871. ):
  872. with self.subTest(cols=cols):
  873. actual = list(mi.reshape(matrix, cols))
  874. self.assertEqual(actual, expected)
  875. class MatMulTests(TestCase):
  876. def test_n_by_n(self):
  877. actual = list(mi.matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]))
  878. expected = [(49, 80), (41, 60)]
  879. self.assertEqual(actual, expected)
  880. def test_m_by_n(self):
  881. m1 = [[2, 5], [7, 9], [3, 4]]
  882. m2 = [[7, 11, 5, 4, 9], [3, 5, 2, 6, 3]]
  883. actual = list(mi.matmul(m1, m2))
  884. expected = [
  885. (29, 47, 20, 38, 33),
  886. (76, 122, 53, 82, 90),
  887. (33, 53, 23, 36, 39),
  888. ]
  889. self.assertEqual(actual, expected)
  890. class FactorTests(TestCase):
  891. def test_basic(self):
  892. for n, expected in (
  893. (0, []),
  894. (1, []),
  895. (2, [2]),
  896. (3, [3]),
  897. (4, [2, 2]),
  898. (6, [2, 3]),
  899. (360, [2, 2, 2, 3, 3, 5]),
  900. (128_884_753_939, [128_884_753_939]),
  901. (999953 * 999983, [999953, 999983]),
  902. (909_909_090_909, [3, 3, 7, 13, 13, 751, 113797]),
  903. ):
  904. with self.subTest(n=n):
  905. actual = list(mi.factor(n))
  906. self.assertEqual(actual, expected)
  907. def test_cross_check(self):
  908. prod = lambda x: reduce(mul, x, 1)
  909. self.assertTrue(all(prod(mi.factor(n)) == n for n in range(1, 2000)))
  910. self.assertTrue(
  911. all(set(mi.factor(n)) <= set(mi.sieve(n + 1)) for n in range(2000))
  912. )
  913. self.assertTrue(
  914. all(
  915. list(mi.factor(n)) == sorted(mi.factor(n)) for n in range(2000)
  916. )
  917. )
  918. class SumOfSquaresTests(TestCase):
  919. def test_basic(self):
  920. for it, expected in (
  921. ([], 0),
  922. ([1, 2, 3], 1 + 4 + 9),
  923. ([2, 4, 6, 8], 4 + 16 + 36 + 64),
  924. ):
  925. with self.subTest(it=it):
  926. actual = mi.sum_of_squares(it)
  927. self.assertEqual(actual, expected)
  928. class PolynomialDerivativeTests(TestCase):
  929. def test_basic(self):
  930. for coefficients, expected in [
  931. ([], []),
  932. ([1], []),
  933. ([1, 2], [1]),
  934. ([1, 2, 3], [2, 2]),
  935. ([1, 2, 3, 4], [3, 4, 3]),
  936. ([1.1, 2, 3, 4], [(1.1 * 3), 4, 3]),
  937. ]:
  938. with self.subTest(coefficients=coefficients):
  939. actual = mi.polynomial_derivative(coefficients)
  940. self.assertEqual(actual, expected)
  941. class TotientTests(TestCase):
  942. def test_basic(self):
  943. for n, expected in (
  944. (1, 1),
  945. (2, 1),
  946. (3, 2),
  947. (4, 2),
  948. (9, 6),
  949. (12, 4),
  950. (128_884_753_939, 128_884_753_938),
  951. (999953 * 999983, 999952 * 999982),
  952. (6**20, 1 * 2**19 * 2 * 3**19),
  953. ):
  954. with self.subTest(n=n):
  955. self.assertEqual(mi.totient(n), expected)