sts_test.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. /*
  2. *
  3. * Copyright 2020 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package sts
  19. import (
  20. "bytes"
  21. "context"
  22. "crypto/x509"
  23. "encoding/json"
  24. "errors"
  25. "fmt"
  26. "io"
  27. "net/http"
  28. "net/http/httputil"
  29. "strings"
  30. "testing"
  31. "time"
  32. "github.com/google/go-cmp/cmp"
  33. "google.golang.org/grpc/credentials"
  34. icredentials "google.golang.org/grpc/internal/credentials"
  35. "google.golang.org/grpc/internal/grpctest"
  36. "google.golang.org/grpc/internal/testutils"
  37. )
  38. const (
  39. requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
  40. actorTokenPath = "/var/run/secrets/token.jwt"
  41. actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
  42. actorTokenContents = "actorToken.jwt.contents"
  43. accessTokenContents = "access_token"
  44. subjectTokenPath = "/var/run/secrets/token.jwt"
  45. subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
  46. subjectTokenContents = "subjectToken.jwt.contents"
  47. serviceURI = "http://localhost"
  48. exampleResource = "https://backend.example.com/api"
  49. exampleAudience = "example-backend-service"
  50. testScope = "https://www.googleapis.com/auth/monitoring"
  51. defaultTestTimeout = 1 * time.Second
  52. defaultTestShortTimeout = 10 * time.Millisecond
  53. )
  54. var (
  55. goodOptions = Options{
  56. TokenExchangeServiceURI: serviceURI,
  57. Audience: exampleAudience,
  58. RequestedTokenType: requestedTokenType,
  59. SubjectTokenPath: subjectTokenPath,
  60. SubjectTokenType: subjectTokenType,
  61. }
  62. goodRequestParams = &requestParameters{
  63. GrantType: tokenExchangeGrantType,
  64. Audience: exampleAudience,
  65. Scope: defaultCloudPlatformScope,
  66. RequestedTokenType: requestedTokenType,
  67. SubjectToken: subjectTokenContents,
  68. SubjectTokenType: subjectTokenType,
  69. }
  70. goodMetadata = map[string]string{
  71. "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
  72. }
  73. )
  74. type s struct {
  75. grpctest.Tester
  76. }
  77. func Test(t *testing.T) {
  78. grpctest.RunSubTests(t, s{})
  79. }
  80. // A struct that implements AuthInfo interface and added to the context passed
  81. // to GetRequestMetadata from tests.
  82. type testAuthInfo struct {
  83. credentials.CommonAuthInfo
  84. }
  85. func (ta testAuthInfo) AuthType() string {
  86. return "testAuthInfo"
  87. }
  88. func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
  89. auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
  90. ri := credentials.RequestInfo{
  91. Method: "testInfo",
  92. AuthInfo: auth,
  93. }
  94. return icredentials.NewRequestInfoContext(ctx, ri)
  95. }
  96. // errReader implements the io.Reader interface and returns an error from the
  97. // Read method.
  98. type errReader struct{}
  99. func (r errReader) Read(b []byte) (n int, err error) {
  100. return 0, errors.New("read error")
  101. }
  102. // We need a function to construct the response instead of simply declaring it
  103. // as a variable since the response body will be consumed by the
  104. // credentials, and therefore we will need a new one everytime.
  105. func makeGoodResponse() *http.Response {
  106. respJSON, _ := json.Marshal(responseParameters{
  107. AccessToken: accessTokenContents,
  108. IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
  109. TokenType: "Bearer",
  110. ExpiresIn: 3600,
  111. })
  112. respBody := io.NopCloser(bytes.NewReader(respJSON))
  113. return &http.Response{
  114. Status: "200 OK",
  115. StatusCode: http.StatusOK,
  116. Body: respBody,
  117. }
  118. }
  119. // Overrides the http.Client with a fakeClient which sends a good response.
  120. func overrideHTTPClientGood() (*testutils.FakeHTTPClient, func()) {
  121. fc := &testutils.FakeHTTPClient{
  122. ReqChan: testutils.NewChannel(),
  123. RespChan: testutils.NewChannel(),
  124. }
  125. fc.RespChan.Send(makeGoodResponse())
  126. origMakeHTTPDoer := makeHTTPDoer
  127. makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
  128. return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
  129. }
  130. // Overrides the http.Client with the provided fakeClient.
  131. func overrideHTTPClient(fc *testutils.FakeHTTPClient) func() {
  132. origMakeHTTPDoer := makeHTTPDoer
  133. makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
  134. return func() { makeHTTPDoer = origMakeHTTPDoer }
  135. }
  136. // Overrides the subject token read to return a const which we can compare in
  137. // our tests.
  138. func overrideSubjectTokenGood() func() {
  139. origReadSubjectTokenFrom := readSubjectTokenFrom
  140. readSubjectTokenFrom = func(path string) ([]byte, error) {
  141. return []byte(subjectTokenContents), nil
  142. }
  143. return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
  144. }
  145. // Overrides the subject token read to always return an error.
  146. func overrideSubjectTokenError() func() {
  147. origReadSubjectTokenFrom := readSubjectTokenFrom
  148. readSubjectTokenFrom = func(path string) ([]byte, error) {
  149. return nil, errors.New("error reading subject token")
  150. }
  151. return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
  152. }
  153. // Overrides the actor token read to return a const which we can compare in
  154. // our tests.
  155. func overrideActorTokenGood() func() {
  156. origReadActorTokenFrom := readActorTokenFrom
  157. readActorTokenFrom = func(path string) ([]byte, error) {
  158. return []byte(actorTokenContents), nil
  159. }
  160. return func() { readActorTokenFrom = origReadActorTokenFrom }
  161. }
  162. // Overrides the actor token read to always return an error.
  163. func overrideActorTokenError() func() {
  164. origReadActorTokenFrom := readActorTokenFrom
  165. readActorTokenFrom = func(path string) ([]byte, error) {
  166. return nil, errors.New("error reading actor token")
  167. }
  168. return func() { readActorTokenFrom = origReadActorTokenFrom }
  169. }
  170. // compareRequest compares the http.Request received in the test with the
  171. // expected requestParameters specified in wantReqParams.
  172. func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
  173. jsonBody, err := json.Marshal(wantReqParams)
  174. if err != nil {
  175. return err
  176. }
  177. wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
  178. if err != nil {
  179. return fmt.Errorf("failed to create http request: %v", err)
  180. }
  181. wantReq.Header.Set("Content-Type", "application/json")
  182. wantR, err := httputil.DumpRequestOut(wantReq, true)
  183. if err != nil {
  184. return err
  185. }
  186. gotR, err := httputil.DumpRequestOut(gotRequest, true)
  187. if err != nil {
  188. return err
  189. }
  190. if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
  191. return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
  192. }
  193. return nil
  194. }
  195. // receiveAndCompareRequest waits for a request to be sent out by the
  196. // credentials implementation using the fakeHTTPClient and compares it to an
  197. // expected goodRequest. This is expected to be called in a separate goroutine
  198. // by the tests. So, any errors encountered are pushed to an error channel
  199. // which is monitored by the test.
  200. func receiveAndCompareRequest(ReqChan *testutils.Channel, errCh chan error) {
  201. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  202. defer cancel()
  203. val, err := ReqChan.Receive(ctx)
  204. if err != nil {
  205. errCh <- err
  206. return
  207. }
  208. req := val.(*http.Request)
  209. if err := compareRequest(req, goodRequestParams); err != nil {
  210. errCh <- err
  211. return
  212. }
  213. errCh <- nil
  214. }
  215. // TestGetRequestMetadataSuccess verifies the successful case of sending an
  216. // token exchange request and processing the response.
  217. func (s) TestGetRequestMetadataSuccess(t *testing.T) {
  218. defer overrideSubjectTokenGood()()
  219. fc, cancel := overrideHTTPClientGood()
  220. defer cancel()
  221. creds, err := NewCredentials(goodOptions)
  222. if err != nil {
  223. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  224. }
  225. errCh := make(chan error, 1)
  226. go receiveAndCompareRequest(fc.ReqChan, errCh)
  227. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  228. defer cancel()
  229. gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
  230. if err != nil {
  231. t.Fatalf("creds.GetRequestMetadata() = %v", err)
  232. }
  233. if !cmp.Equal(gotMetadata, goodMetadata) {
  234. t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
  235. }
  236. if err := <-errCh; err != nil {
  237. t.Fatal(err)
  238. }
  239. // Make another call to get request metadata and this should return contents
  240. // from the cache. This will fail if the credentials tries to send a fresh
  241. // request here since we have not configured our fakeClient to return any
  242. // response on retries.
  243. gotMetadata, err = creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
  244. if err != nil {
  245. t.Fatalf("creds.GetRequestMetadata() = %v", err)
  246. }
  247. if !cmp.Equal(gotMetadata, goodMetadata) {
  248. t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
  249. }
  250. }
  251. // TestGetRequestMetadataBadSecurityLevel verifies the case where the
  252. // securityLevel specified in the context passed to GetRequestMetadata is not
  253. // sufficient.
  254. func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
  255. defer overrideSubjectTokenGood()()
  256. creds, err := NewCredentials(goodOptions)
  257. if err != nil {
  258. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  259. }
  260. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  261. defer cancel()
  262. gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.IntegrityOnly), "")
  263. if err == nil {
  264. t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
  265. }
  266. }
  267. // TestGetRequestMetadataCacheExpiry verifies the case where the cached access
  268. // token has expired, and the credentials implementation will have to send a
  269. // fresh token exchange request.
  270. func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
  271. const expiresInSecs = 1
  272. defer overrideSubjectTokenGood()()
  273. fc := &testutils.FakeHTTPClient{
  274. ReqChan: testutils.NewChannel(),
  275. RespChan: testutils.NewChannel(),
  276. }
  277. defer overrideHTTPClient(fc)()
  278. creds, err := NewCredentials(goodOptions)
  279. if err != nil {
  280. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  281. }
  282. // The fakeClient is configured to return an access_token with a one second
  283. // expiry. So, in the second iteration, the credentials will find the cache
  284. // entry, but that would have expired, and therefore we expect it to send
  285. // out a fresh request.
  286. for i := 0; i < 2; i++ {
  287. errCh := make(chan error, 1)
  288. go receiveAndCompareRequest(fc.ReqChan, errCh)
  289. respJSON, _ := json.Marshal(responseParameters{
  290. AccessToken: accessTokenContents,
  291. IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
  292. TokenType: "Bearer",
  293. ExpiresIn: expiresInSecs,
  294. })
  295. respBody := io.NopCloser(bytes.NewReader(respJSON))
  296. resp := &http.Response{
  297. Status: "200 OK",
  298. StatusCode: http.StatusOK,
  299. Body: respBody,
  300. }
  301. fc.RespChan.Send(resp)
  302. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  303. defer cancel()
  304. gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
  305. if err != nil {
  306. t.Fatalf("creds.GetRequestMetadata() = %v", err)
  307. }
  308. if !cmp.Equal(gotMetadata, goodMetadata) {
  309. t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
  310. }
  311. if err := <-errCh; err != nil {
  312. t.Fatal(err)
  313. }
  314. time.Sleep(expiresInSecs * time.Second)
  315. }
  316. }
  317. // TestGetRequestMetadataBadResponses verifies the scenario where the token
  318. // exchange server returns bad responses.
  319. func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
  320. tests := []struct {
  321. name string
  322. response *http.Response
  323. }{
  324. {
  325. name: "bad JSON",
  326. response: &http.Response{
  327. Status: "200 OK",
  328. StatusCode: http.StatusOK,
  329. Body: io.NopCloser(strings.NewReader("not JSON")),
  330. },
  331. },
  332. {
  333. name: "no access token",
  334. response: &http.Response{
  335. Status: "200 OK",
  336. StatusCode: http.StatusOK,
  337. Body: io.NopCloser(strings.NewReader("{}")),
  338. },
  339. },
  340. }
  341. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  342. defer cancel()
  343. for _, test := range tests {
  344. t.Run(test.name, func(t *testing.T) {
  345. defer overrideSubjectTokenGood()()
  346. fc := &testutils.FakeHTTPClient{
  347. ReqChan: testutils.NewChannel(),
  348. RespChan: testutils.NewChannel(),
  349. }
  350. defer overrideHTTPClient(fc)()
  351. creds, err := NewCredentials(goodOptions)
  352. if err != nil {
  353. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  354. }
  355. errCh := make(chan error, 1)
  356. go receiveAndCompareRequest(fc.ReqChan, errCh)
  357. fc.RespChan.Send(test.response)
  358. if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
  359. t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
  360. }
  361. if err := <-errCh; err != nil {
  362. t.Fatal(err)
  363. }
  364. })
  365. }
  366. }
  367. // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the
  368. // attempt to read the subjectToken fails.
  369. func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
  370. defer overrideSubjectTokenError()()
  371. fc, cancel := overrideHTTPClientGood()
  372. defer cancel()
  373. creds, err := NewCredentials(goodOptions)
  374. if err != nil {
  375. t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
  376. }
  377. errCh := make(chan error, 1)
  378. go func() {
  379. ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
  380. defer cancel()
  381. if _, err := fc.ReqChan.Receive(ctx); err != context.DeadlineExceeded {
  382. errCh <- err
  383. return
  384. }
  385. errCh <- nil
  386. }()
  387. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  388. defer cancel()
  389. if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
  390. t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
  391. }
  392. if err := <-errCh; err != nil {
  393. t.Fatal(err)
  394. }
  395. }
  396. func (s) TestNewCredentials(t *testing.T) {
  397. tests := []struct {
  398. name string
  399. opts Options
  400. errSystemRoots bool
  401. wantErr bool
  402. }{
  403. {
  404. name: "invalid options - empty subjectTokenPath",
  405. opts: Options{
  406. TokenExchangeServiceURI: serviceURI,
  407. },
  408. wantErr: true,
  409. },
  410. {
  411. name: "invalid system root certs",
  412. opts: goodOptions,
  413. errSystemRoots: true,
  414. wantErr: true,
  415. },
  416. {
  417. name: "good case",
  418. opts: goodOptions,
  419. },
  420. }
  421. for _, test := range tests {
  422. t.Run(test.name, func(t *testing.T) {
  423. if test.errSystemRoots {
  424. oldSystemRoots := loadSystemCertPool
  425. loadSystemCertPool = func() (*x509.CertPool, error) {
  426. return nil, errors.New("failed to load system cert pool")
  427. }
  428. defer func() {
  429. loadSystemCertPool = oldSystemRoots
  430. }()
  431. }
  432. creds, err := NewCredentials(test.opts)
  433. if (err != nil) != test.wantErr {
  434. t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
  435. }
  436. if err == nil {
  437. if !creds.RequireTransportSecurity() {
  438. t.Errorf("creds.RequireTransportSecurity() returned false")
  439. }
  440. }
  441. })
  442. }
  443. }
  444. func (s) TestValidateOptions(t *testing.T) {
  445. tests := []struct {
  446. name string
  447. opts Options
  448. wantErrPrefix string
  449. }{
  450. {
  451. name: "empty token exchange service URI",
  452. opts: Options{},
  453. wantErrPrefix: "empty token_exchange_service_uri in options",
  454. },
  455. {
  456. name: "invalid URI",
  457. opts: Options{
  458. TokenExchangeServiceURI: "\tI'm a bad URI\n",
  459. },
  460. wantErrPrefix: "invalid control character in URL",
  461. },
  462. {
  463. name: "unsupported scheme",
  464. opts: Options{
  465. TokenExchangeServiceURI: "unix:///path/to/socket",
  466. },
  467. wantErrPrefix: "scheme is not supported",
  468. },
  469. {
  470. name: "empty subjectTokenPath",
  471. opts: Options{
  472. TokenExchangeServiceURI: serviceURI,
  473. },
  474. wantErrPrefix: "required field SubjectTokenPath is not specified",
  475. },
  476. {
  477. name: "empty subjectTokenType",
  478. opts: Options{
  479. TokenExchangeServiceURI: serviceURI,
  480. SubjectTokenPath: subjectTokenPath,
  481. },
  482. wantErrPrefix: "required field SubjectTokenType is not specified",
  483. },
  484. {
  485. name: "good options",
  486. opts: goodOptions,
  487. },
  488. }
  489. for _, test := range tests {
  490. t.Run(test.name, func(t *testing.T) {
  491. err := validateOptions(test.opts)
  492. if (err != nil) != (test.wantErrPrefix != "") {
  493. t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
  494. }
  495. if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
  496. t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
  497. }
  498. })
  499. }
  500. }
  501. func (s) TestConstructRequest(t *testing.T) {
  502. tests := []struct {
  503. name string
  504. opts Options
  505. subjectTokenReadErr bool
  506. actorTokenReadErr bool
  507. wantReqParams *requestParameters
  508. wantErr bool
  509. }{
  510. {
  511. name: "subject token read failure",
  512. subjectTokenReadErr: true,
  513. opts: goodOptions,
  514. wantErr: true,
  515. },
  516. {
  517. name: "actor token read failure",
  518. actorTokenReadErr: true,
  519. opts: Options{
  520. TokenExchangeServiceURI: serviceURI,
  521. Audience: exampleAudience,
  522. RequestedTokenType: requestedTokenType,
  523. SubjectTokenPath: subjectTokenPath,
  524. SubjectTokenType: subjectTokenType,
  525. ActorTokenPath: actorTokenPath,
  526. ActorTokenType: actorTokenType,
  527. },
  528. wantErr: true,
  529. },
  530. {
  531. name: "default cloud platform scope",
  532. opts: goodOptions,
  533. wantReqParams: goodRequestParams,
  534. },
  535. {
  536. name: "all good",
  537. opts: Options{
  538. TokenExchangeServiceURI: serviceURI,
  539. Resource: exampleResource,
  540. Audience: exampleAudience,
  541. Scope: testScope,
  542. RequestedTokenType: requestedTokenType,
  543. SubjectTokenPath: subjectTokenPath,
  544. SubjectTokenType: subjectTokenType,
  545. ActorTokenPath: actorTokenPath,
  546. ActorTokenType: actorTokenType,
  547. },
  548. wantReqParams: &requestParameters{
  549. GrantType: tokenExchangeGrantType,
  550. Resource: exampleResource,
  551. Audience: exampleAudience,
  552. Scope: testScope,
  553. RequestedTokenType: requestedTokenType,
  554. SubjectToken: subjectTokenContents,
  555. SubjectTokenType: subjectTokenType,
  556. ActorToken: actorTokenContents,
  557. ActorTokenType: actorTokenType,
  558. },
  559. },
  560. }
  561. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  562. defer cancel()
  563. for _, test := range tests {
  564. t.Run(test.name, func(t *testing.T) {
  565. if test.subjectTokenReadErr {
  566. defer overrideSubjectTokenError()()
  567. } else {
  568. defer overrideSubjectTokenGood()()
  569. }
  570. if test.actorTokenReadErr {
  571. defer overrideActorTokenError()()
  572. } else {
  573. defer overrideActorTokenGood()()
  574. }
  575. gotRequest, err := constructRequest(ctx, test.opts)
  576. if (err != nil) != test.wantErr {
  577. t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
  578. }
  579. if test.wantErr {
  580. return
  581. }
  582. if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
  583. t.Fatal(err)
  584. }
  585. })
  586. }
  587. }
  588. func (s) TestSendRequest(t *testing.T) {
  589. defer overrideSubjectTokenGood()()
  590. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  591. defer cancel()
  592. req, err := constructRequest(ctx, goodOptions)
  593. if err != nil {
  594. t.Fatal(err)
  595. }
  596. tests := []struct {
  597. name string
  598. resp *http.Response
  599. respErr error
  600. wantErr bool
  601. }{
  602. {
  603. name: "client error",
  604. respErr: errors.New("http.Client.Do failed"),
  605. wantErr: true,
  606. },
  607. {
  608. name: "bad response body",
  609. resp: &http.Response{
  610. Status: "200 OK",
  611. StatusCode: http.StatusOK,
  612. Body: io.NopCloser(errReader{}),
  613. },
  614. wantErr: true,
  615. },
  616. {
  617. name: "nonOK status code",
  618. resp: &http.Response{
  619. Status: "400 BadRequest",
  620. StatusCode: http.StatusBadRequest,
  621. Body: io.NopCloser(strings.NewReader("")),
  622. },
  623. wantErr: true,
  624. },
  625. {
  626. name: "good case",
  627. resp: makeGoodResponse(),
  628. },
  629. }
  630. for _, test := range tests {
  631. t.Run(test.name, func(t *testing.T) {
  632. client := &testutils.FakeHTTPClient{
  633. ReqChan: testutils.NewChannel(),
  634. RespChan: testutils.NewChannel(),
  635. Err: test.respErr,
  636. }
  637. client.RespChan.Send(test.resp)
  638. _, err := sendRequest(client, req)
  639. if (err != nil) != test.wantErr {
  640. t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
  641. }
  642. })
  643. }
  644. }
  645. func (s) TestTokenInfoFromResponse(t *testing.T) {
  646. noAccessToken, _ := json.Marshal(responseParameters{
  647. IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
  648. TokenType: "Bearer",
  649. ExpiresIn: 3600,
  650. })
  651. goodResponse, _ := json.Marshal(responseParameters{
  652. IssuedTokenType: requestedTokenType,
  653. AccessToken: accessTokenContents,
  654. TokenType: "Bearer",
  655. ExpiresIn: 3600,
  656. })
  657. tests := []struct {
  658. name string
  659. respBody []byte
  660. wantTokenInfo *tokenInfo
  661. wantErr bool
  662. }{
  663. {
  664. name: "bad JSON",
  665. respBody: []byte("not JSON"),
  666. wantErr: true,
  667. },
  668. {
  669. name: "empty response",
  670. respBody: []byte(""),
  671. wantErr: true,
  672. },
  673. {
  674. name: "non-empty response with no access token",
  675. respBody: noAccessToken,
  676. wantErr: true,
  677. },
  678. {
  679. name: "good response",
  680. respBody: goodResponse,
  681. wantTokenInfo: &tokenInfo{
  682. tokenType: "Bearer",
  683. token: accessTokenContents,
  684. },
  685. },
  686. }
  687. for _, test := range tests {
  688. t.Run(test.name, func(t *testing.T) {
  689. gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
  690. if (err != nil) != test.wantErr {
  691. t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
  692. }
  693. if test.wantErr {
  694. return
  695. }
  696. // Can't do a cmp.Equal on the whole struct since the expiryField
  697. // is populated based on time.Now().
  698. if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
  699. t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
  700. }
  701. })
  702. }
  703. }