test__metadata.py 19 KB


  1. # Copyright 2016 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import http.client as http_client
  16. import importlib
  17. import json
  18. import os
  19. import mock
  20. import pytest # type: ignore
  21. from google.auth import _helpers
  22. from google.auth import environment_vars
  23. from google.auth import exceptions
  24. from google.auth import transport
  25. from google.auth.compute_engine import _metadata
  26. PATH = "instance/service-accounts/default"
  27. DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
  28. SMBIOS_PRODUCT_NAME_FILE = os.path.join(DATA_DIR, "smbios_product_name")
  29. SMBIOS_PRODUCT_NAME_NONEXISTENT_FILE = os.path.join(
  30. DATA_DIR, "smbios_product_name_nonexistent"
  31. )
  32. SMBIOS_PRODUCT_NAME_NON_GOOGLE = os.path.join(
  33. DATA_DIR, "smbios_product_name_non_google"
  34. )
  35. ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE = (
  36. "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds"
  37. )
  38. MDS_PING_METRICS_HEADER_VALUE = "gl-python/3.7 auth/1.1 auth-request-type/mds"
  39. MDS_PING_REQUEST_HEADER = {
  40. "metadata-flavor": "Google",
  41. "x-goog-api-client": MDS_PING_METRICS_HEADER_VALUE,
  42. }
  43. def make_request(data, status=http_client.OK, headers=None, retry=False):
  44. response = mock.create_autospec(transport.Response, instance=True)
  45. response.status = status
  46. response.data = _helpers.to_bytes(data)
  47. response.headers = headers or {}
  48. request = mock.create_autospec(transport.Request)
  49. if retry:
  50. request.side_effect = [exceptions.TransportError(), response]
  51. else:
  52. request.return_value = response
  53. return request
  54. @pytest.mark.xfail
  55. def test_detect_gce_residency_linux_success():
  56. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_FILE
  57. assert _metadata.detect_gce_residency_linux()
  58. def test_detect_gce_residency_linux_non_google():
  59. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_NON_GOOGLE
  60. assert not _metadata.detect_gce_residency_linux()
  61. def test_detect_gce_residency_linux_nonexistent():
  62. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_NONEXISTENT_FILE
  63. assert not _metadata.detect_gce_residency_linux()
  64. def test_is_on_gce_ping_success():
  65. request = make_request("", headers=_metadata._METADATA_HEADERS)
  66. assert _metadata.is_on_gce(request)
  67. @mock.patch("os.name", new="nt")
  68. def test_is_on_gce_windows_success():
  69. request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"})
  70. assert not _metadata.is_on_gce(request)
  71. @pytest.mark.xfail
  72. @mock.patch("os.name", new="posix")
  73. def test_is_on_gce_linux_success():
  74. request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"})
  75. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_FILE
  76. assert _metadata.is_on_gce(request)
  77. @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
  78. def test_ping_success(mock_metrics_header_value):
  79. request = make_request("", headers=_metadata._METADATA_HEADERS)
  80. assert _metadata.ping(request)
  81. request.assert_called_once_with(
  82. method="GET",
  83. url=_metadata._METADATA_IP_ROOT,
  84. headers=MDS_PING_REQUEST_HEADER,
  85. timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
  86. )
  87. @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
  88. def test_ping_success_retry(mock_metrics_header_value):
  89. request = make_request("", headers=_metadata._METADATA_HEADERS, retry=True)
  90. assert _metadata.ping(request)
  91. request.assert_called_with(
  92. method="GET",
  93. url=_metadata._METADATA_IP_ROOT,
  94. headers=MDS_PING_REQUEST_HEADER,
  95. timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
  96. )
  97. assert request.call_count == 2
  98. @mock.patch("time.sleep", return_value=None)
  99. def test_ping_failure_bad_flavor(mock_sleep):
  100. request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"})
  101. assert not _metadata.ping(request)
  102. @mock.patch("time.sleep", return_value=None)
  103. def test_ping_failure_connection_failed(mock_sleep):
  104. request = make_request("")
  105. request.side_effect = exceptions.TransportError()
  106. assert not _metadata.ping(request)
  107. @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
  108. def _test_ping_success_custom_root(mock_metrics_header_value):
  109. request = make_request("", headers=_metadata._METADATA_HEADERS)
  110. fake_ip = "1.2.3.4"
  111. os.environ[environment_vars.GCE_METADATA_IP] = fake_ip
  112. importlib.reload(_metadata)
  113. try:
  114. assert _metadata.ping(request)
  115. finally:
  116. del os.environ[environment_vars.GCE_METADATA_IP]
  117. importlib.reload(_metadata)
  118. request.assert_called_once_with(
  119. method="GET",
  120. url="http://" + fake_ip,
  121. headers=MDS_PING_REQUEST_HEADER,
  122. timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
  123. )
  124. def test_get_success_json():
  125. key, value = "foo", "bar"
  126. data = json.dumps({key: value})
  127. request = make_request(data, headers={"content-type": "application/json"})
  128. result = _metadata.get(request, PATH)
  129. request.assert_called_once_with(
  130. method="GET",
  131. url=_metadata._METADATA_ROOT + PATH,
  132. headers=_metadata._METADATA_HEADERS,
  133. )
  134. assert result[key] == value
  135. def test_get_success_json_content_type_charset():
  136. key, value = "foo", "bar"
  137. data = json.dumps({key: value})
  138. request = make_request(
  139. data, headers={"content-type": "application/json; charset=UTF-8"}
  140. )
  141. result = _metadata.get(request, PATH)
  142. request.assert_called_once_with(
  143. method="GET",
  144. url=_metadata._METADATA_ROOT + PATH,
  145. headers=_metadata._METADATA_HEADERS,
  146. )
  147. assert result[key] == value
  148. @mock.patch("time.sleep", return_value=None)
  149. def test_get_success_retry(mock_sleep):
  150. key, value = "foo", "bar"
  151. data = json.dumps({key: value})
  152. request = make_request(
  153. data, headers={"content-type": "application/json"}, retry=True
  154. )
  155. result = _metadata.get(request, PATH)
  156. request.assert_called_with(
  157. method="GET",
  158. url=_metadata._METADATA_ROOT + PATH,
  159. headers=_metadata._METADATA_HEADERS,
  160. )
  161. assert request.call_count == 2
  162. assert result[key] == value
  163. def test_get_success_text():
  164. data = "foobar"
  165. request = make_request(data, headers={"content-type": "text/plain"})
  166. result = _metadata.get(request, PATH)
  167. request.assert_called_once_with(
  168. method="GET",
  169. url=_metadata._METADATA_ROOT + PATH,
  170. headers=_metadata._METADATA_HEADERS,
  171. )
  172. assert result == data
  173. def test_get_success_params():
  174. data = "foobar"
  175. request = make_request(data, headers={"content-type": "text/plain"})
  176. params = {"recursive": "true"}
  177. result = _metadata.get(request, PATH, params=params)
  178. request.assert_called_once_with(
  179. method="GET",
  180. url=_metadata._METADATA_ROOT + PATH + "?recursive=true",
  181. headers=_metadata._METADATA_HEADERS,
  182. )
  183. assert result == data
  184. def test_get_success_recursive_and_params():
  185. data = "foobar"
  186. request = make_request(data, headers={"content-type": "text/plain"})
  187. params = {"recursive": "false"}
  188. result = _metadata.get(request, PATH, recursive=True, params=params)
  189. request.assert_called_once_with(
  190. method="GET",
  191. url=_metadata._METADATA_ROOT + PATH + "?recursive=true",
  192. headers=_metadata._METADATA_HEADERS,
  193. )
  194. assert result == data
  195. def test_get_success_recursive():
  196. data = "foobar"
  197. request = make_request(data, headers={"content-type": "text/plain"})
  198. result = _metadata.get(request, PATH, recursive=True)
  199. request.assert_called_once_with(
  200. method="GET",
  201. url=_metadata._METADATA_ROOT + PATH + "?recursive=true",
  202. headers=_metadata._METADATA_HEADERS,
  203. )
  204. assert result == data
  205. def _test_get_success_custom_root_new_variable():
  206. request = make_request("{}", headers={"content-type": "application/json"})
  207. fake_root = "another.metadata.service"
  208. os.environ[environment_vars.GCE_METADATA_HOST] = fake_root
  209. importlib.reload(_metadata)
  210. try:
  211. _metadata.get(request, PATH)
  212. finally:
  213. del os.environ[environment_vars.GCE_METADATA_HOST]
  214. importlib.reload(_metadata)
  215. request.assert_called_once_with(
  216. method="GET",
  217. url="http://{}/computeMetadata/v1/{}".format(fake_root, PATH),
  218. headers=_metadata._METADATA_HEADERS,
  219. )
  220. def _test_get_success_custom_root_old_variable():
  221. request = make_request("{}", headers={"content-type": "application/json"})
  222. fake_root = "another.metadata.service"
  223. os.environ[environment_vars.GCE_METADATA_ROOT] = fake_root
  224. importlib.reload(_metadata)
  225. try:
  226. _metadata.get(request, PATH)
  227. finally:
  228. del os.environ[environment_vars.GCE_METADATA_ROOT]
  229. importlib.reload(_metadata)
  230. request.assert_called_once_with(
  231. method="GET",
  232. url="http://{}/computeMetadata/v1/{}".format(fake_root, PATH),
  233. headers=_metadata._METADATA_HEADERS,
  234. )
  235. @mock.patch("time.sleep", return_value=None)
  236. def test_get_failure(mock_sleep):
  237. request = make_request("Metadata error", status=http_client.NOT_FOUND)
  238. with pytest.raises(exceptions.TransportError) as excinfo:
  239. _metadata.get(request, PATH)
  240. assert excinfo.match(r"Metadata error")
  241. request.assert_called_once_with(
  242. method="GET",
  243. url=_metadata._METADATA_ROOT + PATH,
  244. headers=_metadata._METADATA_HEADERS,
  245. )
  246. def test_get_return_none_for_not_found_error():
  247. request = make_request("Metadata error", status=http_client.NOT_FOUND)
  248. assert _metadata.get(request, PATH, return_none_for_not_found_error=True) is None
  249. request.assert_called_once_with(
  250. method="GET",
  251. url=_metadata._METADATA_ROOT + PATH,
  252. headers=_metadata._METADATA_HEADERS,
  253. )
  254. @mock.patch("time.sleep", return_value=None)
  255. def test_get_failure_connection_failed(mock_sleep):
  256. request = make_request("")
  257. request.side_effect = exceptions.TransportError()
  258. with pytest.raises(exceptions.TransportError) as excinfo:
  259. _metadata.get(request, PATH)
  260. assert excinfo.match(r"Compute Engine Metadata server unavailable")
  261. request.assert_called_with(
  262. method="GET",
  263. url=_metadata._METADATA_ROOT + PATH,
  264. headers=_metadata._METADATA_HEADERS,
  265. )
  266. assert request.call_count == 5
  267. def test_get_failure_bad_json():
  268. request = make_request("{", headers={"content-type": "application/json"})
  269. with pytest.raises(exceptions.TransportError) as excinfo:
  270. _metadata.get(request, PATH)
  271. assert excinfo.match(r"invalid JSON")
  272. request.assert_called_once_with(
  273. method="GET",
  274. url=_metadata._METADATA_ROOT + PATH,
  275. headers=_metadata._METADATA_HEADERS,
  276. )
  277. def test_get_project_id():
  278. project = "example-project"
  279. request = make_request(project, headers={"content-type": "text/plain"})
  280. project_id = _metadata.get_project_id(request)
  281. request.assert_called_once_with(
  282. method="GET",
  283. url=_metadata._METADATA_ROOT + "project/project-id",
  284. headers=_metadata._METADATA_HEADERS,
  285. )
  286. assert project_id == project
  287. def test_get_universe_domain_success():
  288. request = make_request(
  289. "fake_universe_domain", headers={"content-type": "text/plain"}
  290. )
  291. universe_domain = _metadata.get_universe_domain(request)
  292. request.assert_called_once_with(
  293. method="GET",
  294. url=_metadata._METADATA_ROOT + "universe/universe_domain",
  295. headers=_metadata._METADATA_HEADERS,
  296. )
  297. assert universe_domain == "fake_universe_domain"
  298. def test_get_universe_domain_success_empty_response():
  299. request = make_request("", headers={"content-type": "text/plain"})
  300. universe_domain = _metadata.get_universe_domain(request)
  301. request.assert_called_once_with(
  302. method="GET",
  303. url=_metadata._METADATA_ROOT + "universe/universe_domain",
  304. headers=_metadata._METADATA_HEADERS,
  305. )
  306. assert universe_domain == "googleapis.com"
  307. def test_get_universe_domain_not_found():
  308. # Test that if the universe domain endpoint returns 404 error, we should
  309. # use googleapis.com as the universe domain
  310. request = make_request("not found", status=http_client.NOT_FOUND)
  311. universe_domain = _metadata.get_universe_domain(request)
  312. request.assert_called_once_with(
  313. method="GET",
  314. url=_metadata._METADATA_ROOT + "universe/universe_domain",
  315. headers=_metadata._METADATA_HEADERS,
  316. )
  317. assert universe_domain == "googleapis.com"
  318. def test_get_universe_domain_retryable_error_failure():
  319. # Test that if the universe domain endpoint returns a retryable error
  320. # we should retry.
  321. #
  322. # In this case, the error persists, and we still fail after retrying.
  323. request = make_request("too many requests", status=http_client.TOO_MANY_REQUESTS)
  324. with pytest.raises(exceptions.TransportError) as excinfo:
  325. _metadata.get_universe_domain(request)
  326. assert excinfo.match(r"Compute Engine Metadata server unavailable")
  327. request.assert_called_with(
  328. method="GET",
  329. url=_metadata._METADATA_ROOT + "universe/universe_domain",
  330. headers=_metadata._METADATA_HEADERS,
  331. )
  332. assert request.call_count == 5
  333. def test_get_universe_domain_retryable_error_success():
  334. # Test that if the universe domain endpoint returns a retryable error
  335. # we should retry.
  336. #
  337. # In this case, the error is temporary, and we succeed after retrying.
  338. request_error = make_request(
  339. "too many requests", status=http_client.TOO_MANY_REQUESTS
  340. )
  341. request_ok = make_request(
  342. "fake_universe_domain", headers={"content-type": "text/plain"}
  343. )
  344. class _RequestErrorOnce:
  345. """This class forwards the request parameters to `request_error` once.
  346. All subsequent calls are forwarded to `request_ok`.
  347. """
  348. def __init__(self, request_error, request_ok):
  349. self._request_error = request_error
  350. self._request_ok = request_ok
  351. self._call_index = 0
  352. def request(self, *args, **kwargs):
  353. if self._call_index == 0:
  354. self._call_index += 1
  355. return self._request_error(*args, **kwargs)
  356. return self._request_ok(*args, **kwargs)
  357. request = _RequestErrorOnce(request_error, request_ok).request
  358. universe_domain = _metadata.get_universe_domain(request)
  359. request_error.assert_called_once_with(
  360. method="GET",
  361. url=_metadata._METADATA_ROOT + "universe/universe_domain",
  362. headers=_metadata._METADATA_HEADERS,
  363. )
  364. request_ok.assert_called_once_with(
  365. method="GET",
  366. url=_metadata._METADATA_ROOT + "universe/universe_domain",
  367. headers=_metadata._METADATA_HEADERS,
  368. )
  369. assert universe_domain == "fake_universe_domain"
  370. def test_get_universe_domain_other_error():
  371. # Test that if the universe domain endpoint returns an error other than 404
  372. # we should throw the error
  373. request = make_request("unauthorized", status=http_client.UNAUTHORIZED)
  374. with pytest.raises(exceptions.TransportError) as excinfo:
  375. _metadata.get_universe_domain(request)
  376. assert excinfo.match(r"unauthorized")
  377. request.assert_called_once_with(
  378. method="GET",
  379. url=_metadata._METADATA_ROOT + "universe/universe_domain",
  380. headers=_metadata._METADATA_HEADERS,
  381. )
  382. @mock.patch(
  383. "google.auth.metrics.token_request_access_token_mds",
  384. return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  385. )
  386. @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
  387. def test_get_service_account_token(utcnow, mock_metrics_header_value):
  388. ttl = 500
  389. request = make_request(
  390. json.dumps({"access_token": "token", "expires_in": ttl}),
  391. headers={"content-type": "application/json"},
  392. )
  393. token, expiry = _metadata.get_service_account_token(request)
  394. request.assert_called_once_with(
  395. method="GET",
  396. url=_metadata._METADATA_ROOT + PATH + "/token",
  397. headers={
  398. "metadata-flavor": "Google",
  399. "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  400. },
  401. )
  402. assert token == "token"
  403. assert expiry == utcnow() + datetime.timedelta(seconds=ttl)
  404. @mock.patch(
  405. "google.auth.metrics.token_request_access_token_mds",
  406. return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  407. )
  408. @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
  409. def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_value):
  410. ttl = 500
  411. request = make_request(
  412. json.dumps({"access_token": "token", "expires_in": ttl}),
  413. headers={"content-type": "application/json"},
  414. )
  415. token, expiry = _metadata.get_service_account_token(request, scopes=["foo", "bar"])
  416. request.assert_called_once_with(
  417. method="GET",
  418. url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar",
  419. headers={
  420. "metadata-flavor": "Google",
  421. "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  422. },
  423. )
  424. assert token == "token"
  425. assert expiry == utcnow() + datetime.timedelta(seconds=ttl)
  426. @mock.patch(
  427. "google.auth.metrics.token_request_access_token_mds",
  428. return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  429. )
  430. @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
  431. def test_get_service_account_token_with_scopes_string(
  432. utcnow, mock_metrics_header_value
  433. ):
  434. ttl = 500
  435. request = make_request(
  436. json.dumps({"access_token": "token", "expires_in": ttl}),
  437. headers={"content-type": "application/json"},
  438. )
  439. token, expiry = _metadata.get_service_account_token(request, scopes="foo,bar")
  440. request.assert_called_once_with(
  441. method="GET",
  442. url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar",
  443. headers={
  444. "metadata-flavor": "Google",
  445. "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  446. },
  447. )
  448. assert token == "token"
  449. assert expiry == utcnow() + datetime.timedelta(seconds=ttl)
  450. def test_get_service_account_info():
  451. key, value = "foo", "bar"
  452. request = make_request(
  453. json.dumps({key: value}), headers={"content-type": "application/json"}
  454. )
  455. info = _metadata.get_service_account_info(request)
  456. request.assert_called_once_with(
  457. method="GET",
  458. url=_metadata._METADATA_ROOT + PATH + "/?recursive=true",
  459. headers=_metadata._METADATA_HEADERS,
  460. )
  461. assert info[key] == value