test__metadata.py 19 KB


  1. # Copyright 2016 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import http.client as http_client
  16. import importlib
  17. import json
  18. import os
  19. import mock
  20. import pytest # type: ignore
  21. from google.auth import _helpers
  22. from google.auth import environment_vars
  23. from google.auth import exceptions
  24. from google.auth import transport
  25. from google.auth.compute_engine import _metadata
  26. PATH = "instance/service-accounts/default"
  27. DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
  28. SMBIOS_PRODUCT_NAME_FILE = os.path.join(DATA_DIR, "smbios_product_name")
  29. SMBIOS_PRODUCT_NAME_NONEXISTENT_FILE = os.path.join(
  30. DATA_DIR, "smbios_product_name_nonexistent"
  31. )
  32. SMBIOS_PRODUCT_NAME_NON_GOOGLE = os.path.join(
  33. DATA_DIR, "smbios_product_name_non_google"
  34. )
  35. ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE = (
  36. "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds"
  37. )
  38. MDS_PING_METRICS_HEADER_VALUE = "gl-python/3.7 auth/1.1 auth-request-type/mds"
  39. MDS_PING_REQUEST_HEADER = {
  40. "metadata-flavor": "Google",
  41. "x-goog-api-client": MDS_PING_METRICS_HEADER_VALUE,
  42. }
  43. def make_request(data, status=http_client.OK, headers=None, retry=False):
  44. response = mock.create_autospec(transport.Response, instance=True)
  45. response.status = status
  46. response.data = _helpers.to_bytes(data)
  47. response.headers = headers or {}
  48. request = mock.create_autospec(transport.Request)
  49. if retry:
  50. request.side_effect = [exceptions.TransportError(), response]
  51. else:
  52. request.return_value = response
  53. return request
  54. @pytest.mark.xfail
  55. def test_detect_gce_residency_linux_success():
  56. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_FILE
  57. assert _metadata.detect_gce_residency_linux()
  58. def test_detect_gce_residency_linux_non_google():
  59. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_NON_GOOGLE
  60. assert not _metadata.detect_gce_residency_linux()
  61. def test_detect_gce_residency_linux_nonexistent():
  62. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_NONEXISTENT_FILE
  63. assert not _metadata.detect_gce_residency_linux()
  64. def test_is_on_gce_ping_success():
  65. request = make_request("", headers=_metadata._METADATA_HEADERS)
  66. assert _metadata.is_on_gce(request)
  67. @mock.patch("os.name", new="nt")
  68. def test_is_on_gce_windows_success():
  69. request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"})
  70. assert not _metadata.is_on_gce(request)
  71. @pytest.mark.xfail
  72. @mock.patch("os.name", new="posix")
  73. def test_is_on_gce_linux_success():
  74. request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"})
  75. _metadata._GCE_PRODUCT_NAME_FILE = SMBIOS_PRODUCT_NAME_FILE
  76. assert _metadata.is_on_gce(request)
  77. @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
  78. def test_ping_success(mock_metrics_header_value):
  79. request = make_request("", headers=_metadata._METADATA_HEADERS)
  80. assert _metadata.ping(request)
  81. request.assert_called_once_with(
  82. method="GET",
  83. url=_metadata._METADATA_IP_ROOT,
  84. headers=MDS_PING_REQUEST_HEADER,
  85. timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
  86. )
  87. @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
  88. def test_ping_success_retry(mock_metrics_header_value):
  89. request = make_request("", headers=_metadata._METADATA_HEADERS, retry=True)
  90. assert _metadata.ping(request)
  91. request.assert_called_with(
  92. method="GET",
  93. url=_metadata._METADATA_IP_ROOT,
  94. headers=MDS_PING_REQUEST_HEADER,
  95. timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
  96. )
  97. assert request.call_count == 2
  98. @mock.patch("time.sleep", return_value=None)
  99. def test_ping_failure_bad_flavor(mock_sleep):
  100. request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"})
  101. assert not _metadata.ping(request)
  102. @mock.patch("time.sleep", return_value=None)
  103. def test_ping_failure_connection_failed(mock_sleep):
  104. request = make_request("")
  105. request.side_effect = exceptions.TransportError()
  106. assert not _metadata.ping(request)
  107. @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
  108. def _test_ping_success_custom_root(mock_metrics_header_value):
  109. request = make_request("", headers=_metadata._METADATA_HEADERS)
  110. fake_ip = "1.2.3.4"
  111. os.environ[environment_vars.GCE_METADATA_IP] = fake_ip
  112. importlib.reload(_metadata)
  113. try:
  114. assert _metadata.ping(request)
  115. finally:
  116. del os.environ[environment_vars.GCE_METADATA_IP]
  117. importlib.reload(_metadata)
  118. request.assert_called_once_with(
  119. method="GET",
  120. url="http://" + fake_ip,
  121. headers=MDS_PING_REQUEST_HEADER,
  122. timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
  123. )
  124. def test_get_success_json():
  125. key, value = "foo", "bar"
  126. data = json.dumps({key: value})
  127. request = make_request(data, headers={"content-type": "application/json"})
  128. result = _metadata.get(request, PATH)
  129. request.assert_called_once_with(
  130. method="GET",
  131. url=_metadata._METADATA_ROOT + PATH,
  132. headers=_metadata._METADATA_HEADERS,
  133. )
  134. assert result[key] == value
  135. def test_get_success_json_content_type_charset():
  136. key, value = "foo", "bar"
  137. data = json.dumps({key: value})
  138. request = make_request(
  139. data, headers={"content-type": "application/json; charset=UTF-8"}
  140. )
  141. result = _metadata.get(request, PATH)
  142. request.assert_called_once_with(
  143. method="GET",
  144. url=_metadata._METADATA_ROOT + PATH,
  145. headers=_metadata._METADATA_HEADERS,
  146. )
  147. assert result[key] == value
  148. @mock.patch("time.sleep", return_value=None)
  149. def test_get_success_retry(mock_sleep):
  150. key, value = "foo", "bar"
  151. data = json.dumps({key: value})
  152. request = make_request(
  153. data, headers={"content-type": "application/json"}, retry=True
  154. )
  155. result = _metadata.get(request, PATH)
  156. request.assert_called_with(
  157. method="GET",
  158. url=_metadata._METADATA_ROOT + PATH,
  159. headers=_metadata._METADATA_HEADERS,
  160. )
  161. assert request.call_count == 2
  162. assert result[key] == value
  163. def test_get_success_text():
  164. data = "foobar"
  165. request = make_request(data, headers={"content-type": "text/plain"})
  166. result = _metadata.get(request, PATH)
  167. request.assert_called_once_with(
  168. method="GET",
  169. url=_metadata._METADATA_ROOT + PATH,
  170. headers=_metadata._METADATA_HEADERS,
  171. )
  172. assert result == data
  173. def test_get_success_params():
  174. data = "foobar"
  175. request = make_request(data, headers={"content-type": "text/plain"})
  176. params = {"recursive": "true"}
  177. result = _metadata.get(request, PATH, params=params)
  178. request.assert_called_once_with(
  179. method="GET",
  180. url=_metadata._METADATA_ROOT + PATH + "?recursive=true",
  181. headers=_metadata._METADATA_HEADERS,
  182. )
  183. assert result == data
  184. def test_get_success_recursive_and_params():
  185. data = "foobar"
  186. request = make_request(data, headers={"content-type": "text/plain"})
  187. params = {"recursive": "false"}
  188. result = _metadata.get(request, PATH, recursive=True, params=params)
  189. request.assert_called_once_with(
  190. method="GET",
  191. url=_metadata._METADATA_ROOT + PATH + "?recursive=true",
  192. headers=_metadata._METADATA_HEADERS,
  193. )
  194. assert result == data
  195. def test_get_success_recursive():
  196. data = "foobar"
  197. request = make_request(data, headers={"content-type": "text/plain"})
  198. result = _metadata.get(request, PATH, recursive=True)
  199. request.assert_called_once_with(
  200. method="GET",
  201. url=_metadata._METADATA_ROOT + PATH + "?recursive=true",
  202. headers=_metadata._METADATA_HEADERS,
  203. )
  204. assert result == data
  205. def _test_get_success_custom_root_new_variable():
  206. request = make_request("{}", headers={"content-type": "application/json"})
  207. fake_root = "another.metadata.service"
  208. os.environ[environment_vars.GCE_METADATA_HOST] = fake_root
  209. importlib.reload(_metadata)
  210. try:
  211. _metadata.get(request, PATH)
  212. finally:
  213. del os.environ[environment_vars.GCE_METADATA_HOST]
  214. importlib.reload(_metadata)
  215. request.assert_called_once_with(
  216. method="GET",
  217. url="http://{}/computeMetadata/v1/{}".format(fake_root, PATH),
  218. headers=_metadata._METADATA_HEADERS,
  219. )
  220. def _test_get_success_custom_root_old_variable():
  221. request = make_request("{}", headers={"content-type": "application/json"})
  222. fake_root = "another.metadata.service"
  223. os.environ[environment_vars.GCE_METADATA_ROOT] = fake_root
  224. importlib.reload(_metadata)
  225. try:
  226. _metadata.get(request, PATH)
  227. finally:
  228. del os.environ[environment_vars.GCE_METADATA_ROOT]
  229. importlib.reload(_metadata)
  230. request.assert_called_once_with(
  231. method="GET",
  232. url="http://{}/computeMetadata/v1/{}".format(fake_root, PATH),
  233. headers=_metadata._METADATA_HEADERS,
  234. )
  235. @mock.patch("time.sleep", return_value=None)
  236. def test_get_failure(mock_sleep):
  237. request = make_request("Metadata error", status=http_client.NOT_FOUND)
  238. with pytest.raises(exceptions.TransportError) as excinfo:
  239. _metadata.get(request, PATH)
  240. assert excinfo.match(r"Metadata error")
  241. request.assert_called_once_with(
  242. method="GET",
  243. url=_metadata._METADATA_ROOT + PATH,
  244. headers=_metadata._METADATA_HEADERS,
  245. )
  246. def test_get_return_none_for_not_found_error():
  247. request = make_request("Metadata error", status=http_client.NOT_FOUND)
  248. assert _metadata.get(request, PATH, return_none_for_not_found_error=True) is None
  249. request.assert_called_once_with(
  250. method="GET",
  251. url=_metadata._METADATA_ROOT + PATH,
  252. headers=_metadata._METADATA_HEADERS,
  253. )
  254. @mock.patch("time.sleep", return_value=None)
  255. def test_get_failure_connection_failed(mock_sleep):
  256. request = make_request("")
  257. request.side_effect = exceptions.TransportError("failure message")
  258. with pytest.raises(exceptions.TransportError) as excinfo:
  259. _metadata.get(request, PATH)
  260. assert excinfo.match(
  261. r"Compute Engine Metadata server unavailable due to failure message"
  262. )
  263. request.assert_called_with(
  264. method="GET",
  265. url=_metadata._METADATA_ROOT + PATH,
  266. headers=_metadata._METADATA_HEADERS,
  267. )
  268. assert request.call_count == 5
  269. def test_get_too_many_requests_retryable_error_failure():
  270. request = make_request("too many requests", status=http_client.TOO_MANY_REQUESTS)
  271. with pytest.raises(exceptions.TransportError) as excinfo:
  272. _metadata.get(request, PATH)
  273. assert excinfo.match(
  274. r"Compute Engine Metadata server unavailable due to too many requests"
  275. )
  276. request.assert_called_with(
  277. method="GET",
  278. url=_metadata._METADATA_ROOT + PATH,
  279. headers=_metadata._METADATA_HEADERS,
  280. )
  281. assert request.call_count == 5
  282. def test_get_failure_bad_json():
  283. request = make_request("{", headers={"content-type": "application/json"})
  284. with pytest.raises(exceptions.TransportError) as excinfo:
  285. _metadata.get(request, PATH)
  286. assert excinfo.match(r"invalid JSON")
  287. request.assert_called_once_with(
  288. method="GET",
  289. url=_metadata._METADATA_ROOT + PATH,
  290. headers=_metadata._METADATA_HEADERS,
  291. )
  292. def test_get_project_id():
  293. project = "example-project"
  294. request = make_request(project, headers={"content-type": "text/plain"})
  295. project_id = _metadata.get_project_id(request)
  296. request.assert_called_once_with(
  297. method="GET",
  298. url=_metadata._METADATA_ROOT + "project/project-id",
  299. headers=_metadata._METADATA_HEADERS,
  300. )
  301. assert project_id == project
  302. def test_get_universe_domain_success():
  303. request = make_request(
  304. "fake_universe_domain", headers={"content-type": "text/plain"}
  305. )
  306. universe_domain = _metadata.get_universe_domain(request)
  307. request.assert_called_once_with(
  308. method="GET",
  309. url=_metadata._METADATA_ROOT + "universe/universe-domain",
  310. headers=_metadata._METADATA_HEADERS,
  311. )
  312. assert universe_domain == "fake_universe_domain"
  313. def test_get_universe_domain_success_empty_response():
  314. request = make_request("", headers={"content-type": "text/plain"})
  315. universe_domain = _metadata.get_universe_domain(request)
  316. request.assert_called_once_with(
  317. method="GET",
  318. url=_metadata._METADATA_ROOT + "universe/universe-domain",
  319. headers=_metadata._METADATA_HEADERS,
  320. )
  321. assert universe_domain == "googleapis.com"
  322. def test_get_universe_domain_not_found():
  323. # Test that if the universe domain endpoint returns 404 error, we should
  324. # use googleapis.com as the universe domain
  325. request = make_request("not found", status=http_client.NOT_FOUND)
  326. universe_domain = _metadata.get_universe_domain(request)
  327. request.assert_called_once_with(
  328. method="GET",
  329. url=_metadata._METADATA_ROOT + "universe/universe-domain",
  330. headers=_metadata._METADATA_HEADERS,
  331. )
  332. assert universe_domain == "googleapis.com"
  333. def test_get_universe_domain_retryable_error_failure():
  334. # Test that if the universe domain endpoint returns a retryable error
  335. # we should retry.
  336. #
  337. # In this case, the error persists, and we still fail after retrying.
  338. request = make_request("too many requests", status=http_client.TOO_MANY_REQUESTS)
  339. with pytest.raises(exceptions.TransportError) as excinfo:
  340. _metadata.get_universe_domain(request)
  341. assert excinfo.match(r"Compute Engine Metadata server unavailable")
  342. request.assert_called_with(
  343. method="GET",
  344. url=_metadata._METADATA_ROOT + "universe/universe-domain",
  345. headers=_metadata._METADATA_HEADERS,
  346. )
  347. assert request.call_count == 5
  348. def test_get_universe_domain_retryable_error_success():
  349. # Test that if the universe domain endpoint returns a retryable error
  350. # we should retry.
  351. #
  352. # In this case, the error is temporary, and we succeed after retrying.
  353. request_error = make_request(
  354. "too many requests", status=http_client.TOO_MANY_REQUESTS
  355. )
  356. request_ok = make_request(
  357. "fake_universe_domain", headers={"content-type": "text/plain"}
  358. )
  359. class _RequestErrorOnce:
  360. """This class forwards the request parameters to `request_error` once.
  361. All subsequent calls are forwarded to `request_ok`.
  362. """
  363. def __init__(self, request_error, request_ok):
  364. self._request_error = request_error
  365. self._request_ok = request_ok
  366. self._call_index = 0
  367. def request(self, *args, **kwargs):
  368. if self._call_index == 0:
  369. self._call_index += 1
  370. return self._request_error(*args, **kwargs)
  371. return self._request_ok(*args, **kwargs)
  372. request = _RequestErrorOnce(request_error, request_ok).request
  373. universe_domain = _metadata.get_universe_domain(request)
  374. request_error.assert_called_once_with(
  375. method="GET",
  376. url=_metadata._METADATA_ROOT + "universe/universe-domain",
  377. headers=_metadata._METADATA_HEADERS,
  378. )
  379. request_ok.assert_called_once_with(
  380. method="GET",
  381. url=_metadata._METADATA_ROOT + "universe/universe-domain",
  382. headers=_metadata._METADATA_HEADERS,
  383. )
  384. assert universe_domain == "fake_universe_domain"
  385. def test_get_universe_domain_other_error():
  386. # Test that if the universe domain endpoint returns an error other than 404
  387. # we should throw the error
  388. request = make_request("unauthorized", status=http_client.UNAUTHORIZED)
  389. with pytest.raises(exceptions.TransportError) as excinfo:
  390. _metadata.get_universe_domain(request)
  391. assert excinfo.match(r"unauthorized")
  392. request.assert_called_once_with(
  393. method="GET",
  394. url=_metadata._METADATA_ROOT + "universe/universe-domain",
  395. headers=_metadata._METADATA_HEADERS,
  396. )
  397. @mock.patch(
  398. "google.auth.metrics.token_request_access_token_mds",
  399. return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  400. )
  401. @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
  402. def test_get_service_account_token(utcnow, mock_metrics_header_value):
  403. ttl = 500
  404. request = make_request(
  405. json.dumps({"access_token": "token", "expires_in": ttl}),
  406. headers={"content-type": "application/json"},
  407. )
  408. token, expiry = _metadata.get_service_account_token(request)
  409. request.assert_called_once_with(
  410. method="GET",
  411. url=_metadata._METADATA_ROOT + PATH + "/token",
  412. headers={
  413. "metadata-flavor": "Google",
  414. "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  415. },
  416. )
  417. assert token == "token"
  418. assert expiry == utcnow() + datetime.timedelta(seconds=ttl)
  419. @mock.patch(
  420. "google.auth.metrics.token_request_access_token_mds",
  421. return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  422. )
  423. @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
  424. def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_value):
  425. ttl = 500
  426. request = make_request(
  427. json.dumps({"access_token": "token", "expires_in": ttl}),
  428. headers={"content-type": "application/json"},
  429. )
  430. token, expiry = _metadata.get_service_account_token(request, scopes=["foo", "bar"])
  431. request.assert_called_once_with(
  432. method="GET",
  433. url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar",
  434. headers={
  435. "metadata-flavor": "Google",
  436. "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  437. },
  438. )
  439. assert token == "token"
  440. assert expiry == utcnow() + datetime.timedelta(seconds=ttl)
  441. @mock.patch(
  442. "google.auth.metrics.token_request_access_token_mds",
  443. return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  444. )
  445. @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
  446. def test_get_service_account_token_with_scopes_string(
  447. utcnow, mock_metrics_header_value
  448. ):
  449. ttl = 500
  450. request = make_request(
  451. json.dumps({"access_token": "token", "expires_in": ttl}),
  452. headers={"content-type": "application/json"},
  453. )
  454. token, expiry = _metadata.get_service_account_token(request, scopes="foo,bar")
  455. request.assert_called_once_with(
  456. method="GET",
  457. url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar",
  458. headers={
  459. "metadata-flavor": "Google",
  460. "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
  461. },
  462. )
  463. assert token == "token"
  464. assert expiry == utcnow() + datetime.timedelta(seconds=ttl)
  465. def test_get_service_account_info():
  466. key, value = "foo", "bar"
  467. request = make_request(
  468. json.dumps({key: value}), headers={"content-type": "application/json"}
  469. )
  470. info = _metadata.get_service_account_info(request)
  471. request.assert_called_once_with(
  472. method="GET",
  473. url=_metadata._METADATA_ROOT + PATH + "/?recursive=true",
  474. headers=_metadata._METADATA_HEADERS,
  475. )
  476. assert info[key] == value