impersonate.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from __future__ import annotations
  2. import re
  3. from abc import ABC
  4. from dataclasses import dataclass
  5. from typing import Any
  6. from .common import RequestHandler, register_preference
  7. from .exceptions import UnsupportedRequest
  8. from ..compat.types import NoneType
  9. from ..utils import classproperty, join_nonempty
  10. from ..utils.networking import std_headers
  11. @dataclass(order=True, frozen=True)
  12. class ImpersonateTarget:
  13. """
  14. A target for browser impersonation.
  15. Parameters:
  16. @param client: the client to impersonate
  17. @param version: the client version to impersonate
  18. @param os: the client OS to impersonate
  19. @param os_version: the client OS version to impersonate
  20. Note: None is used to indicate to match any.
  21. """
  22. client: str | None = None
  23. version: str | None = None
  24. os: str | None = None
  25. os_version: str | None = None
  26. def __post_init__(self):
  27. if self.version and not self.client:
  28. raise ValueError('client is required if version is set')
  29. if self.os_version and not self.os:
  30. raise ValueError('os is required if os_version is set')
  31. def __contains__(self, target: ImpersonateTarget):
  32. if not isinstance(target, ImpersonateTarget):
  33. return False
  34. return (
  35. (self.client is None or target.client is None or self.client == target.client)
  36. and (self.version is None or target.version is None or self.version == target.version)
  37. and (self.os is None or target.os is None or self.os == target.os)
  38. and (self.os_version is None or target.os_version is None or self.os_version == target.os_version)
  39. )
  40. def __str__(self):
  41. return f'{join_nonempty(self.client, self.version)}:{join_nonempty(self.os, self.os_version)}'.rstrip(':')
  42. @classmethod
  43. def from_str(cls, target: str):
  44. mobj = re.fullmatch(r'(?:(?P<client>[^:-]+)(?:-(?P<version>[^:-]+))?)?(?::(?:(?P<os>[^:-]+)(?:-(?P<os_version>[^:-]+))?)?)?', target)
  45. if not mobj:
  46. raise ValueError(f'Invalid impersonate target "{target}"')
  47. return cls(**mobj.groupdict())
  48. class ImpersonateRequestHandler(RequestHandler, ABC):
  49. """
  50. Base class for request handlers that support browser impersonation.
  51. This provides a method for checking the validity of the impersonate extension,
  52. which can be used in _check_extensions.
  53. Impersonate targets consist of a client, version, os and os_ver.
  54. See the ImpersonateTarget class for more details.
  55. The following may be defined:
  56. - `_SUPPORTED_IMPERSONATE_TARGET_MAP`: a dict mapping supported targets to custom object.
  57. Any Request with an impersonate target not in this list will raise an UnsupportedRequest.
  58. Set to None to disable this check.
  59. Note: Entries are in order of preference
  60. Parameters:
  61. @param impersonate: the default impersonate target to use for requests.
  62. Set to None to disable impersonation.
  63. """
  64. _SUPPORTED_IMPERSONATE_TARGET_MAP: dict[ImpersonateTarget, Any] = {}
  65. def __init__(self, *, impersonate: ImpersonateTarget = None, **kwargs):
  66. super().__init__(**kwargs)
  67. self.impersonate = impersonate
  68. def _check_impersonate_target(self, target: ImpersonateTarget):
  69. assert isinstance(target, (ImpersonateTarget, NoneType))
  70. if target is None or not self.supported_targets:
  71. return
  72. if not self.is_supported_target(target):
  73. raise UnsupportedRequest(f'Unsupported impersonate target: {target}')
  74. def _check_extensions(self, extensions):
  75. super()._check_extensions(extensions)
  76. if 'impersonate' in extensions:
  77. self._check_impersonate_target(extensions.get('impersonate'))
  78. def _validate(self, request):
  79. super()._validate(request)
  80. self._check_impersonate_target(self.impersonate)
  81. def _resolve_target(self, target: ImpersonateTarget | None):
  82. """Resolve a target to a supported target."""
  83. if target is None:
  84. return
  85. for supported_target in self.supported_targets:
  86. if target in supported_target:
  87. if self.verbose:
  88. self._logger.stdout(
  89. f'{self.RH_NAME}: resolved impersonate target {target} to {supported_target}')
  90. return supported_target
  91. @classproperty
  92. def supported_targets(cls) -> tuple[ImpersonateTarget, ...]:
  93. return tuple(cls._SUPPORTED_IMPERSONATE_TARGET_MAP.keys())
  94. def is_supported_target(self, target: ImpersonateTarget):
  95. assert isinstance(target, ImpersonateTarget)
  96. return self._resolve_target(target) is not None
  97. def _get_request_target(self, request):
  98. """Get the requested target for the request"""
  99. return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)
  100. def _get_impersonate_headers(self, request):
  101. headers = self._merge_headers(request.headers)
  102. if self._get_request_target(request) is not None:
  103. # remove all headers present in std_headers
  104. # TODO: change this to not depend on std_headers
  105. for k, v in std_headers.items():
  106. if headers.get(k) == v:
  107. headers.pop(k)
  108. return headers
  109. @register_preference(ImpersonateRequestHandler)
  110. def impersonate_preference(rh, request):
  111. if request.extensions.get('impersonate') or rh.impersonate:
  112. return 1000
  113. return 0