test_grpc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. # Copyright 2016 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import os
  16. import time
  17. import mock
  18. import pytest # type: ignore
  19. from google.auth import _helpers
  20. from google.auth import credentials
  21. from google.auth import environment_vars
  22. from google.auth import exceptions
  23. from google.auth import transport
  24. from google.oauth2 import service_account
  25. try:
  26. # pylint: disable=ungrouped-imports
  27. import grpc # type: ignore
  28. import google.auth.transport.grpc
  29. HAS_GRPC = True
  30. except ImportError: # pragma: NO COVER
  31. HAS_GRPC = False
  32. import yatest.common as yc
  33. DATA_DIR = os.path.join(os.path.dirname(yc.source_path(__file__)), "..", "data")
  34. METADATA_PATH = os.path.join(DATA_DIR, "context_aware_metadata.json")
  35. with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
  36. PRIVATE_KEY_BYTES = fh.read()
  37. with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
  38. PUBLIC_CERT_BYTES = fh.read()
  39. pytestmark = pytest.mark.skipif(not HAS_GRPC, reason="gRPC is unavailable.")
  40. class CredentialsStub(credentials.Credentials):
  41. def __init__(self, token="token"):
  42. super(CredentialsStub, self).__init__()
  43. self.token = token
  44. self.expiry = None
  45. def refresh(self, request):
  46. self.token += "1"
  47. def with_quota_project(self, quota_project_id):
  48. raise NotImplementedError()
  49. class TestAuthMetadataPlugin(object):
  50. def test_call_no_refresh(self):
  51. credentials = CredentialsStub()
  52. request = mock.create_autospec(transport.Request)
  53. plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
  54. context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
  55. context.method_name = mock.sentinel.method_name
  56. context.service_url = mock.sentinel.service_url
  57. callback = mock.create_autospec(grpc.AuthMetadataPluginCallback)
  58. plugin(context, callback)
  59. time.sleep(2)
  60. callback.assert_called_once_with(
  61. [("authorization", "Bearer {}".format(credentials.token))], None
  62. )
  63. def test_call_refresh(self):
  64. credentials = CredentialsStub()
  65. credentials.expiry = datetime.datetime.min + _helpers.REFRESH_THRESHOLD
  66. request = mock.create_autospec(transport.Request)
  67. plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
  68. context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
  69. context.method_name = mock.sentinel.method_name
  70. context.service_url = mock.sentinel.service_url
  71. callback = mock.create_autospec(grpc.AuthMetadataPluginCallback)
  72. plugin(context, callback)
  73. time.sleep(2)
  74. assert credentials.token == "token1"
  75. callback.assert_called_once_with(
  76. [("authorization", "Bearer {}".format(credentials.token))], None
  77. )
  78. def test__get_authorization_headers_with_service_account(self):
  79. credentials = mock.create_autospec(service_account.Credentials)
  80. request = mock.create_autospec(transport.Request)
  81. plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
  82. context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
  83. context.method_name = "methodName"
  84. context.service_url = "https://pubsub.googleapis.com/methodName"
  85. plugin._get_authorization_headers(context)
  86. credentials._create_self_signed_jwt.assert_called_once_with(None)
  87. def test__get_authorization_headers_with_service_account_and_default_host(self):
  88. credentials = mock.create_autospec(service_account.Credentials)
  89. request = mock.create_autospec(transport.Request)
  90. default_host = "pubsub.googleapis.com"
  91. plugin = google.auth.transport.grpc.AuthMetadataPlugin(
  92. credentials, request, default_host=default_host
  93. )
  94. context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
  95. context.method_name = "methodName"
  96. context.service_url = "https://pubsub.googleapis.com/methodName"
  97. plugin._get_authorization_headers(context)
  98. credentials._create_self_signed_jwt.assert_called_once_with(
  99. "https://{}/".format(default_host)
  100. )
  101. @mock.patch(
  102. "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
  103. )
  104. @mock.patch("grpc.composite_channel_credentials", autospec=True)
  105. @mock.patch("grpc.metadata_call_credentials", autospec=True)
  106. @mock.patch("grpc.ssl_channel_credentials", autospec=True)
  107. @mock.patch("grpc.secure_channel", autospec=True)
  108. class TestSecureAuthorizedChannel(object):
  109. @mock.patch(
  110. "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
  111. )
  112. @mock.patch(
  113. "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
  114. )
  115. def test_secure_authorized_channel_adc(
  116. self,
  117. check_dca_metadata_path,
  118. read_dca_metadata_file,
  119. secure_channel,
  120. ssl_channel_credentials,
  121. metadata_call_credentials,
  122. composite_channel_credentials,
  123. get_client_ssl_credentials,
  124. ):
  125. credentials = CredentialsStub()
  126. request = mock.create_autospec(transport.Request)
  127. target = "example.com:80"
  128. # Mock the context aware metadata and client cert/key so mTLS SSL channel
  129. # will be used.
  130. check_dca_metadata_path.return_value = METADATA_PATH
  131. read_dca_metadata_file.return_value = {
  132. "cert_provider_command": ["some command"]
  133. }
  134. get_client_ssl_credentials.return_value = (
  135. True,
  136. PUBLIC_CERT_BYTES,
  137. PRIVATE_KEY_BYTES,
  138. None,
  139. )
  140. channel = None
  141. with mock.patch.dict(
  142. os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
  143. ):
  144. channel = google.auth.transport.grpc.secure_authorized_channel(
  145. credentials, request, target, options=mock.sentinel.options
  146. )
  147. # Check the auth plugin construction.
  148. auth_plugin = metadata_call_credentials.call_args[0][0]
  149. assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
  150. assert auth_plugin._credentials == credentials
  151. assert auth_plugin._request == request
  152. # Check the ssl channel call.
  153. ssl_channel_credentials.assert_called_once_with(
  154. certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
  155. )
  156. # Check the composite credentials call.
  157. composite_channel_credentials.assert_called_once_with(
  158. ssl_channel_credentials.return_value, metadata_call_credentials.return_value
  159. )
  160. # Check the channel call.
  161. secure_channel.assert_called_once_with(
  162. target,
  163. composite_channel_credentials.return_value,
  164. options=mock.sentinel.options,
  165. )
  166. assert channel == secure_channel.return_value
  167. @mock.patch("google.auth.transport.grpc.SslCredentials", autospec=True)
  168. def test_secure_authorized_channel_adc_without_client_cert_env(
  169. self,
  170. ssl_credentials_adc_method,
  171. secure_channel,
  172. ssl_channel_credentials,
  173. metadata_call_credentials,
  174. composite_channel_credentials,
  175. get_client_ssl_credentials,
  176. ):
  177. # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE
  178. # environment variable is not set.
  179. credentials = CredentialsStub()
  180. request = mock.create_autospec(transport.Request)
  181. target = "example.com:80"
  182. channel = google.auth.transport.grpc.secure_authorized_channel(
  183. credentials, request, target, options=mock.sentinel.options
  184. )
  185. # Check the auth plugin construction.
  186. auth_plugin = metadata_call_credentials.call_args[0][0]
  187. assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
  188. assert auth_plugin._credentials == credentials
  189. assert auth_plugin._request == request
  190. # Check the ssl channel call.
  191. ssl_channel_credentials.assert_called_once()
  192. ssl_credentials_adc_method.assert_not_called()
  193. # Check the composite credentials call.
  194. composite_channel_credentials.assert_called_once_with(
  195. ssl_channel_credentials.return_value, metadata_call_credentials.return_value
  196. )
  197. # Check the channel call.
  198. secure_channel.assert_called_once_with(
  199. target,
  200. composite_channel_credentials.return_value,
  201. options=mock.sentinel.options,
  202. )
  203. assert channel == secure_channel.return_value
  204. def test_secure_authorized_channel_explicit_ssl(
  205. self,
  206. secure_channel,
  207. ssl_channel_credentials,
  208. metadata_call_credentials,
  209. composite_channel_credentials,
  210. get_client_ssl_credentials,
  211. ):
  212. credentials = mock.Mock()
  213. request = mock.Mock()
  214. target = "example.com:80"
  215. ssl_credentials = mock.Mock()
  216. google.auth.transport.grpc.secure_authorized_channel(
  217. credentials, request, target, ssl_credentials=ssl_credentials
  218. )
  219. # Since explicit SSL credentials are provided, get_client_ssl_credentials
  220. # shouldn't be called.
  221. assert not get_client_ssl_credentials.called
  222. # Check the ssl channel call.
  223. assert not ssl_channel_credentials.called
  224. # Check the composite credentials call.
  225. composite_channel_credentials.assert_called_once_with(
  226. ssl_credentials, metadata_call_credentials.return_value
  227. )
  228. def test_secure_authorized_channel_mutual_exclusive(
  229. self,
  230. secure_channel,
  231. ssl_channel_credentials,
  232. metadata_call_credentials,
  233. composite_channel_credentials,
  234. get_client_ssl_credentials,
  235. ):
  236. credentials = mock.Mock()
  237. request = mock.Mock()
  238. target = "example.com:80"
  239. ssl_credentials = mock.Mock()
  240. client_cert_callback = mock.Mock()
  241. with pytest.raises(ValueError):
  242. google.auth.transport.grpc.secure_authorized_channel(
  243. credentials,
  244. request,
  245. target,
  246. ssl_credentials=ssl_credentials,
  247. client_cert_callback=client_cert_callback,
  248. )
  249. def test_secure_authorized_channel_with_client_cert_callback_success(
  250. self,
  251. secure_channel,
  252. ssl_channel_credentials,
  253. metadata_call_credentials,
  254. composite_channel_credentials,
  255. get_client_ssl_credentials,
  256. ):
  257. credentials = mock.Mock()
  258. request = mock.Mock()
  259. target = "example.com:80"
  260. client_cert_callback = mock.Mock()
  261. client_cert_callback.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
  262. with mock.patch.dict(
  263. os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
  264. ):
  265. google.auth.transport.grpc.secure_authorized_channel(
  266. credentials, request, target, client_cert_callback=client_cert_callback
  267. )
  268. client_cert_callback.assert_called_once()
  269. # Check we are using the cert and key provided by client_cert_callback.
  270. ssl_channel_credentials.assert_called_once_with(
  271. certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
  272. )
  273. # Check the composite credentials call.
  274. composite_channel_credentials.assert_called_once_with(
  275. ssl_channel_credentials.return_value, metadata_call_credentials.return_value
  276. )
  277. @mock.patch(
  278. "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
  279. )
  280. @mock.patch(
  281. "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
  282. )
  283. def test_secure_authorized_channel_with_client_cert_callback_failure(
  284. self,
  285. check_dca_metadata_path,
  286. read_dca_metadata_file,
  287. secure_channel,
  288. ssl_channel_credentials,
  289. metadata_call_credentials,
  290. composite_channel_credentials,
  291. get_client_ssl_credentials,
  292. ):
  293. credentials = mock.Mock()
  294. request = mock.Mock()
  295. target = "example.com:80"
  296. client_cert_callback = mock.Mock()
  297. client_cert_callback.side_effect = Exception("callback exception")
  298. with pytest.raises(Exception) as excinfo:
  299. with mock.patch.dict(
  300. os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
  301. ):
  302. google.auth.transport.grpc.secure_authorized_channel(
  303. credentials,
  304. request,
  305. target,
  306. client_cert_callback=client_cert_callback,
  307. )
  308. assert str(excinfo.value) == "callback exception"
  309. def test_secure_authorized_channel_cert_callback_without_client_cert_env(
  310. self,
  311. secure_channel,
  312. ssl_channel_credentials,
  313. metadata_call_credentials,
  314. composite_channel_credentials,
  315. get_client_ssl_credentials,
  316. ):
  317. # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE
  318. # environment variable is not set.
  319. credentials = mock.Mock()
  320. request = mock.Mock()
  321. target = "example.com:80"
  322. client_cert_callback = mock.Mock()
  323. google.auth.transport.grpc.secure_authorized_channel(
  324. credentials, request, target, client_cert_callback=client_cert_callback
  325. )
  326. # Check client_cert_callback is not called because GOOGLE_API_USE_CLIENT_CERTIFICATE
  327. # is not set.
  328. client_cert_callback.assert_not_called()
  329. ssl_channel_credentials.assert_called_once()
  330. # Check the composite credentials call.
  331. composite_channel_credentials.assert_called_once_with(
  332. ssl_channel_credentials.return_value, metadata_call_credentials.return_value
  333. )
  334. @mock.patch("grpc.ssl_channel_credentials", autospec=True)
  335. @mock.patch(
  336. "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
  337. )
  338. @mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True)
  339. @mock.patch(
  340. "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
  341. )
  342. class TestSslCredentials(object):
  343. def test_no_context_aware_metadata(
  344. self,
  345. mock_check_dca_metadata_path,
  346. mock_read_dca_metadata_file,
  347. mock_get_client_ssl_credentials,
  348. mock_ssl_channel_credentials,
  349. ):
  350. # Mock that the metadata file doesn't exist.
  351. mock_check_dca_metadata_path.return_value = None
  352. with mock.patch.dict(
  353. os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
  354. ):
  355. ssl_credentials = google.auth.transport.grpc.SslCredentials()
  356. # Since no context aware metadata is found, we wouldn't call
  357. # get_client_ssl_credentials, and the SSL channel credentials created is
  358. # non mTLS.
  359. assert ssl_credentials.ssl_credentials is not None
  360. assert not ssl_credentials.is_mtls
  361. mock_get_client_ssl_credentials.assert_not_called()
  362. mock_ssl_channel_credentials.assert_called_once_with()
  363. def test_get_client_ssl_credentials_failure(
  364. self,
  365. mock_check_dca_metadata_path,
  366. mock_read_dca_metadata_file,
  367. mock_get_client_ssl_credentials,
  368. mock_ssl_channel_credentials,
  369. ):
  370. mock_check_dca_metadata_path.return_value = METADATA_PATH
  371. mock_read_dca_metadata_file.return_value = {
  372. "cert_provider_command": ["some command"]
  373. }
  374. # Mock that client cert and key are not loaded and exception is raised.
  375. mock_get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
  376. with pytest.raises(exceptions.MutualTLSChannelError):
  377. with mock.patch.dict(
  378. os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
  379. ):
  380. assert google.auth.transport.grpc.SslCredentials().ssl_credentials
  381. def test_get_client_ssl_credentials_success(
  382. self,
  383. mock_check_dca_metadata_path,
  384. mock_read_dca_metadata_file,
  385. mock_get_client_ssl_credentials,
  386. mock_ssl_channel_credentials,
  387. ):
  388. mock_check_dca_metadata_path.return_value = METADATA_PATH
  389. mock_read_dca_metadata_file.return_value = {
  390. "cert_provider_command": ["some command"]
  391. }
  392. mock_get_client_ssl_credentials.return_value = (
  393. True,
  394. PUBLIC_CERT_BYTES,
  395. PRIVATE_KEY_BYTES,
  396. None,
  397. )
  398. with mock.patch.dict(
  399. os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
  400. ):
  401. ssl_credentials = google.auth.transport.grpc.SslCredentials()
  402. assert ssl_credentials.ssl_credentials is not None
  403. assert ssl_credentials.is_mtls
  404. mock_get_client_ssl_credentials.assert_called_once()
  405. mock_ssl_channel_credentials.assert_called_once_with(
  406. certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
  407. )
  408. def test_get_client_ssl_credentials_without_client_cert_env(
  409. self,
  410. mock_check_dca_metadata_path,
  411. mock_read_dca_metadata_file,
  412. mock_get_client_ssl_credentials,
  413. mock_ssl_channel_credentials,
  414. ):
  415. # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE is not set.
  416. ssl_credentials = google.auth.transport.grpc.SslCredentials()
  417. assert ssl_credentials.ssl_credentials is not None
  418. assert not ssl_credentials.is_mtls
  419. mock_check_dca_metadata_path.assert_not_called()
  420. mock_read_dca_metadata_file.assert_not_called()
  421. mock_get_client_ssl_credentials.assert_not_called()
  422. mock_ssl_channel_credentials.assert_called_once()