xds_client_test.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. /*
  2. *
  3. * Copyright 2020 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package xds
  19. import (
  20. "context"
  21. "crypto/tls"
  22. "crypto/x509"
  23. "errors"
  24. "fmt"
  25. "net"
  26. "os"
  27. "strings"
  28. "testing"
  29. "time"
  30. "google.golang.org/grpc/credentials"
  31. "google.golang.org/grpc/credentials/tls/certprovider"
  32. icredentials "google.golang.org/grpc/internal/credentials"
  33. xdsinternal "google.golang.org/grpc/internal/credentials/xds"
  34. "google.golang.org/grpc/internal/grpctest"
  35. "google.golang.org/grpc/internal/testutils"
  36. "google.golang.org/grpc/internal/xds/matcher"
  37. "google.golang.org/grpc/resolver"
  38. "google.golang.org/grpc/testdata"
  39. )
  40. const (
  41. defaultTestTimeout = 1 * time.Second
  42. defaultTestShortTimeout = 10 * time.Millisecond
  43. defaultTestCertSAN = "abc.test.example.com"
  44. authority = "authority"
  45. )
  46. type s struct {
  47. grpctest.Tester
  48. }
  49. func Test(t *testing.T) {
  50. grpctest.RunSubTests(t, s{})
  51. }
  52. // Helper function to create a real TLS client credentials which is used as
  53. // fallback credentials from multiple tests.
  54. func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
  55. creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
  56. if err != nil {
  57. t.Fatal(err)
  58. }
  59. return creds
  60. }
  61. // testServer is a no-op server which listens on a local TCP port for incoming
  62. // connections, and performs a manual TLS handshake on the received raw
  63. // connection using a user specified handshake function. It then makes the
  64. // result of the handshake operation available through a channel for tests to
  65. // inspect. Tests should stop the testServer as part of their cleanup.
  66. type testServer struct {
  67. lis net.Listener
  68. address string // Listening address of the test server.
  69. handshakeFunc testHandshakeFunc // Test specified handshake function.
  70. hsResult *testutils.Channel // Channel to deliver handshake results.
  71. }
  72. // handshakeResult wraps the result of the handshake operation on the test
  73. // server. It consists of TLS connection state and an error, if the handshake
  74. // failed. This result is delivered on the `hsResult` channel on the testServer.
  75. type handshakeResult struct {
  76. connState tls.ConnectionState
  77. err error
  78. }
  79. // Configurable handshake function for the testServer. Tests can set this to
  80. // simulate different conditions like handshake success, failure, timeout etc.
  81. type testHandshakeFunc func(net.Conn) handshakeResult
  82. // newTestServerWithHandshakeFunc starts a new testServer which listens for
  83. // connections on a local TCP port, and uses the provided custom handshake
  84. // function to perform TLS handshake.
  85. func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
  86. ts := &testServer{
  87. handshakeFunc: f,
  88. hsResult: testutils.NewChannel(),
  89. }
  90. ts.start()
  91. return ts
  92. }
  93. // starts actually starts listening on a local TCP port, and spawns a goroutine
  94. // to handle new connections.
  95. func (ts *testServer) start() error {
  96. lis, err := net.Listen("tcp", "localhost:0")
  97. if err != nil {
  98. return err
  99. }
  100. ts.lis = lis
  101. ts.address = lis.Addr().String()
  102. go ts.handleConn()
  103. return nil
  104. }
  105. // handleconn accepts a new raw connection, and invokes the test provided
  106. // handshake function to perform TLS handshake, and returns the result on the
  107. // `hsResult` channel.
  108. func (ts *testServer) handleConn() {
  109. for {
  110. rawConn, err := ts.lis.Accept()
  111. if err != nil {
  112. // Once the listeners closed, Accept() will return with an error.
  113. return
  114. }
  115. hsr := ts.handshakeFunc(rawConn)
  116. ts.hsResult.Send(hsr)
  117. }
  118. }
  119. // stop closes the associated listener which causes the connection handling
  120. // goroutine to exit.
  121. func (ts *testServer) stop() {
  122. ts.lis.Close()
  123. }
  124. // A handshake function which simulates a successful handshake without client
  125. // authentication (server does not request for client certificate during the
  126. // handshake here).
  127. func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
  128. cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
  129. if err != nil {
  130. return handshakeResult{err: err}
  131. }
  132. cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
  133. conn := tls.Server(rawConn, cfg)
  134. if err := conn.Handshake(); err != nil {
  135. return handshakeResult{err: err}
  136. }
  137. return handshakeResult{connState: conn.ConnectionState()}
  138. }
  139. // A handshake function which simulates a successful handshake with mutual
  140. // authentication.
  141. func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
  142. cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
  143. if err != nil {
  144. return handshakeResult{err: err}
  145. }
  146. pemData, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
  147. if err != nil {
  148. return handshakeResult{err: err}
  149. }
  150. roots := x509.NewCertPool()
  151. roots.AppendCertsFromPEM(pemData)
  152. cfg := &tls.Config{
  153. Certificates: []tls.Certificate{cert},
  154. ClientCAs: roots,
  155. }
  156. conn := tls.Server(rawConn, cfg)
  157. if err := conn.Handshake(); err != nil {
  158. return handshakeResult{err: err}
  159. }
  160. return handshakeResult{connState: conn.ConnectionState()}
  161. }
  162. // fakeProvider is an implementation of the certprovider.Provider interface
  163. // which returns the configured key material and error in calls to
  164. // KeyMaterial().
  165. type fakeProvider struct {
  166. km *certprovider.KeyMaterial
  167. err error
  168. }
  169. func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
  170. return f.km, f.err
  171. }
  172. func (f *fakeProvider) Close() {}
  173. // makeIdentityProvider creates a new instance of the fakeProvider returning the
  174. // identity key material specified in the provider file paths.
  175. func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider {
  176. t.Helper()
  177. cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
  178. if err != nil {
  179. t.Fatal(err)
  180. }
  181. return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}}
  182. }
  183. // makeRootProvider creates a new instance of the fakeProvider returning the
  184. // root key material specified in the provider file paths.
  185. func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
  186. pemData, err := os.ReadFile(testdata.Path(caPath))
  187. if err != nil {
  188. t.Fatal(err)
  189. }
  190. roots := x509.NewCertPool()
  191. roots.AppendCertsFromPEM(pemData)
  192. return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
  193. }
  194. // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
  195. // context value added to it.
  196. func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
  197. // Creating the HandshakeInfo and adding it to the attributes is very
  198. // similar to what the CDS balancer would do when it intercepts calls to
  199. // NewSubConn().
  200. info := xdsinternal.NewHandshakeInfo(root, identity)
  201. if sanExactMatch != "" {
  202. info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)})
  203. }
  204. addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
  205. // Moving the attributes from the resolver.Address to the context passed to
  206. // the handshaker is done in the transport layer. Since we directly call the
  207. // handshaker in these tests, we need to do the same here.
  208. return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
  209. }
  210. // compareAuthInfo compares the AuthInfo received on the client side after a
  211. // successful handshake with the authInfo available on the testServer.
  212. func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
  213. if ai.AuthType() != "tls" {
  214. return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
  215. }
  216. info, ok := ai.(credentials.TLSInfo)
  217. if !ok {
  218. return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})
  219. }
  220. gotState := info.State
  221. // Read the handshake result from the testServer which contains the TLS
  222. // connection state and compare it with the one received on the client-side.
  223. val, err := ts.hsResult.Receive(ctx)
  224. if err != nil {
  225. return fmt.Errorf("testServer failed to return handshake result: %v", err)
  226. }
  227. hsr := val.(handshakeResult)
  228. if hsr.err != nil {
  229. return fmt.Errorf("testServer handshake failure: %v", hsr.err)
  230. }
  231. // AuthInfo contains a variety of information. We only verify a subset here.
  232. // This is the same subset which is verified in TLS credentials tests.
  233. if err := compareConnState(gotState, hsr.connState); err != nil {
  234. return err
  235. }
  236. return nil
  237. }
  238. func compareConnState(got, want tls.ConnectionState) error {
  239. switch {
  240. case got.Version != want.Version:
  241. return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version)
  242. case got.HandshakeComplete != want.HandshakeComplete:
  243. return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete)
  244. case got.CipherSuite != want.CipherSuite:
  245. return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite)
  246. case got.NegotiatedProtocol != want.NegotiatedProtocol:
  247. return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol)
  248. }
  249. return nil
  250. }
  251. // TestClientCredsWithoutFallback verifies that the call to
  252. // NewClientCredentials() fails when no fallback is specified.
  253. func (s) TestClientCredsWithoutFallback(t *testing.T) {
  254. if _, err := NewClientCredentials(ClientOptions{}); err == nil {
  255. t.Fatal("NewClientCredentials() succeeded without specifying fallback")
  256. }
  257. }
  258. // TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in
  259. // HandshakeInfo is invalid because it does not contain the expected certificate
  260. // providers.
  261. func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
  262. opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
  263. creds, err := NewClientCredentials(opts)
  264. if err != nil {
  265. t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
  266. }
  267. pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  268. defer cancel()
  269. ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "")
  270. if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
  271. t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
  272. }
  273. }
  274. // TestClientCredsProviderFailure verifies the cases where an expected
  275. // certificate provider is missing in the HandshakeInfo value in the context.
  276. func (s) TestClientCredsProviderFailure(t *testing.T) {
  277. opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
  278. creds, err := NewClientCredentials(opts)
  279. if err != nil {
  280. t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
  281. }
  282. tests := []struct {
  283. desc string
  284. rootProvider certprovider.Provider
  285. identityProvider certprovider.Provider
  286. wantErr string
  287. }{
  288. {
  289. desc: "erroring root provider",
  290. rootProvider: &fakeProvider{err: errors.New("root provider error")},
  291. wantErr: "root provider error",
  292. },
  293. {
  294. desc: "erroring identity provider",
  295. rootProvider: &fakeProvider{km: &certprovider.KeyMaterial{}},
  296. identityProvider: &fakeProvider{err: errors.New("identity provider error")},
  297. wantErr: "identity provider error",
  298. },
  299. }
  300. for _, test := range tests {
  301. t.Run(test.desc, func(t *testing.T) {
  302. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  303. defer cancel()
  304. ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "")
  305. if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
  306. t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
  307. }
  308. })
  309. }
  310. }
  311. // TestClientCredsSuccess verifies successful client handshake cases.
  312. func (s) TestClientCredsSuccess(t *testing.T) {
  313. tests := []struct {
  314. desc string
  315. handshakeFunc testHandshakeFunc
  316. handshakeInfoCtx func(ctx context.Context) context.Context
  317. }{
  318. {
  319. desc: "fallback",
  320. handshakeFunc: testServerTLSHandshake,
  321. handshakeInfoCtx: func(ctx context.Context) context.Context {
  322. // Since we don't add a HandshakeInfo to the context, the
  323. // ClientHandshake() method will delegate to the fallback.
  324. return ctx
  325. },
  326. },
  327. {
  328. desc: "TLS",
  329. handshakeFunc: testServerTLSHandshake,
  330. handshakeInfoCtx: func(ctx context.Context) context.Context {
  331. return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
  332. },
  333. },
  334. {
  335. desc: "mTLS",
  336. handshakeFunc: testServerMutualTLSHandshake,
  337. handshakeInfoCtx: func(ctx context.Context) context.Context {
  338. return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
  339. },
  340. },
  341. {
  342. desc: "mTLS with no acceptedSANs specified",
  343. handshakeFunc: testServerMutualTLSHandshake,
  344. handshakeInfoCtx: func(ctx context.Context) context.Context {
  345. return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
  346. },
  347. },
  348. }
  349. for _, test := range tests {
  350. t.Run(test.desc, func(t *testing.T) {
  351. ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
  352. defer ts.stop()
  353. opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
  354. creds, err := NewClientCredentials(opts)
  355. if err != nil {
  356. t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
  357. }
  358. conn, err := net.Dial("tcp", ts.address)
  359. if err != nil {
  360. t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
  361. }
  362. defer conn.Close()
  363. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  364. defer cancel()
  365. _, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
  366. if err != nil {
  367. t.Fatalf("ClientHandshake() returned failed: %q", err)
  368. }
  369. if err := compareAuthInfo(ctx, ts, ai); err != nil {
  370. t.Fatal(err)
  371. }
  372. })
  373. }
  374. }
  375. func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
  376. clientDone := make(chan struct{})
  377. // A handshake function which simulates a handshake timeout from the
  378. // server-side by simply blocking on the client-side handshake to timeout
  379. // and not writing any handshake data.
  380. hErr := errors.New("server handshake error")
  381. ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
  382. <-clientDone
  383. return handshakeResult{err: hErr}
  384. })
  385. defer ts.stop()
  386. opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
  387. creds, err := NewClientCredentials(opts)
  388. if err != nil {
  389. t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
  390. }
  391. conn, err := net.Dial("tcp", ts.address)
  392. if err != nil {
  393. t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
  394. }
  395. defer conn.Close()
  396. sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
  397. defer sCancel()
  398. ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
  399. if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
  400. t.Fatal("ClientHandshake() succeeded when expected to timeout")
  401. }
  402. close(clientDone)
  403. // Read the handshake result from the testServer and make sure the expected
  404. // error is returned.
  405. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  406. defer cancel()
  407. val, err := ts.hsResult.Receive(ctx)
  408. if err != nil {
  409. t.Fatalf("testServer failed to return handshake result: %v", err)
  410. }
  411. hsr := val.(handshakeResult)
  412. if hsr.err != hErr {
  413. t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
  414. }
  415. }
  416. // TestClientCredsHandshakeFailure verifies different handshake failure cases.
  417. func (s) TestClientCredsHandshakeFailure(t *testing.T) {
  418. tests := []struct {
  419. desc string
  420. handshakeFunc testHandshakeFunc
  421. rootProvider certprovider.Provider
  422. san string
  423. wantErr string
  424. }{
  425. {
  426. desc: "cert validation failure",
  427. handshakeFunc: testServerTLSHandshake,
  428. rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"),
  429. san: defaultTestCertSAN,
  430. wantErr: "x509: certificate signed by unknown authority",
  431. },
  432. {
  433. desc: "SAN mismatch",
  434. handshakeFunc: testServerTLSHandshake,
  435. rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"),
  436. san: "bad-san",
  437. wantErr: "do not match any of the accepted SANs",
  438. },
  439. }
  440. for _, test := range tests {
  441. t.Run(test.desc, func(t *testing.T) {
  442. ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
  443. defer ts.stop()
  444. opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
  445. creds, err := NewClientCredentials(opts)
  446. if err != nil {
  447. t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
  448. }
  449. conn, err := net.Dial("tcp", ts.address)
  450. if err != nil {
  451. t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
  452. }
  453. defer conn.Close()
  454. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  455. defer cancel()
  456. ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
  457. if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
  458. t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
  459. }
  460. })
  461. }
  462. }
  463. // TestClientCredsProviderSwitch verifies the case where the first attempt of
  464. // ClientHandshake fails because of a handshake failure. Then we update the
  465. // certificate provider and the second attempt succeeds. This is an
  466. // approximation of the flow of events when the control plane specifies new
  467. // security config which results in new certificate providers being used.
  468. func (s) TestClientCredsProviderSwitch(t *testing.T) {
  469. ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
  470. defer ts.stop()
  471. opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
  472. creds, err := NewClientCredentials(opts)
  473. if err != nil {
  474. t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
  475. }
  476. conn, err := net.Dial("tcp", ts.address)
  477. if err != nil {
  478. t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
  479. }
  480. defer conn.Close()
  481. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  482. defer cancel()
  483. // Create a root provider which will fail the handshake because it does not
  484. // use the correct trust roots.
  485. root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
  486. handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
  487. handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)})
  488. // We need to repeat most of what newTestContextWithHandshakeInfo() does
  489. // here because we need access to the underlying HandshakeInfo so that we
  490. // can update it before the next call to ClientHandshake().
  491. addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
  492. ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
  493. if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
  494. t.Fatal("ClientHandshake() succeeded when expected to fail")
  495. }
  496. // Drain the result channel on the test server so that we can inspect the
  497. // result for the next handshake.
  498. _, err = ts.hsResult.Receive(ctx)
  499. if err != nil {
  500. t.Errorf("testServer failed to return handshake result: %v", err)
  501. }
  502. conn, err = net.Dial("tcp", ts.address)
  503. if err != nil {
  504. t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
  505. }
  506. defer conn.Close()
  507. // Create a new root provider which uses the correct trust roots. And update
  508. // the HandshakeInfo with the new provider.
  509. root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
  510. handshakeInfo.SetRootCertProvider(root2)
  511. _, ai, err := creds.ClientHandshake(ctx, authority, conn)
  512. if err != nil {
  513. t.Fatalf("ClientHandshake() returned failed: %q", err)
  514. }
  515. if err := compareAuthInfo(ctx, ts, ai); err != nil {
  516. t.Fatal(err)
  517. }
  518. }
  519. // TestClientClone verifies the Clone() method on client credentials.
  520. func (s) TestClientClone(t *testing.T) {
  521. opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
  522. orig, err := NewClientCredentials(opts)
  523. if err != nil {
  524. t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
  525. }
  526. // The credsImpl does not have any exported fields, and it does not make
  527. // sense to use any cmp options to look deep into. So, all we make sure here
  528. // is that the cloned object points to a different location in memory.
  529. if clone := orig.Clone(); clone == orig {
  530. t.Fatal("return value from Clone() doesn't point to new credentials instance")
  531. }
  532. }
  533. func newStringP(s string) *string {
  534. return &s
  535. }