123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591 |
- /*
- *
- * Copyright 2020 gRPC authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
- package xds
- import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "errors"
- "fmt"
- "net"
- "os"
- "strings"
- "testing"
- "time"
- "google.golang.org/grpc/credentials"
- "google.golang.org/grpc/credentials/tls/certprovider"
- icredentials "google.golang.org/grpc/internal/credentials"
- xdsinternal "google.golang.org/grpc/internal/credentials/xds"
- "google.golang.org/grpc/internal/grpctest"
- "google.golang.org/grpc/internal/testutils"
- "google.golang.org/grpc/internal/xds/matcher"
- "google.golang.org/grpc/resolver"
- "google.golang.org/grpc/testdata"
- )
- const (
- defaultTestTimeout = 1 * time.Second
- defaultTestShortTimeout = 10 * time.Millisecond
- defaultTestCertSAN = "abc.test.example.com"
- authority = "authority"
- )
- type s struct {
- grpctest.Tester
- }
- func Test(t *testing.T) {
- grpctest.RunSubTests(t, s{})
- }
- // Helper function to create a real TLS client credentials which is used as
- // fallback credentials from multiple tests.
- func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
- creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
- if err != nil {
- t.Fatal(err)
- }
- return creds
- }
- // testServer is a no-op server which listens on a local TCP port for incoming
- // connections, and performs a manual TLS handshake on the received raw
- // connection using a user specified handshake function. It then makes the
- // result of the handshake operation available through a channel for tests to
- // inspect. Tests should stop the testServer as part of their cleanup.
- type testServer struct {
- lis net.Listener
- address string // Listening address of the test server.
- handshakeFunc testHandshakeFunc // Test specified handshake function.
- hsResult *testutils.Channel // Channel to deliver handshake results.
- }
- // handshakeResult wraps the result of the handshake operation on the test
- // server. It consists of TLS connection state and an error, if the handshake
- // failed. This result is delivered on the `hsResult` channel on the testServer.
- type handshakeResult struct {
- connState tls.ConnectionState
- err error
- }
- // Configurable handshake function for the testServer. Tests can set this to
- // simulate different conditions like handshake success, failure, timeout etc.
- type testHandshakeFunc func(net.Conn) handshakeResult
- // newTestServerWithHandshakeFunc starts a new testServer which listens for
- // connections on a local TCP port, and uses the provided custom handshake
- // function to perform TLS handshake.
- func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
- ts := &testServer{
- handshakeFunc: f,
- hsResult: testutils.NewChannel(),
- }
- ts.start()
- return ts
- }
- // starts actually starts listening on a local TCP port, and spawns a goroutine
- // to handle new connections.
- func (ts *testServer) start() error {
- lis, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- return err
- }
- ts.lis = lis
- ts.address = lis.Addr().String()
- go ts.handleConn()
- return nil
- }
- // handleconn accepts a new raw connection, and invokes the test provided
- // handshake function to perform TLS handshake, and returns the result on the
- // `hsResult` channel.
- func (ts *testServer) handleConn() {
- for {
- rawConn, err := ts.lis.Accept()
- if err != nil {
- // Once the listeners closed, Accept() will return with an error.
- return
- }
- hsr := ts.handshakeFunc(rawConn)
- ts.hsResult.Send(hsr)
- }
- }
- // stop closes the associated listener which causes the connection handling
- // goroutine to exit.
- func (ts *testServer) stop() {
- ts.lis.Close()
- }
- // A handshake function which simulates a successful handshake without client
- // authentication (server does not request for client certificate during the
- // handshake here).
- func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
- cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
- if err != nil {
- return handshakeResult{err: err}
- }
- cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
- conn := tls.Server(rawConn, cfg)
- if err := conn.Handshake(); err != nil {
- return handshakeResult{err: err}
- }
- return handshakeResult{connState: conn.ConnectionState()}
- }
- // A handshake function which simulates a successful handshake with mutual
- // authentication.
- func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
- cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
- if err != nil {
- return handshakeResult{err: err}
- }
- pemData, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
- if err != nil {
- return handshakeResult{err: err}
- }
- roots := x509.NewCertPool()
- roots.AppendCertsFromPEM(pemData)
- cfg := &tls.Config{
- Certificates: []tls.Certificate{cert},
- ClientCAs: roots,
- }
- conn := tls.Server(rawConn, cfg)
- if err := conn.Handshake(); err != nil {
- return handshakeResult{err: err}
- }
- return handshakeResult{connState: conn.ConnectionState()}
- }
- // fakeProvider is an implementation of the certprovider.Provider interface
- // which returns the configured key material and error in calls to
- // KeyMaterial().
- type fakeProvider struct {
- km *certprovider.KeyMaterial
- err error
- }
- func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
- return f.km, f.err
- }
- func (f *fakeProvider) Close() {}
- // makeIdentityProvider creates a new instance of the fakeProvider returning the
- // identity key material specified in the provider file paths.
- func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider {
- t.Helper()
- cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
- if err != nil {
- t.Fatal(err)
- }
- return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}}
- }
- // makeRootProvider creates a new instance of the fakeProvider returning the
- // root key material specified in the provider file paths.
- func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
- pemData, err := os.ReadFile(testdata.Path(caPath))
- if err != nil {
- t.Fatal(err)
- }
- roots := x509.NewCertPool()
- roots.AppendCertsFromPEM(pemData)
- return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
- }
- // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
- // context value added to it.
- func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
- // Creating the HandshakeInfo and adding it to the attributes is very
- // similar to what the CDS balancer would do when it intercepts calls to
- // NewSubConn().
- info := xdsinternal.NewHandshakeInfo(root, identity)
- if sanExactMatch != "" {
- info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)})
- }
- addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
- // Moving the attributes from the resolver.Address to the context passed to
- // the handshaker is done in the transport layer. Since we directly call the
- // handshaker in these tests, we need to do the same here.
- return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
- }
- // compareAuthInfo compares the AuthInfo received on the client side after a
- // successful handshake with the authInfo available on the testServer.
- func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
- if ai.AuthType() != "tls" {
- return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
- }
- info, ok := ai.(credentials.TLSInfo)
- if !ok {
- return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})
- }
- gotState := info.State
- // Read the handshake result from the testServer which contains the TLS
- // connection state and compare it with the one received on the client-side.
- val, err := ts.hsResult.Receive(ctx)
- if err != nil {
- return fmt.Errorf("testServer failed to return handshake result: %v", err)
- }
- hsr := val.(handshakeResult)
- if hsr.err != nil {
- return fmt.Errorf("testServer handshake failure: %v", hsr.err)
- }
- // AuthInfo contains a variety of information. We only verify a subset here.
- // This is the same subset which is verified in TLS credentials tests.
- if err := compareConnState(gotState, hsr.connState); err != nil {
- return err
- }
- return nil
- }
- func compareConnState(got, want tls.ConnectionState) error {
- switch {
- case got.Version != want.Version:
- return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version)
- case got.HandshakeComplete != want.HandshakeComplete:
- return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete)
- case got.CipherSuite != want.CipherSuite:
- return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite)
- case got.NegotiatedProtocol != want.NegotiatedProtocol:
- return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol)
- }
- return nil
- }
- // TestClientCredsWithoutFallback verifies that the call to
- // NewClientCredentials() fails when no fallback is specified.
- func (s) TestClientCredsWithoutFallback(t *testing.T) {
- if _, err := NewClientCredentials(ClientOptions{}); err == nil {
- t.Fatal("NewClientCredentials() succeeded without specifying fallback")
- }
- }
- // TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in
- // HandshakeInfo is invalid because it does not contain the expected certificate
- // providers.
- func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
- opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
- creds, err := NewClientCredentials(opts)
- if err != nil {
- t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
- }
- pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
- ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "")
- if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
- t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
- }
- }
- // TestClientCredsProviderFailure verifies the cases where an expected
- // certificate provider is missing in the HandshakeInfo value in the context.
- func (s) TestClientCredsProviderFailure(t *testing.T) {
- opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
- creds, err := NewClientCredentials(opts)
- if err != nil {
- t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
- }
- tests := []struct {
- desc string
- rootProvider certprovider.Provider
- identityProvider certprovider.Provider
- wantErr string
- }{
- {
- desc: "erroring root provider",
- rootProvider: &fakeProvider{err: errors.New("root provider error")},
- wantErr: "root provider error",
- },
- {
- desc: "erroring identity provider",
- rootProvider: &fakeProvider{km: &certprovider.KeyMaterial{}},
- identityProvider: &fakeProvider{err: errors.New("identity provider error")},
- wantErr: "identity provider error",
- },
- }
- for _, test := range tests {
- t.Run(test.desc, func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
- ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "")
- if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
- t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
- }
- })
- }
- }
- // TestClientCredsSuccess verifies successful client handshake cases.
- func (s) TestClientCredsSuccess(t *testing.T) {
- tests := []struct {
- desc string
- handshakeFunc testHandshakeFunc
- handshakeInfoCtx func(ctx context.Context) context.Context
- }{
- {
- desc: "fallback",
- handshakeFunc: testServerTLSHandshake,
- handshakeInfoCtx: func(ctx context.Context) context.Context {
- // Since we don't add a HandshakeInfo to the context, the
- // ClientHandshake() method will delegate to the fallback.
- return ctx
- },
- },
- {
- desc: "TLS",
- handshakeFunc: testServerTLSHandshake,
- handshakeInfoCtx: func(ctx context.Context) context.Context {
- return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
- },
- },
- {
- desc: "mTLS",
- handshakeFunc: testServerMutualTLSHandshake,
- handshakeInfoCtx: func(ctx context.Context) context.Context {
- return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
- },
- },
- {
- desc: "mTLS with no acceptedSANs specified",
- handshakeFunc: testServerMutualTLSHandshake,
- handshakeInfoCtx: func(ctx context.Context) context.Context {
- return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
- },
- },
- }
- for _, test := range tests {
- t.Run(test.desc, func(t *testing.T) {
- ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
- defer ts.stop()
- opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
- creds, err := NewClientCredentials(opts)
- if err != nil {
- t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
- }
- conn, err := net.Dial("tcp", ts.address)
- if err != nil {
- t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
- }
- defer conn.Close()
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
- _, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
- if err != nil {
- t.Fatalf("ClientHandshake() returned failed: %q", err)
- }
- if err := compareAuthInfo(ctx, ts, ai); err != nil {
- t.Fatal(err)
- }
- })
- }
- }
- func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
- clientDone := make(chan struct{})
- // A handshake function which simulates a handshake timeout from the
- // server-side by simply blocking on the client-side handshake to timeout
- // and not writing any handshake data.
- hErr := errors.New("server handshake error")
- ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
- <-clientDone
- return handshakeResult{err: hErr}
- })
- defer ts.stop()
- opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
- creds, err := NewClientCredentials(opts)
- if err != nil {
- t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
- }
- conn, err := net.Dial("tcp", ts.address)
- if err != nil {
- t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
- }
- defer conn.Close()
- sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
- defer sCancel()
- ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
- if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
- t.Fatal("ClientHandshake() succeeded when expected to timeout")
- }
- close(clientDone)
- // Read the handshake result from the testServer and make sure the expected
- // error is returned.
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
- val, err := ts.hsResult.Receive(ctx)
- if err != nil {
- t.Fatalf("testServer failed to return handshake result: %v", err)
- }
- hsr := val.(handshakeResult)
- if hsr.err != hErr {
- t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
- }
- }
- // TestClientCredsHandshakeFailure verifies different handshake failure cases.
- func (s) TestClientCredsHandshakeFailure(t *testing.T) {
- tests := []struct {
- desc string
- handshakeFunc testHandshakeFunc
- rootProvider certprovider.Provider
- san string
- wantErr string
- }{
- {
- desc: "cert validation failure",
- handshakeFunc: testServerTLSHandshake,
- rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"),
- san: defaultTestCertSAN,
- wantErr: "x509: certificate signed by unknown authority",
- },
- {
- desc: "SAN mismatch",
- handshakeFunc: testServerTLSHandshake,
- rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"),
- san: "bad-san",
- wantErr: "do not match any of the accepted SANs",
- },
- }
- for _, test := range tests {
- t.Run(test.desc, func(t *testing.T) {
- ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
- defer ts.stop()
- opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
- creds, err := NewClientCredentials(opts)
- if err != nil {
- t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
- }
- conn, err := net.Dial("tcp", ts.address)
- if err != nil {
- t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
- }
- defer conn.Close()
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
- ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
- if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
- t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
- }
- })
- }
- }
- // TestClientCredsProviderSwitch verifies the case where the first attempt of
- // ClientHandshake fails because of a handshake failure. Then we update the
- // certificate provider and the second attempt succeeds. This is an
- // approximation of the flow of events when the control plane specifies new
- // security config which results in new certificate providers being used.
- func (s) TestClientCredsProviderSwitch(t *testing.T) {
- ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
- defer ts.stop()
- opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
- creds, err := NewClientCredentials(opts)
- if err != nil {
- t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
- }
- conn, err := net.Dial("tcp", ts.address)
- if err != nil {
- t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
- }
- defer conn.Close()
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
- // Create a root provider which will fail the handshake because it does not
- // use the correct trust roots.
- root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
- handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
- handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)})
- // We need to repeat most of what newTestContextWithHandshakeInfo() does
- // here because we need access to the underlying HandshakeInfo so that we
- // can update it before the next call to ClientHandshake().
- addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
- ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
- if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
- t.Fatal("ClientHandshake() succeeded when expected to fail")
- }
- // Drain the result channel on the test server so that we can inspect the
- // result for the next handshake.
- _, err = ts.hsResult.Receive(ctx)
- if err != nil {
- t.Errorf("testServer failed to return handshake result: %v", err)
- }
- conn, err = net.Dial("tcp", ts.address)
- if err != nil {
- t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
- }
- defer conn.Close()
- // Create a new root provider which uses the correct trust roots. And update
- // the HandshakeInfo with the new provider.
- root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
- handshakeInfo.SetRootCertProvider(root2)
- _, ai, err := creds.ClientHandshake(ctx, authority, conn)
- if err != nil {
- t.Fatalf("ClientHandshake() returned failed: %q", err)
- }
- if err := compareAuthInfo(ctx, ts, ai); err != nil {
- t.Fatal(err)
- }
- }
- // TestClientClone verifies the Clone() method on client credentials.
- func (s) TestClientClone(t *testing.T) {
- opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
- orig, err := NewClientCredentials(opts)
- if err != nil {
- t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
- }
- // The credsImpl does not have any exported fields, and it does not make
- // sense to use any cmp options to look deep into. So, all we make sure here
- // is that the cloned object points to a different location in memory.
- if clone := orig.Clone(); clone == orig {
- t.Fatal("return value from Clone() doesn't point to new credentials instance")
- }
- }
- func newStringP(s string) *string {
- return &s
- }
|