oauth2_session.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. import logging
  2. from oauthlib.common import generate_token, urldecode
  3. from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
  4. from oauthlib.oauth2 import LegacyApplicationClient
  5. from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
  6. import requests
  7. log = logging.getLogger(__name__)
  8. class TokenUpdated(Warning):
  9. def __init__(self, token):
  10. super(TokenUpdated, self).__init__()
  11. self.token = token
  12. class OAuth2Session(requests.Session):
  13. """Versatile OAuth 2 extension to :class:`requests.Session`.
  14. Supports any grant type adhering to :class:`oauthlib.oauth2.Client` spec
  15. including the four core OAuth 2 grants.
  16. Can be used to create authorization urls, fetch tokens and access protected
  17. resources using the :class:`requests.Session` interface you are used to.
  18. - :class:`oauthlib.oauth2.WebApplicationClient` (default): Authorization Code Grant
  19. - :class:`oauthlib.oauth2.MobileApplicationClient`: Implicit Grant
  20. - :class:`oauthlib.oauth2.LegacyApplicationClient`: Password Credentials Grant
  21. - :class:`oauthlib.oauth2.BackendApplicationClient`: Client Credentials Grant
  22. Note that the only time you will be using Implicit Grant from python is if
  23. you are driving a user agent able to obtain URL fragments.
  24. """
  25. def __init__(
  26. self,
  27. client_id=None,
  28. client=None,
  29. auto_refresh_url=None,
  30. auto_refresh_kwargs=None,
  31. scope=None,
  32. redirect_uri=None,
  33. token=None,
  34. state=None,
  35. token_updater=None,
  36. pkce=None,
  37. **kwargs
  38. ):
  39. """Construct a new OAuth 2 client session.
  40. :param client_id: Client id obtained during registration
  41. :param client: :class:`oauthlib.oauth2.Client` to be used. Default is
  42. WebApplicationClient which is useful for any
  43. hosted application but not mobile or desktop.
  44. :param scope: List of scopes you wish to request access to
  45. :param redirect_uri: Redirect URI you registered as callback
  46. :param token: Token dictionary, must include access_token
  47. and token_type.
  48. :param state: State string used to prevent CSRF. This will be given
  49. when creating the authorization url and must be supplied
  50. when parsing the authorization response.
  51. Can be either a string or a no argument callable.
  52. :auto_refresh_url: Refresh token endpoint URL, must be HTTPS. Supply
  53. this if you wish the client to automatically refresh
  54. your access tokens.
  55. :auto_refresh_kwargs: Extra arguments to pass to the refresh token
  56. endpoint.
  57. :token_updater: Method with one argument, token, to be used to update
  58. your token database on automatic token refresh. If not
  59. set a TokenUpdated warning will be raised when a token
  60. has been refreshed. This warning will carry the token
  61. in its token argument.
  62. :param pkce: Set "S256" or "plain" to enable PKCE. Default is disabled.
  63. :param kwargs: Arguments to pass to the Session constructor.
  64. """
  65. super(OAuth2Session, self).__init__(**kwargs)
  66. self._client = client or WebApplicationClient(client_id, token=token)
  67. self.token = token or {}
  68. self._scope = scope
  69. self.redirect_uri = redirect_uri
  70. self.state = state or generate_token
  71. self._state = state
  72. self.auto_refresh_url = auto_refresh_url
  73. self.auto_refresh_kwargs = auto_refresh_kwargs or {}
  74. self.token_updater = token_updater
  75. self._pkce = pkce
  76. if self._pkce not in ["S256", "plain", None]:
  77. raise AttributeError("Wrong value for {}(.., pkce={})".format(self.__class__, self._pkce))
  78. # Ensure that requests doesn't do any automatic auth. See #278.
  79. # The default behavior can be re-enabled by setting auth to None.
  80. self.auth = lambda r: r
  81. # Allow customizations for non compliant providers through various
  82. # hooks to adjust requests and responses.
  83. self.compliance_hook = {
  84. "access_token_response": set(),
  85. "refresh_token_response": set(),
  86. "protected_request": set(),
  87. "refresh_token_request": set(),
  88. "access_token_request": set(),
  89. }
  90. @property
  91. def scope(self):
  92. """By default the scope from the client is used, except if overridden"""
  93. if self._scope is not None:
  94. return self._scope
  95. elif self._client is not None:
  96. return self._client.scope
  97. else:
  98. return None
  99. @scope.setter
  100. def scope(self, scope):
  101. self._scope = scope
  102. def new_state(self):
  103. """Generates a state string to be used in authorizations."""
  104. try:
  105. self._state = self.state()
  106. log.debug("Generated new state %s.", self._state)
  107. except TypeError:
  108. self._state = self.state
  109. log.debug("Re-using previously supplied state %s.", self._state)
  110. return self._state
  111. @property
  112. def client_id(self):
  113. return getattr(self._client, "client_id", None)
  114. @client_id.setter
  115. def client_id(self, value):
  116. self._client.client_id = value
  117. @client_id.deleter
  118. def client_id(self):
  119. del self._client.client_id
  120. @property
  121. def token(self):
  122. return getattr(self._client, "token", None)
  123. @token.setter
  124. def token(self, value):
  125. self._client.token = value
  126. self._client.populate_token_attributes(value)
  127. @property
  128. def access_token(self):
  129. return getattr(self._client, "access_token", None)
  130. @access_token.setter
  131. def access_token(self, value):
  132. self._client.access_token = value
  133. @access_token.deleter
  134. def access_token(self):
  135. del self._client.access_token
  136. @property
  137. def authorized(self):
  138. """Boolean that indicates whether this session has an OAuth token
  139. or not. If `self.authorized` is True, you can reasonably expect
  140. OAuth-protected requests to the resource to succeed. If
  141. `self.authorized` is False, you need the user to go through the OAuth
  142. authentication dance before OAuth-protected requests to the resource
  143. will succeed.
  144. """
  145. return bool(self.access_token)
  146. def authorization_url(self, url, state=None, **kwargs):
  147. """Form an authorization URL.
  148. :param url: Authorization endpoint url, must be HTTPS.
  149. :param state: An optional state string for CSRF protection. If not
  150. given it will be generated for you.
  151. :param kwargs: Extra parameters to include.
  152. :return: authorization_url, state
  153. """
  154. state = state or self.new_state()
  155. if self._pkce:
  156. self._code_verifier = self._client.create_code_verifier(43)
  157. kwargs["code_challenge_method"] = self._pkce
  158. kwargs["code_challenge"] = self._client.create_code_challenge(
  159. code_verifier=self._code_verifier,
  160. code_challenge_method=self._pkce
  161. )
  162. return (
  163. self._client.prepare_request_uri(
  164. url,
  165. redirect_uri=self.redirect_uri,
  166. scope=self.scope,
  167. state=state,
  168. **kwargs
  169. ),
  170. state,
  171. )
  172. def fetch_token(
  173. self,
  174. token_url,
  175. code=None,
  176. authorization_response=None,
  177. body="",
  178. auth=None,
  179. username=None,
  180. password=None,
  181. method="POST",
  182. force_querystring=False,
  183. timeout=None,
  184. headers=None,
  185. verify=None,
  186. proxies=None,
  187. include_client_id=None,
  188. client_secret=None,
  189. cert=None,
  190. **kwargs
  191. ):
  192. """Generic method for fetching an access token from the token endpoint.
  193. If you are using the MobileApplicationClient you will want to use
  194. `token_from_fragment` instead of `fetch_token`.
  195. The current implementation enforces the RFC guidelines.
  196. :param token_url: Token endpoint URL, must use HTTPS.
  197. :param code: Authorization code (used by WebApplicationClients).
  198. :param authorization_response: Authorization response URL, the callback
  199. URL of the request back to you. Used by
  200. WebApplicationClients instead of code.
  201. :param body: Optional application/x-www-form-urlencoded body to add the
  202. include in the token request. Prefer kwargs over body.
  203. :param auth: An auth tuple or method as accepted by `requests`.
  204. :param username: Username required by LegacyApplicationClients to appear
  205. in the request body.
  206. :param password: Password required by LegacyApplicationClients to appear
  207. in the request body.
  208. :param method: The HTTP method used to make the request. Defaults
  209. to POST, but may also be GET. Other methods should
  210. be added as needed.
  211. :param force_querystring: If True, force the request body to be sent
  212. in the querystring instead.
  213. :param timeout: Timeout of the request in seconds.
  214. :param headers: Dict to default request headers with.
  215. :param verify: Verify SSL certificate.
  216. :param proxies: The `proxies` argument is passed onto `requests`.
  217. :param include_client_id: Should the request body include the
  218. `client_id` parameter. Default is `None`,
  219. which will attempt to autodetect. This can be
  220. forced to always include (True) or never
  221. include (False).
  222. :param client_secret: The `client_secret` paired to the `client_id`.
  223. This is generally required unless provided in the
  224. `auth` tuple. If the value is `None`, it will be
  225. omitted from the request, however if the value is
  226. an empty string, an empty string will be sent.
  227. :param cert: Client certificate to send for OAuth 2.0 Mutual-TLS Client
  228. Authentication (draft-ietf-oauth-mtls). Can either be the
  229. path of a file containing the private key and certificate or
  230. a tuple of two filenames for certificate and key.
  231. :param kwargs: Extra parameters to include in the token request.
  232. :return: A token dict
  233. """
  234. if not is_secure_transport(token_url):
  235. raise InsecureTransportError()
  236. if not code and authorization_response:
  237. self._client.parse_request_uri_response(
  238. authorization_response, state=self._state
  239. )
  240. code = self._client.code
  241. elif not code and isinstance(self._client, WebApplicationClient):
  242. code = self._client.code
  243. if not code:
  244. raise ValueError(
  245. "Please supply either code or " "authorization_response parameters."
  246. )
  247. if self._pkce:
  248. if self._code_verifier is None:
  249. raise ValueError(
  250. "Code verifier is not found, authorization URL must be generated before"
  251. )
  252. kwargs["code_verifier"] = self._code_verifier
  253. # Earlier versions of this library build an HTTPBasicAuth header out of
  254. # `username` and `password`. The RFC states, however these attributes
  255. # must be in the request body and not the header.
  256. # If an upstream server is not spec compliant and requires them to
  257. # appear as an Authorization header, supply an explicit `auth` header
  258. # to this function.
  259. # This check will allow for empty strings, but not `None`.
  260. #
  261. # References
  262. # 4.3.2 - Resource Owner Password Credentials Grant
  263. # https://tools.ietf.org/html/rfc6749#section-4.3.2
  264. if isinstance(self._client, LegacyApplicationClient):
  265. if username is None:
  266. raise ValueError(
  267. "`LegacyApplicationClient` requires both the "
  268. "`username` and `password` parameters."
  269. )
  270. if password is None:
  271. raise ValueError(
  272. "The required parameter `username` was supplied, "
  273. "but `password` was not."
  274. )
  275. # merge username and password into kwargs for `prepare_request_body`
  276. if username is not None:
  277. kwargs["username"] = username
  278. if password is not None:
  279. kwargs["password"] = password
  280. # is an auth explicitly supplied?
  281. if auth is not None:
  282. # if we're dealing with the default of `include_client_id` (None):
  283. # we will assume the `auth` argument is for an RFC compliant server
  284. # and we should not send the `client_id` in the body.
  285. # This approach allows us to still force the client_id by submitting
  286. # `include_client_id=True` along with an `auth` object.
  287. if include_client_id is None:
  288. include_client_id = False
  289. # otherwise we may need to create an auth header
  290. else:
  291. # since we don't have an auth header, we MAY need to create one
  292. # it is possible that we want to send the `client_id` in the body
  293. # if so, `include_client_id` should be set to True
  294. # otherwise, we will generate an auth header
  295. if include_client_id is not True:
  296. client_id = self.client_id
  297. if client_id:
  298. log.debug(
  299. 'Encoding `client_id` "%s" with `client_secret` '
  300. "as Basic auth credentials.",
  301. client_id,
  302. )
  303. client_secret = client_secret if client_secret is not None else ""
  304. auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
  305. if include_client_id:
  306. # this was pulled out of the params
  307. # it needs to be passed into prepare_request_body
  308. if client_secret is not None:
  309. kwargs["client_secret"] = client_secret
  310. body = self._client.prepare_request_body(
  311. code=code,
  312. body=body,
  313. redirect_uri=self.redirect_uri,
  314. include_client_id=include_client_id,
  315. **kwargs
  316. )
  317. headers = headers or {
  318. "Accept": "application/json",
  319. "Content-Type": "application/x-www-form-urlencoded",
  320. }
  321. self.token = {}
  322. request_kwargs = {}
  323. if method.upper() == "POST":
  324. request_kwargs["params" if force_querystring else "data"] = dict(
  325. urldecode(body)
  326. )
  327. elif method.upper() == "GET":
  328. request_kwargs["params"] = dict(urldecode(body))
  329. else:
  330. raise ValueError("The method kwarg must be POST or GET.")
  331. for hook in self.compliance_hook["access_token_request"]:
  332. log.debug("Invoking access_token_request hook %s.", hook)
  333. token_url, headers, request_kwargs = hook(
  334. token_url, headers, request_kwargs
  335. )
  336. r = self.request(
  337. method=method,
  338. url=token_url,
  339. timeout=timeout,
  340. headers=headers,
  341. auth=auth,
  342. verify=verify,
  343. proxies=proxies,
  344. cert=cert,
  345. **request_kwargs
  346. )
  347. log.debug("Request to fetch token completed with status %s.", r.status_code)
  348. log.debug("Request url was %s", r.request.url)
  349. log.debug("Request headers were %s", r.request.headers)
  350. log.debug("Request body was %s", r.request.body)
  351. log.debug("Response headers were %s and content %s.", r.headers, r.text)
  352. log.debug(
  353. "Invoking %d token response hooks.",
  354. len(self.compliance_hook["access_token_response"]),
  355. )
  356. for hook in self.compliance_hook["access_token_response"]:
  357. log.debug("Invoking hook %s.", hook)
  358. r = hook(r)
  359. self._client.parse_request_body_response(r.text, scope=self.scope)
  360. self.token = self._client.token
  361. log.debug("Obtained token %s.", self.token)
  362. return self.token
  363. def token_from_fragment(self, authorization_response):
  364. """Parse token from the URI fragment, used by MobileApplicationClients.
  365. :param authorization_response: The full URL of the redirect back to you
  366. :return: A token dict
  367. """
  368. self._client.parse_request_uri_response(
  369. authorization_response, state=self._state
  370. )
  371. self.token = self._client.token
  372. return self.token
  373. def refresh_token(
  374. self,
  375. token_url,
  376. refresh_token=None,
  377. body="",
  378. auth=None,
  379. timeout=None,
  380. headers=None,
  381. verify=None,
  382. proxies=None,
  383. **kwargs
  384. ):
  385. """Fetch a new access token using a refresh token.
  386. :param token_url: The token endpoint, must be HTTPS.
  387. :param refresh_token: The refresh_token to use.
  388. :param body: Optional application/x-www-form-urlencoded body to add the
  389. include in the token request. Prefer kwargs over body.
  390. :param auth: An auth tuple or method as accepted by `requests`.
  391. :param timeout: Timeout of the request in seconds.
  392. :param headers: A dict of headers to be used by `requests`.
  393. :param verify: Verify SSL certificate.
  394. :param proxies: The `proxies` argument will be passed to `requests`.
  395. :param kwargs: Extra parameters to include in the token request.
  396. :return: A token dict
  397. """
  398. if not token_url:
  399. raise ValueError("No token endpoint set for auto_refresh.")
  400. if not is_secure_transport(token_url):
  401. raise InsecureTransportError()
  402. refresh_token = refresh_token or self.token.get("refresh_token")
  403. log.debug(
  404. "Adding auto refresh key word arguments %s.", self.auto_refresh_kwargs
  405. )
  406. kwargs.update(self.auto_refresh_kwargs)
  407. body = self._client.prepare_refresh_body(
  408. body=body, refresh_token=refresh_token, scope=self.scope, **kwargs
  409. )
  410. log.debug("Prepared refresh token request body %s", body)
  411. if headers is None:
  412. headers = {
  413. "Accept": "application/json",
  414. "Content-Type": ("application/x-www-form-urlencoded"),
  415. }
  416. for hook in self.compliance_hook["refresh_token_request"]:
  417. log.debug("Invoking refresh_token_request hook %s.", hook)
  418. token_url, headers, body = hook(token_url, headers, body)
  419. r = self.post(
  420. token_url,
  421. data=dict(urldecode(body)),
  422. auth=auth,
  423. timeout=timeout,
  424. headers=headers,
  425. verify=verify,
  426. withhold_token=True,
  427. proxies=proxies,
  428. )
  429. log.debug("Request to refresh token completed with status %s.", r.status_code)
  430. log.debug("Response headers were %s and content %s.", r.headers, r.text)
  431. log.debug(
  432. "Invoking %d token response hooks.",
  433. len(self.compliance_hook["refresh_token_response"]),
  434. )
  435. for hook in self.compliance_hook["refresh_token_response"]:
  436. log.debug("Invoking hook %s.", hook)
  437. r = hook(r)
  438. self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
  439. if "refresh_token" not in self.token:
  440. log.debug("No new refresh token given. Re-using old.")
  441. self.token["refresh_token"] = refresh_token
  442. return self.token
  443. def request(
  444. self,
  445. method,
  446. url,
  447. data=None,
  448. headers=None,
  449. withhold_token=False,
  450. client_id=None,
  451. client_secret=None,
  452. files=None,
  453. **kwargs
  454. ):
  455. """Intercept all requests and add the OAuth 2 token if present."""
  456. if not is_secure_transport(url):
  457. raise InsecureTransportError()
  458. if self.token and not withhold_token:
  459. log.debug(
  460. "Invoking %d protected resource request hooks.",
  461. len(self.compliance_hook["protected_request"]),
  462. )
  463. for hook in self.compliance_hook["protected_request"]:
  464. log.debug("Invoking hook %s.", hook)
  465. url, headers, data = hook(url, headers, data)
  466. log.debug("Adding token %s to request.", self.token)
  467. try:
  468. url, headers, data = self._client.add_token(
  469. url, http_method=method, body=data, headers=headers
  470. )
  471. # Attempt to retrieve and save new access token if expired
  472. except TokenExpiredError:
  473. if self.auto_refresh_url:
  474. log.debug(
  475. "Auto refresh is set, attempting to refresh at %s.",
  476. self.auto_refresh_url,
  477. )
  478. # We mustn't pass auth twice.
  479. auth = kwargs.pop("auth", None)
  480. if client_id and client_secret and (auth is None):
  481. log.debug(
  482. 'Encoding client_id "%s" with client_secret as Basic auth credentials.',
  483. client_id,
  484. )
  485. auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
  486. token = self.refresh_token(
  487. self.auto_refresh_url, auth=auth, **kwargs
  488. )
  489. if self.token_updater:
  490. log.debug(
  491. "Updating token to %s using %s.", token, self.token_updater
  492. )
  493. self.token_updater(token)
  494. url, headers, data = self._client.add_token(
  495. url, http_method=method, body=data, headers=headers
  496. )
  497. else:
  498. raise TokenUpdated(token)
  499. else:
  500. raise
  501. log.debug("Requesting url %s using method %s.", url, method)
  502. log.debug("Supplying headers %s and data %s", headers, data)
  503. log.debug("Passing through key word arguments %s.", kwargs)
  504. return super(OAuth2Session, self).request(
  505. method, url, headers=headers, data=data, files=files, **kwargs
  506. )
  507. def register_compliance_hook(self, hook_type, hook):
  508. """Register a hook for request/response tweaking.
  509. Available hooks are:
  510. access_token_response invoked before token parsing.
  511. refresh_token_response invoked before refresh token parsing.
  512. protected_request invoked before making a request.
  513. access_token_request invoked before making a token fetch request.
  514. refresh_token_request invoked before making a refresh request.
  515. If you find a new hook is needed please send a GitHub PR request
  516. or open an issue.
  517. """
  518. if hook_type not in self.compliance_hook:
  519. raise ValueError(
  520. "Hook type %s is not in %s.", hook_type, self.compliance_hook
  521. )
  522. self.compliance_hook[hook_type].add(hook)