test_mtls.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright 2020 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import mock
  15. import pytest # type: ignore
  16. from google.auth import exceptions
  17. from google.auth.transport import _mtls_helper
  18. from google.auth.transport import mtls
  19. @mock.patch("google.auth.transport._mtls_helper._check_config_path", autospec=True)
  20. def test_has_default_client_cert_source(check_config_path):
  21. def return_path_for_metadata(path):
  22. return mock.Mock() if path == _mtls_helper.CONTEXT_AWARE_METADATA_PATH else None
  23. check_config_path.side_effect = return_path_for_metadata
  24. assert mtls.has_default_client_cert_source()
  25. def return_path_for_cert_config(path):
  26. return (
  27. mock.Mock()
  28. if path == _mtls_helper.CERTIFICATE_CONFIGURATION_DEFAULT_PATH
  29. else None
  30. )
  31. check_config_path.side_effect = return_path_for_cert_config
  32. assert mtls.has_default_client_cert_source()
  33. check_config_path.side_effect = None
  34. check_config_path.return_value = None
  35. assert not mtls.has_default_client_cert_source()
  36. @mock.patch("google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True)
  37. @mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True)
  38. def test_default_client_cert_source(
  39. has_default_client_cert_source, get_client_cert_and_key
  40. ):
  41. # Test default client cert source doesn't exist.
  42. has_default_client_cert_source.return_value = False
  43. with pytest.raises(exceptions.MutualTLSChannelError):
  44. mtls.default_client_cert_source()
  45. # The following tests will assume default client cert source exists.
  46. has_default_client_cert_source.return_value = True
  47. # Test good callback.
  48. get_client_cert_and_key.return_value = (True, b"cert", b"key")
  49. callback = mtls.default_client_cert_source()
  50. assert callback() == (b"cert", b"key")
  51. # Test bad callback which throws exception.
  52. get_client_cert_and_key.side_effect = ValueError()
  53. callback = mtls.default_client_cert_source()
  54. with pytest.raises(exceptions.MutualTLSChannelError):
  55. callback()
  56. @mock.patch(
  57. "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
  58. )
  59. @mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True)
  60. def test_default_client_encrypted_cert_source(
  61. has_default_client_cert_source, get_client_ssl_credentials
  62. ):
  63. # Test default client cert source doesn't exist.
  64. has_default_client_cert_source.return_value = False
  65. with pytest.raises(exceptions.MutualTLSChannelError):
  66. mtls.default_client_encrypted_cert_source("cert_path", "key_path")
  67. # The following tests will assume default client cert source exists.
  68. has_default_client_cert_source.return_value = True
  69. # Test good callback.
  70. get_client_ssl_credentials.return_value = (True, b"cert", b"key", b"passphrase")
  71. callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
  72. with mock.patch("{}.open".format(__name__), return_value=mock.MagicMock()):
  73. assert callback() == ("cert_path", "key_path", b"passphrase")
  74. # Test bad callback which throws exception.
  75. get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
  76. callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
  77. with pytest.raises(exceptions.MutualTLSChannelError):
  78. callback()