test_oauth2_session.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. import json
  2. import time
  3. import tempfile
  4. import shutil
  5. import os
  6. from base64 import b64encode
  7. from copy import deepcopy
  8. from unittest import TestCase
  9. from unittest import mock
  10. from oauthlib.common import urlencode
  11. from oauthlib.oauth2 import TokenExpiredError, OAuth2Error
  12. from oauthlib.oauth2 import MismatchingStateError
  13. from oauthlib.oauth2 import WebApplicationClient, MobileApplicationClient
  14. from oauthlib.oauth2 import LegacyApplicationClient, BackendApplicationClient
  15. from requests_oauthlib import OAuth2Session, TokenUpdated
  16. import requests
  17. from requests.auth import _basic_auth_str
  18. fake_time = time.time()
  19. CODE = "asdf345xdf"
  20. def fake_token(token):
  21. def fake_send(r, **kwargs):
  22. resp = mock.MagicMock()
  23. resp.text = json.dumps(token)
  24. return resp
  25. return fake_send
  26. class OAuth2SessionTest(TestCase):
  27. def setUp(self):
  28. self.token = {
  29. "token_type": "Bearer",
  30. "access_token": "asdfoiw37850234lkjsdfsdf",
  31. "refresh_token": "sldvafkjw34509s8dfsdf",
  32. "expires_in": 3600,
  33. "expires_at": fake_time + 3600,
  34. }
  35. # use someclientid:someclientsecret to easily differentiate between client and user credentials
  36. # these are the values used in oauthlib tests
  37. self.client_id = "someclientid"
  38. self.client_secret = "someclientsecret"
  39. self.user_username = "user_username"
  40. self.user_password = "user_password"
  41. self.client_WebApplication = WebApplicationClient(self.client_id, code=CODE)
  42. self.client_LegacyApplication = LegacyApplicationClient(self.client_id)
  43. self.client_BackendApplication = BackendApplicationClient(self.client_id)
  44. self.client_MobileApplication = MobileApplicationClient(self.client_id)
  45. self.clients = [
  46. self.client_WebApplication,
  47. self.client_LegacyApplication,
  48. self.client_BackendApplication,
  49. ]
  50. self.all_clients = self.clients + [self.client_MobileApplication]
  51. def test_add_token(self):
  52. token = "Bearer " + self.token["access_token"]
  53. def verifier(r, **kwargs):
  54. auth_header = r.headers.get(str("Authorization"), None)
  55. self.assertEqual(auth_header, token)
  56. resp = mock.MagicMock()
  57. resp.cookes = []
  58. return resp
  59. for client in self.all_clients:
  60. sess = OAuth2Session(client=client, token=self.token)
  61. sess.send = verifier
  62. sess.get("https://i.b")
  63. def test_mtls(self):
  64. cert = (
  65. "testsomething.example-client.pem",
  66. "testsomething.example-client-key.pem",
  67. )
  68. def verifier(r, **kwargs):
  69. self.assertIn("cert", kwargs)
  70. self.assertEqual(cert, kwargs["cert"])
  71. self.assertIn("client_id=" + self.client_id, r.body)
  72. resp = mock.MagicMock()
  73. resp.text = json.dumps(self.token)
  74. return resp
  75. for client in self.clients:
  76. sess = OAuth2Session(client=client)
  77. sess.send = verifier
  78. if isinstance(client, LegacyApplicationClient):
  79. sess.fetch_token(
  80. "https://i.b",
  81. include_client_id=True,
  82. cert=cert,
  83. username="username1",
  84. password="password1",
  85. )
  86. else:
  87. sess.fetch_token("https://i.b", include_client_id=True, cert=cert)
  88. def test_authorization_url(self):
  89. url = "https://example.com/authorize?foo=bar"
  90. web = WebApplicationClient(self.client_id)
  91. s = OAuth2Session(client=web)
  92. auth_url, state = s.authorization_url(url)
  93. self.assertIn(state, auth_url)
  94. self.assertIn(self.client_id, auth_url)
  95. self.assertIn("response_type=code", auth_url)
  96. mobile = MobileApplicationClient(self.client_id)
  97. s = OAuth2Session(client=mobile)
  98. auth_url, state = s.authorization_url(url)
  99. self.assertIn(state, auth_url)
  100. self.assertIn(self.client_id, auth_url)
  101. self.assertIn("response_type=token", auth_url)
  102. def test_pkce_authorization_url(self):
  103. url = "https://example.com/authorize?foo=bar"
  104. web = WebApplicationClient(self.client_id)
  105. s = OAuth2Session(client=web, pkce="S256")
  106. auth_url, state = s.authorization_url(url)
  107. self.assertIn(state, auth_url)
  108. self.assertIn(self.client_id, auth_url)
  109. self.assertIn("response_type=code", auth_url)
  110. self.assertIn("code_challenge=", auth_url)
  111. self.assertIn("code_challenge_method=S256", auth_url)
  112. mobile = MobileApplicationClient(self.client_id)
  113. s = OAuth2Session(client=mobile, pkce="S256")
  114. auth_url, state = s.authorization_url(url)
  115. self.assertIn(state, auth_url)
  116. self.assertIn(self.client_id, auth_url)
  117. self.assertIn("response_type=token", auth_url)
  118. self.assertIn("code_challenge=", auth_url)
  119. self.assertIn("code_challenge_method=S256", auth_url)
  120. @mock.patch("time.time", new=lambda: fake_time)
  121. def test_refresh_token_request(self):
  122. self.expired_token = dict(self.token)
  123. self.expired_token["expires_in"] = "-1"
  124. del self.expired_token["expires_at"]
  125. def fake_refresh(r, **kwargs):
  126. if "/refresh" in r.url:
  127. self.assertNotIn("Authorization", r.headers)
  128. resp = mock.MagicMock()
  129. resp.text = json.dumps(self.token)
  130. return resp
  131. # No auto refresh setup
  132. for client in self.clients:
  133. sess = OAuth2Session(client=client, token=self.expired_token)
  134. self.assertRaises(TokenExpiredError, sess.get, "https://i.b")
  135. # Auto refresh but no auto update
  136. for client in self.clients:
  137. sess = OAuth2Session(
  138. client=client,
  139. token=self.expired_token,
  140. auto_refresh_url="https://i.b/refresh",
  141. )
  142. sess.send = fake_refresh
  143. self.assertRaises(TokenUpdated, sess.get, "https://i.b")
  144. # Auto refresh and auto update
  145. def token_updater(token):
  146. self.assertEqual(token, self.token)
  147. for client in self.clients:
  148. sess = OAuth2Session(
  149. client=client,
  150. token=self.expired_token,
  151. auto_refresh_url="https://i.b/refresh",
  152. token_updater=token_updater,
  153. )
  154. sess.send = fake_refresh
  155. sess.get("https://i.b")
  156. def fake_refresh_with_auth(r, **kwargs):
  157. if "/refresh" in r.url:
  158. self.assertIn("Authorization", r.headers)
  159. encoded = b64encode(
  160. "{client_id}:{client_secret}".format(
  161. client_id=self.client_id, client_secret=self.client_secret
  162. ).encode("latin1")
  163. )
  164. content = "Basic {encoded}".format(encoded=encoded.decode("latin1"))
  165. self.assertEqual(r.headers["Authorization"], content)
  166. resp = mock.MagicMock()
  167. resp.text = json.dumps(self.token)
  168. return resp
  169. for client in self.clients:
  170. sess = OAuth2Session(
  171. client=client,
  172. token=self.expired_token,
  173. auto_refresh_url="https://i.b/refresh",
  174. token_updater=token_updater,
  175. )
  176. sess.send = fake_refresh_with_auth
  177. sess.get(
  178. "https://i.b",
  179. client_id=self.client_id,
  180. client_secret=self.client_secret,
  181. )
  182. @mock.patch("time.time", new=lambda: fake_time)
  183. def test_token_from_fragment(self):
  184. mobile = MobileApplicationClient(self.client_id)
  185. response_url = "https://i.b/callback#" + urlencode(self.token.items())
  186. sess = OAuth2Session(client=mobile)
  187. self.assertEqual(sess.token_from_fragment(response_url), self.token)
  188. @mock.patch("time.time", new=lambda: fake_time)
  189. def test_fetch_token(self):
  190. url = "https://example.com/token"
  191. for client in self.clients:
  192. sess = OAuth2Session(client=client, token=self.token)
  193. sess.send = fake_token(self.token)
  194. if isinstance(client, LegacyApplicationClient):
  195. # this client requires a username+password
  196. # if unset, an error will be raised
  197. self.assertRaises(ValueError, sess.fetch_token, url)
  198. self.assertRaises(
  199. ValueError, sess.fetch_token, url, username="username1"
  200. )
  201. self.assertRaises(
  202. ValueError, sess.fetch_token, url, password="password1"
  203. )
  204. # otherwise it will pass
  205. self.assertEqual(
  206. sess.fetch_token(url, username="username1", password="password1"),
  207. self.token,
  208. )
  209. else:
  210. self.assertEqual(sess.fetch_token(url), self.token)
  211. error = {"error": "invalid_request"}
  212. for client in self.clients:
  213. sess = OAuth2Session(client=client, token=self.token)
  214. sess.send = fake_token(error)
  215. if isinstance(client, LegacyApplicationClient):
  216. # this client requires a username+password
  217. # if unset, an error will be raised
  218. self.assertRaises(ValueError, sess.fetch_token, url)
  219. self.assertRaises(
  220. ValueError, sess.fetch_token, url, username="username1"
  221. )
  222. self.assertRaises(
  223. ValueError, sess.fetch_token, url, password="password1"
  224. )
  225. # otherwise it will pass
  226. self.assertRaises(
  227. OAuth2Error,
  228. sess.fetch_token,
  229. url,
  230. username="username1",
  231. password="password1",
  232. )
  233. else:
  234. self.assertRaises(OAuth2Error, sess.fetch_token, url)
  235. # there are different scenarios in which the `client_id` can be specified
  236. # reference `oauthlib.tests.oauth2.rfc6749.clients.test_web_application.WebApplicationClientTest.test_prepare_request_body`
  237. # this only needs to test WebApplicationClient
  238. client = self.client_WebApplication
  239. client.tester = True
  240. # this should be a tuple of (r.url, r.body, r.headers.get('Authorization'))
  241. _fetch_history = []
  242. def fake_token_history(token):
  243. def fake_send(r, **kwargs):
  244. resp = mock.MagicMock()
  245. resp.text = json.dumps(token)
  246. _fetch_history.append(
  247. (r.url, r.body, r.headers.get("Authorization", None))
  248. )
  249. return resp
  250. return fake_send
  251. sess = OAuth2Session(client=client, token=self.token)
  252. sess.send = fake_token_history(self.token)
  253. expected_auth_header = _basic_auth_str(self.client_id, self.client_secret)
  254. # scenario 1 - default request
  255. # this should send the `client_id` in the headers, as that is recommended by the RFC
  256. self.assertEqual(
  257. sess.fetch_token(url, client_secret="someclientsecret"), self.token
  258. )
  259. self.assertEqual(len(_fetch_history), 1)
  260. self.assertNotIn(
  261. "client_id", _fetch_history[0][1]
  262. ) # no `client_id` in the body
  263. self.assertNotIn(
  264. "client_secret", _fetch_history[0][1]
  265. ) # no `client_secret` in the body
  266. self.assertEqual(
  267. _fetch_history[0][2], expected_auth_header
  268. ) # ensure a Basic Authorization header
  269. # scenario 2 - force the `client_id` into the body
  270. self.assertEqual(
  271. sess.fetch_token(
  272. url, client_secret="someclientsecret", include_client_id=True
  273. ),
  274. self.token,
  275. )
  276. self.assertEqual(len(_fetch_history), 2)
  277. self.assertIn("client_id=%s" % self.client_id, _fetch_history[1][1])
  278. self.assertIn("client_secret=%s" % self.client_secret, _fetch_history[1][1])
  279. self.assertEqual(
  280. _fetch_history[1][2], None
  281. ) # ensure NO Basic Authorization header
  282. # scenario 3 - send in an auth object
  283. auth = requests.auth.HTTPBasicAuth(self.client_id, self.client_secret)
  284. self.assertEqual(sess.fetch_token(url, auth=auth), self.token)
  285. self.assertEqual(len(_fetch_history), 3)
  286. self.assertNotIn(
  287. "client_id", _fetch_history[2][1]
  288. ) # no `client_id` in the body
  289. self.assertNotIn(
  290. "client_secret", _fetch_history[2][1]
  291. ) # no `client_secret` in the body
  292. self.assertEqual(
  293. _fetch_history[2][2], expected_auth_header
  294. ) # ensure a Basic Authorization header
  295. # scenario 4 - send in a username/password combo
  296. # this should send the `client_id` in the headers, like scenario 1
  297. self.assertEqual(
  298. sess.fetch_token(
  299. url, username=self.user_username, password=self.user_password
  300. ),
  301. self.token,
  302. )
  303. self.assertEqual(len(_fetch_history), 4)
  304. self.assertNotIn(
  305. "client_id", _fetch_history[3][1]
  306. ) # no `client_id` in the body
  307. self.assertNotIn(
  308. "client_secret", _fetch_history[3][1]
  309. ) # no `client_secret` in the body
  310. self.assertEqual(
  311. _fetch_history[0][2], expected_auth_header
  312. ) # ensure a Basic Authorization header
  313. self.assertIn("username=%s" % self.user_username, _fetch_history[3][1])
  314. self.assertIn("password=%s" % self.user_password, _fetch_history[3][1])
  315. # scenario 5 - send data in `params` and not in `data` for providers
  316. # that expect data in URL
  317. self.assertEqual(
  318. sess.fetch_token(url, client_secret="somesecret", force_querystring=True),
  319. self.token,
  320. )
  321. self.assertIn("code=%s" % CODE, _fetch_history[4][0])
  322. # some quick tests for valid ways of supporting `client_secret`
  323. # scenario 2b - force the `client_id` into the body; but the `client_secret` is `None`
  324. self.assertEqual(
  325. sess.fetch_token(url, client_secret=None, include_client_id=True),
  326. self.token,
  327. )
  328. self.assertEqual(len(_fetch_history), 6)
  329. self.assertIn("client_id=%s" % self.client_id, _fetch_history[5][1])
  330. self.assertNotIn(
  331. "client_secret=", _fetch_history[5][1]
  332. ) # no `client_secret` in the body
  333. self.assertEqual(
  334. _fetch_history[5][2], None
  335. ) # ensure NO Basic Authorization header
  336. # scenario 2c - force the `client_id` into the body; but the `client_secret` is an empty string
  337. self.assertEqual(
  338. sess.fetch_token(url, client_secret="", include_client_id=True), self.token
  339. )
  340. self.assertEqual(len(_fetch_history), 7)
  341. self.assertIn("client_id=%s" % self.client_id, _fetch_history[6][1])
  342. self.assertIn("client_secret=", _fetch_history[6][1])
  343. self.assertEqual(
  344. _fetch_history[6][2], None
  345. ) # ensure NO Basic Authorization header
  346. def test_cleans_previous_token_before_fetching_new_one(self):
  347. """Makes sure the previous token is cleaned before fetching a new one.
  348. The reason behind it is that, if the previous token is expired, this
  349. method shouldn't fail with a TokenExpiredError, since it's attempting
  350. to get a new one (which shouldn't be expired).
  351. """
  352. new_token = deepcopy(self.token)
  353. past = time.time() - 7200
  354. now = time.time()
  355. self.token["expires_at"] = past
  356. new_token["expires_at"] = now + 3600
  357. url = "https://example.com/token"
  358. with mock.patch("time.time", lambda: now):
  359. for client in self.clients:
  360. sess = OAuth2Session(client=client, token=self.token)
  361. sess.send = fake_token(new_token)
  362. if isinstance(client, LegacyApplicationClient):
  363. # this client requires a username+password
  364. # if unset, an error will be raised
  365. self.assertRaises(ValueError, sess.fetch_token, url)
  366. self.assertRaises(
  367. ValueError, sess.fetch_token, url, username="username1"
  368. )
  369. self.assertRaises(
  370. ValueError, sess.fetch_token, url, password="password1"
  371. )
  372. # otherwise it will pass
  373. self.assertEqual(
  374. sess.fetch_token(
  375. url, username="username1", password="password1"
  376. ),
  377. new_token,
  378. )
  379. else:
  380. self.assertEqual(sess.fetch_token(url), new_token)
  381. def test_web_app_fetch_token(self):
  382. # Ensure the state parameter is used, see issue #105.
  383. client = OAuth2Session("someclientid", state="somestate")
  384. self.assertRaises(
  385. MismatchingStateError,
  386. client.fetch_token,
  387. "https://i.b/token",
  388. authorization_response="https://i.b/no-state?code=abc",
  389. )
  390. @mock.patch("time.time", new=lambda: fake_time)
  391. def test_pkce_web_app_fetch_token(self):
  392. url = "https://example.com/token"
  393. web = WebApplicationClient(self.client_id, code=CODE)
  394. sess = OAuth2Session(client=web, token=self.token, pkce="S256")
  395. sess.send = fake_token(self.token)
  396. sess._code_verifier = "foobar"
  397. self.assertEqual(sess.fetch_token(url), self.token)
  398. def test_client_id_proxy(self):
  399. sess = OAuth2Session("test-id")
  400. self.assertEqual(sess.client_id, "test-id")
  401. sess.client_id = "different-id"
  402. self.assertEqual(sess.client_id, "different-id")
  403. sess._client.client_id = "something-else"
  404. self.assertEqual(sess.client_id, "something-else")
  405. del sess.client_id
  406. self.assertIsNone(sess.client_id)
  407. def test_access_token_proxy(self):
  408. sess = OAuth2Session("test-id")
  409. self.assertIsNone(sess.access_token)
  410. sess.access_token = "test-token"
  411. self.assertEqual(sess.access_token, "test-token")
  412. sess._client.access_token = "different-token"
  413. self.assertEqual(sess.access_token, "different-token")
  414. del sess.access_token
  415. self.assertIsNone(sess.access_token)
  416. def test_token_proxy(self):
  417. token = {"access_token": "test-access"}
  418. sess = OAuth2Session("test-id", token=token)
  419. self.assertEqual(sess.access_token, "test-access")
  420. self.assertEqual(sess.token, token)
  421. token["access_token"] = "something-else"
  422. sess.token = token
  423. self.assertEqual(sess.access_token, "something-else")
  424. self.assertEqual(sess.token, token)
  425. sess._client.access_token = "different-token"
  426. token["access_token"] = "different-token"
  427. self.assertEqual(sess.access_token, "different-token")
  428. self.assertEqual(sess.token, token)
  429. # can't delete token attribute
  430. with self.assertRaises(AttributeError):
  431. del sess.token
  432. def test_authorized_false(self):
  433. sess = OAuth2Session("someclientid")
  434. self.assertFalse(sess.authorized)
  435. @mock.patch("time.time", new=lambda: fake_time)
  436. def test_authorized_true(self):
  437. def fake_token(token):
  438. def fake_send(r, **kwargs):
  439. resp = mock.MagicMock()
  440. resp.text = json.dumps(token)
  441. return resp
  442. return fake_send
  443. url = "https://example.com/token"
  444. for client in self.clients:
  445. sess = OAuth2Session(client=client)
  446. sess.send = fake_token(self.token)
  447. self.assertFalse(sess.authorized)
  448. if isinstance(client, LegacyApplicationClient):
  449. # this client requires a username+password
  450. # if unset, an error will be raised
  451. self.assertRaises(ValueError, sess.fetch_token, url)
  452. self.assertRaises(
  453. ValueError, sess.fetch_token, url, username="username1"
  454. )
  455. self.assertRaises(
  456. ValueError, sess.fetch_token, url, password="password1"
  457. )
  458. # otherwise it will pass
  459. sess.fetch_token(url, username="username1", password="password1")
  460. else:
  461. sess.fetch_token(url)
  462. self.assertTrue(sess.authorized)
  463. class OAuth2SessionNetrcTest(OAuth2SessionTest):
  464. """Ensure that there is no magic auth handling.
  465. By default, requests sessions have magic handling of netrc files,
  466. which is undesirable for this library because it will take
  467. precedence over manually set authentication headers.
  468. """
  469. def setUp(self):
  470. # Set up a temporary home directory
  471. self.homedir = tempfile.mkdtemp()
  472. self.prehome = os.environ.get("HOME", None)
  473. os.environ["HOME"] = self.homedir
  474. # Write a .netrc file that will cause problems
  475. netrc_loc = os.path.expanduser("~/.netrc")
  476. with open(netrc_loc, "w") as f:
  477. f.write("machine i.b\n" " password abc123\n" " login spam@eggs.co\n")
  478. super(OAuth2SessionNetrcTest, self).setUp()
  479. def tearDown(self):
  480. super(OAuth2SessionNetrcTest, self).tearDown()
  481. if self.prehome is not None:
  482. os.environ["HOME"] = self.prehome
  483. shutil.rmtree(self.homedir)