test_base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. from re import sub
  2. from unittest.mock import MagicMock
  3. from oauthlib.common import CaseInsensitiveDict, safe_string_equals
  4. from oauthlib.oauth1 import Client, RequestValidator
  5. from oauthlib.oauth1.rfc5849 import (
  6. SIGNATURE_HMAC, SIGNATURE_PLAINTEXT, SIGNATURE_RSA, errors,
  7. )
  8. from oauthlib.oauth1.rfc5849.endpoints import (
  9. BaseEndpoint, RequestTokenEndpoint,
  10. )
  11. from tests.unittest import TestCase
  12. URLENCODED = {"Content-Type": "application/x-www-form-urlencoded"}
  13. class BaseEndpointTest(TestCase):
  14. def setUp(self):
  15. self.validator = MagicMock(spec=RequestValidator)
  16. self.validator.allowed_signature_methods = ['HMAC-SHA1']
  17. self.validator.timestamp_lifetime = 600
  18. self.endpoint = RequestTokenEndpoint(self.validator)
  19. self.client = Client('foo', callback_uri='https://c.b/cb')
  20. self.uri, self.headers, self.body = self.client.sign(
  21. 'https://i.b/request_token')
  22. def test_ssl_enforcement(self):
  23. uri, headers, _ = self.client.sign('http://i.b/request_token')
  24. h, b, s = self.endpoint.create_request_token_response(
  25. uri, headers=headers)
  26. self.assertEqual(s, 400)
  27. self.assertIn('insecure_transport_protocol', b)
  28. def test_missing_parameters(self):
  29. h, b, s = self.endpoint.create_request_token_response(self.uri)
  30. self.assertEqual(s, 400)
  31. self.assertIn('invalid_request', b)
  32. def test_signature_methods(self):
  33. headers = {}
  34. headers['Authorization'] = self.headers['Authorization'].replace(
  35. 'HMAC', 'RSA')
  36. h, b, s = self.endpoint.create_request_token_response(
  37. self.uri, headers=headers)
  38. self.assertEqual(s, 400)
  39. self.assertIn('invalid_signature_method', b)
  40. def test_invalid_version(self):
  41. headers = {}
  42. headers['Authorization'] = self.headers['Authorization'].replace(
  43. '1.0', '2.0')
  44. h, b, s = self.endpoint.create_request_token_response(
  45. self.uri, headers=headers)
  46. self.assertEqual(s, 400)
  47. self.assertIn('invalid_request', b)
  48. def test_expired_timestamp(self):
  49. headers = {}
  50. for pattern in ('12345678901', '4567890123', '123456789K'):
  51. headers['Authorization'] = sub(r'timestamp="\d*k?"',
  52. 'timestamp="%s"' % pattern,
  53. self.headers['Authorization'])
  54. h, b, s = self.endpoint.create_request_token_response(
  55. self.uri, headers=headers)
  56. self.assertEqual(s, 400)
  57. self.assertIn('invalid_request', b)
  58. def test_client_key_check(self):
  59. self.validator.check_client_key.return_value = False
  60. h, b, s = self.endpoint.create_request_token_response(
  61. self.uri, headers=self.headers)
  62. self.assertEqual(s, 400)
  63. self.assertIn('invalid_request', b)
  64. def test_noncecheck(self):
  65. self.validator.check_nonce.return_value = False
  66. h, b, s = self.endpoint.create_request_token_response(
  67. self.uri, headers=self.headers)
  68. self.assertEqual(s, 400)
  69. self.assertIn('invalid_request', b)
  70. def test_enforce_ssl(self):
  71. """Ensure SSL is enforced by default."""
  72. v = RequestValidator()
  73. e = BaseEndpoint(v)
  74. c = Client('foo')
  75. u, h, b = c.sign('http://example.com')
  76. r = e._create_request(u, 'GET', b, h)
  77. self.assertRaises(errors.InsecureTransportError,
  78. e._check_transport_security, r)
  79. def test_multiple_source_params(self):
  80. """Check for duplicate params"""
  81. v = RequestValidator()
  82. e = BaseEndpoint(v)
  83. self.assertRaises(errors.InvalidRequestError, e._create_request,
  84. 'https://a.b/?oauth_signature_method=HMAC-SHA1',
  85. 'GET', 'oauth_version=foo', URLENCODED)
  86. headers = {'Authorization': 'OAuth oauth_signature="foo"'}
  87. headers.update(URLENCODED)
  88. self.assertRaises(errors.InvalidRequestError, e._create_request,
  89. 'https://a.b/?oauth_signature_method=HMAC-SHA1',
  90. 'GET',
  91. 'oauth_version=foo',
  92. headers)
  93. headers = {'Authorization': 'OAuth oauth_signature_method="foo"'}
  94. headers.update(URLENCODED)
  95. self.assertRaises(errors.InvalidRequestError, e._create_request,
  96. 'https://a.b/',
  97. 'GET',
  98. 'oauth_signature=foo',
  99. headers)
  100. def test_duplicate_params(self):
  101. """Ensure params are only supplied once"""
  102. v = RequestValidator()
  103. e = BaseEndpoint(v)
  104. self.assertRaises(errors.InvalidRequestError, e._create_request,
  105. 'https://a.b/?oauth_version=a&oauth_version=b',
  106. 'GET', None, URLENCODED)
  107. self.assertRaises(errors.InvalidRequestError, e._create_request,
  108. 'https://a.b/', 'GET', 'oauth_version=a&oauth_version=b',
  109. URLENCODED)
  110. def test_mandated_params(self):
  111. """Ensure all mandatory params are present."""
  112. v = RequestValidator()
  113. e = BaseEndpoint(v)
  114. r = e._create_request('https://a.b/', 'GET',
  115. 'oauth_signature=a&oauth_consumer_key=b&oauth_nonce',
  116. URLENCODED)
  117. self.assertRaises(errors.InvalidRequestError,
  118. e._check_mandatory_parameters, r)
  119. def test_oauth_version(self):
  120. """OAuth version must be 1.0 if present."""
  121. v = RequestValidator()
  122. e = BaseEndpoint(v)
  123. r = e._create_request('https://a.b/', 'GET',
  124. ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
  125. 'oauth_timestamp=a&oauth_signature_method=RSA-SHA1&'
  126. 'oauth_version=2.0'),
  127. URLENCODED)
  128. self.assertRaises(errors.InvalidRequestError,
  129. e._check_mandatory_parameters, r)
  130. def test_oauth_timestamp(self):
  131. """Check for a valid UNIX timestamp."""
  132. v = RequestValidator()
  133. e = BaseEndpoint(v)
  134. # Invalid timestamp length, must be 10
  135. r = e._create_request('https://a.b/', 'GET',
  136. ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
  137. 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
  138. 'oauth_timestamp=123456789'),
  139. URLENCODED)
  140. self.assertRaises(errors.InvalidRequestError,
  141. e._check_mandatory_parameters, r)
  142. # Invalid timestamp age, must be younger than 10 minutes
  143. r = e._create_request('https://a.b/', 'GET',
  144. ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
  145. 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
  146. 'oauth_timestamp=1234567890'),
  147. URLENCODED)
  148. self.assertRaises(errors.InvalidRequestError,
  149. e._check_mandatory_parameters, r)
  150. # Timestamp must be an integer
  151. r = e._create_request('https://a.b/', 'GET',
  152. ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
  153. 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
  154. 'oauth_timestamp=123456789a'),
  155. URLENCODED)
  156. self.assertRaises(errors.InvalidRequestError,
  157. e._check_mandatory_parameters, r)
  158. def test_case_insensitive_headers(self):
  159. """Ensure headers are case-insensitive"""
  160. v = RequestValidator()
  161. e = BaseEndpoint(v)
  162. r = e._create_request('https://a.b', 'POST',
  163. ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
  164. 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
  165. 'oauth_timestamp=123456789a'),
  166. URLENCODED)
  167. self.assertIsInstance(r.headers, CaseInsensitiveDict)
  168. def test_signature_method_validation(self):
  169. """Ensure valid signature method is used."""
  170. body = ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
  171. 'oauth_version=1.0&oauth_signature_method=%s&'
  172. 'oauth_timestamp=1234567890')
  173. uri = 'https://example.com/'
  174. class HMACValidator(RequestValidator):
  175. @property
  176. def allowed_signature_methods(self):
  177. return (SIGNATURE_HMAC,)
  178. v = HMACValidator()
  179. e = BaseEndpoint(v)
  180. r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
  181. self.assertRaises(errors.InvalidSignatureMethodError,
  182. e._check_mandatory_parameters, r)
  183. r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
  184. self.assertRaises(errors.InvalidSignatureMethodError,
  185. e._check_mandatory_parameters, r)
  186. r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
  187. self.assertRaises(errors.InvalidSignatureMethodError,
  188. e._check_mandatory_parameters, r)
  189. class RSAValidator(RequestValidator):
  190. @property
  191. def allowed_signature_methods(self):
  192. return (SIGNATURE_RSA,)
  193. v = RSAValidator()
  194. e = BaseEndpoint(v)
  195. r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
  196. self.assertRaises(errors.InvalidSignatureMethodError,
  197. e._check_mandatory_parameters, r)
  198. r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
  199. self.assertRaises(errors.InvalidSignatureMethodError,
  200. e._check_mandatory_parameters, r)
  201. r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
  202. self.assertRaises(errors.InvalidSignatureMethodError,
  203. e._check_mandatory_parameters, r)
  204. class PlainValidator(RequestValidator):
  205. @property
  206. def allowed_signature_methods(self):
  207. return (SIGNATURE_PLAINTEXT,)
  208. v = PlainValidator()
  209. e = BaseEndpoint(v)
  210. r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
  211. self.assertRaises(errors.InvalidSignatureMethodError,
  212. e._check_mandatory_parameters, r)
  213. r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
  214. self.assertRaises(errors.InvalidSignatureMethodError,
  215. e._check_mandatory_parameters, r)
  216. r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
  217. self.assertRaises(errors.InvalidSignatureMethodError,
  218. e._check_mandatory_parameters, r)
  219. class ClientValidator(RequestValidator):
  220. clients = ['foo']
  221. nonces = [('foo', 'once', '1234567891', 'fez')]
  222. owners = {'foo': ['abcdefghijklmnopqrstuvxyz', 'fez']}
  223. assigned_realms = {('foo', 'abcdefghijklmnopqrstuvxyz'): 'photos'}
  224. verifiers = {('foo', 'fez'): 'shibboleth'}
  225. @property
  226. def client_key_length(self):
  227. return 1, 30
  228. @property
  229. def request_token_length(self):
  230. return 1, 30
  231. @property
  232. def access_token_length(self):
  233. return 1, 30
  234. @property
  235. def nonce_length(self):
  236. return 2, 30
  237. @property
  238. def verifier_length(self):
  239. return 2, 30
  240. @property
  241. def realms(self):
  242. return ['photos']
  243. @property
  244. def timestamp_lifetime(self):
  245. # Disabled check to allow hardcoded verification signatures
  246. return 1000000000
  247. @property
  248. def dummy_client(self):
  249. return 'dummy'
  250. @property
  251. def dummy_request_token(self):
  252. return 'dumbo'
  253. @property
  254. def dummy_access_token(self):
  255. return 'dumbo'
  256. def validate_timestamp_and_nonce(self, client_key, timestamp, nonce,
  257. request, request_token=None, access_token=None):
  258. resource_owner_key = request_token if request_token else access_token
  259. return not (client_key, nonce, timestamp, resource_owner_key) in self.nonces
  260. def validate_client_key(self, client_key):
  261. return client_key in self.clients
  262. def validate_access_token(self, client_key, access_token, request):
  263. return (self.owners.get(client_key) and
  264. access_token in self.owners.get(client_key))
  265. def validate_request_token(self, client_key, request_token, request):
  266. return (self.owners.get(client_key) and
  267. request_token in self.owners.get(client_key))
  268. def validate_requested_realm(self, client_key, realm, request):
  269. return True
  270. def validate_realm(self, client_key, access_token, request, uri=None,
  271. required_realm=None):
  272. return (client_key, access_token) in self.assigned_realms
  273. def validate_verifier(self, client_key, request_token, verifier,
  274. request):
  275. return ((client_key, request_token) in self.verifiers and
  276. safe_string_equals(verifier, self.verifiers.get(
  277. (client_key, request_token))))
  278. def validate_redirect_uri(self, client_key, redirect_uri, request):
  279. return redirect_uri.startswith('http://client.example.com/')
  280. def get_client_secret(self, client_key, request):
  281. return 'super secret'
  282. def get_access_token_secret(self, client_key, access_token, request):
  283. return 'even more secret'
  284. def get_request_token_secret(self, client_key, request_token, request):
  285. return 'even more secret'
  286. def get_rsa_key(self, client_key, request):
  287. return ("-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNA"
  288. "DCBiQKBgQDVLQCATX8iK+aZuGVdkGb6uiar\nLi/jqFwL1dYj0JLIsdQc"
  289. "KaMWtPC06K0+vI+RRZcjKc6sNB9/7kJcKN9Ekc9BUxyT\n/D09Cz47cmC"
  290. "YsUoiW7G8NSqbE4wPiVpGkJRzFAxaCWwOSSQ+lpC9vwxnvVQfOoZ1\nnp"
  291. "mWbCdA0iTxsMahwQIDAQAB\n-----END PUBLIC KEY-----")
  292. class SignatureVerificationTest(TestCase):
  293. def setUp(self):
  294. v = ClientValidator()
  295. self.e = BaseEndpoint(v)
  296. self.uri = 'https://example.com/'
  297. self.sig = ('oauth_signature=%s&'
  298. 'oauth_timestamp=1234567890&'
  299. 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
  300. 'oauth_version=1.0&'
  301. 'oauth_signature_method=%s&'
  302. 'oauth_token=abcdefghijklmnopqrstuvxyz&'
  303. 'oauth_consumer_key=foo')
  304. def test_signature_too_short(self):
  305. short_sig = ('oauth_signature=fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY&'
  306. 'oauth_timestamp=1234567890&'
  307. 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
  308. 'oauth_version=1.0&oauth_signature_method=HMAC-SHA1&'
  309. 'oauth_token=abcdefghijklmnopqrstuvxyz&'
  310. 'oauth_consumer_key=foo')
  311. r = self.e._create_request(self.uri, 'GET', short_sig, URLENCODED)
  312. self.assertFalse(self.e._check_signature(r))
  313. plain = ('oauth_signature=correctlengthbutthewrongcontent1111&'
  314. 'oauth_timestamp=1234567890&'
  315. 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
  316. 'oauth_version=1.0&oauth_signature_method=PLAINTEXT&'
  317. 'oauth_token=abcdefghijklmnopqrstuvxyz&'
  318. 'oauth_consumer_key=foo')
  319. r = self.e._create_request(self.uri, 'GET', plain, URLENCODED)
  320. self.assertFalse(self.e._check_signature(r))
  321. def test_hmac_signature(self):
  322. hmac_sig = "fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY%3D"
  323. sig = self.sig % (hmac_sig, "HMAC-SHA1")
  324. r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
  325. self.assertTrue(self.e._check_signature(r))
  326. def test_rsa_signature(self):
  327. rsa_sig = ("fxFvCx33oKlR9wDquJ%2FPsndFzJphyBa3RFPPIKi3flqK%2BJ7yIrMVbH"
  328. "YTM%2FLHPc7NChWz4F4%2FzRA%2BDN1k08xgYGSBoWJUOW6VvOQ6fbYhMA"
  329. "FkOGYbuGDbje487XMzsAcv6ZjqZHCROSCk5vofgLk2SN7RZ3OrgrFzf4in"
  330. "xetClqA%3D")
  331. sig = self.sig % (rsa_sig, "RSA-SHA1")
  332. r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
  333. self.assertTrue(self.e._check_signature(r))
  334. def test_plaintext_signature(self):
  335. plain_sig = "super%252520secret%26even%252520more%252520secret"
  336. sig = self.sig % (plain_sig, "PLAINTEXT")
  337. r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
  338. self.assertTrue(self.e._check_signature(r))