utils.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import ast
  2. import itertools
  3. import types
  4. from collections import OrderedDict, Counter, defaultdict
  5. from types import FrameType, TracebackType
  6. from typing import (
  7. Iterator, List, Tuple, Iterable, Callable, Union,
  8. TypeVar, Mapping,
  9. )
  10. from asttokens import ASTText
  11. T = TypeVar('T')
  12. R = TypeVar('R')
  13. def truncate(seq, max_length: int, middle):
  14. if len(seq) > max_length:
  15. right = (max_length - len(middle)) // 2
  16. left = max_length - len(middle) - right
  17. seq = seq[:left] + middle + seq[-right:]
  18. return seq
  19. def unique_in_order(it: Iterable[T]) -> List[T]:
  20. return list(OrderedDict.fromkeys(it))
  21. def line_range(atok: ASTText, node: ast.AST) -> Tuple[int, int]:
  22. """
  23. Returns a pair of numbers representing a half open range
  24. (i.e. suitable as arguments to the `range()` builtin)
  25. of line numbers of the given AST nodes.
  26. """
  27. if isinstance(node, getattr(ast, "match_case", ())):
  28. start, _end = line_range(atok, node.pattern)
  29. _start, end = line_range(atok, node.body[-1])
  30. return start, end
  31. else:
  32. (start, _), (end, _) = atok.get_text_positions(node, padded=False)
  33. return start, end + 1
  34. def highlight_unique(lst: List[T]) -> Iterator[Tuple[T, bool]]:
  35. counts = Counter(lst)
  36. for is_common, group in itertools.groupby(lst, key=lambda x: counts[x] > 3):
  37. if is_common:
  38. group = list(group)
  39. highlighted = [False] * len(group)
  40. def highlight_index(f):
  41. try:
  42. i = f()
  43. except ValueError:
  44. return None
  45. highlighted[i] = True
  46. return i
  47. for item in set(group):
  48. first = highlight_index(lambda: group.index(item))
  49. if first is not None:
  50. highlight_index(lambda: group.index(item, first + 1))
  51. highlight_index(lambda: -1 - group[::-1].index(item))
  52. else:
  53. highlighted = itertools.repeat(True)
  54. yield from zip(group, highlighted)
  55. def identity(x: T) -> T:
  56. return x
  57. def collapse_repeated(lst, *, collapser, mapper=identity, key=identity):
  58. keyed = list(map(key, lst))
  59. for is_highlighted, group in itertools.groupby(
  60. zip(lst, highlight_unique(keyed)),
  61. key=lambda t: t[1][1],
  62. ):
  63. original_group, highlighted_group = zip(*group)
  64. if is_highlighted:
  65. yield from map(mapper, original_group)
  66. else:
  67. keyed_group, _ = zip(*highlighted_group)
  68. yield collapser(list(original_group), list(keyed_group))
  69. def is_frame(frame_or_tb: Union[FrameType, TracebackType]) -> bool:
  70. assert_(isinstance(frame_or_tb, (types.FrameType, types.TracebackType)))
  71. return isinstance(frame_or_tb, (types.FrameType,))
  72. def iter_stack(frame_or_tb: Union[FrameType, TracebackType]) -> Iterator[Union[FrameType, TracebackType]]:
  73. current: Union[FrameType, TracebackType, None] = frame_or_tb
  74. while current:
  75. yield current
  76. if is_frame(current):
  77. current = current.f_back
  78. else:
  79. current = current.tb_next
  80. def frame_and_lineno(frame_or_tb: Union[FrameType, TracebackType]) -> Tuple[FrameType, int]:
  81. if is_frame(frame_or_tb):
  82. return frame_or_tb, frame_or_tb.f_lineno
  83. else:
  84. return frame_or_tb.tb_frame, frame_or_tb.tb_lineno
  85. def group_by_key_func(iterable: Iterable[T], key_func: Callable[[T], R]) -> Mapping[R, List[T]]:
  86. # noinspection PyUnresolvedReferences
  87. """
  88. Create a dictionary from an iterable such that the keys are the result of evaluating a key function on elements
  89. of the iterable and the values are lists of elements all of which correspond to the key.
  90. >>> def si(d): return sorted(d.items())
  91. >>> si(group_by_key_func("a bb ccc d ee fff".split(), len))
  92. [(1, ['a', 'd']), (2, ['bb', 'ee']), (3, ['ccc', 'fff'])]
  93. >>> si(group_by_key_func([-1, 0, 1, 3, 6, 8, 9, 2], lambda x: x % 2))
  94. [(0, [0, 6, 8, 2]), (1, [-1, 1, 3, 9])]
  95. """
  96. result = defaultdict(list)
  97. for item in iterable:
  98. result[key_func(item)].append(item)
  99. return result
  100. class cached_property(object):
  101. """
  102. A property that is only computed once per instance and then replaces itself
  103. with an ordinary attribute. Deleting the attribute resets the property.
  104. Based on https://github.com/pydanny/cached-property/blob/master/cached_property.py
  105. """
  106. def __init__(self, func):
  107. self.__doc__ = func.__doc__
  108. self.func = func
  109. def cached_property_wrapper(self, obj, _cls):
  110. if obj is None:
  111. return self
  112. value = obj.__dict__[self.func.__name__] = self.func(obj)
  113. return value
  114. __get__ = cached_property_wrapper
  115. def _pygmented_with_ranges(formatter, code, ranges):
  116. import pygments
  117. from pygments.lexers import get_lexer_by_name
  118. class MyLexer(type(get_lexer_by_name("python3"))):
  119. def get_tokens(self, text):
  120. length = 0
  121. for ttype, value in super().get_tokens(text):
  122. if any(start <= length < end for start, end in ranges):
  123. ttype = ttype.ExecutingNode
  124. length += len(value)
  125. yield ttype, value
  126. lexer = MyLexer(stripnl=False)
  127. try:
  128. highlighted = pygments.highlight(code, lexer, formatter)
  129. except Exception:
  130. # When pygments fails, prefer code without highlighting over crashing
  131. highlighted = code
  132. return highlighted.splitlines()
  133. def assert_(condition, error=""):
  134. if not condition:
  135. if isinstance(error, str):
  136. error = AssertionError(error)
  137. raise error
  138. # Copied from the standard traceback module pre-3.11
  139. def some_str(value):
  140. try:
  141. return str(value)
  142. except:
  143. return '<unprintable %s object>' % type(value).__name__