auto_signature_v4_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. package s3api
  2. import (
  3. "bytes"
  4. "crypto/md5"
  5. "crypto/sha256"
  6. "encoding/base64"
  7. "encoding/hex"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "net/http"
  13. "net/url"
  14. "sort"
  15. "strconv"
  16. "strings"
  17. "testing"
  18. "time"
  19. "unicode/utf8"
  20. )
  21. // TestIsRequestPresignedSignatureV4 - Test validates the logic for presign signature verision v4 detection.
  22. func TestIsRequestPresignedSignatureV4(t *testing.T) {
  23. testCases := []struct {
  24. inputQueryKey string
  25. inputQueryValue string
  26. expectedResult bool
  27. }{
  28. // Test case - 1.
  29. // Test case with query key ""X-Amz-Credential" set.
  30. {"", "", false},
  31. // Test case - 2.
  32. {"X-Amz-Credential", "", true},
  33. // Test case - 3.
  34. {"X-Amz-Content-Sha256", "", false},
  35. }
  36. for i, testCase := range testCases {
  37. // creating an input HTTP request.
  38. // Only the query parameters are relevant for this particular test.
  39. inputReq, err := http.NewRequest("GET", "http://example.com", nil)
  40. if err != nil {
  41. t.Fatalf("Error initializing input HTTP request: %v", err)
  42. }
  43. q := inputReq.URL.Query()
  44. q.Add(testCase.inputQueryKey, testCase.inputQueryValue)
  45. inputReq.URL.RawQuery = q.Encode()
  46. actualResult := isRequestPresignedSignatureV4(inputReq)
  47. if testCase.expectedResult != actualResult {
  48. t.Errorf("Test %d: Expected the result to `%v`, but instead got `%v`", i+1, testCase.expectedResult, actualResult)
  49. }
  50. }
  51. }
  52. // Tests is requested authenticated function, tests replies for s3 errors.
  53. func TestIsReqAuthenticated(t *testing.T) {
  54. iam := NewIdentityAccessManagement("", "")
  55. iam.identities = []*Identity{
  56. {
  57. Name: "someone",
  58. Credentials: []*Credential{
  59. {
  60. AccessKey: "access_key_1",
  61. SecretKey: "secret_key_1",
  62. },
  63. },
  64. Actions: nil,
  65. },
  66. }
  67. // List of test cases for validating http request authentication.
  68. testCases := []struct {
  69. req *http.Request
  70. s3Error ErrorCode
  71. }{
  72. // When request is unsigned, access denied is returned.
  73. {mustNewRequest("GET", "http://127.0.0.1:9000", 0, nil, t), ErrAccessDenied},
  74. // When request is properly signed, error is none.
  75. {mustNewSignedRequest("GET", "http://127.0.0.1:9000", 0, nil, t), ErrNone},
  76. }
  77. // Validates all testcases.
  78. for i, testCase := range testCases {
  79. if _, s3Error := iam.reqSignatureV4Verify(testCase.req); s3Error != testCase.s3Error {
  80. ioutil.ReadAll(testCase.req.Body)
  81. t.Fatalf("Test %d: Unexpected S3 error: want %d - got %d", i, testCase.s3Error, s3Error)
  82. }
  83. }
  84. }
  85. func TestCheckAdminRequestAuthType(t *testing.T) {
  86. iam := NewIdentityAccessManagement("", "")
  87. iam.identities = []*Identity{
  88. {
  89. Name: "someone",
  90. Credentials: []*Credential{
  91. {
  92. AccessKey: "access_key_1",
  93. SecretKey: "secret_key_1",
  94. },
  95. },
  96. Actions: nil,
  97. },
  98. }
  99. testCases := []struct {
  100. Request *http.Request
  101. ErrCode ErrorCode
  102. }{
  103. {Request: mustNewRequest("GET", "http://127.0.0.1:9000", 0, nil, t), ErrCode: ErrAccessDenied},
  104. {Request: mustNewSignedRequest("GET", "http://127.0.0.1:9000", 0, nil, t), ErrCode: ErrNone},
  105. {Request: mustNewPresignedRequest("GET", "http://127.0.0.1:9000", 0, nil, t), ErrCode: ErrNone},
  106. }
  107. for i, testCase := range testCases {
  108. if _, s3Error := iam.reqSignatureV4Verify(testCase.Request); s3Error != testCase.ErrCode {
  109. t.Errorf("Test %d: Unexpected s3error returned wanted %d, got %d", i, testCase.ErrCode, s3Error)
  110. }
  111. }
  112. }
  113. // Provides a fully populated http request instance, fails otherwise.
  114. func mustNewRequest(method string, urlStr string, contentLength int64, body io.ReadSeeker, t *testing.T) *http.Request {
  115. req, err := newTestRequest(method, urlStr, contentLength, body)
  116. if err != nil {
  117. t.Fatalf("Unable to initialize new http request %s", err)
  118. }
  119. return req
  120. }
  121. // This is similar to mustNewRequest but additionally the request
  122. // is signed with AWS Signature V4, fails if not able to do so.
  123. func mustNewSignedRequest(method string, urlStr string, contentLength int64, body io.ReadSeeker, t *testing.T) *http.Request {
  124. req := mustNewRequest(method, urlStr, contentLength, body, t)
  125. cred := &Credential{"access_key_1", "secret_key_1"}
  126. if err := signRequestV4(req, cred.AccessKey, cred.SecretKey); err != nil {
  127. t.Fatalf("Unable to inititalized new signed http request %s", err)
  128. }
  129. return req
  130. }
  131. // This is similar to mustNewRequest but additionally the request
  132. // is presigned with AWS Signature V4, fails if not able to do so.
  133. func mustNewPresignedRequest(method string, urlStr string, contentLength int64, body io.ReadSeeker, t *testing.T) *http.Request {
  134. req := mustNewRequest(method, urlStr, contentLength, body, t)
  135. cred := &Credential{"access_key_1", "secret_key_1"}
  136. if err := preSignV4(req, cred.AccessKey, cred.SecretKey, int64(10*time.Minute.Seconds())); err != nil {
  137. t.Fatalf("Unable to inititalized new signed http request %s", err)
  138. }
  139. return req
  140. }
  141. // Returns new HTTP request object.
  142. func newTestRequest(method, urlStr string, contentLength int64, body io.ReadSeeker) (*http.Request, error) {
  143. if method == "" {
  144. method = "POST"
  145. }
  146. // Save for subsequent use
  147. var hashedPayload string
  148. var md5Base64 string
  149. switch {
  150. case body == nil:
  151. hashedPayload = getSHA256Hash([]byte{})
  152. default:
  153. payloadBytes, err := ioutil.ReadAll(body)
  154. if err != nil {
  155. return nil, err
  156. }
  157. hashedPayload = getSHA256Hash(payloadBytes)
  158. md5Base64 = getMD5HashBase64(payloadBytes)
  159. }
  160. // Seek back to beginning.
  161. if body != nil {
  162. body.Seek(0, 0)
  163. } else {
  164. body = bytes.NewReader([]byte(""))
  165. }
  166. req, err := http.NewRequest(method, urlStr, body)
  167. if err != nil {
  168. return nil, err
  169. }
  170. if md5Base64 != "" {
  171. req.Header.Set("Content-Md5", md5Base64)
  172. }
  173. req.Header.Set("x-amz-content-sha256", hashedPayload)
  174. // Add Content-Length
  175. req.ContentLength = contentLength
  176. return req, nil
  177. }
  178. // getSHA256Hash returns SHA-256 hash in hex encoding of given data.
  179. func getSHA256Hash(data []byte) string {
  180. return hex.EncodeToString(getSHA256Sum(data))
  181. }
  182. // getMD5HashBase64 returns MD5 hash in base64 encoding of given data.
  183. func getMD5HashBase64(data []byte) string {
  184. return base64.StdEncoding.EncodeToString(getMD5Sum(data))
  185. }
  186. // getSHA256Hash returns SHA-256 sum of given data.
  187. func getSHA256Sum(data []byte) []byte {
  188. hash := sha256.New()
  189. hash.Write(data)
  190. return hash.Sum(nil)
  191. }
  192. // getMD5Sum returns MD5 sum of given data.
  193. func getMD5Sum(data []byte) []byte {
  194. hash := md5.New()
  195. hash.Write(data)
  196. return hash.Sum(nil)
  197. }
  198. // getMD5Hash returns MD5 hash in hex encoding of given data.
  199. func getMD5Hash(data []byte) string {
  200. return hex.EncodeToString(getMD5Sum(data))
  201. }
  202. var ignoredHeaders = map[string]bool{
  203. "Authorization": true,
  204. "Content-Type": true,
  205. "Content-Length": true,
  206. "User-Agent": true,
  207. }
  208. // Sign given request using Signature V4.
  209. func signRequestV4(req *http.Request, accessKey, secretKey string) error {
  210. // Get hashed payload.
  211. hashedPayload := req.Header.Get("x-amz-content-sha256")
  212. if hashedPayload == "" {
  213. return fmt.Errorf("Invalid hashed payload")
  214. }
  215. currTime := time.Now()
  216. // Set x-amz-date.
  217. req.Header.Set("x-amz-date", currTime.Format(iso8601Format))
  218. // Get header map.
  219. headerMap := make(map[string][]string)
  220. for k, vv := range req.Header {
  221. // If request header key is not in ignored headers, then add it.
  222. if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; !ok {
  223. headerMap[strings.ToLower(k)] = vv
  224. }
  225. }
  226. // Get header keys.
  227. headers := []string{"host"}
  228. for k := range headerMap {
  229. headers = append(headers, k)
  230. }
  231. sort.Strings(headers)
  232. region := "us-east-1"
  233. // Get canonical headers.
  234. var buf bytes.Buffer
  235. for _, k := range headers {
  236. buf.WriteString(k)
  237. buf.WriteByte(':')
  238. switch {
  239. case k == "host":
  240. buf.WriteString(req.URL.Host)
  241. fallthrough
  242. default:
  243. for idx, v := range headerMap[k] {
  244. if idx > 0 {
  245. buf.WriteByte(',')
  246. }
  247. buf.WriteString(v)
  248. }
  249. buf.WriteByte('\n')
  250. }
  251. }
  252. canonicalHeaders := buf.String()
  253. // Get signed headers.
  254. signedHeaders := strings.Join(headers, ";")
  255. // Get canonical query string.
  256. req.URL.RawQuery = strings.Replace(req.URL.Query().Encode(), "+", "%20", -1)
  257. // Get canonical URI.
  258. canonicalURI := EncodePath(req.URL.Path)
  259. // Get canonical request.
  260. // canonicalRequest =
  261. // <HTTPMethod>\n
  262. // <CanonicalURI>\n
  263. // <CanonicalQueryString>\n
  264. // <CanonicalHeaders>\n
  265. // <SignedHeaders>\n
  266. // <HashedPayload>
  267. //
  268. canonicalRequest := strings.Join([]string{
  269. req.Method,
  270. canonicalURI,
  271. req.URL.RawQuery,
  272. canonicalHeaders,
  273. signedHeaders,
  274. hashedPayload,
  275. }, "\n")
  276. // Get scope.
  277. scope := strings.Join([]string{
  278. currTime.Format(yyyymmdd),
  279. region,
  280. "s3",
  281. "aws4_request",
  282. }, "/")
  283. stringToSign := "AWS4-HMAC-SHA256" + "\n" + currTime.Format(iso8601Format) + "\n"
  284. stringToSign = stringToSign + scope + "\n"
  285. stringToSign = stringToSign + getSHA256Hash([]byte(canonicalRequest))
  286. date := sumHMAC([]byte("AWS4"+secretKey), []byte(currTime.Format(yyyymmdd)))
  287. regionHMAC := sumHMAC(date, []byte(region))
  288. service := sumHMAC(regionHMAC, []byte("s3"))
  289. signingKey := sumHMAC(service, []byte("aws4_request"))
  290. signature := hex.EncodeToString(sumHMAC(signingKey, []byte(stringToSign)))
  291. // final Authorization header
  292. parts := []string{
  293. "AWS4-HMAC-SHA256" + " Credential=" + accessKey + "/" + scope,
  294. "SignedHeaders=" + signedHeaders,
  295. "Signature=" + signature,
  296. }
  297. auth := strings.Join(parts, ", ")
  298. req.Header.Set("Authorization", auth)
  299. return nil
  300. }
  301. // preSignV4 presign the request, in accordance with
  302. // http://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html.
  303. func preSignV4(req *http.Request, accessKeyID, secretAccessKey string, expires int64) error {
  304. // Presign is not needed for anonymous credentials.
  305. if accessKeyID == "" || secretAccessKey == "" {
  306. return errors.New("Presign cannot be generated without access and secret keys")
  307. }
  308. region := "us-east-1"
  309. date := time.Now().UTC()
  310. scope := getScope(date, region)
  311. credential := fmt.Sprintf("%s/%s", accessKeyID, scope)
  312. // Set URL query.
  313. query := req.URL.Query()
  314. query.Set("X-Amz-Algorithm", signV4Algorithm)
  315. query.Set("X-Amz-Date", date.Format(iso8601Format))
  316. query.Set("X-Amz-Expires", strconv.FormatInt(expires, 10))
  317. query.Set("X-Amz-SignedHeaders", "host")
  318. query.Set("X-Amz-Credential", credential)
  319. query.Set("X-Amz-Content-Sha256", unsignedPayload)
  320. // "host" is the only header required to be signed for Presigned URLs.
  321. extractedSignedHeaders := make(http.Header)
  322. extractedSignedHeaders.Set("host", req.Host)
  323. queryStr := strings.Replace(query.Encode(), "+", "%20", -1)
  324. canonicalRequest := getCanonicalRequest(extractedSignedHeaders, unsignedPayload, queryStr, req.URL.Path, req.Method)
  325. stringToSign := getStringToSign(canonicalRequest, date, scope)
  326. signingKey := getSigningKey(secretAccessKey, date, region)
  327. signature := getSignature(signingKey, stringToSign)
  328. req.URL.RawQuery = query.Encode()
  329. // Add signature header to RawQuery.
  330. req.URL.RawQuery += "&X-Amz-Signature=" + url.QueryEscape(signature)
  331. // Construct the final presigned URL.
  332. return nil
  333. }
  334. // EncodePath encode the strings from UTF-8 byte representations to HTML hex escape sequences
  335. //
  336. // This is necessary since regular url.Parse() and url.Encode() functions do not support UTF-8
  337. // non english characters cannot be parsed due to the nature in which url.Encode() is written
  338. //
  339. // This function on the other hand is a direct replacement for url.Encode() technique to support
  340. // pretty much every UTF-8 character.
  341. func EncodePath(pathName string) string {
  342. if reservedObjectNames.MatchString(pathName) {
  343. return pathName
  344. }
  345. var encodedPathname string
  346. for _, s := range pathName {
  347. if 'A' <= s && s <= 'Z' || 'a' <= s && s <= 'z' || '0' <= s && s <= '9' { // §2.3 Unreserved characters (mark)
  348. encodedPathname = encodedPathname + string(s)
  349. continue
  350. }
  351. switch s {
  352. case '-', '_', '.', '~', '/': // §2.3 Unreserved characters (mark)
  353. encodedPathname = encodedPathname + string(s)
  354. continue
  355. default:
  356. len := utf8.RuneLen(s)
  357. if len < 0 {
  358. // if utf8 cannot convert return the same string as is
  359. return pathName
  360. }
  361. u := make([]byte, len)
  362. utf8.EncodeRune(u, s)
  363. for _, r := range u {
  364. hex := hex.EncodeToString([]byte{r})
  365. encodedPathname = encodedPathname + "%" + strings.ToUpper(hex)
  366. }
  367. }
  368. }
  369. return encodedPathname
  370. }