recipes.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. """Imported from the recipes section of the itertools documentation.
  2. All functions taken from the recipes section of the itertools library docs
  3. [1]_.
  4. Some backward-compatible usability improvements have been made.
  5. .. [1] http://docs.python.org/library/itertools.html#recipes
  6. """
  7. from collections import deque
  8. from itertools import (
  9. chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee
  10. )
  11. import operator
  12. from random import randrange, sample, choice
  13. from six import PY2
  14. from six.moves import filter, filterfalse, map, range, zip, zip_longest
  15. __all__ = [
  16. 'accumulate',
  17. 'all_equal',
  18. 'consume',
  19. 'dotproduct',
  20. 'first_true',
  21. 'flatten',
  22. 'grouper',
  23. 'iter_except',
  24. 'ncycles',
  25. 'nth',
  26. 'nth_combination',
  27. 'padnone',
  28. 'pairwise',
  29. 'partition',
  30. 'powerset',
  31. 'prepend',
  32. 'quantify',
  33. 'random_combination_with_replacement',
  34. 'random_combination',
  35. 'random_permutation',
  36. 'random_product',
  37. 'repeatfunc',
  38. 'roundrobin',
  39. 'tabulate',
  40. 'tail',
  41. 'take',
  42. 'unique_everseen',
  43. 'unique_justseen',
  44. ]
  45. def accumulate(iterable, func=operator.add):
  46. """
  47. Return an iterator whose items are the accumulated results of a function
  48. (specified by the optional *func* argument) that takes two arguments.
  49. By default, returns accumulated sums with :func:`operator.add`.
  50. >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum
  51. [1, 3, 6, 10, 15]
  52. >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product
  53. [1, 2, 6]
  54. >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum
  55. [0, 1, 1, 2, 3, 3]
  56. This function is available in the ``itertools`` module for Python 3.2 and
  57. greater.
  58. """
  59. it = iter(iterable)
  60. try:
  61. total = next(it)
  62. except StopIteration:
  63. return
  64. else:
  65. yield total
  66. for element in it:
  67. total = func(total, element)
  68. yield total
  69. def take(n, iterable):
  70. """Return first *n* items of the iterable as a list.
  71. >>> take(3, range(10))
  72. [0, 1, 2]
  73. >>> take(5, range(3))
  74. [0, 1, 2]
  75. Effectively a short replacement for ``next`` based iterator consumption
  76. when you want more than one item, but less than the whole iterator.
  77. """
  78. return list(islice(iterable, n))
  79. def tabulate(function, start=0):
  80. """Return an iterator over the results of ``func(start)``,
  81. ``func(start + 1)``, ``func(start + 2)``...
  82. *func* should be a function that accepts one integer argument.
  83. If *start* is not specified it defaults to 0. It will be incremented each
  84. time the iterator is advanced.
  85. >>> square = lambda x: x ** 2
  86. >>> iterator = tabulate(square, -3)
  87. >>> take(4, iterator)
  88. [9, 4, 1, 0]
  89. """
  90. return map(function, count(start))
  91. def tail(n, iterable):
  92. """Return an iterator over the last *n* items of *iterable*.
  93. >>> t = tail(3, 'ABCDEFG')
  94. >>> list(t)
  95. ['E', 'F', 'G']
  96. """
  97. return iter(deque(iterable, maxlen=n))
  98. def consume(iterator, n=None):
  99. """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
  100. entirely.
  101. Efficiently exhausts an iterator without returning values. Defaults to
  102. consuming the whole iterator, but an optional second argument may be
  103. provided to limit consumption.
  104. >>> i = (x for x in range(10))
  105. >>> next(i)
  106. 0
  107. >>> consume(i, 3)
  108. >>> next(i)
  109. 4
  110. >>> consume(i)
  111. >>> next(i)
  112. Traceback (most recent call last):
  113. File "<stdin>", line 1, in <module>
  114. StopIteration
  115. If the iterator has fewer items remaining than the provided limit, the
  116. whole iterator will be consumed.
  117. >>> i = (x for x in range(3))
  118. >>> consume(i, 5)
  119. >>> next(i)
  120. Traceback (most recent call last):
  121. File "<stdin>", line 1, in <module>
  122. StopIteration
  123. """
  124. # Use functions that consume iterators at C speed.
  125. if n is None:
  126. # feed the entire iterator into a zero-length deque
  127. deque(iterator, maxlen=0)
  128. else:
  129. # advance to the empty slice starting at position n
  130. next(islice(iterator, n, n), None)
  131. def nth(iterable, n, default=None):
  132. """Returns the nth item or a default value.
  133. >>> l = range(10)
  134. >>> nth(l, 3)
  135. 3
  136. >>> nth(l, 20, "zebra")
  137. 'zebra'
  138. """
  139. return next(islice(iterable, n, None), default)
  140. def all_equal(iterable):
  141. """
  142. Returns ``True`` if all the elements are equal to each other.
  143. >>> all_equal('aaaa')
  144. True
  145. >>> all_equal('aaab')
  146. False
  147. """
  148. g = groupby(iterable)
  149. return next(g, True) and not next(g, False)
  150. def quantify(iterable, pred=bool):
  151. """Return the how many times the predicate is true.
  152. >>> quantify([True, False, True])
  153. 2
  154. """
  155. return sum(map(pred, iterable))
  156. def padnone(iterable):
  157. """Returns the sequence of elements and then returns ``None`` indefinitely.
  158. >>> take(5, padnone(range(3)))
  159. [0, 1, 2, None, None]
  160. Useful for emulating the behavior of the built-in :func:`map` function.
  161. See also :func:`padded`.
  162. """
  163. return chain(iterable, repeat(None))
  164. def ncycles(iterable, n):
  165. """Returns the sequence elements *n* times
  166. >>> list(ncycles(["a", "b"], 3))
  167. ['a', 'b', 'a', 'b', 'a', 'b']
  168. """
  169. return chain.from_iterable(repeat(tuple(iterable), n))
  170. def dotproduct(vec1, vec2):
  171. """Returns the dot product of the two iterables.
  172. >>> dotproduct([10, 10], [20, 20])
  173. 400
  174. """
  175. return sum(map(operator.mul, vec1, vec2))
  176. def flatten(listOfLists):
  177. """Return an iterator flattening one level of nesting in a list of lists.
  178. >>> list(flatten([[0, 1], [2, 3]]))
  179. [0, 1, 2, 3]
  180. See also :func:`collapse`, which can flatten multiple levels of nesting.
  181. """
  182. return chain.from_iterable(listOfLists)
  183. def repeatfunc(func, times=None, *args):
  184. """Call *func* with *args* repeatedly, returning an iterable over the
  185. results.
  186. If *times* is specified, the iterable will terminate after that many
  187. repetitions:
  188. >>> from operator import add
  189. >>> times = 4
  190. >>> args = 3, 5
  191. >>> list(repeatfunc(add, times, *args))
  192. [8, 8, 8, 8]
  193. If *times* is ``None`` the iterable will not terminate:
  194. >>> from random import randrange
  195. >>> times = None
  196. >>> args = 1, 11
  197. >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
  198. [2, 4, 8, 1, 8, 4]
  199. """
  200. if times is None:
  201. return starmap(func, repeat(args))
  202. return starmap(func, repeat(args, times))
  203. def pairwise(iterable):
  204. """Returns an iterator of paired items, overlapping, from the original
  205. >>> take(4, pairwise(count()))
  206. [(0, 1), (1, 2), (2, 3), (3, 4)]
  207. """
  208. a, b = tee(iterable)
  209. next(b, None)
  210. return zip(a, b)
  211. def grouper(n, iterable, fillvalue=None):
  212. """Collect data into fixed-length chunks or blocks.
  213. >>> list(grouper(3, 'ABCDEFG', 'x'))
  214. [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
  215. """
  216. args = [iter(iterable)] * n
  217. return zip_longest(fillvalue=fillvalue, *args)
  218. def roundrobin(*iterables):
  219. """Yields an item from each iterable, alternating between them.
  220. >>> list(roundrobin('ABC', 'D', 'EF'))
  221. ['A', 'D', 'E', 'B', 'F', 'C']
  222. This function produces the same output as :func:`interleave_longest`, but
  223. may perform better for some inputs (in particular when the number of
  224. iterables is small).
  225. """
  226. # Recipe credited to George Sakkis
  227. pending = len(iterables)
  228. if PY2:
  229. nexts = cycle(iter(it).next for it in iterables)
  230. else:
  231. nexts = cycle(iter(it).__next__ for it in iterables)
  232. while pending:
  233. try:
  234. for next in nexts:
  235. yield next()
  236. except StopIteration:
  237. pending -= 1
  238. nexts = cycle(islice(nexts, pending))
  239. def partition(pred, iterable):
  240. """
  241. Returns a 2-tuple of iterables derived from the input iterable.
  242. The first yields the items that have ``pred(item) == False``.
  243. The second yields the items that have ``pred(item) == True``.
  244. >>> is_odd = lambda x: x % 2 != 0
  245. >>> iterable = range(10)
  246. >>> even_items, odd_items = partition(is_odd, iterable)
  247. >>> list(even_items), list(odd_items)
  248. ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
  249. """
  250. # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
  251. t1, t2 = tee(iterable)
  252. return filterfalse(pred, t1), filter(pred, t2)
  253. def powerset(iterable):
  254. """Yields all possible subsets of the iterable.
  255. >>> list(powerset([1, 2, 3]))
  256. [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
  257. :func:`powerset` will operate on iterables that aren't :class:`set`
  258. instances, so repeated elements in the input will produce repeated elements
  259. in the output. Use :func:`unique_everseen` on the input to avoid generating
  260. duplicates:
  261. >>> seq = [1, 1, 0]
  262. >>> list(powerset(seq))
  263. [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
  264. >>> from more_itertools import unique_everseen
  265. >>> list(powerset(unique_everseen(seq)))
  266. [(), (1,), (0,), (1, 0)]
  267. """
  268. s = list(iterable)
  269. return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
  270. def unique_everseen(iterable, key=None):
  271. """
  272. Yield unique elements, preserving order.
  273. >>> list(unique_everseen('AAAABBBCCDAABBB'))
  274. ['A', 'B', 'C', 'D']
  275. >>> list(unique_everseen('ABBCcAD', str.lower))
  276. ['A', 'B', 'C', 'D']
  277. Sequences with a mix of hashable and unhashable items can be used.
  278. The function will be slower (i.e., `O(n^2)`) for unhashable items.
  279. """
  280. seenset = set()
  281. seenset_add = seenset.add
  282. seenlist = []
  283. seenlist_add = seenlist.append
  284. if key is None:
  285. for element in iterable:
  286. try:
  287. if element not in seenset:
  288. seenset_add(element)
  289. yield element
  290. except TypeError:
  291. if element not in seenlist:
  292. seenlist_add(element)
  293. yield element
  294. else:
  295. for element in iterable:
  296. k = key(element)
  297. try:
  298. if k not in seenset:
  299. seenset_add(k)
  300. yield element
  301. except TypeError:
  302. if k not in seenlist:
  303. seenlist_add(k)
  304. yield element
  305. def unique_justseen(iterable, key=None):
  306. """Yields elements in order, ignoring serial duplicates
  307. >>> list(unique_justseen('AAAABBBCCDAABBB'))
  308. ['A', 'B', 'C', 'D', 'A', 'B']
  309. >>> list(unique_justseen('ABBCcAD', str.lower))
  310. ['A', 'B', 'C', 'A', 'D']
  311. """
  312. return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
  313. def iter_except(func, exception, first=None):
  314. """Yields results from a function repeatedly until an exception is raised.
  315. Converts a call-until-exception interface to an iterator interface.
  316. Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
  317. to end the loop.
  318. >>> l = [0, 1, 2]
  319. >>> list(iter_except(l.pop, IndexError))
  320. [2, 1, 0]
  321. """
  322. try:
  323. if first is not None:
  324. yield first()
  325. while 1:
  326. yield func()
  327. except exception:
  328. pass
  329. def first_true(iterable, default=None, pred=None):
  330. """
  331. Returns the first true value in the iterable.
  332. If no true value is found, returns *default*
  333. If *pred* is not None, returns the first item for which
  334. ``pred(item) == True`` .
  335. >>> first_true(range(10))
  336. 1
  337. >>> first_true(range(10), pred=lambda x: x > 5)
  338. 6
  339. >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
  340. 'missing'
  341. """
  342. return next(filter(pred, iterable), default)
  343. def random_product(*args, **kwds):
  344. """Draw an item at random from each of the input iterables.
  345. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
  346. ('c', 3, 'Z')
  347. If *repeat* is provided as a keyword argument, that many items will be
  348. drawn from each iterable.
  349. >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
  350. ('a', 2, 'd', 3)
  351. This equivalent to taking a random selection from
  352. ``itertools.product(*args, **kwarg)``.
  353. """
  354. pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1)
  355. return tuple(choice(pool) for pool in pools)
  356. def random_permutation(iterable, r=None):
  357. """Return a random *r* length permutation of the elements in *iterable*.
  358. If *r* is not specified or is ``None``, then *r* defaults to the length of
  359. *iterable*.
  360. >>> random_permutation(range(5)) # doctest:+SKIP
  361. (3, 4, 0, 1, 2)
  362. This equivalent to taking a random selection from
  363. ``itertools.permutations(iterable, r)``.
  364. """
  365. pool = tuple(iterable)
  366. r = len(pool) if r is None else r
  367. return tuple(sample(pool, r))
  368. def random_combination(iterable, r):
  369. """Return a random *r* length subsequence of the elements in *iterable*.
  370. >>> random_combination(range(5), 3) # doctest:+SKIP
  371. (2, 3, 4)
  372. This equivalent to taking a random selection from
  373. ``itertools.combinations(iterable, r)``.
  374. """
  375. pool = tuple(iterable)
  376. n = len(pool)
  377. indices = sorted(sample(range(n), r))
  378. return tuple(pool[i] for i in indices)
  379. def random_combination_with_replacement(iterable, r):
  380. """Return a random *r* length subsequence of elements in *iterable*,
  381. allowing individual elements to be repeated.
  382. >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
  383. (0, 0, 1, 2, 2)
  384. This equivalent to taking a random selection from
  385. ``itertools.combinations_with_replacement(iterable, r)``.
  386. """
  387. pool = tuple(iterable)
  388. n = len(pool)
  389. indices = sorted(randrange(n) for i in range(r))
  390. return tuple(pool[i] for i in indices)
  391. def nth_combination(iterable, r, index):
  392. """Equivalent to ``list(combinations(iterable, r))[index]``.
  393. The subsequences of *iterable* that are of length *r* can be ordered
  394. lexicographically. :func:`nth_combination` computes the subsequence at
  395. sort position *index* directly, without computing the previous
  396. subsequences.
  397. """
  398. pool = tuple(iterable)
  399. n = len(pool)
  400. if (r < 0) or (r > n):
  401. raise ValueError
  402. c = 1
  403. k = min(r, n - r)
  404. for i in range(1, k + 1):
  405. c = c * (n - k + i) // i
  406. if index < 0:
  407. index += c
  408. if (index < 0) or (index >= c):
  409. raise IndexError
  410. result = []
  411. while r:
  412. c, n, r = c * r // n, n - 1, r - 1
  413. while index >= c:
  414. index -= c
  415. c, n = c * (n - r) // n, n - 1
  416. result.append(pool[-1 - n])
  417. return tuple(result)
  418. def prepend(value, iterator):
  419. """Yield *value*, followed by the elements in *iterator*.
  420. >>> value = '0'
  421. >>> iterator = ['1', '2', '3']
  422. >>> list(prepend(value, iterator))
  423. ['0', '1', '2', '3']
  424. To prepend multiple values, see :func:`itertools.chain`.
  425. """
  426. return chain([value], iterator)