123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- from re import sub
- from unittest.mock import MagicMock
- from oauthlib.common import CaseInsensitiveDict, safe_string_equals
- from oauthlib.oauth1 import Client, RequestValidator
- from oauthlib.oauth1.rfc5849 import (
- SIGNATURE_HMAC, SIGNATURE_PLAINTEXT, SIGNATURE_RSA, errors,
- )
- from oauthlib.oauth1.rfc5849.endpoints import (
- BaseEndpoint, RequestTokenEndpoint,
- )
- from tests.unittest import TestCase
- URLENCODED = {"Content-Type": "application/x-www-form-urlencoded"}
- class BaseEndpointTest(TestCase):
- def setUp(self):
- self.validator = MagicMock(spec=RequestValidator)
- self.validator.allowed_signature_methods = ['HMAC-SHA1']
- self.validator.timestamp_lifetime = 600
- self.endpoint = RequestTokenEndpoint(self.validator)
- self.client = Client('foo', callback_uri='https://c.b/cb')
- self.uri, self.headers, self.body = self.client.sign(
- 'https://i.b/request_token')
- def test_ssl_enforcement(self):
- uri, headers, _ = self.client.sign('http://i.b/request_token')
- h, b, s = self.endpoint.create_request_token_response(
- uri, headers=headers)
- self.assertEqual(s, 400)
- self.assertIn('insecure_transport_protocol', b)
- def test_missing_parameters(self):
- h, b, s = self.endpoint.create_request_token_response(self.uri)
- self.assertEqual(s, 400)
- self.assertIn('invalid_request', b)
- def test_signature_methods(self):
- headers = {}
- headers['Authorization'] = self.headers['Authorization'].replace(
- 'HMAC', 'RSA')
- h, b, s = self.endpoint.create_request_token_response(
- self.uri, headers=headers)
- self.assertEqual(s, 400)
- self.assertIn('invalid_signature_method', b)
- def test_invalid_version(self):
- headers = {}
- headers['Authorization'] = self.headers['Authorization'].replace(
- '1.0', '2.0')
- h, b, s = self.endpoint.create_request_token_response(
- self.uri, headers=headers)
- self.assertEqual(s, 400)
- self.assertIn('invalid_request', b)
- def test_expired_timestamp(self):
- headers = {}
- for pattern in ('12345678901', '4567890123', '123456789K'):
- headers['Authorization'] = sub(r'timestamp="\d*k?"',
- 'timestamp="%s"' % pattern,
- self.headers['Authorization'])
- h, b, s = self.endpoint.create_request_token_response(
- self.uri, headers=headers)
- self.assertEqual(s, 400)
- self.assertIn('invalid_request', b)
- def test_client_key_check(self):
- self.validator.check_client_key.return_value = False
- h, b, s = self.endpoint.create_request_token_response(
- self.uri, headers=self.headers)
- self.assertEqual(s, 400)
- self.assertIn('invalid_request', b)
- def test_noncecheck(self):
- self.validator.check_nonce.return_value = False
- h, b, s = self.endpoint.create_request_token_response(
- self.uri, headers=self.headers)
- self.assertEqual(s, 400)
- self.assertIn('invalid_request', b)
- def test_enforce_ssl(self):
- """Ensure SSL is enforced by default."""
- v = RequestValidator()
- e = BaseEndpoint(v)
- c = Client('foo')
- u, h, b = c.sign('http://example.com')
- r = e._create_request(u, 'GET', b, h)
- self.assertRaises(errors.InsecureTransportError,
- e._check_transport_security, r)
- def test_multiple_source_params(self):
- """Check for duplicate params"""
- v = RequestValidator()
- e = BaseEndpoint(v)
- self.assertRaises(errors.InvalidRequestError, e._create_request,
- 'https://a.b/?oauth_signature_method=HMAC-SHA1',
- 'GET', 'oauth_version=foo', URLENCODED)
- headers = {'Authorization': 'OAuth oauth_signature="foo"'}
- headers.update(URLENCODED)
- self.assertRaises(errors.InvalidRequestError, e._create_request,
- 'https://a.b/?oauth_signature_method=HMAC-SHA1',
- 'GET',
- 'oauth_version=foo',
- headers)
- headers = {'Authorization': 'OAuth oauth_signature_method="foo"'}
- headers.update(URLENCODED)
- self.assertRaises(errors.InvalidRequestError, e._create_request,
- 'https://a.b/',
- 'GET',
- 'oauth_signature=foo',
- headers)
- def test_duplicate_params(self):
- """Ensure params are only supplied once"""
- v = RequestValidator()
- e = BaseEndpoint(v)
- self.assertRaises(errors.InvalidRequestError, e._create_request,
- 'https://a.b/?oauth_version=a&oauth_version=b',
- 'GET', None, URLENCODED)
- self.assertRaises(errors.InvalidRequestError, e._create_request,
- 'https://a.b/', 'GET', 'oauth_version=a&oauth_version=b',
- URLENCODED)
- def test_mandated_params(self):
- """Ensure all mandatory params are present."""
- v = RequestValidator()
- e = BaseEndpoint(v)
- r = e._create_request('https://a.b/', 'GET',
- 'oauth_signature=a&oauth_consumer_key=b&oauth_nonce',
- URLENCODED)
- self.assertRaises(errors.InvalidRequestError,
- e._check_mandatory_parameters, r)
- def test_oauth_version(self):
- """OAuth version must be 1.0 if present."""
- v = RequestValidator()
- e = BaseEndpoint(v)
- r = e._create_request('https://a.b/', 'GET',
- ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
- 'oauth_timestamp=a&oauth_signature_method=RSA-SHA1&'
- 'oauth_version=2.0'),
- URLENCODED)
- self.assertRaises(errors.InvalidRequestError,
- e._check_mandatory_parameters, r)
- def test_oauth_timestamp(self):
- """Check for a valid UNIX timestamp."""
- v = RequestValidator()
- e = BaseEndpoint(v)
- # Invalid timestamp length, must be 10
- r = e._create_request('https://a.b/', 'GET',
- ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
- 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
- 'oauth_timestamp=123456789'),
- URLENCODED)
- self.assertRaises(errors.InvalidRequestError,
- e._check_mandatory_parameters, r)
- # Invalid timestamp age, must be younger than 10 minutes
- r = e._create_request('https://a.b/', 'GET',
- ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
- 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
- 'oauth_timestamp=1234567890'),
- URLENCODED)
- self.assertRaises(errors.InvalidRequestError,
- e._check_mandatory_parameters, r)
- # Timestamp must be an integer
- r = e._create_request('https://a.b/', 'GET',
- ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
- 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
- 'oauth_timestamp=123456789a'),
- URLENCODED)
- self.assertRaises(errors.InvalidRequestError,
- e._check_mandatory_parameters, r)
- def test_case_insensitive_headers(self):
- """Ensure headers are case-insensitive"""
- v = RequestValidator()
- e = BaseEndpoint(v)
- r = e._create_request('https://a.b', 'POST',
- ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
- 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
- 'oauth_timestamp=123456789a'),
- URLENCODED)
- self.assertIsInstance(r.headers, CaseInsensitiveDict)
- def test_signature_method_validation(self):
- """Ensure valid signature method is used."""
- body = ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
- 'oauth_version=1.0&oauth_signature_method=%s&'
- 'oauth_timestamp=1234567890')
- uri = 'https://example.com/'
- class HMACValidator(RequestValidator):
- @property
- def allowed_signature_methods(self):
- return (SIGNATURE_HMAC,)
- v = HMACValidator()
- e = BaseEndpoint(v)
- r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- class RSAValidator(RequestValidator):
- @property
- def allowed_signature_methods(self):
- return (SIGNATURE_RSA,)
- v = RSAValidator()
- e = BaseEndpoint(v)
- r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- class PlainValidator(RequestValidator):
- @property
- def allowed_signature_methods(self):
- return (SIGNATURE_PLAINTEXT,)
- v = PlainValidator()
- e = BaseEndpoint(v)
- r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
- self.assertRaises(errors.InvalidSignatureMethodError,
- e._check_mandatory_parameters, r)
- class ClientValidator(RequestValidator):
- clients = ['foo']
- nonces = [('foo', 'once', '1234567891', 'fez')]
- owners = {'foo': ['abcdefghijklmnopqrstuvxyz', 'fez']}
- assigned_realms = {('foo', 'abcdefghijklmnopqrstuvxyz'): 'photos'}
- verifiers = {('foo', 'fez'): 'shibboleth'}
- @property
- def client_key_length(self):
- return 1, 30
- @property
- def request_token_length(self):
- return 1, 30
- @property
- def access_token_length(self):
- return 1, 30
- @property
- def nonce_length(self):
- return 2, 30
- @property
- def verifier_length(self):
- return 2, 30
- @property
- def realms(self):
- return ['photos']
- @property
- def timestamp_lifetime(self):
- # Disabled check to allow hardcoded verification signatures
- return 1000000000
- @property
- def dummy_client(self):
- return 'dummy'
- @property
- def dummy_request_token(self):
- return 'dumbo'
- @property
- def dummy_access_token(self):
- return 'dumbo'
- def validate_timestamp_and_nonce(self, client_key, timestamp, nonce,
- request, request_token=None, access_token=None):
- resource_owner_key = request_token if request_token else access_token
- return not (client_key, nonce, timestamp, resource_owner_key) in self.nonces
- def validate_client_key(self, client_key):
- return client_key in self.clients
- def validate_access_token(self, client_key, access_token, request):
- return (self.owners.get(client_key) and
- access_token in self.owners.get(client_key))
- def validate_request_token(self, client_key, request_token, request):
- return (self.owners.get(client_key) and
- request_token in self.owners.get(client_key))
- def validate_requested_realm(self, client_key, realm, request):
- return True
- def validate_realm(self, client_key, access_token, request, uri=None,
- required_realm=None):
- return (client_key, access_token) in self.assigned_realms
- def validate_verifier(self, client_key, request_token, verifier,
- request):
- return ((client_key, request_token) in self.verifiers and
- safe_string_equals(verifier, self.verifiers.get(
- (client_key, request_token))))
- def validate_redirect_uri(self, client_key, redirect_uri, request):
- return redirect_uri.startswith('http://client.example.com/')
- def get_client_secret(self, client_key, request):
- return 'super secret'
- def get_access_token_secret(self, client_key, access_token, request):
- return 'even more secret'
- def get_request_token_secret(self, client_key, request_token, request):
- return 'even more secret'
- def get_rsa_key(self, client_key, request):
- return ("-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNA"
- "DCBiQKBgQDVLQCATX8iK+aZuGVdkGb6uiar\nLi/jqFwL1dYj0JLIsdQc"
- "KaMWtPC06K0+vI+RRZcjKc6sNB9/7kJcKN9Ekc9BUxyT\n/D09Cz47cmC"
- "YsUoiW7G8NSqbE4wPiVpGkJRzFAxaCWwOSSQ+lpC9vwxnvVQfOoZ1\nnp"
- "mWbCdA0iTxsMahwQIDAQAB\n-----END PUBLIC KEY-----")
- class SignatureVerificationTest(TestCase):
- def setUp(self):
- v = ClientValidator()
- self.e = BaseEndpoint(v)
- self.uri = 'https://example.com/'
- self.sig = ('oauth_signature=%s&'
- 'oauth_timestamp=1234567890&'
- 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
- 'oauth_version=1.0&'
- 'oauth_signature_method=%s&'
- 'oauth_token=abcdefghijklmnopqrstuvxyz&'
- 'oauth_consumer_key=foo')
- def test_signature_too_short(self):
- short_sig = ('oauth_signature=fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY&'
- 'oauth_timestamp=1234567890&'
- 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
- 'oauth_version=1.0&oauth_signature_method=HMAC-SHA1&'
- 'oauth_token=abcdefghijklmnopqrstuvxyz&'
- 'oauth_consumer_key=foo')
- r = self.e._create_request(self.uri, 'GET', short_sig, URLENCODED)
- self.assertFalse(self.e._check_signature(r))
- plain = ('oauth_signature=correctlengthbutthewrongcontent1111&'
- 'oauth_timestamp=1234567890&'
- 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
- 'oauth_version=1.0&oauth_signature_method=PLAINTEXT&'
- 'oauth_token=abcdefghijklmnopqrstuvxyz&'
- 'oauth_consumer_key=foo')
- r = self.e._create_request(self.uri, 'GET', plain, URLENCODED)
- self.assertFalse(self.e._check_signature(r))
- def test_hmac_signature(self):
- hmac_sig = "fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY%3D"
- sig = self.sig % (hmac_sig, "HMAC-SHA1")
- r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
- self.assertTrue(self.e._check_signature(r))
- def test_rsa_signature(self):
- rsa_sig = ("fxFvCx33oKlR9wDquJ%2FPsndFzJphyBa3RFPPIKi3flqK%2BJ7yIrMVbH"
- "YTM%2FLHPc7NChWz4F4%2FzRA%2BDN1k08xgYGSBoWJUOW6VvOQ6fbYhMA"
- "FkOGYbuGDbje487XMzsAcv6ZjqZHCROSCk5vofgLk2SN7RZ3OrgrFzf4in"
- "xetClqA%3D")
- sig = self.sig % (rsa_sig, "RSA-SHA1")
- r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
- self.assertTrue(self.e._check_signature(r))
- def test_plaintext_signature(self):
- plain_sig = "super%252520secret%26even%252520more%252520secret"
- sig = self.sig % (plain_sig, "PLAINTEXT")
- r = self.e._create_request(self.uri, 'GET', sig, URLENCODED)
- self.assertTrue(self.e._check_signature(r))
|