test_jws.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from datetime import timedelta
  2. from functools import partial
  3. import pytest
  4. from .test_serializer import TestSerializer
  5. from .test_timed import TestTimedSerializer
  6. from itsdangerous.exc import BadData
  7. from itsdangerous.exc import BadHeader
  8. from itsdangerous.exc import BadPayload
  9. from itsdangerous.exc import BadSignature
  10. from itsdangerous.exc import SignatureExpired
  11. from itsdangerous.jws import JSONWebSignatureSerializer
  12. from itsdangerous.jws import TimedJSONWebSignatureSerializer
  13. class TestJWSSerializer(TestSerializer):
  14. @pytest.fixture()
  15. def serializer_factory(self):
  16. return partial(JSONWebSignatureSerializer, secret_key="secret-key")
  17. test_signer_cls = None
  18. test_signer_kwargs = None
  19. test_fallback_signers = None
  20. test_iter_unsigners = None
  21. @pytest.mark.parametrize("algorithm_name", ("HS256", "HS384", "HS512", "none"))
  22. def test_algorithm(self, serializer_factory, algorithm_name):
  23. serializer = serializer_factory(algorithm_name=algorithm_name)
  24. assert serializer.loads(serializer.dumps("value")) == "value"
  25. def test_invalid_algorithm(self, serializer_factory):
  26. with pytest.raises(NotImplementedError) as exc_info:
  27. serializer_factory(algorithm_name="invalid")
  28. assert "not supported" in str(exc_info.value)
  29. def test_algorithm_mismatch(self, serializer_factory, serializer):
  30. other = serializer_factory(algorithm_name="HS256")
  31. other.algorithm = serializer.algorithm
  32. signed = other.dumps("value")
  33. with pytest.raises(BadHeader) as exc_info:
  34. serializer.loads(signed)
  35. assert "mismatch" in str(exc_info.value)
  36. @pytest.mark.parametrize(
  37. ("value", "exc_cls", "match"),
  38. (
  39. ("ab", BadPayload, '"."'),
  40. ("a.b", BadHeader, "base64 decode"),
  41. ("ew.b", BadPayload, "base64 decode"),
  42. ("ew.ab", BadData, "malformed"),
  43. ("W10.ab", BadHeader, "JSON object"),
  44. ),
  45. )
  46. def test_load_payload_exceptions(self, serializer, value, exc_cls, match):
  47. signer = serializer.make_signer()
  48. signed = signer.sign(value)
  49. with pytest.raises(exc_cls) as exc_info:
  50. serializer.loads(signed)
  51. assert match in str(exc_info.value)
  52. class TestTimedJWSSerializer(TestJWSSerializer, TestTimedSerializer):
  53. @pytest.fixture()
  54. def serializer_factory(self):
  55. return partial(
  56. TimedJSONWebSignatureSerializer, secret_key="secret-key", expires_in=10
  57. )
  58. def test_default_expires_in(self, serializer_factory):
  59. serializer = serializer_factory(expires_in=None)
  60. assert serializer.expires_in == serializer.DEFAULT_EXPIRES_IN
  61. test_max_age = None
  62. def test_exp(self, serializer, value, ts, freeze):
  63. signed = serializer.dumps(value)
  64. freeze.tick()
  65. assert serializer.loads(signed) == value
  66. freeze.tick(timedelta(seconds=10))
  67. with pytest.raises(SignatureExpired) as exc_info:
  68. serializer.loads(signed)
  69. assert exc_info.value.date_signed == ts
  70. assert exc_info.value.payload == value
  71. test_return_payload = None
  72. def test_return_header(self, serializer, value, ts):
  73. signed = serializer.dumps(value)
  74. payload, header = serializer.loads(signed, return_header=True)
  75. date_signed = serializer.get_issue_date(header)
  76. assert (payload, date_signed) == (value, ts)
  77. def test_missing_exp(self, serializer):
  78. header = serializer.make_header(None)
  79. del header["exp"]
  80. signer = serializer.make_signer()
  81. signed = signer.sign(serializer.dump_payload(header, "value"))
  82. with pytest.raises(BadSignature):
  83. serializer.loads(signed)
  84. @pytest.mark.parametrize("exp", ("invalid", -1))
  85. def test_invalid_exp(self, serializer, exp):
  86. header = serializer.make_header(None)
  87. header["exp"] = exp
  88. signer = serializer.make_signer()
  89. signed = signer.sign(serializer.dump_payload(header, "value"))
  90. with pytest.raises(BadHeader) as exc_info:
  91. serializer.loads(signed)
  92. assert "IntDate" in str(exc_info.value)
  93. def test_invalid_iat(self, serializer):
  94. header = serializer.make_header(None)
  95. header["iat"] = "invalid"
  96. assert serializer.get_issue_date(header) is None