recipes.py 27 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. """Imported from the recipes section of the itertools documentation.
  2. All functions taken from the recipes section of the itertools library docs
  3. [1]_.
  4. Some backward-compatible usability improvements have been made.
  5. .. [1] http://docs.python.org/library/itertools.html#recipes
  6. """
  7. import math
  8. import operator
  9. from collections import deque
  10. from collections.abc import Sized
  11. from functools import partial, reduce
  12. from itertools import (
  13. chain,
  14. combinations,
  15. compress,
  16. count,
  17. cycle,
  18. groupby,
  19. islice,
  20. product,
  21. repeat,
  22. starmap,
  23. tee,
  24. zip_longest,
  25. )
  26. from random import randrange, sample, choice
  27. from sys import hexversion
  28. __all__ = [
  29. 'all_equal',
  30. 'batched',
  31. 'before_and_after',
  32. 'consume',
  33. 'convolve',
  34. 'dotproduct',
  35. 'first_true',
  36. 'factor',
  37. 'flatten',
  38. 'grouper',
  39. 'iter_except',
  40. 'iter_index',
  41. 'matmul',
  42. 'ncycles',
  43. 'nth',
  44. 'nth_combination',
  45. 'padnone',
  46. 'pad_none',
  47. 'pairwise',
  48. 'partition',
  49. 'polynomial_eval',
  50. 'polynomial_from_roots',
  51. 'polynomial_derivative',
  52. 'powerset',
  53. 'prepend',
  54. 'quantify',
  55. 'reshape',
  56. 'random_combination_with_replacement',
  57. 'random_combination',
  58. 'random_permutation',
  59. 'random_product',
  60. 'repeatfunc',
  61. 'roundrobin',
  62. 'sieve',
  63. 'sliding_window',
  64. 'subslices',
  65. 'sum_of_squares',
  66. 'tabulate',
  67. 'tail',
  68. 'take',
  69. 'totient',
  70. 'transpose',
  71. 'triplewise',
  72. 'unique_everseen',
  73. 'unique_justseen',
  74. ]
  75. _marker = object()
  76. # zip with strict is available for Python 3.10+
  77. try:
  78. zip(strict=True)
  79. except TypeError:
  80. _zip_strict = zip
  81. else:
  82. _zip_strict = partial(zip, strict=True)
  83. # math.sumprod is available for Python 3.12+
  84. _sumprod = getattr(math, 'sumprod', lambda x, y: dotproduct(x, y))
  85. def take(n, iterable):
  86. """Return first *n* items of the iterable as a list.
  87. >>> take(3, range(10))
  88. [0, 1, 2]
  89. If there are fewer than *n* items in the iterable, all of them are
  90. returned.
  91. >>> take(10, range(3))
  92. [0, 1, 2]
  93. """
  94. return list(islice(iterable, n))
  95. def tabulate(function, start=0):
  96. """Return an iterator over the results of ``func(start)``,
  97. ``func(start + 1)``, ``func(start + 2)``...
  98. *func* should be a function that accepts one integer argument.
  99. If *start* is not specified it defaults to 0. It will be incremented each
  100. time the iterator is advanced.
  101. >>> square = lambda x: x ** 2
  102. >>> iterator = tabulate(square, -3)
  103. >>> take(4, iterator)
  104. [9, 4, 1, 0]
  105. """
  106. return map(function, count(start))
  107. def tail(n, iterable):
  108. """Return an iterator over the last *n* items of *iterable*.
  109. >>> t = tail(3, 'ABCDEFG')
  110. >>> list(t)
  111. ['E', 'F', 'G']
  112. """
  113. # If the given iterable has a length, then we can use islice to get its
  114. # final elements. Note that if the iterable is not actually Iterable,
  115. # either islice or deque will throw a TypeError. This is why we don't
  116. # check if it is Iterable.
  117. if isinstance(iterable, Sized):
  118. yield from islice(iterable, max(0, len(iterable) - n), None)
  119. else:
  120. yield from iter(deque(iterable, maxlen=n))
  121. def consume(iterator, n=None):
  122. """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
  123. entirely.
  124. Efficiently exhausts an iterator without returning values. Defaults to
  125. consuming the whole iterator, but an optional second argument may be
  126. provided to limit consumption.
  127. >>> i = (x for x in range(10))
  128. >>> next(i)
  129. 0
  130. >>> consume(i, 3)
  131. >>> next(i)
  132. 4
  133. >>> consume(i)
  134. >>> next(i)
  135. Traceback (most recent call last):
  136. File "<stdin>", line 1, in <module>
  137. StopIteration
  138. If the iterator has fewer items remaining than the provided limit, the
  139. whole iterator will be consumed.
  140. >>> i = (x for x in range(3))
  141. >>> consume(i, 5)
  142. >>> next(i)
  143. Traceback (most recent call last):
  144. File "<stdin>", line 1, in <module>
  145. StopIteration
  146. """
  147. # Use functions that consume iterators at C speed.
  148. if n is None:
  149. # feed the entire iterator into a zero-length deque
  150. deque(iterator, maxlen=0)
  151. else:
  152. # advance to the empty slice starting at position n
  153. next(islice(iterator, n, n), None)
  154. def nth(iterable, n, default=None):
  155. """Returns the nth item or a default value.
  156. >>> l = range(10)
  157. >>> nth(l, 3)
  158. 3
  159. >>> nth(l, 20, "zebra")
  160. 'zebra'
  161. """
  162. return next(islice(iterable, n, None), default)
  163. def all_equal(iterable):
  164. """
  165. Returns ``True`` if all the elements are equal to each other.
  166. >>> all_equal('aaaa')
  167. True
  168. >>> all_equal('aaab')
  169. False
  170. """
  171. g = groupby(iterable)
  172. return next(g, True) and not next(g, False)
  173. def quantify(iterable, pred=bool):
  174. """Return the how many times the predicate is true.
  175. >>> quantify([True, False, True])
  176. 2
  177. """
  178. return sum(map(pred, iterable))
  179. def pad_none(iterable):
  180. """Returns the sequence of elements and then returns ``None`` indefinitely.
  181. >>> take(5, pad_none(range(3)))
  182. [0, 1, 2, None, None]
  183. Useful for emulating the behavior of the built-in :func:`map` function.
  184. See also :func:`padded`.
  185. """
  186. return chain(iterable, repeat(None))
  187. padnone = pad_none
  188. def ncycles(iterable, n):
  189. """Returns the sequence elements *n* times
  190. >>> list(ncycles(["a", "b"], 3))
  191. ['a', 'b', 'a', 'b', 'a', 'b']
  192. """
  193. return chain.from_iterable(repeat(tuple(iterable), n))
  194. def dotproduct(vec1, vec2):
  195. """Returns the dot product of the two iterables.
  196. >>> dotproduct([10, 10], [20, 20])
  197. 400
  198. """
  199. return sum(map(operator.mul, vec1, vec2))
  200. def flatten(listOfLists):
  201. """Return an iterator flattening one level of nesting in a list of lists.
  202. >>> list(flatten([[0, 1], [2, 3]]))
  203. [0, 1, 2, 3]
  204. See also :func:`collapse`, which can flatten multiple levels of nesting.
  205. """
  206. return chain.from_iterable(listOfLists)
  207. def repeatfunc(func, times=None, *args):
  208. """Call *func* with *args* repeatedly, returning an iterable over the
  209. results.
  210. If *times* is specified, the iterable will terminate after that many
  211. repetitions:
  212. >>> from operator import add
  213. >>> times = 4
  214. >>> args = 3, 5
  215. >>> list(repeatfunc(add, times, *args))
  216. [8, 8, 8, 8]
  217. If *times* is ``None`` the iterable will not terminate:
  218. >>> from random import randrange
  219. >>> times = None
  220. >>> args = 1, 11
  221. >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
  222. [2, 4, 8, 1, 8, 4]
  223. """
  224. if times is None:
  225. return starmap(func, repeat(args))
  226. return starmap(func, repeat(args, times))
  227. def _pairwise(iterable):
  228. """Returns an iterator of paired items, overlapping, from the original
  229. >>> take(4, pairwise(count()))
  230. [(0, 1), (1, 2), (2, 3), (3, 4)]
  231. On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
  232. """
  233. a, b = tee(iterable)
  234. next(b, None)
  235. return zip(a, b)
  236. try:
  237. from itertools import pairwise as itertools_pairwise
  238. except ImportError:
  239. pairwise = _pairwise
  240. else:
  241. def pairwise(iterable):
  242. return itertools_pairwise(iterable)
  243. pairwise.__doc__ = _pairwise.__doc__
  244. class UnequalIterablesError(ValueError):
  245. def __init__(self, details=None):
  246. msg = 'Iterables have different lengths'
  247. if details is not None:
  248. msg += (': index 0 has length {}; index {} has length {}').format(
  249. *details
  250. )
  251. super().__init__(msg)
  252. def _zip_equal_generator(iterables):
  253. for combo in zip_longest(*iterables, fillvalue=_marker):
  254. for val in combo:
  255. if val is _marker:
  256. raise UnequalIterablesError()
  257. yield combo
  258. def _zip_equal(*iterables):
  259. # Check whether the iterables are all the same size.
  260. try:
  261. first_size = len(iterables[0])
  262. for i, it in enumerate(iterables[1:], 1):
  263. size = len(it)
  264. if size != first_size:
  265. raise UnequalIterablesError(details=(first_size, i, size))
  266. # All sizes are equal, we can use the built-in zip.
  267. return zip(*iterables)
  268. # If any one of the iterables didn't have a length, start reading
  269. # them until one runs out.
  270. except TypeError:
  271. return _zip_equal_generator(iterables)
  272. def grouper(iterable, n, incomplete='fill', fillvalue=None):
  273. """Group elements from *iterable* into fixed-length groups of length *n*.
  274. >>> list(grouper('ABCDEF', 3))
  275. [('A', 'B', 'C'), ('D', 'E', 'F')]
  276. The keyword arguments *incomplete* and *fillvalue* control what happens for
  277. iterables whose length is not a multiple of *n*.
  278. When *incomplete* is `'fill'`, the last group will contain instances of
  279. *fillvalue*.
  280. >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
  281. [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
  282. When *incomplete* is `'ignore'`, the last group will not be emitted.
  283. >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
  284. [('A', 'B', 'C'), ('D', 'E', 'F')]
  285. When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
  286. >>> it = grouper('ABCDEFG', 3, incomplete='strict')
  287. >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
  288. Traceback (most recent call last):
  289. ...
  290. UnequalIterablesError
  291. """
  292. args = [iter(iterable)] * n
  293. if incomplete == 'fill':
  294. return zip_longest(*args, fillvalue=fillvalue)
  295. if incomplete == 'strict':
  296. return _zip_equal(*args)
  297. if incomplete == 'ignore':
  298. return zip(*args)
  299. else:
  300. raise ValueError('Expected fill, strict, or ignore')
  301. def roundrobin(*iterables):
  302. """Yields an item from each iterable, alternating between them.
  303. >>> list(roundrobin('ABC', 'D', 'EF'))
  304. ['A', 'D', 'E', 'B', 'F', 'C']
  305. This function produces the same output as :func:`interleave_longest`, but
  306. may perform better for some inputs (in particular when the number of
  307. iterables is small).
  308. """
  309. # Recipe credited to George Sakkis
  310. pending = len(iterables)
  311. nexts = cycle(iter(it).__next__ for it in iterables)
  312. while pending:
  313. try:
  314. for next in nexts:
  315. yield next()
  316. except StopIteration:
  317. pending -= 1
  318. nexts = cycle(islice(nexts, pending))
  319. def partition(pred, iterable):
  320. """
  321. Returns a 2-tuple of iterables derived from the input iterable.
  322. The first yields the items that have ``pred(item) == False``.
  323. The second yields the items that have ``pred(item) == True``.
  324. >>> is_odd = lambda x: x % 2 != 0
  325. >>> iterable = range(10)
  326. >>> even_items, odd_items = partition(is_odd, iterable)
  327. >>> list(even_items), list(odd_items)
  328. ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
  329. If *pred* is None, :func:`bool` is used.
  330. >>> iterable = [0, 1, False, True, '', ' ']
  331. >>> false_items, true_items = partition(None, iterable)
  332. >>> list(false_items), list(true_items)
  333. ([0, False, ''], [1, True, ' '])
  334. """
  335. if pred is None:
  336. pred = bool
  337. t1, t2, p = tee(iterable, 3)
  338. p1, p2 = tee(map(pred, p))
  339. return (compress(t1, map(operator.not_, p1)), compress(t2, p2))
  340. def powerset(iterable):
  341. """Yields all possible subsets of the iterable.
  342. >>> list(powerset([1, 2, 3]))
  343. [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
  344. :func:`powerset` will operate on iterables that aren't :class:`set`
  345. instances, so repeated elements in the input will produce repeated elements
  346. in the output. Use :func:`unique_everseen` on the input to avoid generating
  347. duplicates:
  348. >>> seq = [1, 1, 0]
  349. >>> list(powerset(seq))
  350. [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
  351. >>> from more_itertools import unique_everseen
  352. >>> list(powerset(unique_everseen(seq)))
  353. [(), (1,), (0,), (1, 0)]
  354. """
  355. s = list(iterable)
  356. return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
  357. def unique_everseen(iterable, key=None):
  358. """
  359. Yield unique elements, preserving order.
  360. >>> list(unique_everseen('AAAABBBCCDAABBB'))
  361. ['A', 'B', 'C', 'D']
  362. >>> list(unique_everseen('ABBCcAD', str.lower))
  363. ['A', 'B', 'C', 'D']
  364. Sequences with a mix of hashable and unhashable items can be used.
  365. The function will be slower (i.e., `O(n^2)`) for unhashable items.
  366. Remember that ``list`` objects are unhashable - you can use the *key*
  367. parameter to transform the list to a tuple (which is hashable) to
  368. avoid a slowdown.
  369. >>> iterable = ([1, 2], [2, 3], [1, 2])
  370. >>> list(unique_everseen(iterable)) # Slow
  371. [[1, 2], [2, 3]]
  372. >>> list(unique_everseen(iterable, key=tuple)) # Faster
  373. [[1, 2], [2, 3]]
  374. Similarly, you may want to convert unhashable ``set`` objects with
  375. ``key=frozenset``. For ``dict`` objects,
  376. ``key=lambda x: frozenset(x.items())`` can be used.
  377. """
  378. seenset = set()
  379. seenset_add = seenset.add
  380. seenlist = []
  381. seenlist_add = seenlist.append
  382. use_key = key is not None
  383. for element in iterable:
  384. k = key(element) if use_key else element
  385. try:
  386. if k not in seenset:
  387. seenset_add(k)
  388. yield element
  389. except TypeError:
  390. if k not in seenlist:
  391. seenlist_add(k)
  392. yield element
  393. def unique_justseen(iterable, key=None):
  394. """Yields elements in order, ignoring serial duplicates
  395. >>> list(unique_justseen('AAAABBBCCDAABBB'))
  396. ['A', 'B', 'C', 'D', 'A', 'B']
  397. >>> list(unique_justseen('ABBCcAD', str.lower))
  398. ['A', 'B', 'C', 'A', 'D']
  399. """
  400. if key is None:
  401. return map(operator.itemgetter(0), groupby(iterable))
  402. return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
  403. def iter_except(func, exception, first=None):
  404. """Yields results from a function repeatedly until an exception is raised.
  405. Converts a call-until-exception interface to an iterator interface.
  406. Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
  407. to end the loop.
  408. >>> l = [0, 1, 2]
  409. >>> list(iter_except(l.pop, IndexError))
  410. [2, 1, 0]
  411. Multiple exceptions can be specified as a stopping condition:
  412. >>> l = [1, 2, 3, '...', 4, 5, 6]
  413. >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
  414. [7, 6, 5]
  415. >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
  416. [4, 3, 2]
  417. >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
  418. []
  419. """
  420. try:
  421. if first is not None:
  422. yield first()
  423. while 1:
  424. yield func()
  425. except exception:
  426. pass
  427. def first_true(iterable, default=None, pred=None):
  428. """
  429. Returns the first true value in the iterable.
  430. If no true value is found, returns *default*
  431. If *pred* is not None, returns the first item for which
  432. ``pred(item) == True`` .
  433. >>> first_true(range(10))
  434. 1
  435. >>> first_true(range(10), pred=lambda x: x > 5)
  436. 6
  437. >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
  438. 'missing'
  439. """
  440. return next(filter(pred, iterable), default)
  441. def random_product(*args, repeat=1):
  442. """Draw an item at random from each of the input iterables.
  443. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
  444. ('c', 3, 'Z')
  445. If *repeat* is provided as a keyword argument, that many items will be
  446. drawn from each iterable.
  447. >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
  448. ('a', 2, 'd', 3)
  449. This equivalent to taking a random selection from
  450. ``itertools.product(*args, **kwarg)``.
  451. """
  452. pools = [tuple(pool) for pool in args] * repeat
  453. return tuple(choice(pool) for pool in pools)
  454. def random_permutation(iterable, r=None):
  455. """Return a random *r* length permutation of the elements in *iterable*.
  456. If *r* is not specified or is ``None``, then *r* defaults to the length of
  457. *iterable*.
  458. >>> random_permutation(range(5)) # doctest:+SKIP
  459. (3, 4, 0, 1, 2)
  460. This equivalent to taking a random selection from
  461. ``itertools.permutations(iterable, r)``.
  462. """
  463. pool = tuple(iterable)
  464. r = len(pool) if r is None else r
  465. return tuple(sample(pool, r))
  466. def random_combination(iterable, r):
  467. """Return a random *r* length subsequence of the elements in *iterable*.
  468. >>> random_combination(range(5), 3) # doctest:+SKIP
  469. (2, 3, 4)
  470. This equivalent to taking a random selection from
  471. ``itertools.combinations(iterable, r)``.
  472. """
  473. pool = tuple(iterable)
  474. n = len(pool)
  475. indices = sorted(sample(range(n), r))
  476. return tuple(pool[i] for i in indices)
  477. def random_combination_with_replacement(iterable, r):
  478. """Return a random *r* length subsequence of elements in *iterable*,
  479. allowing individual elements to be repeated.
  480. >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
  481. (0, 0, 1, 2, 2)
  482. This equivalent to taking a random selection from
  483. ``itertools.combinations_with_replacement(iterable, r)``.
  484. """
  485. pool = tuple(iterable)
  486. n = len(pool)
  487. indices = sorted(randrange(n) for i in range(r))
  488. return tuple(pool[i] for i in indices)
  489. def nth_combination(iterable, r, index):
  490. """Equivalent to ``list(combinations(iterable, r))[index]``.
  491. The subsequences of *iterable* that are of length *r* can be ordered
  492. lexicographically. :func:`nth_combination` computes the subsequence at
  493. sort position *index* directly, without computing the previous
  494. subsequences.
  495. >>> nth_combination(range(5), 3, 5)
  496. (0, 3, 4)
  497. ``ValueError`` will be raised If *r* is negative or greater than the length
  498. of *iterable*.
  499. ``IndexError`` will be raised if the given *index* is invalid.
  500. """
  501. pool = tuple(iterable)
  502. n = len(pool)
  503. if (r < 0) or (r > n):
  504. raise ValueError
  505. c = 1
  506. k = min(r, n - r)
  507. for i in range(1, k + 1):
  508. c = c * (n - k + i) // i
  509. if index < 0:
  510. index += c
  511. if (index < 0) or (index >= c):
  512. raise IndexError
  513. result = []
  514. while r:
  515. c, n, r = c * r // n, n - 1, r - 1
  516. while index >= c:
  517. index -= c
  518. c, n = c * (n - r) // n, n - 1
  519. result.append(pool[-1 - n])
  520. return tuple(result)
  521. def prepend(value, iterator):
  522. """Yield *value*, followed by the elements in *iterator*.
  523. >>> value = '0'
  524. >>> iterator = ['1', '2', '3']
  525. >>> list(prepend(value, iterator))
  526. ['0', '1', '2', '3']
  527. To prepend multiple values, see :func:`itertools.chain`
  528. or :func:`value_chain`.
  529. """
  530. return chain([value], iterator)
  531. def convolve(signal, kernel):
  532. """Convolve the iterable *signal* with the iterable *kernel*.
  533. >>> signal = (1, 2, 3, 4, 5)
  534. >>> kernel = [3, 2, 1]
  535. >>> list(convolve(signal, kernel))
  536. [3, 8, 14, 20, 26, 14, 5]
  537. Note: the input arguments are not interchangeable, as the *kernel*
  538. is immediately consumed and stored.
  539. """
  540. # This implementation intentionally doesn't match the one in the itertools
  541. # documentation.
  542. kernel = tuple(kernel)[::-1]
  543. n = len(kernel)
  544. window = deque([0], maxlen=n) * n
  545. for x in chain(signal, repeat(0, n - 1)):
  546. window.append(x)
  547. yield _sumprod(kernel, window)
  548. def before_and_after(predicate, it):
  549. """A variant of :func:`takewhile` that allows complete access to the
  550. remainder of the iterator.
  551. >>> it = iter('ABCdEfGhI')
  552. >>> all_upper, remainder = before_and_after(str.isupper, it)
  553. >>> ''.join(all_upper)
  554. 'ABC'
  555. >>> ''.join(remainder) # takewhile() would lose the 'd'
  556. 'dEfGhI'
  557. Note that the first iterator must be fully consumed before the second
  558. iterator can generate valid results.
  559. """
  560. it = iter(it)
  561. transition = []
  562. def true_iterator():
  563. for elem in it:
  564. if predicate(elem):
  565. yield elem
  566. else:
  567. transition.append(elem)
  568. return
  569. # Note: this is different from itertools recipes to allow nesting
  570. # before_and_after remainders into before_and_after again. See tests
  571. # for an example.
  572. remainder_iterator = chain(transition, it)
  573. return true_iterator(), remainder_iterator
  574. def triplewise(iterable):
  575. """Return overlapping triplets from *iterable*.
  576. >>> list(triplewise('ABCDE'))
  577. [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
  578. """
  579. for (a, _), (b, c) in pairwise(pairwise(iterable)):
  580. yield a, b, c
  581. def sliding_window(iterable, n):
  582. """Return a sliding window of width *n* over *iterable*.
  583. >>> list(sliding_window(range(6), 4))
  584. [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
  585. If *iterable* has fewer than *n* items, then nothing is yielded:
  586. >>> list(sliding_window(range(3), 4))
  587. []
  588. For a variant with more features, see :func:`windowed`.
  589. """
  590. it = iter(iterable)
  591. window = deque(islice(it, n - 1), maxlen=n)
  592. for x in it:
  593. window.append(x)
  594. yield tuple(window)
  595. def subslices(iterable):
  596. """Return all contiguous non-empty subslices of *iterable*.
  597. >>> list(subslices('ABC'))
  598. [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
  599. This is similar to :func:`substrings`, but emits items in a different
  600. order.
  601. """
  602. seq = list(iterable)
  603. slices = starmap(slice, combinations(range(len(seq) + 1), 2))
  604. return map(operator.getitem, repeat(seq), slices)
  605. def polynomial_from_roots(roots):
  606. """Compute a polynomial's coefficients from its roots.
  607. >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
  608. >>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
  609. [1, -4, -17, 60]
  610. """
  611. factors = zip(repeat(1), map(operator.neg, roots))
  612. return list(reduce(convolve, factors, [1]))
  613. def iter_index(iterable, value, start=0, stop=None):
  614. """Yield the index of each place in *iterable* that *value* occurs,
  615. beginning with index *start* and ending before index *stop*.
  616. See :func:`locate` for a more general means of finding the indexes
  617. associated with particular values.
  618. >>> list(iter_index('AABCADEAF', 'A'))
  619. [0, 1, 4, 7]
  620. >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
  621. [1, 4, 7]
  622. >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
  623. [1, 4]
  624. """
  625. seq_index = getattr(iterable, 'index', None)
  626. if seq_index is None:
  627. # Slow path for general iterables
  628. it = islice(iterable, start, stop)
  629. for i, element in enumerate(it, start):
  630. if element is value or element == value:
  631. yield i
  632. else:
  633. # Fast path for sequences
  634. stop = len(iterable) if stop is None else stop
  635. i = start - 1
  636. try:
  637. while True:
  638. yield (i := seq_index(value, i + 1, stop))
  639. except ValueError:
  640. pass
  641. def sieve(n):
  642. """Yield the primes less than n.
  643. >>> list(sieve(30))
  644. [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
  645. """
  646. if n > 2:
  647. yield 2
  648. start = 3
  649. data = bytearray((0, 1)) * (n // 2)
  650. limit = math.isqrt(n) + 1
  651. for p in iter_index(data, 1, start, limit):
  652. yield from iter_index(data, 1, start, p * p)
  653. data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
  654. start = p * p
  655. yield from iter_index(data, 1, start)
  656. def _batched(iterable, n, *, strict=False):
  657. """Batch data into tuples of length *n*. If the number of items in
  658. *iterable* is not divisible by *n*:
  659. * The last batch will be shorter if *strict* is ``False``.
  660. * :exc:`ValueError` will be raised if *strict* is ``True``.
  661. >>> list(batched('ABCDEFG', 3))
  662. [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
  663. On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
  664. """
  665. if n < 1:
  666. raise ValueError('n must be at least one')
  667. it = iter(iterable)
  668. while batch := tuple(islice(it, n)):
  669. if strict and len(batch) != n:
  670. raise ValueError('batched(): incomplete batch')
  671. yield batch
  672. if hexversion >= 0x30D00A2:
  673. from itertools import batched as itertools_batched
  674. def batched(iterable, n, *, strict=False):
  675. return itertools_batched(iterable, n, strict=strict)
  676. else:
  677. batched = _batched
  678. batched.__doc__ = _batched.__doc__
  679. def transpose(it):
  680. """Swap the rows and columns of the input matrix.
  681. >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
  682. [(1, 11), (2, 22), (3, 33)]
  683. The caller should ensure that the dimensions of the input are compatible.
  684. If the input is empty, no output will be produced.
  685. """
  686. return _zip_strict(*it)
  687. def reshape(matrix, cols):
  688. """Reshape the 2-D input *matrix* to have a column count given by *cols*.
  689. >>> matrix = [(0, 1), (2, 3), (4, 5)]
  690. >>> cols = 3
  691. >>> list(reshape(matrix, cols))
  692. [(0, 1, 2), (3, 4, 5)]
  693. """
  694. return batched(chain.from_iterable(matrix), cols)
  695. def matmul(m1, m2):
  696. """Multiply two matrices.
  697. >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
  698. [(49, 80), (41, 60)]
  699. The caller should ensure that the dimensions of the input matrices are
  700. compatible with each other.
  701. """
  702. n = len(m2[0])
  703. return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
  704. def factor(n):
  705. """Yield the prime factors of n.
  706. >>> list(factor(360))
  707. [2, 2, 2, 3, 3, 5]
  708. """
  709. for prime in sieve(math.isqrt(n) + 1):
  710. while not n % prime:
  711. yield prime
  712. n //= prime
  713. if n == 1:
  714. return
  715. if n > 1:
  716. yield n
  717. def polynomial_eval(coefficients, x):
  718. """Evaluate a polynomial at a specific value.
  719. Example: evaluating x^3 - 4 * x^2 - 17 * x + 60 at x = 2.5:
  720. >>> coefficients = [1, -4, -17, 60]
  721. >>> x = 2.5
  722. >>> polynomial_eval(coefficients, x)
  723. 8.125
  724. """
  725. n = len(coefficients)
  726. if n == 0:
  727. return x * 0 # coerce zero to the type of x
  728. powers = map(pow, repeat(x), reversed(range(n)))
  729. return _sumprod(coefficients, powers)
  730. def sum_of_squares(it):
  731. """Return the sum of the squares of the input values.
  732. >>> sum_of_squares([10, 20, 30])
  733. 1400
  734. """
  735. return _sumprod(*tee(it))
  736. def polynomial_derivative(coefficients):
  737. """Compute the first derivative of a polynomial.
  738. Example: evaluating the derivative of x^3 - 4 * x^2 - 17 * x + 60
  739. >>> coefficients = [1, -4, -17, 60]
  740. >>> derivative_coefficients = polynomial_derivative(coefficients)
  741. >>> derivative_coefficients
  742. [3, -8, -17]
  743. """
  744. n = len(coefficients)
  745. powers = reversed(range(1, n))
  746. return list(map(operator.mul, coefficients, powers))
  747. def totient(n):
  748. """Return the count of natural numbers up to *n* that are coprime with *n*.
  749. >>> totient(9)
  750. 6
  751. >>> totient(12)
  752. 4
  753. """
  754. for p in unique_justseen(factor(n)):
  755. n = n // p * (p - 1)
  756. return n