123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- package security
- import (
- "crypto/tls"
- "crypto/x509"
- "fmt"
- "os"
- "slices"
- "strings"
- "time"
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/util"
- "google.golang.org/grpc"
- "google.golang.org/grpc/credentials/insecure"
- "google.golang.org/grpc/credentials/tls/certprovider/pemfile"
- "google.golang.org/grpc/security/advancedtls"
- )
- const CredRefreshingInterval = time.Duration(5) * time.Hour
- type Authenticator struct {
- AllowedWildcardDomain string
- AllowedCommonNames map[string]bool
- }
- func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption, grpc.ServerOption) {
- if config == nil {
- return nil, nil
- }
- serverOptions := pemfile.Options{
- CertFile: config.GetString(component + ".cert"),
- KeyFile: config.GetString(component + ".key"),
- RefreshDuration: CredRefreshingInterval,
- }
- if serverOptions.CertFile == "" || serverOptions.KeyFile == "" {
- return nil, nil
- }
- serverIdentityProvider, err := pemfile.NewProvider(serverOptions)
- if err != nil {
- glog.Warningf("pemfile.NewProvider(%v) %v failed: %v", serverOptions, component, err)
- return nil, nil
- }
- serverRootOptions := pemfile.Options{
- RootFile: config.GetString("grpc.ca"),
- RefreshDuration: CredRefreshingInterval,
- }
- serverRootProvider, err := pemfile.NewProvider(serverRootOptions)
- if err != nil {
- glog.Warningf("pemfile.NewProvider(%v) failed: %v", serverRootOptions, err)
- return nil, nil
- }
- // Start a server and create a client using advancedtls API with Provider.
- options := &advancedtls.Options{
- IdentityOptions: advancedtls.IdentityCertificateOptions{
- IdentityProvider: serverIdentityProvider,
- },
- RootOptions: advancedtls.RootCertificateOptions{
- RootProvider: serverRootProvider,
- },
- RequireClientCert: true,
- VerificationType: advancedtls.CertVerification,
- }
- options.MinTLSVersion, err = TlsVersionByName(config.GetString("tls.min_version"))
- if err != nil {
- glog.Warningf("tls min version parse failed, %v", err)
- return nil, nil
- }
- options.MaxTLSVersion, err = TlsVersionByName(config.GetString("tls.max_version"))
- if err != nil {
- glog.Warningf("tls max version parse failed, %v", err)
- return nil, nil
- }
- options.CipherSuites, err = TlsCipherSuiteByNames(config.GetString("tls.cipher_suites"))
- if err != nil {
- glog.Warningf("tls cipher suite parse failed, %v", err)
- return nil, nil
- }
- allowedCommonNames := config.GetString(component + ".allowed_commonNames")
- allowedWildcardDomain := config.GetString("grpc.allowed_wildcard_domain")
- if allowedCommonNames != "" || allowedWildcardDomain != "" {
- allowedCommonNamesMap := make(map[string]bool)
- for _, s := range strings.Split(allowedCommonNames, ",") {
- allowedCommonNamesMap[s] = true
- }
- auther := Authenticator{
- AllowedCommonNames: allowedCommonNamesMap,
- AllowedWildcardDomain: allowedWildcardDomain,
- }
- options.AdditionalPeerVerification = auther.Authenticate
- } else {
- options.AdditionalPeerVerification = func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
- return &advancedtls.PostHandshakeVerificationResults{}, nil
- }
- }
- ta, err := advancedtls.NewServerCreds(options)
- if err != nil {
- glog.Warningf("advancedtls.NewServerCreds(%v) failed: %v", options, err)
- return nil, nil
- }
- return grpc.Creds(ta), nil
- }
- func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption {
- if config == nil {
- return grpc.WithTransportCredentials(insecure.NewCredentials())
- }
- certFileName, keyFileName, caFileName := config.GetString(component+".cert"), config.GetString(component+".key"), config.GetString("grpc.ca")
- if certFileName == "" || keyFileName == "" || caFileName == "" {
- return grpc.WithTransportCredentials(insecure.NewCredentials())
- }
- clientOptions := pemfile.Options{
- CertFile: certFileName,
- KeyFile: keyFileName,
- RefreshDuration: CredRefreshingInterval,
- }
- clientProvider, err := pemfile.NewProvider(clientOptions)
- if err != nil {
- glog.Warningf("pemfile.NewProvider(%v) failed %v", clientOptions, err)
- return grpc.WithTransportCredentials(insecure.NewCredentials())
- }
- clientRootOptions := pemfile.Options{
- RootFile: config.GetString("grpc.ca"),
- RefreshDuration: CredRefreshingInterval,
- }
- clientRootProvider, err := pemfile.NewProvider(clientRootOptions)
- if err != nil {
- glog.Warningf("pemfile.NewProvider(%v) failed: %v", clientRootOptions, err)
- return grpc.WithTransportCredentials(insecure.NewCredentials())
- }
- options := &advancedtls.Options{
- IdentityOptions: advancedtls.IdentityCertificateOptions{
- IdentityProvider: clientProvider,
- },
- AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
- return &advancedtls.PostHandshakeVerificationResults{}, nil
- },
- RootOptions: advancedtls.RootCertificateOptions{
- RootProvider: clientRootProvider,
- },
- VerificationType: advancedtls.CertVerification,
- }
- ta, err := advancedtls.NewClientCreds(options)
- if err != nil {
- glog.Warningf("advancedtls.NewClientCreds(%v) failed: %v", options, err)
- return grpc.WithTransportCredentials(insecure.NewCredentials())
- }
- return grpc.WithTransportCredentials(ta)
- }
- func LoadClientTLSHTTP(clientCertFile string) *tls.Config {
- clientCerts, err := os.ReadFile(clientCertFile)
- if err != nil {
- glog.Fatal(err)
- }
- certPool := x509.NewCertPool()
- ok := certPool.AppendCertsFromPEM(clientCerts)
- if !ok {
- glog.Fatalf("Error processing client certificate in %s\n", clientCertFile)
- }
- return &tls.Config{
- ClientCAs: certPool,
- ClientAuth: tls.RequireAndVerifyClientCert,
- }
- }
- func (a Authenticator) Authenticate(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
- if a.AllowedWildcardDomain != "" && strings.HasSuffix(params.Leaf.Subject.CommonName, a.AllowedWildcardDomain) {
- return &advancedtls.PostHandshakeVerificationResults{}, nil
- }
- if _, ok := a.AllowedCommonNames[params.Leaf.Subject.CommonName]; ok {
- return &advancedtls.PostHandshakeVerificationResults{}, nil
- }
- err := fmt.Errorf("Authenticate: invalid subject client common name: %s", params.Leaf.Subject.CommonName)
- glog.Error(err)
- return nil, err
- }
- func FixTlsConfig(viper *util.ViperProxy, config *tls.Config) error {
- var err error
- config.MinVersion, err = TlsVersionByName(viper.GetString("tls.min_version"))
- if err != nil {
- return err
- }
- config.MaxVersion, err = TlsVersionByName(viper.GetString("tls.max_version"))
- if err != nil {
- return err
- }
- config.CipherSuites, err = TlsCipherSuiteByNames(viper.GetString("tls.cipher_suites"))
- return err
- }
- func TlsVersionByName(name string) (uint16, error) {
- switch name {
- case "":
- return 0, nil
- case "SSLv3":
- return tls.VersionSSL30, nil
- case "TLS 1.0":
- return tls.VersionTLS10, nil
- case "TLS 1.1":
- return tls.VersionTLS11, nil
- case "TLS 1.2":
- return tls.VersionTLS12, nil
- case "TLS 1.3":
- return tls.VersionTLS13, nil
- default:
- return 0, fmt.Errorf("invalid tls version %s", name)
- }
- }
- func TlsCipherSuiteByNames(cipherSuiteNames string) ([]uint16, error) {
- cipherSuiteNames = strings.TrimSpace(cipherSuiteNames)
- if cipherSuiteNames == "" {
- return nil, nil
- }
- names := strings.Split(cipherSuiteNames, ",")
- cipherSuites := tls.CipherSuites()
- cipherIds := make([]uint16, 0, len(names))
- for _, name := range names {
- name = strings.TrimSpace(name)
- index := slices.IndexFunc(cipherSuites, func(suite *tls.CipherSuite) bool {
- return name == suite.Name
- })
- if index == -1 {
- return nil, fmt.Errorf("invalid tls cipher suite name %s", name)
- }
- cipherIds = append(cipherIds, cipherSuites[index].ID)
- }
- return cipherIds, nil
- }
|