test_base.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # -*- coding: utf-8 -*-
  2. import datetime
  3. from oauthlib import common
  4. from oauthlib.oauth2 import Client, InsecureTransportError, TokenExpiredError
  5. from oauthlib.oauth2.rfc6749 import utils
  6. from oauthlib.oauth2.rfc6749.clients import AUTH_HEADER, BODY, URI_QUERY
  7. from tests.unittest import TestCase
  8. class ClientTest(TestCase):
  9. client_id = "someclientid"
  10. uri = "https://example.com/path?query=world"
  11. body = "not=empty"
  12. headers = {}
  13. access_token = "token"
  14. mac_key = "secret"
  15. bearer_query = uri + "&access_token=" + access_token
  16. bearer_header = {
  17. "Authorization": "Bearer " + access_token
  18. }
  19. bearer_body = body + "&access_token=" + access_token
  20. mac_00_header = {
  21. "Authorization": 'MAC id="' + access_token + '", nonce="0:abc123",' +
  22. ' bodyhash="Yqyso8r3hR5Nm1ZFv+6AvNHrxjE=",' +
  23. ' mac="0X6aACoBY0G6xgGZVJ1IeE8dF9k="'
  24. }
  25. mac_01_header = {
  26. "Authorization": 'MAC id="' + access_token + '", ts="123456789",' +
  27. ' nonce="abc123", mac="Xuk+9oqaaKyhitkgh1CD0xrI6+s="'
  28. }
  29. def test_add_bearer_token(self):
  30. """Test a number of bearer token placements"""
  31. # Invalid token type
  32. client = Client(self.client_id, token_type="invalid")
  33. self.assertRaises(ValueError, client.add_token, self.uri)
  34. # Case-insensitive token type
  35. client = Client(self.client_id, access_token=self.access_token, token_type="bEAreR")
  36. uri, headers, body = client.add_token(self.uri, body=self.body,
  37. headers=self.headers)
  38. self.assertURLEqual(uri, self.uri)
  39. self.assertFormBodyEqual(body, self.body)
  40. self.assertEqual(headers, self.bearer_header)
  41. # Non-HTTPS
  42. insecure_uri = 'http://example.com/path?query=world'
  43. client = Client(self.client_id, access_token=self.access_token, token_type="Bearer")
  44. self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
  45. body=self.body,
  46. headers=self.headers)
  47. # Missing access token
  48. client = Client(self.client_id)
  49. self.assertRaises(ValueError, client.add_token, self.uri)
  50. # Expired token
  51. expired = 523549800
  52. expired_token = {
  53. 'expires_at': expired,
  54. }
  55. client = Client(self.client_id, token=expired_token, access_token=self.access_token, token_type="Bearer")
  56. self.assertRaises(TokenExpiredError, client.add_token, self.uri,
  57. body=self.body, headers=self.headers)
  58. # The default token placement, bearer in auth header
  59. client = Client(self.client_id, access_token=self.access_token)
  60. uri, headers, body = client.add_token(self.uri, body=self.body,
  61. headers=self.headers)
  62. self.assertURLEqual(uri, self.uri)
  63. self.assertFormBodyEqual(body, self.body)
  64. self.assertEqual(headers, self.bearer_header)
  65. # Setting default placements of tokens
  66. client = Client(self.client_id, access_token=self.access_token,
  67. default_token_placement=AUTH_HEADER)
  68. uri, headers, body = client.add_token(self.uri, body=self.body,
  69. headers=self.headers)
  70. self.assertURLEqual(uri, self.uri)
  71. self.assertFormBodyEqual(body, self.body)
  72. self.assertEqual(headers, self.bearer_header)
  73. client = Client(self.client_id, access_token=self.access_token,
  74. default_token_placement=URI_QUERY)
  75. uri, headers, body = client.add_token(self.uri, body=self.body,
  76. headers=self.headers)
  77. self.assertURLEqual(uri, self.bearer_query)
  78. self.assertFormBodyEqual(body, self.body)
  79. self.assertEqual(headers, self.headers)
  80. client = Client(self.client_id, access_token=self.access_token,
  81. default_token_placement=BODY)
  82. uri, headers, body = client.add_token(self.uri, body=self.body,
  83. headers=self.headers)
  84. self.assertURLEqual(uri, self.uri)
  85. self.assertFormBodyEqual(body, self.bearer_body)
  86. self.assertEqual(headers, self.headers)
  87. # Asking for specific placement in the add_token method
  88. client = Client(self.client_id, access_token=self.access_token)
  89. uri, headers, body = client.add_token(self.uri, body=self.body,
  90. headers=self.headers, token_placement=AUTH_HEADER)
  91. self.assertURLEqual(uri, self.uri)
  92. self.assertFormBodyEqual(body, self.body)
  93. self.assertEqual(headers, self.bearer_header)
  94. client = Client(self.client_id, access_token=self.access_token)
  95. uri, headers, body = client.add_token(self.uri, body=self.body,
  96. headers=self.headers, token_placement=URI_QUERY)
  97. self.assertURLEqual(uri, self.bearer_query)
  98. self.assertFormBodyEqual(body, self.body)
  99. self.assertEqual(headers, self.headers)
  100. client = Client(self.client_id, access_token=self.access_token)
  101. uri, headers, body = client.add_token(self.uri, body=self.body,
  102. headers=self.headers, token_placement=BODY)
  103. self.assertURLEqual(uri, self.uri)
  104. self.assertFormBodyEqual(body, self.bearer_body)
  105. self.assertEqual(headers, self.headers)
  106. # Invalid token placement
  107. client = Client(self.client_id, access_token=self.access_token)
  108. self.assertRaises(ValueError, client.add_token, self.uri, body=self.body,
  109. headers=self.headers, token_placement="invalid")
  110. client = Client(self.client_id, access_token=self.access_token,
  111. default_token_placement="invalid")
  112. self.assertRaises(ValueError, client.add_token, self.uri, body=self.body,
  113. headers=self.headers)
  114. def test_add_mac_token(self):
  115. # Missing access token
  116. client = Client(self.client_id, token_type="MAC")
  117. self.assertRaises(ValueError, client.add_token, self.uri)
  118. # Invalid hash algorithm
  119. client = Client(self.client_id, token_type="MAC",
  120. access_token=self.access_token, mac_key=self.mac_key,
  121. mac_algorithm="hmac-sha-2")
  122. self.assertRaises(ValueError, client.add_token, self.uri)
  123. orig_generate_timestamp = common.generate_timestamp
  124. orig_generate_nonce = common.generate_nonce
  125. orig_generate_age = utils.generate_age
  126. self.addCleanup(setattr, common, 'generage_timestamp', orig_generate_timestamp)
  127. self.addCleanup(setattr, common, 'generage_nonce', orig_generate_nonce)
  128. self.addCleanup(setattr, utils, 'generate_age', orig_generate_age)
  129. common.generate_timestamp = lambda: '123456789'
  130. common.generate_nonce = lambda: 'abc123'
  131. utils.generate_age = lambda *args: 0
  132. # Add the Authorization header (draft 00)
  133. client = Client(self.client_id, token_type="MAC",
  134. access_token=self.access_token, mac_key=self.mac_key,
  135. mac_algorithm="hmac-sha-1")
  136. uri, headers, body = client.add_token(self.uri, body=self.body,
  137. headers=self.headers, issue_time=datetime.datetime.now())
  138. self.assertEqual(uri, self.uri)
  139. self.assertEqual(body, self.body)
  140. self.assertEqual(headers, self.mac_00_header)
  141. # Non-HTTPS
  142. insecure_uri = 'http://example.com/path?query=world'
  143. self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
  144. body=self.body,
  145. headers=self.headers,
  146. issue_time=datetime.datetime.now())
  147. # Expired Token
  148. expired = 523549800
  149. expired_token = {
  150. 'expires_at': expired,
  151. }
  152. client = Client(self.client_id, token=expired_token, token_type="MAC",
  153. access_token=self.access_token, mac_key=self.mac_key,
  154. mac_algorithm="hmac-sha-1")
  155. self.assertRaises(TokenExpiredError, client.add_token, self.uri,
  156. body=self.body,
  157. headers=self.headers,
  158. issue_time=datetime.datetime.now())
  159. # Add the Authorization header (draft 01)
  160. client = Client(self.client_id, token_type="MAC",
  161. access_token=self.access_token, mac_key=self.mac_key,
  162. mac_algorithm="hmac-sha-1")
  163. uri, headers, body = client.add_token(self.uri, body=self.body,
  164. headers=self.headers, draft=1)
  165. self.assertEqual(uri, self.uri)
  166. self.assertEqual(body, self.body)
  167. self.assertEqual(headers, self.mac_01_header)
  168. # Non-HTTPS
  169. insecure_uri = 'http://example.com/path?query=world'
  170. self.assertRaises(InsecureTransportError, client.add_token, insecure_uri,
  171. body=self.body,
  172. headers=self.headers,
  173. draft=1)
  174. # Expired Token
  175. expired = 523549800
  176. expired_token = {
  177. 'expires_at': expired,
  178. }
  179. client = Client(self.client_id, token=expired_token, token_type="MAC",
  180. access_token=self.access_token, mac_key=self.mac_key,
  181. mac_algorithm="hmac-sha-1")
  182. self.assertRaises(TokenExpiredError, client.add_token, self.uri,
  183. body=self.body,
  184. headers=self.headers,
  185. draft=1)
  186. def test_revocation_request(self):
  187. client = Client(self.client_id)
  188. url = 'https://example.com/revoke'
  189. token = 'foobar'
  190. # Valid request
  191. u, h, b = client.prepare_token_revocation_request(url, token)
  192. self.assertEqual(u, url)
  193. self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
  194. self.assertEqual(b, 'token=%s&token_type_hint=access_token' % token)
  195. # Non-HTTPS revocation endpoint
  196. self.assertRaises(InsecureTransportError,
  197. client.prepare_token_revocation_request,
  198. 'http://example.com/revoke', token)
  199. u, h, b = client.prepare_token_revocation_request(
  200. url, token, token_type_hint='refresh_token')
  201. self.assertEqual(u, url)
  202. self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
  203. self.assertEqual(b, 'token=%s&token_type_hint=refresh_token' % token)
  204. # JSONP
  205. u, h, b = client.prepare_token_revocation_request(
  206. url, token, callback='hello.world')
  207. self.assertURLEqual(u, url + '?callback=hello.world&token=%s&token_type_hint=access_token' % token)
  208. self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
  209. self.assertEqual(b, '')
  210. def test_prepare_authorization_request(self):
  211. redirect_url = 'https://example.com/callback/'
  212. scopes = 'read'
  213. auth_url = 'https://example.com/authorize/'
  214. state = 'fake_state'
  215. client = Client(self.client_id, redirect_url=redirect_url, scope=scopes, state=state)
  216. # Non-HTTPS
  217. self.assertRaises(InsecureTransportError,
  218. client.prepare_authorization_request, 'http://example.com/authorize/')
  219. # NotImplementedError
  220. self.assertRaises(NotImplementedError, client.prepare_authorization_request, auth_url)
  221. def test_prepare_token_request(self):
  222. redirect_url = 'https://example.com/callback/'
  223. scopes = 'read'
  224. token_url = 'https://example.com/token/'
  225. state = 'fake_state'
  226. client = Client(self.client_id, scope=scopes, state=state)
  227. # Non-HTTPS
  228. self.assertRaises(InsecureTransportError,
  229. client.prepare_token_request, 'http://example.com/token/')
  230. # NotImplementedError
  231. self.assertRaises(NotImplementedError, client.prepare_token_request, token_url)
  232. def test_prepare_refresh_token_request(self):
  233. client = Client(self.client_id)
  234. url = 'https://example.com/revoke'
  235. token = 'foobar'
  236. scope = 'extra_scope'
  237. u, h, b = client.prepare_refresh_token_request(url, token)
  238. self.assertEqual(u, url)
  239. self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
  240. self.assertFormBodyEqual(b, 'grant_type=refresh_token&refresh_token=%s' % token)
  241. # Non-HTTPS revocation endpoint
  242. self.assertRaises(InsecureTransportError,
  243. client.prepare_refresh_token_request,
  244. 'http://example.com/revoke', token)
  245. # provide extra scope
  246. u, h, b = client.prepare_refresh_token_request(url, token, scope=scope)
  247. self.assertEqual(u, url)
  248. self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
  249. self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token))
  250. # provide scope while init
  251. client = Client(self.client_id, scope=scope)
  252. u, h, b = client.prepare_refresh_token_request(url, token, scope=scope)
  253. self.assertEqual(u, url)
  254. self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'})
  255. self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token))
  256. def test_parse_token_response_invalid_expires_at(self):
  257. token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",'
  258. ' "token_type":"example",'
  259. ' "expires_at":"2006-01-02T15:04:05Z",'
  260. ' "scope":"/profile",'
  261. ' "example_parameter":"example_value"}')
  262. token = {
  263. "access_token": "2YotnFZFEjr1zCsicMWpAA",
  264. "token_type": "example",
  265. "expires_at": "2006-01-02T15:04:05Z",
  266. "scope": ["/profile"],
  267. "example_parameter": "example_value"
  268. }
  269. client = Client(self.client_id)
  270. # Parse code and state
  271. response = client.parse_request_body_response(token_json, scope=["/profile"])
  272. self.assertEqual(response, token)
  273. self.assertEqual(None, client._expires_at)
  274. self.assertEqual(client.access_token, response.get("access_token"))
  275. self.assertEqual(client.refresh_token, response.get("refresh_token"))
  276. self.assertEqual(client.token_type, response.get("token_type"))
  277. def test_create_code_verifier_min_length(self):
  278. client = Client(self.client_id)
  279. length = 43
  280. code_verifier = client.create_code_verifier(length=length)
  281. self.assertEqual(client.code_verifier, code_verifier)
  282. def test_create_code_verifier_max_length(self):
  283. client = Client(self.client_id)
  284. length = 128
  285. code_verifier = client.create_code_verifier(length=length)
  286. self.assertEqual(client.code_verifier, code_verifier)
  287. def test_create_code_challenge_plain(self):
  288. client = Client(self.client_id)
  289. code_verifier = client.create_code_verifier(length=128)
  290. code_challenge_plain = client.create_code_challenge(code_verifier=code_verifier)
  291. # if no code_challenge_method specified, code_challenge = code_verifier
  292. self.assertEqual(code_challenge_plain, client.code_verifier)
  293. self.assertEqual(client.code_challenge_method, "plain")
  294. def test_create_code_challenge_s256(self):
  295. client = Client(self.client_id)
  296. code_verifier = client.create_code_verifier(length=128)
  297. code_challenge_s256 = client.create_code_challenge(code_verifier=code_verifier, code_challenge_method='S256')
  298. self.assertEqual(code_challenge_s256, client.code_challenge)