s3api_circuit_breaker.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package s3api
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/gorilla/mux"
  6. "github.com/seaweedfs/seaweedfs/weed/filer"
  7. "github.com/seaweedfs/seaweedfs/weed/glog"
  8. "github.com/seaweedfs/seaweedfs/weed/pb"
  9. "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
  10. "github.com/seaweedfs/seaweedfs/weed/pb/s3_pb"
  11. "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
  12. "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
  13. "net/http"
  14. "sync"
  15. "sync/atomic"
  16. )
  17. type CircuitBreaker struct {
  18. sync.RWMutex
  19. Enabled bool
  20. counters map[string]*int64
  21. limitations map[string]int64
  22. }
  23. func NewCircuitBreaker(option *S3ApiServerOption) *CircuitBreaker {
  24. cb := &CircuitBreaker{
  25. counters: make(map[string]*int64),
  26. limitations: make(map[string]int64),
  27. }
  28. err := pb.WithFilerClient(false, 0, option.Filer, option.GrpcDialOption, func(client filer_pb.SeaweedFilerClient) error {
  29. content, err := filer.ReadInsideFiler(client, s3_constants.CircuitBreakerConfigDir, s3_constants.CircuitBreakerConfigFile)
  30. if errors.Is(err, filer_pb.ErrNotFound) {
  31. glog.Infof("s3 circuit breaker not configured")
  32. return nil
  33. }
  34. if err != nil {
  35. return fmt.Errorf("read S3 circuit breaker config: %v", err)
  36. }
  37. return cb.LoadS3ApiConfigurationFromBytes(content)
  38. })
  39. if err != nil {
  40. glog.Infof("s3 circuit breaker not configured correctly: %v", err)
  41. }
  42. return cb
  43. }
  44. func (cb *CircuitBreaker) LoadS3ApiConfigurationFromBytes(content []byte) error {
  45. cbCfg := &s3_pb.S3CircuitBreakerConfig{}
  46. if err := filer.ParseS3ConfigurationFromBytes(content, cbCfg); err != nil {
  47. glog.Warningf("unmarshal error: %v", err)
  48. return fmt.Errorf("unmarshal error: %v", err)
  49. }
  50. if err := cb.loadCircuitBreakerConfig(cbCfg); err != nil {
  51. return err
  52. }
  53. return nil
  54. }
  55. func (cb *CircuitBreaker) loadCircuitBreakerConfig(cfg *s3_pb.S3CircuitBreakerConfig) error {
  56. //global
  57. globalEnabled := false
  58. globalOptions := cfg.Global
  59. limitations := make(map[string]int64)
  60. if globalOptions != nil && globalOptions.Enabled && len(globalOptions.Actions) > 0 {
  61. globalEnabled = globalOptions.Enabled
  62. for action, limit := range globalOptions.Actions {
  63. limitations[action] = limit
  64. }
  65. }
  66. cb.Enabled = globalEnabled
  67. //buckets
  68. for bucket, cbOptions := range cfg.Buckets {
  69. if cbOptions.Enabled {
  70. for action, limit := range cbOptions.Actions {
  71. limitations[s3_constants.Concat(bucket, action)] = limit
  72. }
  73. }
  74. }
  75. cb.limitations = limitations
  76. return nil
  77. }
  78. func (cb *CircuitBreaker) Limit(f func(w http.ResponseWriter, r *http.Request), action string) (http.HandlerFunc, Action) {
  79. return func(w http.ResponseWriter, r *http.Request) {
  80. if !cb.Enabled {
  81. f(w, r)
  82. return
  83. }
  84. vars := mux.Vars(r)
  85. bucket := vars["bucket"]
  86. rollback, errCode := cb.limit(r, bucket, action)
  87. defer func() {
  88. for _, rf := range rollback {
  89. rf()
  90. }
  91. }()
  92. if errCode == s3err.ErrNone {
  93. f(w, r)
  94. return
  95. }
  96. s3err.WriteErrorResponse(w, r, errCode)
  97. }, Action(action)
  98. }
  99. func (cb *CircuitBreaker) limit(r *http.Request, bucket string, action string) (rollback []func(), errCode s3err.ErrorCode) {
  100. //bucket simultaneous request count
  101. bucketCountRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(bucket, action, s3_constants.LimitTypeCount), 1, s3err.ErrTooManyRequest)
  102. if bucketCountRollBack != nil {
  103. rollback = append(rollback, bucketCountRollBack)
  104. }
  105. if errCode != s3err.ErrNone {
  106. return
  107. }
  108. //bucket simultaneous request content bytes
  109. bucketContentLengthRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(bucket, action, s3_constants.LimitTypeBytes), r.ContentLength, s3err.ErrRequestBytesExceed)
  110. if bucketContentLengthRollBack != nil {
  111. rollback = append(rollback, bucketContentLengthRollBack)
  112. }
  113. if errCode != s3err.ErrNone {
  114. return
  115. }
  116. //global simultaneous request count
  117. globalCountRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(action, s3_constants.LimitTypeCount), 1, s3err.ErrTooManyRequest)
  118. if globalCountRollBack != nil {
  119. rollback = append(rollback, globalCountRollBack)
  120. }
  121. if errCode != s3err.ErrNone {
  122. return
  123. }
  124. //global simultaneous request content bytes
  125. globalContentLengthRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(action, s3_constants.LimitTypeBytes), r.ContentLength, s3err.ErrRequestBytesExceed)
  126. if globalContentLengthRollBack != nil {
  127. rollback = append(rollback, globalContentLengthRollBack)
  128. }
  129. if errCode != s3err.ErrNone {
  130. return
  131. }
  132. return
  133. }
  134. func (cb *CircuitBreaker) loadCounterAndCompare(key string, inc int64, errCode s3err.ErrorCode) (f func(), e s3err.ErrorCode) {
  135. e = s3err.ErrNone
  136. if max, ok := cb.limitations[key]; ok {
  137. cb.RLock()
  138. counter, exists := cb.counters[key]
  139. cb.RUnlock()
  140. if !exists {
  141. cb.Lock()
  142. counter, exists = cb.counters[key]
  143. if !exists {
  144. var newCounter int64
  145. counter = &newCounter
  146. cb.counters[key] = counter
  147. }
  148. cb.Unlock()
  149. }
  150. current := atomic.LoadInt64(counter)
  151. if current+inc > max {
  152. e = errCode
  153. return
  154. } else {
  155. current := atomic.AddInt64(counter, inc)
  156. f = func() {
  157. atomic.AddInt64(counter, -inc)
  158. }
  159. if current > max {
  160. e = errCode
  161. return
  162. }
  163. }
  164. }
  165. return
  166. }