traversal.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. from __future__ import annotations
  2. import collections
  3. import collections.abc
  4. import contextlib
  5. import functools
  6. import http.cookies
  7. import inspect
  8. import itertools
  9. import re
  10. import typing
  11. import xml.etree.ElementTree
  12. from ._utils import (
  13. IDENTITY,
  14. NO_DEFAULT,
  15. ExtractorError,
  16. LazyList,
  17. deprecation_warning,
  18. get_elements_html_by_class,
  19. get_elements_html_by_attribute,
  20. get_elements_by_attribute,
  21. get_element_by_class,
  22. get_element_html_by_attribute,
  23. get_element_by_attribute,
  24. get_element_html_by_id,
  25. get_element_by_id,
  26. get_element_html_by_class,
  27. get_elements_by_class,
  28. get_element_text_and_html_by_tag,
  29. is_iterable_like,
  30. try_call,
  31. url_or_none,
  32. variadic,
  33. )
  34. def traverse_obj(
  35. obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
  36. casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
  37. """
  38. Safely traverse nested `dict`s and `Iterable`s
  39. >>> obj = [{}, {"key": "value"}]
  40. >>> traverse_obj(obj, (1, "key"))
  41. 'value'
  42. Each of the provided `paths` is tested and the first producing a valid result will be returned.
  43. The next path will also be tested if the path branched but no results could be found.
  44. Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
  45. `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
  46. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
  47. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
  48. The keys in the path can be one of:
  49. - `None`: Return the current object.
  50. - `set`: Requires the only item in the set to be a type or function,
  51. like `{type}`/`{type, type, ...}`/`{func}`. If a `type`, return only
  52. values of this type. If a function, returns `func(obj)`.
  53. - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
  54. - `slice`: Branch out and return all values in `obj[key]`.
  55. - `Ellipsis`: Branch out and return a list of all values.
  56. - `tuple`/`list`: Branch out and return a list of all matching values.
  57. Read as: `[traverse_obj(obj, branch) for branch in branches]`.
  58. - `function`: Branch out and return values filtered by the function.
  59. Read as: `[value for key, value in obj if function(key, value)]`.
  60. For `Iterable`s, `key` is the index of the value.
  61. For `re.Match`es, `key` is the group number (0 = full match)
  62. as well as additionally any group names, if given.
  63. - `dict`: Transform the current object and return a matching dict.
  64. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
  65. - `any`-builtin: Take the first matching object and return it, resetting branching.
  66. - `all`-builtin: Take all matching objects and return them as a list, resetting branching.
  67. - `filter`-builtin: Return the value if it is truthy, `None` otherwise.
  68. `tuple`, `list`, and `dict` all support nested paths and branches.
  69. @params paths Paths by which to traverse.
  70. @param default Value to return if the paths do not match.
  71. If the last key in the path is a `dict`, it will apply to each value inside
  72. the dict instead, depth first. Try to avoid if using nested `dict` keys.
  73. @param expected_type If a `type`, only accept final values of this type.
  74. If any other callable, try to call the function on each result.
  75. If the last key in the path is a `dict`, it will apply to each value inside
  76. the dict instead, recursively. This does respect branching paths.
  77. @param get_all If `False`, return the first matching result, otherwise all matching ones.
  78. @param casesense If `False`, consider string dictionary keys as case insensitive.
  79. `traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
  80. @param traverse_string Whether to traverse into objects as strings.
  81. If `True`, any non-compatible object will first be
  82. converted into a string and then traversed into.
  83. The return value of that path will be a string instead,
  84. not respecting any further branching.
  85. @returns The result of the object traversal.
  86. If successful, `get_all=True`, and the path branches at least once,
  87. then a list of results is returned instead.
  88. If no `default` is given and the last path branches, a `list` of results
  89. is always returned. If a path ends on a `dict` that result will always be a `dict`.
  90. """
  91. if is_user_input is not NO_DEFAULT:
  92. deprecation_warning('The is_user_input parameter is deprecated and no longer works')
  93. casefold = lambda k: k.casefold() if isinstance(k, str) else k
  94. if isinstance(expected_type, type):
  95. type_test = lambda val: val if isinstance(val, expected_type) else None
  96. else:
  97. type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
  98. def apply_key(key, obj, is_last):
  99. branching = False
  100. result = None
  101. if obj is None and traverse_string:
  102. if key is ... or callable(key) or isinstance(key, slice):
  103. branching = True
  104. result = ()
  105. elif key is None:
  106. result = obj
  107. elif isinstance(key, set):
  108. item = next(iter(key))
  109. if len(key) > 1 or isinstance(item, type):
  110. assert all(isinstance(item, type) for item in key)
  111. if isinstance(obj, tuple(key)):
  112. result = obj
  113. else:
  114. result = try_call(item, args=(obj,))
  115. elif isinstance(key, (list, tuple)):
  116. branching = True
  117. result = itertools.chain.from_iterable(
  118. apply_path(obj, branch, is_last)[0] for branch in key)
  119. elif key is ...:
  120. branching = True
  121. if isinstance(obj, http.cookies.Morsel):
  122. obj = dict(obj, key=obj.key, value=obj.value)
  123. if isinstance(obj, collections.abc.Mapping):
  124. result = obj.values()
  125. elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
  126. result = obj
  127. elif isinstance(obj, re.Match):
  128. result = obj.groups()
  129. elif traverse_string:
  130. branching = False
  131. result = str(obj)
  132. else:
  133. result = ()
  134. elif callable(key):
  135. branching = True
  136. if isinstance(obj, http.cookies.Morsel):
  137. obj = dict(obj, key=obj.key, value=obj.value)
  138. if isinstance(obj, collections.abc.Mapping):
  139. iter_obj = obj.items()
  140. elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
  141. iter_obj = enumerate(obj)
  142. elif isinstance(obj, re.Match):
  143. iter_obj = itertools.chain(
  144. enumerate((obj.group(), *obj.groups())),
  145. obj.groupdict().items())
  146. elif traverse_string:
  147. branching = False
  148. iter_obj = enumerate(str(obj))
  149. else:
  150. iter_obj = ()
  151. result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
  152. if not branching: # string traversal
  153. result = ''.join(result)
  154. elif isinstance(key, dict):
  155. iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
  156. result = {
  157. k: v if v is not None else default for k, v in iter_obj
  158. if v is not None or default is not NO_DEFAULT
  159. } or None
  160. elif isinstance(obj, collections.abc.Mapping):
  161. if isinstance(obj, http.cookies.Morsel):
  162. obj = dict(obj, key=obj.key, value=obj.value)
  163. result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
  164. next((v for k, v in obj.items() if casefold(k) == key), None))
  165. elif isinstance(obj, re.Match):
  166. if isinstance(key, int) or casesense:
  167. with contextlib.suppress(IndexError):
  168. result = obj.group(key)
  169. elif isinstance(key, str):
  170. result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
  171. elif isinstance(key, (int, slice)):
  172. if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
  173. branching = isinstance(key, slice)
  174. with contextlib.suppress(IndexError):
  175. result = obj[key]
  176. elif traverse_string:
  177. with contextlib.suppress(IndexError):
  178. result = str(obj)[key]
  179. elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
  180. xpath, _, special = key.rpartition('/')
  181. if not special.startswith('@') and not special.endswith('()'):
  182. xpath = key
  183. special = None
  184. # Allow abbreviations of relative paths, absolute paths error
  185. if xpath.startswith('/'):
  186. xpath = f'.{xpath}'
  187. elif xpath and not xpath.startswith('./'):
  188. xpath = f'./{xpath}'
  189. def apply_specials(element):
  190. if special is None:
  191. return element
  192. if special == '@':
  193. return element.attrib
  194. if special.startswith('@'):
  195. return try_call(element.attrib.get, args=(special[1:],))
  196. if special == 'text()':
  197. return element.text
  198. raise SyntaxError(f'apply_specials is missing case for {special!r}')
  199. if xpath:
  200. result = list(map(apply_specials, obj.iterfind(xpath)))
  201. else:
  202. result = apply_specials(obj)
  203. return branching, result if branching else (result,)
  204. def lazy_last(iterable):
  205. iterator = iter(iterable)
  206. prev = next(iterator, NO_DEFAULT)
  207. if prev is NO_DEFAULT:
  208. return
  209. for item in iterator:
  210. yield False, prev
  211. prev = item
  212. yield True, prev
  213. def apply_path(start_obj, path, test_type):
  214. objs = (start_obj,)
  215. has_branched = False
  216. key = None
  217. for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
  218. if not casesense and isinstance(key, str):
  219. key = key.casefold()
  220. if key in (any, all):
  221. has_branched = False
  222. filtered_objs = (obj for obj in objs if obj not in (None, {}))
  223. if key is any:
  224. objs = (next(filtered_objs, None),)
  225. else:
  226. objs = (list(filtered_objs),)
  227. continue
  228. if key is filter:
  229. objs = filter(None, objs)
  230. continue
  231. if __debug__ and callable(key):
  232. # Verify function signature
  233. inspect.signature(key).bind(None, None)
  234. new_objs = []
  235. for obj in objs:
  236. branching, results = apply_key(key, obj, last)
  237. has_branched |= branching
  238. new_objs.append(results)
  239. objs = itertools.chain.from_iterable(new_objs)
  240. if test_type and not isinstance(key, (dict, list, tuple)):
  241. objs = map(type_test, objs)
  242. return objs, has_branched, isinstance(key, dict)
  243. def _traverse_obj(obj, path, allow_empty, test_type):
  244. results, has_branched, is_dict = apply_path(obj, path, test_type)
  245. results = LazyList(item for item in results if item not in (None, {}))
  246. if get_all and has_branched:
  247. if results:
  248. return results.exhaust()
  249. if allow_empty:
  250. return [] if default is NO_DEFAULT else default
  251. return None
  252. return results[0] if results else {} if allow_empty and is_dict else None
  253. for index, path in enumerate(paths, 1):
  254. is_last = index == len(paths)
  255. try:
  256. result = _traverse_obj(obj, path, is_last, True)
  257. if result is not None:
  258. return result
  259. except _RequiredError as e:
  260. if is_last:
  261. # Reraise to get cleaner stack trace
  262. raise ExtractorError(e.orig_msg, expected=e.expected) from None
  263. return None if default is NO_DEFAULT else default
  264. def value(value, /):
  265. return lambda _: value
  266. def require(name, /, *, expected=False):
  267. def func(value):
  268. if value is None:
  269. raise _RequiredError(f'Unable to extract {name}', expected=expected)
  270. return value
  271. return func
  272. class _RequiredError(ExtractorError):
  273. pass
  274. @typing.overload
  275. def subs_list_to_dict(*, lang: str | None = 'und', ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ...
  276. @typing.overload
  277. def subs_list_to_dict(subs: list[dict] | None, /, *, lang: str | None = 'und', ext: str | None = None) -> dict[str, list[dict]]: ...
  278. def subs_list_to_dict(subs: list[dict] | None = None, /, *, lang='und', ext=None):
  279. """
  280. Convert subtitles from a traversal into a subtitle dict.
  281. The path should have an `all` immediately before this function.
  282. Arguments:
  283. `ext` The default value for `ext` in the subtitle dict
  284. In the dict you can set the following additional items:
  285. `id` The subtitle id to sort the dict into
  286. `quality` The sort order for each subtitle
  287. """
  288. if subs is None:
  289. return functools.partial(subs_list_to_dict, lang=lang, ext=ext)
  290. result = collections.defaultdict(list)
  291. for sub in subs:
  292. if not url_or_none(sub.get('url')) and not sub.get('data'):
  293. continue
  294. sub_id = sub.pop('id', None)
  295. if not isinstance(sub_id, str):
  296. if not lang:
  297. continue
  298. sub_id = lang
  299. sub_ext = sub.get('ext')
  300. if not isinstance(sub_ext, str):
  301. if not ext:
  302. sub.pop('ext', None)
  303. else:
  304. sub['ext'] = ext
  305. result[sub_id].append(sub)
  306. result = dict(result)
  307. for subs in result.values():
  308. subs.sort(key=lambda x: x.pop('quality', 0) or 0)
  309. return result
  310. @typing.overload
  311. def find_element(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ...
  312. @typing.overload
  313. def find_element(*, cls: str, html=False): ...
  314. @typing.overload
  315. def find_element(*, id: str, tag: str | None = None, html=False, regex=False): ...
  316. @typing.overload
  317. def find_element(*, tag: str, html=False, regex=False): ...
  318. def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False, regex=False):
  319. # deliberately using `id=` and `cls=` for ease of readability
  320. assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required'
  321. ANY_TAG = r'[\w:.-]+'
  322. if attr and value:
  323. assert not cls, 'Cannot match both attr and cls'
  324. assert not id, 'Cannot match both attr and id'
  325. func = get_element_html_by_attribute if html else get_element_by_attribute
  326. return functools.partial(func, attr, value, tag=tag or ANY_TAG, escape_value=not regex)
  327. elif cls:
  328. assert not id, 'Cannot match both cls and id'
  329. assert tag is None, 'Cannot match both cls and tag'
  330. assert not regex, 'Cannot use regex with cls'
  331. func = get_element_html_by_class if html else get_element_by_class
  332. return functools.partial(func, cls)
  333. elif id:
  334. func = get_element_html_by_id if html else get_element_by_id
  335. return functools.partial(func, id, tag=tag or ANY_TAG, escape_value=not regex)
  336. index = int(bool(html))
  337. return lambda html: get_element_text_and_html_by_tag(tag, html)[index]
  338. @typing.overload
  339. def find_elements(*, cls: str, html=False): ...
  340. @typing.overload
  341. def find_elements(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ...
  342. def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False, regex=False):
  343. # deliberately using `cls=` for ease of readability
  344. assert cls or (attr and value), 'One of cls or (attr AND value) is required'
  345. if attr and value:
  346. assert not cls, 'Cannot match both attr and cls'
  347. func = get_elements_html_by_attribute if html else get_elements_by_attribute
  348. return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+', escape_value=not regex)
  349. assert not tag, 'Cannot match both cls and tag'
  350. assert not regex, 'Cannot use regex with cls'
  351. func = get_elements_html_by_class if html else get_elements_by_class
  352. return functools.partial(func, cls)
  353. def trim_str(*, start=None, end=None):
  354. def trim(s):
  355. if s is None:
  356. return None
  357. start_idx = 0
  358. if start and s.startswith(start):
  359. start_idx = len(start)
  360. if end and s.endswith(end):
  361. return s[start_idx:-len(end)]
  362. return s[start_idx:]
  363. return trim
  364. def unpack(func, **kwargs):
  365. @functools.wraps(func)
  366. def inner(items):
  367. return func(*items, **kwargs)
  368. return inner
  369. def get_first(obj, *paths, **kwargs):
  370. return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
  371. def dict_get(d, key_or_keys, default=None, skip_false_values=True):
  372. for val in map(d.get, variadic(key_or_keys)):
  373. if val is not None and (val or not skip_false_values):
  374. return val
  375. return default