|
- import hashlib
- import hmac
- import json
- from .compat import constant_time_compare, string_types
- from .exceptions import InvalidKeyError
- from .utils import (
- base64url_decode, base64url_encode, der_to_raw_signature,
- force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
- to_base64url_uint
- )
- try:
- from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives.serialization import (
- load_pem_private_key, load_pem_public_key, load_ssh_public_key
- )
- from cryptography.hazmat.primitives.asymmetric.rsa import (
- RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
- rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
- )
- from cryptography.hazmat.primitives.asymmetric.ec import (
- EllipticCurvePrivateKey, EllipticCurvePublicKey
- )
- from cryptography.hazmat.primitives.asymmetric import ec, padding
- from cryptography.hazmat.backends import default_backend
- from cryptography.exceptions import InvalidSignature
- has_crypto = True
- except ImportError:
- has_crypto = False
- requires_cryptography = set(['RS256', 'RS384', 'RS512', 'ES256', 'ES384',
- 'ES521', 'ES512', 'PS256', 'PS384', 'PS512'])
- def get_default_algorithms():
- """
- Returns the algorithms that are implemented by the library.
- """
- default_algorithms = {
- 'none': NoneAlgorithm(),
- 'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
- 'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
- 'HS512': HMACAlgorithm(HMACAlgorithm.SHA512)
- }
- if has_crypto:
- default_algorithms.update({
- 'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
- 'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
- 'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
- 'ES256': ECAlgorithm(ECAlgorithm.SHA256),
- 'ES384': ECAlgorithm(ECAlgorithm.SHA384),
- 'ES521': ECAlgorithm(ECAlgorithm.SHA512),
- 'ES512': ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix
- 'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
- 'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
- 'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512)
- })
- return default_algorithms
- class Algorithm(object):
- """
- The interface for an algorithm used to sign and verify tokens.
- """
- def prepare_key(self, key):
- """
- Performs necessary validation and conversions on the key and returns
- the key value in the proper format for sign() and verify().
- """
- raise NotImplementedError
- def sign(self, msg, key):
- """
- Returns a digital signature for the specified message
- using the specified key value.
- """
- raise NotImplementedError
- def verify(self, msg, key, sig):
- """
- Verifies that the specified digital signature is valid
- for the specified message and key values.
- """
- raise NotImplementedError
- @staticmethod
- def to_jwk(key_obj):
- """
- Serializes a given RSA key into a JWK
- """
- raise NotImplementedError
- @staticmethod
- def from_jwk(jwk):
- """
- Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
- """
- raise NotImplementedError
- class NoneAlgorithm(Algorithm):
- """
- Placeholder for use when no signing or verification
- operations are required.
- """
- def prepare_key(self, key):
- if key == '':
- key = None
- if key is not None:
- raise InvalidKeyError('When alg = "none", key value must be None.')
- return key
- def sign(self, msg, key):
- return b''
- def verify(self, msg, key, sig):
- return False
- class HMACAlgorithm(Algorithm):
- """
- Performs signing and verification operations using HMAC
- and the specified hash function.
- """
- SHA256 = hashlib.sha256
- SHA384 = hashlib.sha384
- SHA512 = hashlib.sha512
- def __init__(self, hash_alg):
- self.hash_alg = hash_alg
- def prepare_key(self, key):
- key = force_bytes(key)
- invalid_strings = [
- b'-----BEGIN PUBLIC KEY-----',
- b'-----BEGIN CERTIFICATE-----',
- b'-----BEGIN RSA PUBLIC KEY-----',
- b'ssh-rsa'
- ]
- if any([string_value in key for string_value in invalid_strings]):
- raise InvalidKeyError(
- 'The specified key is an asymmetric key or x509 certificate and'
- ' should not be used as an HMAC secret.')
- return key
- @staticmethod
- def to_jwk(key_obj):
- return json.dumps({
- 'k': force_unicode(base64url_encode(force_bytes(key_obj))),
- 'kty': 'oct'
- })
- @staticmethod
- def from_jwk(jwk):
- obj = json.loads(jwk)
- if obj.get('kty') != 'oct':
- raise InvalidKeyError('Not an HMAC key')
- return base64url_decode(obj['k'])
- def sign(self, msg, key):
- return hmac.new(key, msg, self.hash_alg).digest()
- def verify(self, msg, key, sig):
- return constant_time_compare(sig, self.sign(msg, key))
- if has_crypto:
- class RSAAlgorithm(Algorithm):
- """
- Performs signing and verification operations using
- RSASSA-PKCS-v1_5 and the specified hash function.
- """
- SHA256 = hashes.SHA256
- SHA384 = hashes.SHA384
- SHA512 = hashes.SHA512
- def __init__(self, hash_alg):
- self.hash_alg = hash_alg
- def prepare_key(self, key):
- if isinstance(key, RSAPrivateKey) or \
- isinstance(key, RSAPublicKey):
- return key
- if isinstance(key, string_types):
- key = force_bytes(key)
- try:
- if key.startswith(b'ssh-rsa'):
- key = load_ssh_public_key(key, backend=default_backend())
- else:
- key = load_pem_private_key(key, password=None, backend=default_backend())
- except ValueError:
- key = load_pem_public_key(key, backend=default_backend())
- else:
- raise TypeError('Expecting a PEM-formatted key.')
- return key
- @staticmethod
- def to_jwk(key_obj):
- obj = None
- if getattr(key_obj, 'private_numbers', None):
- # Private key
- numbers = key_obj.private_numbers()
- obj = {
- 'kty': 'RSA',
- 'key_ops': ['sign'],
- 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
- 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
- 'd': force_unicode(to_base64url_uint(numbers.d)),
- 'p': force_unicode(to_base64url_uint(numbers.p)),
- 'q': force_unicode(to_base64url_uint(numbers.q)),
- 'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
- 'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
- 'qi': force_unicode(to_base64url_uint(numbers.iqmp))
- }
- elif getattr(key_obj, 'verify', None):
- # Public key
- numbers = key_obj.public_numbers()
- obj = {
- 'kty': 'RSA',
- 'key_ops': ['verify'],
- 'n': force_unicode(to_base64url_uint(numbers.n)),
- 'e': force_unicode(to_base64url_uint(numbers.e))
- }
- else:
- raise InvalidKeyError('Not a public or private key')
- return json.dumps(obj)
- @staticmethod
- def from_jwk(jwk):
- try:
- obj = json.loads(jwk)
- except ValueError:
- raise InvalidKeyError('Key is not valid JSON')
- if obj.get('kty') != 'RSA':
- raise InvalidKeyError('Not an RSA key')
- if 'd' in obj and 'e' in obj and 'n' in obj:
- # Private key
- if 'oth' in obj:
- raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
- other_props = ['p', 'q', 'dp', 'dq', 'qi']
- props_found = [prop in obj for prop in other_props]
- any_props_found = any(props_found)
- if any_props_found and not all(props_found):
- raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
- public_numbers = RSAPublicNumbers(
- from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
- )
- if any_props_found:
- numbers = RSAPrivateNumbers(
- d=from_base64url_uint(obj['d']),
- p=from_base64url_uint(obj['p']),
- q=from_base64url_uint(obj['q']),
- dmp1=from_base64url_uint(obj['dp']),
- dmq1=from_base64url_uint(obj['dq']),
- iqmp=from_base64url_uint(obj['qi']),
- public_numbers=public_numbers
- )
- else:
- d = from_base64url_uint(obj['d'])
- p, q = rsa_recover_prime_factors(
- public_numbers.n, d, public_numbers.e
- )
- numbers = RSAPrivateNumbers(
- d=d,
- p=p,
- q=q,
- dmp1=rsa_crt_dmp1(d, p),
- dmq1=rsa_crt_dmq1(d, q),
- iqmp=rsa_crt_iqmp(p, q),
- public_numbers=public_numbers
- )
- return numbers.private_key(default_backend())
- elif 'n' in obj and 'e' in obj:
- # Public key
- numbers = RSAPublicNumbers(
- from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
- )
- return numbers.public_key(default_backend())
- else:
- raise InvalidKeyError('Not a public or private key')
- def sign(self, msg, key):
- return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
- def verify(self, msg, key, sig):
- try:
- key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
- return True
- except InvalidSignature:
- return False
- class ECAlgorithm(Algorithm):
- """
- Performs signing and verification operations using
- ECDSA and the specified hash function
- """
- SHA256 = hashes.SHA256
- SHA384 = hashes.SHA384
- SHA512 = hashes.SHA512
- def __init__(self, hash_alg):
- self.hash_alg = hash_alg
- def prepare_key(self, key):
- if isinstance(key, EllipticCurvePrivateKey) or \
- isinstance(key, EllipticCurvePublicKey):
- return key
- if isinstance(key, string_types):
- key = force_bytes(key)
- # Attempt to load key. We don't know if it's
- # a Signing Key or a Verifying Key, so we try
- # the Verifying Key first.
- try:
- if key.startswith(b'ecdsa-sha2-'):
- key = load_ssh_public_key(key, backend=default_backend())
- else:
- key = load_pem_public_key(key, backend=default_backend())
- except ValueError:
- key = load_pem_private_key(key, password=None, backend=default_backend())
- else:
- raise TypeError('Expecting a PEM-formatted key.')
- return key
- def sign(self, msg, key):
- der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
- return der_to_raw_signature(der_sig, key.curve)
- def verify(self, msg, key, sig):
- try:
- der_sig = raw_to_der_signature(sig, key.curve)
- except ValueError:
- return False
- try:
- key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
- return True
- except InvalidSignature:
- return False
- class RSAPSSAlgorithm(RSAAlgorithm):
- """
- Performs a signature using RSASSA-PSS with MGF1
- """
- def sign(self, msg, key):
- return key.sign(
- msg,
- padding.PSS(
- mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size
- ),
- self.hash_alg()
- )
- def verify(self, msg, key, sig):
- try:
- key.verify(
- sig,
- msg,
- padding.PSS(
- mgf=padding.MGF1(self.hash_alg()),
- salt_length=self.hash_alg.digest_size
- ),
- self.hash_alg()
- )
- return True
- except InvalidSignature:
- return False
|