_auth.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import hashlib
  2. import os
  3. import re
  4. import time
  5. import typing
  6. from base64 import b64encode
  7. from urllib.request import parse_http_list
  8. from ._exceptions import ProtocolError
  9. from ._models import Cookies, Request, Response
  10. from ._utils import to_bytes, to_str, unquote
  11. if typing.TYPE_CHECKING: # pragma: no cover
  12. from hashlib import _Hash
  13. class Auth:
  14. """
  15. Base class for all authentication schemes.
  16. To implement a custom authentication scheme, subclass `Auth` and override
  17. the `.auth_flow()` method.
  18. If the authentication scheme does I/O such as disk access or network calls, or uses
  19. synchronization primitives such as locks, you should override `.sync_auth_flow()`
  20. and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
  21. implementations that will be used by `Client` and `AsyncClient` respectively.
  22. """
  23. requires_request_body = False
  24. requires_response_body = False
  25. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  26. """
  27. Execute the authentication flow.
  28. To dispatch a request, `yield` it:
  29. ```
  30. yield request
  31. ```
  32. The client will `.send()` the response back into the flow generator. You can
  33. access it like so:
  34. ```
  35. response = yield request
  36. ```
  37. A `return` (or reaching the end of the generator) will result in the
  38. client returning the last response obtained from the server.
  39. You can dispatch as many requests as is necessary.
  40. """
  41. yield request
  42. def sync_auth_flow(
  43. self, request: Request
  44. ) -> typing.Generator[Request, Response, None]:
  45. """
  46. Execute the authentication flow synchronously.
  47. By default, this defers to `.auth_flow()`. You should override this method
  48. when the authentication scheme does I/O and/or uses concurrency primitives.
  49. """
  50. if self.requires_request_body:
  51. request.read()
  52. flow = self.auth_flow(request)
  53. request = next(flow)
  54. while True:
  55. response = yield request
  56. if self.requires_response_body:
  57. response.read()
  58. try:
  59. request = flow.send(response)
  60. except StopIteration:
  61. break
  62. async def async_auth_flow(
  63. self, request: Request
  64. ) -> typing.AsyncGenerator[Request, Response]:
  65. """
  66. Execute the authentication flow asynchronously.
  67. By default, this defers to `.auth_flow()`. You should override this method
  68. when the authentication scheme does I/O and/or uses concurrency primitives.
  69. """
  70. if self.requires_request_body:
  71. await request.aread()
  72. flow = self.auth_flow(request)
  73. request = next(flow)
  74. while True:
  75. response = yield request
  76. if self.requires_response_body:
  77. await response.aread()
  78. try:
  79. request = flow.send(response)
  80. except StopIteration:
  81. break
  82. class FunctionAuth(Auth):
  83. """
  84. Allows the 'auth' argument to be passed as a simple callable function,
  85. that takes the request, and returns a new, modified request.
  86. """
  87. def __init__(self, func: typing.Callable[[Request], Request]) -> None:
  88. self._func = func
  89. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  90. yield self._func(request)
  91. class BasicAuth(Auth):
  92. """
  93. Allows the 'auth' argument to be passed as a (username, password) pair,
  94. and uses HTTP Basic authentication.
  95. """
  96. def __init__(
  97. self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
  98. ):
  99. self._auth_header = self._build_auth_header(username, password)
  100. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  101. request.headers["Authorization"] = self._auth_header
  102. yield request
  103. def _build_auth_header(
  104. self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
  105. ) -> str:
  106. userpass = b":".join((to_bytes(username), to_bytes(password)))
  107. token = b64encode(userpass).decode()
  108. return f"Basic {token}"
  109. class NetRCAuth(Auth):
  110. """
  111. Use a 'netrc' file to lookup basic auth credentials based on the url host.
  112. """
  113. def __init__(self, file: typing.Optional[str] = None):
  114. # Lazily import 'netrc'.
  115. # There's no need for us to load this module unless 'NetRCAuth' is being used.
  116. import netrc
  117. self._netrc_info = netrc.netrc(file)
  118. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  119. auth_info = self._netrc_info.authenticators(request.url.host)
  120. if auth_info is None or not auth_info[2]:
  121. # The netrc file did not have authentication credentials for this host.
  122. yield request
  123. else:
  124. # Build a basic auth header with credentials from the netrc file.
  125. request.headers["Authorization"] = self._build_auth_header(
  126. username=auth_info[0], password=auth_info[2]
  127. )
  128. yield request
  129. def _build_auth_header(
  130. self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
  131. ) -> str:
  132. userpass = b":".join((to_bytes(username), to_bytes(password)))
  133. token = b64encode(userpass).decode()
  134. return f"Basic {token}"
  135. class DigestAuth(Auth):
  136. _ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = {
  137. "MD5": hashlib.md5,
  138. "MD5-SESS": hashlib.md5,
  139. "SHA": hashlib.sha1,
  140. "SHA-SESS": hashlib.sha1,
  141. "SHA-256": hashlib.sha256,
  142. "SHA-256-SESS": hashlib.sha256,
  143. "SHA-512": hashlib.sha512,
  144. "SHA-512-SESS": hashlib.sha512,
  145. }
  146. def __init__(
  147. self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
  148. ) -> None:
  149. self._username = to_bytes(username)
  150. self._password = to_bytes(password)
  151. self._last_challenge: typing.Optional[_DigestAuthChallenge] = None
  152. self._nonce_count = 1
  153. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  154. if self._last_challenge:
  155. request.headers["Authorization"] = self._build_auth_header(
  156. request, self._last_challenge
  157. )
  158. response = yield request
  159. if response.status_code != 401 or "www-authenticate" not in response.headers:
  160. # If the response is not a 401 then we don't
  161. # need to build an authenticated request.
  162. return
  163. for auth_header in response.headers.get_list("www-authenticate"):
  164. if auth_header.lower().startswith("digest "):
  165. break
  166. else:
  167. # If the response does not include a 'WWW-Authenticate: Digest ...'
  168. # header, then we don't need to build an authenticated request.
  169. return
  170. self._last_challenge = self._parse_challenge(request, response, auth_header)
  171. self._nonce_count = 1
  172. request.headers["Authorization"] = self._build_auth_header(
  173. request, self._last_challenge
  174. )
  175. if response.cookies:
  176. Cookies(response.cookies).set_cookie_header(request=request)
  177. yield request
  178. def _parse_challenge(
  179. self, request: Request, response: Response, auth_header: str
  180. ) -> "_DigestAuthChallenge":
  181. """
  182. Returns a challenge from a Digest WWW-Authenticate header.
  183. These take the form of:
  184. `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
  185. """
  186. scheme, _, fields = auth_header.partition(" ")
  187. # This method should only ever have been called with a Digest auth header.
  188. assert scheme.lower() == "digest"
  189. header_dict: typing.Dict[str, str] = {}
  190. for field in parse_http_list(fields):
  191. key, value = field.strip().split("=", 1)
  192. header_dict[key] = unquote(value)
  193. try:
  194. realm = header_dict["realm"].encode()
  195. nonce = header_dict["nonce"].encode()
  196. algorithm = header_dict.get("algorithm", "MD5")
  197. opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
  198. qop = header_dict["qop"].encode() if "qop" in header_dict else None
  199. return _DigestAuthChallenge(
  200. realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop
  201. )
  202. except KeyError as exc:
  203. message = "Malformed Digest WWW-Authenticate header"
  204. raise ProtocolError(message, request=request) from exc
  205. def _build_auth_header(
  206. self, request: Request, challenge: "_DigestAuthChallenge"
  207. ) -> str:
  208. hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()]
  209. def digest(data: bytes) -> bytes:
  210. return hash_func(data).hexdigest().encode()
  211. A1 = b":".join((self._username, challenge.realm, self._password))
  212. path = request.url.raw_path
  213. A2 = b":".join((request.method.encode(), path))
  214. # TODO: implement auth-int
  215. HA2 = digest(A2)
  216. nc_value = b"%08x" % self._nonce_count
  217. cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce)
  218. self._nonce_count += 1
  219. HA1 = digest(A1)
  220. if challenge.algorithm.lower().endswith("-sess"):
  221. HA1 = digest(b":".join((HA1, challenge.nonce, cnonce)))
  222. qop = self._resolve_qop(challenge.qop, request=request)
  223. if qop is None:
  224. digest_data = [HA1, challenge.nonce, HA2]
  225. else:
  226. digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
  227. key_digest = b":".join(digest_data)
  228. format_args = {
  229. "username": self._username,
  230. "realm": challenge.realm,
  231. "nonce": challenge.nonce,
  232. "uri": path,
  233. "response": digest(b":".join((HA1, key_digest))),
  234. "algorithm": challenge.algorithm.encode(),
  235. }
  236. if challenge.opaque:
  237. format_args["opaque"] = challenge.opaque
  238. if qop:
  239. format_args["qop"] = b"auth"
  240. format_args["nc"] = nc_value
  241. format_args["cnonce"] = cnonce
  242. return "Digest " + self._get_header_value(format_args)
  243. def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes:
  244. s = str(nonce_count).encode()
  245. s += nonce
  246. s += time.ctime().encode()
  247. s += os.urandom(8)
  248. return hashlib.sha1(s).hexdigest()[:16].encode()
  249. def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
  250. NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
  251. QUOTED_TEMPLATE = '{}="{}"'
  252. NON_QUOTED_TEMPLATE = "{}={}"
  253. header_value = ""
  254. for i, (field, value) in enumerate(header_fields.items()):
  255. if i > 0:
  256. header_value += ", "
  257. template = (
  258. QUOTED_TEMPLATE
  259. if field not in NON_QUOTED_FIELDS
  260. else NON_QUOTED_TEMPLATE
  261. )
  262. header_value += template.format(field, to_str(value))
  263. return header_value
  264. def _resolve_qop(
  265. self, qop: typing.Optional[bytes], request: Request
  266. ) -> typing.Optional[bytes]:
  267. if qop is None:
  268. return None
  269. qops = re.split(b", ?", qop)
  270. if b"auth" in qops:
  271. return b"auth"
  272. if qops == [b"auth-int"]:
  273. raise NotImplementedError("Digest auth-int support is not yet implemented")
  274. message = f'Unexpected qop value "{qop!r}" in digest auth'
  275. raise ProtocolError(message, request=request)
  276. class _DigestAuthChallenge(typing.NamedTuple):
  277. realm: bytes
  278. nonce: bytes
  279. algorithm: str
  280. opaque: typing.Optional[bytes]
  281. qop: typing.Optional[bytes]