123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- package client
- import (
- "crypto/tls"
- "crypto/x509"
- "fmt"
- util "github.com/seaweedfs/seaweedfs/weed/util"
- "github.com/spf13/viper"
- "io"
- "net/http"
- "net/url"
- "os"
- "strings"
- "sync"
- )
- var (
- loadSecurityConfigOnce sync.Once
- )
- type HTTPClient struct {
- Client *http.Client
- Transport *http.Transport
- expectHttpsScheme bool
- }
- func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) {
- req.URL.Scheme = httpClient.GetHttpScheme()
- return httpClient.Client.Do(req)
- }
- func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) {
- url, err = httpClient.NormalizeHttpScheme(url)
- if err != nil {
- return nil, err
- }
- return httpClient.Client.Get(url)
- }
- func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) {
- url, err = httpClient.NormalizeHttpScheme(url)
- if err != nil {
- return nil, err
- }
- return httpClient.Client.Post(url, contentType, body)
- }
- func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) {
- url, err = httpClient.NormalizeHttpScheme(url)
- if err != nil {
- return nil, err
- }
- return httpClient.Client.PostForm(url, data)
- }
- func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) {
- url, err = httpClient.NormalizeHttpScheme(url)
- if err != nil {
- return nil, err
- }
- return httpClient.Client.Head(url)
- }
- func (httpClient *HTTPClient) CloseIdleConnections() {
- httpClient.Client.CloseIdleConnections()
- }
- func (httpClient *HTTPClient) GetClientTransport() *http.Transport {
- return httpClient.Transport
- }
- func (httpClient *HTTPClient) GetHttpScheme() string {
- if httpClient.expectHttpsScheme {
- return "https"
- }
- return "http"
- }
- func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) {
- expectedScheme := httpClient.GetHttpScheme()
- if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) {
- return expectedScheme + "://" + rawURL, nil
- }
- parsedURL, err := url.Parse(rawURL)
- if err != nil {
- return "", err
- }
- if expectedScheme != parsedURL.Scheme {
- parsedURL.Scheme = expectedScheme
- }
- return parsedURL.String(), nil
- }
- func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) {
- httpClient := HTTPClient{}
- httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName)
- var tlsConfig *tls.Config = nil
- if httpClient.expectHttpsScheme {
- clientCertPair, err := getClientCertPair(clientName)
- if err != nil {
- return nil, err
- }
- clientCaCert, clientCaCertName, err := getClientCaCert(clientName)
- if err != nil {
- return nil, err
- }
- if clientCertPair != nil || len(clientCaCert) != 0 {
- caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName)
- if err != nil {
- return nil, err
- }
- tlsConfig = &tls.Config{
- Certificates: []tls.Certificate{},
- RootCAs: caCertPool,
- InsecureSkipVerify: false,
- }
- if clientCertPair != nil {
- tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair)
- }
- }
- }
- httpClient.Transport = &http.Transport{
- MaxIdleConns: 1024,
- MaxIdleConnsPerHost: 1024,
- TLSClientConfig: tlsConfig,
- }
- httpClient.Client = &http.Client{
- Transport: httpClient.Transport,
- }
- for _, opt := range opts {
- opt(&httpClient)
- }
- return &httpClient, nil
- }
- func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string {
- util.LoadSecurityConfiguration()
- return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName))
- }
- func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool {
- util.LoadSecurityConfiguration()
- return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName))
- }
- func checkIsHttpsClientEnabled(clientName ClientName) bool {
- return getBoolOptionFromSecurityConfiguration(clientName, "enabled")
- }
- func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) {
- if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" {
- fileContent, err := os.ReadFile(fileName)
- if err != nil {
- return nil, fileName, err
- }
- return fileContent, fileName, err
- }
- return nil, "", nil
- }
- func getClientCertPair(clientName ClientName) (*tls.Certificate, error) {
- certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert")
- keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key")
- if certFileName == "" && keyFileName == "" {
- return nil, nil
- }
- if certFileName != "" && keyFileName != "" {
- clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName)
- if err != nil {
- return nil, fmt.Errorf("error loading client certificate and key: %s", err)
- }
- return &clientCert, nil
- }
- return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName)
- }
- func getClientCaCert(clientName ClientName) ([]byte, string, error) {
- return getFileContentFromSecurityConfiguration(clientName, "ca")
- }
- func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) {
- certPool := x509.NewCertPool()
- if len(certContent) == 0 {
- return certPool, nil
- }
- ok := certPool.AppendCertsFromPEM(certContent)
- if !ok {
- return nil, fmt.Errorf("error processing certificate in %s", fileName)
- }
- return certPool, nil
- }
|