_itertools.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from collections import defaultdict, deque
  2. from itertools import filterfalse
  3. def unique_everseen(iterable, key=None):
  4. "List unique elements, preserving order. Remember all elements ever seen."
  5. # unique_everseen('AAAABBBCCDAABBB') --> A B C D
  6. # unique_everseen('ABBCcAD', str.lower) --> A B C D
  7. seen = set()
  8. seen_add = seen.add
  9. if key is None:
  10. for element in filterfalse(seen.__contains__, iterable):
  11. seen_add(element)
  12. yield element
  13. else:
  14. for element in iterable:
  15. k = key(element)
  16. if k not in seen:
  17. seen_add(k)
  18. yield element
  19. # copied from more_itertools 8.8
  20. def always_iterable(obj, base_type=(str, bytes)):
  21. """If *obj* is iterable, return an iterator over its items::
  22. >>> obj = (1, 2, 3)
  23. >>> list(always_iterable(obj))
  24. [1, 2, 3]
  25. If *obj* is not iterable, return a one-item iterable containing *obj*::
  26. >>> obj = 1
  27. >>> list(always_iterable(obj))
  28. [1]
  29. If *obj* is ``None``, return an empty iterable:
  30. >>> obj = None
  31. >>> list(always_iterable(None))
  32. []
  33. By default, binary and text strings are not considered iterable::
  34. >>> obj = 'foo'
  35. >>> list(always_iterable(obj))
  36. ['foo']
  37. If *base_type* is set, objects for which ``isinstance(obj, base_type)``
  38. returns ``True`` won't be considered iterable.
  39. >>> obj = {'a': 1}
  40. >>> list(always_iterable(obj)) # Iterate over the dict's keys
  41. ['a']
  42. >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
  43. [{'a': 1}]
  44. Set *base_type* to ``None`` to avoid any special handling and treat objects
  45. Python considers iterable as iterable:
  46. >>> obj = 'foo'
  47. >>> list(always_iterable(obj, base_type=None))
  48. ['f', 'o', 'o']
  49. """
  50. if obj is None:
  51. return iter(())
  52. if (base_type is not None) and isinstance(obj, base_type):
  53. return iter((obj,))
  54. try:
  55. return iter(obj)
  56. except TypeError:
  57. return iter((obj,))
  58. # Copied from more_itertools 10.3
  59. class bucket:
  60. """Wrap *iterable* and return an object that buckets the iterable into
  61. child iterables based on a *key* function.
  62. >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
  63. >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
  64. >>> sorted(list(s)) # Get the keys
  65. ['a', 'b', 'c']
  66. >>> a_iterable = s['a']
  67. >>> next(a_iterable)
  68. 'a1'
  69. >>> next(a_iterable)
  70. 'a2'
  71. >>> list(s['b'])
  72. ['b1', 'b2', 'b3']
  73. The original iterable will be advanced and its items will be cached until
  74. they are used by the child iterables. This may require significant storage.
  75. By default, attempting to select a bucket to which no items belong will
  76. exhaust the iterable and cache all values.
  77. If you specify a *validator* function, selected buckets will instead be
  78. checked against it.
  79. >>> from itertools import count
  80. >>> it = count(1, 2) # Infinite sequence of odd numbers
  81. >>> key = lambda x: x % 10 # Bucket by last digit
  82. >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
  83. >>> s = bucket(it, key=key, validator=validator)
  84. >>> 2 in s
  85. False
  86. >>> list(s[2])
  87. []
  88. """
  89. def __init__(self, iterable, key, validator=None):
  90. self._it = iter(iterable)
  91. self._key = key
  92. self._cache = defaultdict(deque)
  93. self._validator = validator or (lambda x: True)
  94. def __contains__(self, value):
  95. if not self._validator(value):
  96. return False
  97. try:
  98. item = next(self[value])
  99. except StopIteration:
  100. return False
  101. else:
  102. self._cache[value].appendleft(item)
  103. return True
  104. def _get_values(self, value):
  105. """
  106. Helper to yield items from the parent iterator that match *value*.
  107. Items that don't match are stored in the local cache as they
  108. are encountered.
  109. """
  110. while True:
  111. # If we've cached some items that match the target value, emit
  112. # the first one and evict it from the cache.
  113. if self._cache[value]:
  114. yield self._cache[value].popleft()
  115. # Otherwise we need to advance the parent iterator to search for
  116. # a matching item, caching the rest.
  117. else:
  118. while True:
  119. try:
  120. item = next(self._it)
  121. except StopIteration:
  122. return
  123. item_value = self._key(item)
  124. if item_value == value:
  125. yield item
  126. break
  127. elif self._validator(item_value):
  128. self._cache[item_value].append(item)
  129. def __iter__(self):
  130. for item in self._it:
  131. item_value = self._key(item)
  132. if self._validator(item_value):
  133. self._cache[item_value].append(item)
  134. yield from self._cache.keys()
  135. def __getitem__(self, value):
  136. if not self._validator(value):
  137. return iter(())
  138. return self._get_values(value)