STSCredentialsProvider.cpp 6.4 KB

  1. /**
  2. * Copyright, Inc. or its affiliates. All Rights Reserved.
  3. * SPDX-License-Identifier: Apache-2.0.
  4. */
  5. #include <aws/core/auth/STSCredentialsProvider.h>
  6. #include <aws/core/config/AWSProfileConfigLoader.h>
  7. #include <aws/core/platform/Environment.h>
  8. #include <aws/core/platform/FileSystem.h>
  9. #include <aws/core/utils/logging/LogMacros.h>
  10. #include <aws/core/utils/StringUtils.h>
  11. #include <aws/core/utils/FileSystemUtils.h>
  12. #include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h>
  13. #include <aws/core/utils/StringUtils.h>
  14. #include <aws/core/utils/UUID.h>
  15. #include <cstdlib>
  16. #include <fstream>
  17. #include <string.h>
  18. #include <climits>
  19. using namespace Aws::Utils;
  20. using namespace Aws::Utils::Logging;
  21. using namespace Aws::Auth;
  22. using namespace Aws::Internal;
  23. using namespace Aws::FileSystem;
  24. using namespace Aws::Client;
  25. using Aws::Utils::Threading::ReaderLockGuard;
  26. using Aws::Utils::Threading::WriterLockGuard;
  27. static const char STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG[] = "STSAssumeRoleWithWebIdentityCredentialsProvider";
  29. STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider() :
  30. m_initialized(false)
  31. {
  32. // check environment variables
  33. Aws::String tmpRegion = Aws::Environment::GetEnv("AWS_DEFAULT_REGION");
  34. m_roleArn = Aws::Environment::GetEnv("AWS_ROLE_ARN");
  35. m_tokenFile = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE");
  36. m_sessionName = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME");
  37. // check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable
  38. // region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file.
  39. if (m_roleArn.empty() || m_tokenFile.empty() || tmpRegion.empty())
  40. {
  41. auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName());
  42. if (tmpRegion.empty())
  43. {
  44. tmpRegion = profile.GetRegion();
  45. }
  46. // If either of these two were not found from environment, use whatever found for all three in config file
  47. if (m_roleArn.empty() || m_tokenFile.empty())
  48. {
  49. m_roleArn = profile.GetRoleArn();
  50. m_tokenFile = profile.GetValue("web_identity_token_file");
  51. m_sessionName = profile.GetValue("role_session_name");
  52. }
  53. }
  54. if (m_tokenFile.empty())
  55. {
  56. AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Token file must be specified to use STS AssumeRole web identity creds provider.");
  57. return; // No need to do further constructing
  58. }
  59. else
  60. {
  61. AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved token_file from profile_config or environment variable to be " << m_tokenFile);
  62. }
  63. if (m_roleArn.empty())
  64. {
  65. AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "RoleArn must be specified to use STS AssumeRole web identity creds provider.");
  66. return; // No need to do further constructing
  67. }
  68. else
  69. {
  70. AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved role_arn from profile_config or environment variable to be " << m_roleArn);
  71. }
  72. if (tmpRegion.empty())
  73. {
  74. tmpRegion = Aws::Region::US_EAST_1;
  75. }
  76. else
  77. {
  78. AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved region from profile_config or environment variable to be " << tmpRegion);
  79. }
  80. if (m_sessionName.empty())
  81. {
  82. m_sessionName = Aws::Utils::UUID::RandomUUID();
  83. }
  84. else
  85. {
  86. AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved session_name from profile_config or environment variable to be " << m_sessionName);
  87. }
  88. Aws::Client::ClientConfiguration config;
  89. config.scheme = Aws::Http::Scheme::HTTPS;
  90. config.region = tmpRegion;
  91. Aws::Vector<Aws::String> retryableErrors;
  92. retryableErrors.push_back("IDPCommunicationError");
  93. retryableErrors.push_back("InvalidIdentityToken");
  94. config.retryStrategy = Aws::MakeShared<SpecifiedRetryableErrorsRetryStrategy>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, retryableErrors, 3/*maxRetries*/);
  95. m_client = Aws::MakeUnique<Aws::Internal::STSCredentialsClient>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, config);
  96. m_initialized = true;
  97. AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Creating STS AssumeRole with web identity creds provider.");
  98. }
  99. AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
  100. {
  101. // A valid client means required information like role arn and token file were constructed correctly.
  102. // We can use this provider to load creds, otherwise, we can just return empty creds.
  103. if (!m_initialized)
  104. {
  105. return Aws::Auth::AWSCredentials();
  106. }
  107. RefreshIfExpired();
  108. ReaderLockGuard guard(m_reloadLock);
  109. return m_credentials;
  110. }
  111. void STSAssumeRoleWebIdentityCredentialsProvider::Reload()
  112. {
  113. AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Credentials have expired, attempting to renew from STS.");
  114. Aws::IFStream tokenFile(m_tokenFile.c_str());
  115. if(tokenFile)
  116. {
  117. Aws::String token((std::istreambuf_iterator<char>(tokenFile)), std::istreambuf_iterator<char>());
  118. m_token = token;
  119. }
  120. else
  121. {
  122. AWS_LOGSTREAM_ERROR(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Can't open token file: " << m_tokenFile);
  123. return;
  124. }
  125. STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request {m_sessionName, m_roleArn, m_token};
  126. auto result = m_client->GetAssumeRoleWithWebIdentityCredentials(request);
  127. AWS_LOGSTREAM_TRACE(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Successfully retrieved credentials with AWS_ACCESS_KEY: " << result.creds.GetAWSAccessKeyId());
  128. m_credentials = result.creds;
  129. }
  130. bool STSAssumeRoleWebIdentityCredentialsProvider::ExpiresSoon() const
  131. {
  132. return ((m_credentials.GetExpiration() - Aws::Utils::DateTime::Now()).count() < STS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD);
  133. }
  134. void STSAssumeRoleWebIdentityCredentialsProvider::RefreshIfExpired()
  135. {
  136. ReaderLockGuard guard(m_reloadLock);
  137. if (!m_credentials.IsEmpty() && !ExpiresSoon())
  138. {
  139. return;
  140. }
  141. guard.UpgradeToWriterLock();
  142. if (!m_credentials.IsExpiredOrEmpty() && !ExpiresSoon()) // double-checked lock to avoid refreshing twice
  143. {
  144. return;
  145. }
  146. Reload();
  147. }