test_webauthn_handler.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import json
  2. import struct
  3. import mock
  4. import pytest # type: ignore
  5. from google.auth import exceptions
  6. from google.oauth2 import webauthn_handler
  7. from google.oauth2 import webauthn_types
  8. @pytest.fixture
  9. def os_get_stub():
  10. with mock.patch.object(
  11. webauthn_handler.os.environ,
  12. "get",
  13. return_value="gcloud_webauthn_plugin",
  14. name="fake os.environ.get",
  15. ) as mock_os_environ_get:
  16. yield mock_os_environ_get
  17. @pytest.fixture
  18. def subprocess_run_stub():
  19. with mock.patch.object(
  20. webauthn_handler.subprocess, "run", name="fake subprocess.run"
  21. ) as mock_subprocess_run:
  22. yield mock_subprocess_run
  23. def test_PluginHandler_is_available(os_get_stub):
  24. test_handler = webauthn_handler.PluginHandler()
  25. assert test_handler.is_available() is True
  26. os_get_stub.return_value = None
  27. assert test_handler.is_available() is False
  28. GET_ASSERTION_REQUEST = webauthn_types.GetRequest(
  29. origin="fake_origin",
  30. rpid="fake_rpid",
  31. challenge="fake_challenge",
  32. allow_credentials=[webauthn_types.PublicKeyCredentialDescriptor(id="fake_id_1")],
  33. )
  34. def test_malformated_get_assertion_response(os_get_stub, subprocess_run_stub):
  35. response_len = struct.pack("<I", 5)
  36. response = "1234567890"
  37. mock_response = mock.Mock()
  38. mock_response.stdout = response_len + response.encode()
  39. subprocess_run_stub.return_value = mock_response
  40. test_handler = webauthn_handler.PluginHandler()
  41. with pytest.raises(exceptions.MalformedError) as excinfo:
  42. test_handler.get(GET_ASSERTION_REQUEST)
  43. assert "Plugin response length" in str(excinfo.value)
  44. def test_failure_get_assertion(os_get_stub, subprocess_run_stub):
  45. failure_response = {
  46. "type": "getResponse",
  47. "error": "fake_plugin_get_assertion_failure",
  48. }
  49. response_json = json.dumps(failure_response).encode()
  50. response_len = struct.pack("<I", len(response_json))
  51. # process returns get response in json
  52. mock_response = mock.Mock()
  53. mock_response.stdout = response_len + response_json
  54. subprocess_run_stub.return_value = mock_response
  55. test_handler = webauthn_handler.PluginHandler()
  56. with pytest.raises(exceptions.ReauthFailError) as excinfo:
  57. test_handler.get(GET_ASSERTION_REQUEST)
  58. assert failure_response["error"] in str(excinfo.value)
  59. def test_success_get_assertion(os_get_stub, subprocess_run_stub):
  60. success_response = {
  61. "type": "public-key",
  62. "id": "fake-id",
  63. "authenticatorAttachment": "cross-platform",
  64. "clientExtensionResults": {"appid": True},
  65. "response": {
  66. "clientDataJSON": "fake_client_data_json_base64",
  67. "authenticatorData": "fake_authenticator_data_base64",
  68. "signature": "fake_signature_base64",
  69. "userHandle": "fake_user_handle_base64",
  70. },
  71. }
  72. valid_plugin_response = {"type": "getResponse", "responseData": success_response}
  73. valid_plugin_response_json = json.dumps(valid_plugin_response).encode()
  74. valid_plugin_response_len = struct.pack("<I", len(valid_plugin_response_json))
  75. # process returns get response in json
  76. mock_response = mock.Mock()
  77. mock_response.stdout = valid_plugin_response_len + valid_plugin_response_json
  78. subprocess_run_stub.return_value = mock_response
  79. # Call get()
  80. test_handler = webauthn_handler.PluginHandler()
  81. got_response = test_handler.get(GET_ASSERTION_REQUEST)
  82. # Validate expected plugin request
  83. os_get_stub.assert_called_once()
  84. subprocess_run_stub.assert_called_once()
  85. stdin_input = subprocess_run_stub.call_args.kwargs["input"]
  86. input_json_len_le = stdin_input[:4]
  87. input_json_len = struct.unpack("<I", input_json_len_le)[0]
  88. input_json = stdin_input[4:]
  89. assert len(input_json) == input_json_len
  90. input_dict = json.loads(input_json.decode("utf8"))
  91. assert input_dict == {
  92. "type": "get",
  93. "origin": "fake_origin",
  94. "requestData": {
  95. "rpid": "fake_rpid",
  96. "challenge": "fake_challenge",
  97. "allowCredentials": [{"type": "public-key", "id": "fake_id_1"}],
  98. },
  99. }
  100. # Validate get assertion response
  101. assert got_response.id == success_response["id"]
  102. assert (
  103. got_response.authenticator_attachment
  104. == success_response["authenticatorAttachment"]
  105. )
  106. assert (
  107. got_response.client_extension_results
  108. == success_response["clientExtensionResults"]
  109. )
  110. assert (
  111. got_response.response.client_data_json
  112. == success_response["response"]["clientDataJSON"]
  113. )
  114. assert (
  115. got_response.response.authenticator_data
  116. == success_response["response"]["authenticatorData"]
  117. )
  118. assert got_response.response.signature == success_response["response"]["signature"]
  119. assert (
  120. got_response.response.user_handle == success_response["response"]["userHandle"]
  121. )