utils.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from collections import OrderedDict, deque
  2. from datetime import date, time, datetime
  3. from decimal import Decimal
  4. from fractions import Fraction
  5. import ast
  6. import enum
  7. import typing
  8. class CannotEval(Exception):
  9. def __repr__(self):
  10. return self.__class__.__name__
  11. __str__ = __repr__
  12. def is_any(x, *args):
  13. return any(
  14. x is arg
  15. for arg in args
  16. )
  17. def of_type(x, *types):
  18. if is_any(type(x), *types):
  19. return x
  20. else:
  21. raise CannotEval
  22. def of_standard_types(x, *, check_dict_values: bool, deep: bool):
  23. if is_standard_types(x, check_dict_values=check_dict_values, deep=deep):
  24. return x
  25. else:
  26. raise CannotEval
  27. def is_standard_types(x, *, check_dict_values: bool, deep: bool):
  28. try:
  29. return _is_standard_types_deep(x, check_dict_values, deep)[0]
  30. except RecursionError:
  31. return False
  32. def _is_standard_types_deep(x, check_dict_values: bool, deep: bool):
  33. typ = type(x)
  34. if is_any(
  35. typ,
  36. str,
  37. int,
  38. bool,
  39. float,
  40. bytes,
  41. complex,
  42. date,
  43. time,
  44. datetime,
  45. Fraction,
  46. Decimal,
  47. type(None),
  48. object,
  49. ):
  50. return True, 0
  51. if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice):
  52. if typ in [slice]:
  53. length = 0
  54. else:
  55. length = len(x)
  56. assert isinstance(deep, bool)
  57. if not deep:
  58. return True, length
  59. if check_dict_values and typ in (dict, OrderedDict):
  60. items = (v for pair in x.items() for v in pair)
  61. elif typ is slice:
  62. items = [x.start, x.stop, x.step]
  63. else:
  64. items = x
  65. for item in items:
  66. if length > 100000:
  67. return False, length
  68. is_standard, item_length = _is_standard_types_deep(
  69. item, check_dict_values, deep
  70. )
  71. if not is_standard:
  72. return False, length
  73. length += item_length
  74. return True, length
  75. return False, 0
  76. class _E(enum.Enum):
  77. pass
  78. class _C:
  79. def foo(self): pass # pragma: nocover
  80. def bar(self): pass # pragma: nocover
  81. @classmethod
  82. def cm(cls): pass # pragma: nocover
  83. @staticmethod
  84. def sm(): pass # pragma: nocover
  85. safe_name_samples = {
  86. "len": len,
  87. "append": list.append,
  88. "__add__": list.__add__,
  89. "insert": [].insert,
  90. "__mul__": [].__mul__,
  91. "fromkeys": dict.__dict__['fromkeys'],
  92. "is_any": is_any,
  93. "__repr__": CannotEval.__repr__,
  94. "foo": _C().foo,
  95. "bar": _C.bar,
  96. "cm": _C.cm,
  97. "sm": _C.sm,
  98. "ast": ast,
  99. "CannotEval": CannotEval,
  100. "_E": _E,
  101. }
  102. typing_annotation_samples = {
  103. name: getattr(typing, name)
  104. for name in "List Dict Tuple Set Callable Mapping".split()
  105. }
  106. safe_name_types = tuple({
  107. type(f)
  108. for f in safe_name_samples.values()
  109. })
  110. typing_annotation_types = tuple({
  111. type(f)
  112. for f in typing_annotation_samples.values()
  113. })
  114. def eq_checking_types(a, b):
  115. return type(a) is type(b) and a == b
  116. def ast_name(node):
  117. if isinstance(node, ast.Name):
  118. return node.id
  119. elif isinstance(node, ast.Attribute):
  120. return node.attr
  121. else:
  122. return None
  123. def safe_name(value):
  124. typ = type(value)
  125. if is_any(typ, *safe_name_types):
  126. return value.__name__
  127. elif value is typing.Optional:
  128. return "Optional"
  129. elif value is typing.Union:
  130. return "Union"
  131. elif is_any(typ, *typing_annotation_types):
  132. return getattr(value, "__name__", None) or getattr(value, "_name", None)
  133. else:
  134. return None
  135. def has_ast_name(value, node):
  136. value_name = safe_name(value)
  137. if type(value_name) is not str:
  138. return False
  139. return eq_checking_types(ast_name(node), value_name)
  140. def copy_ast_without_context(x):
  141. if isinstance(x, ast.AST):
  142. kwargs = {
  143. field: copy_ast_without_context(getattr(x, field))
  144. for field in x._fields
  145. if field != 'ctx'
  146. if hasattr(x, field)
  147. }
  148. a = type(x)(**kwargs)
  149. if hasattr(a, 'ctx'):
  150. # Python 3.13.0b2+ defaults to Load when we don't pass ctx
  151. # https://github.com/python/cpython/pull/118871
  152. del a.ctx
  153. return a
  154. elif isinstance(x, list):
  155. return list(map(copy_ast_without_context, x))
  156. else:
  157. return x
  158. def ensure_dict(x):
  159. """
  160. Handles invalid non-dict inputs
  161. """
  162. try:
  163. return dict(x)
  164. except Exception:
  165. return {}