webauthn_handler.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import abc
  2. import os
  3. import struct
  4. import subprocess
  5. from google.auth import exceptions
  6. from google.oauth2.webauthn_types import GetRequest, GetResponse
  7. class WebAuthnHandler(abc.ABC):
  8. @abc.abstractmethod
  9. def is_available(self) -> bool:
  10. """Check whether this WebAuthn handler is available"""
  11. raise NotImplementedError("is_available method must be implemented")
  12. @abc.abstractmethod
  13. def get(self, get_request: GetRequest) -> GetResponse:
  14. """WebAuthn get (assertion)"""
  15. raise NotImplementedError("get method must be implemented")
  16. class PluginHandler(WebAuthnHandler):
  17. """Offloads WebAuthn get reqeust to a pluggable command-line tool.
  18. Offloads WebAuthn get to a plugin which takes the form of a
  19. command-line tool. The command-line tool is configurable via the
  20. PluginHandler._ENV_VAR environment variable.
  21. The WebAuthn plugin should implement the following interface:
  22. Communication occurs over stdin/stdout, and messages are both sent and
  23. received in the form:
  24. [4 bytes - payload size (little-endian)][variable bytes - json payload]
  25. """
  26. _ENV_VAR = "GOOGLE_AUTH_WEBAUTHN_PLUGIN"
  27. def is_available(self) -> bool:
  28. try:
  29. self._find_plugin()
  30. except Exception:
  31. return False
  32. else:
  33. return True
  34. def get(self, get_request: GetRequest) -> GetResponse:
  35. request_json = get_request.to_json()
  36. cmd = self._find_plugin()
  37. response_json = self._call_plugin(cmd, request_json)
  38. return GetResponse.from_json(response_json)
  39. def _call_plugin(self, cmd: str, input_json: str) -> str:
  40. # Calculate length of input
  41. input_length = len(input_json)
  42. length_bytes_le = struct.pack("<I", input_length)
  43. request = length_bytes_le + input_json.encode()
  44. # Call plugin
  45. process_result = subprocess.run(
  46. [cmd], input=request, capture_output=True, check=True
  47. )
  48. # Check length of response
  49. response_len_le = process_result.stdout[:4]
  50. response_len = struct.unpack("<I", response_len_le)[0]
  51. response = process_result.stdout[4:]
  52. if response_len != len(response):
  53. raise exceptions.MalformedError(
  54. "Plugin response length {} does not match data {}".format(
  55. response_len, len(response)
  56. )
  57. )
  58. return response.decode()
  59. def _find_plugin(self) -> str:
  60. plugin_cmd = os.environ.get(PluginHandler._ENV_VAR)
  61. if plugin_cmd is None:
  62. raise exceptions.InvalidResource(
  63. "{} env var is not set".format(PluginHandler._ENV_VAR)
  64. )
  65. return plugin_cmd