http_down.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. // Package httpdown provides http.ConnState enabled graceful termination of
  2. // http.Server.
  3. // based on github.com/facebookarchive/httpdown, who's licence is MIT-licence,
  4. // we add a feature of supporting for http TLS
  5. package httpdown
  6. import (
  7. "crypto/tls"
  8. "fmt"
  9. "net"
  10. "net/http"
  11. "os"
  12. "os/signal"
  13. "sync"
  14. "syscall"
  15. "time"
  16. "github.com/facebookgo/clock"
  17. "github.com/facebookgo/stats"
  18. )
  19. const (
  20. defaultStopTimeout = time.Minute
  21. defaultKillTimeout = time.Minute
  22. )
  23. // A Server allows encapsulates the process of accepting new connections and
  24. // serving them, and gracefully shutting down the listener without dropping
  25. // active connections.
  26. type Server interface {
  27. // Wait waits for the serving loop to finish. This will happen when Stop is
  28. // called, at which point it returns no error, or if there is an error in the
  29. // serving loop. You must call Wait after calling Serve or ListenAndServe.
  30. Wait() error
  31. // Stop stops the listener. It will block until all connections have been
  32. // closed.
  33. Stop() error
  34. }
  35. // HTTP defines the configuration for serving a http.Server. Multiple calls to
  36. // Serve or ListenAndServe can be made on the same HTTP instance. The default
  37. // timeouts of 1 minute each result in a maximum of 2 minutes before a Stop()
  38. // returns.
  39. type HTTP struct {
  40. // StopTimeout is the duration before we begin force closing connections.
  41. // Defaults to 1 minute.
  42. StopTimeout time.Duration
  43. // KillTimeout is the duration before which we completely give up and abort
  44. // even though we still have connected clients. This is useful when a large
  45. // number of client connections exist and closing them can take a long time.
  46. // Note, this is in addition to the StopTimeout. Defaults to 1 minute.
  47. KillTimeout time.Duration
  48. // Stats is optional. If provided, it will be used to record various metrics.
  49. Stats stats.Client
  50. // Clock allows for testing timing related functionality. Do not specify this
  51. // in production code.
  52. Clock clock.Clock
  53. // when set CertFile and KeyFile, the httpDown will start a http with TLS.
  54. // Files containing a certificate and matching private key for the
  55. // server must be provided if neither the Server's
  56. // TLSConfig.Certificates nor TLSConfig.GetCertificate are populated.
  57. // If the certificate is signed by a certificate authority, the
  58. // certFile should be the concatenation of the server's certificate,
  59. // any intermediates, and the CA's certificate.
  60. CertFile, KeyFile string
  61. }
  62. // Serve provides the low-level API which is useful if you're creating your own
  63. // net.Listener.
  64. func (h HTTP) Serve(s *http.Server, l net.Listener) Server {
  65. stopTimeout := h.StopTimeout
  66. if stopTimeout == 0 {
  67. stopTimeout = defaultStopTimeout
  68. }
  69. killTimeout := h.KillTimeout
  70. if killTimeout == 0 {
  71. killTimeout = defaultKillTimeout
  72. }
  73. klock := h.Clock
  74. if klock == nil {
  75. klock = clock.New()
  76. }
  77. ss := &server{
  78. stopTimeout: stopTimeout,
  79. killTimeout: killTimeout,
  80. stats: h.Stats,
  81. clock: klock,
  82. oldConnState: s.ConnState,
  83. listener: l,
  84. server: s,
  85. serveDone: make(chan struct{}),
  86. serveErr: make(chan error, 1),
  87. new: make(chan net.Conn),
  88. active: make(chan net.Conn),
  89. idle: make(chan net.Conn),
  90. closed: make(chan net.Conn),
  91. stop: make(chan chan struct{}),
  92. kill: make(chan chan struct{}),
  93. certFile: h.CertFile,
  94. keyFile: h.KeyFile,
  95. }
  96. s.ConnState = ss.connState
  97. go ss.manage()
  98. go ss.serve()
  99. return ss
  100. }
  101. // ListenAndServe returns a Server for the given http.Server. It is equivalent
  102. // to ListenAndServe from the standard library, but returns immediately.
  103. // Requests will be accepted in a background goroutine. If the http.Server has
  104. // a non-nil TLSConfig, a TLS enabled listener will be setup.
  105. func (h HTTP) ListenAndServe(s *http.Server) (Server, error) {
  106. addr := s.Addr
  107. if addr == "" {
  108. if s.TLSConfig == nil {
  109. addr = ":http"
  110. } else {
  111. addr = ":https"
  112. }
  113. }
  114. l, err := net.Listen("tcp", addr)
  115. if err != nil {
  116. stats.BumpSum(h.Stats, "listen.error", 1)
  117. return nil, err
  118. }
  119. if s.TLSConfig != nil {
  120. l = tls.NewListener(l, s.TLSConfig)
  121. }
  122. return h.Serve(s, l), nil
  123. }
  124. // server manages the serving process and allows for gracefully stopping it.
  125. type server struct {
  126. stopTimeout time.Duration
  127. killTimeout time.Duration
  128. stats stats.Client
  129. clock clock.Clock
  130. oldConnState func(net.Conn, http.ConnState)
  131. server *http.Server
  132. serveDone chan struct{}
  133. serveErr chan error
  134. listener net.Listener
  135. new chan net.Conn
  136. active chan net.Conn
  137. idle chan net.Conn
  138. closed chan net.Conn
  139. stop chan chan struct{}
  140. kill chan chan struct{}
  141. stopOnce sync.Once
  142. stopErr error
  143. certFile, keyFile string
  144. }
  145. func (s *server) connState(c net.Conn, cs http.ConnState) {
  146. if s.oldConnState != nil {
  147. s.oldConnState(c, cs)
  148. }
  149. switch cs {
  150. case http.StateNew:
  151. s.new <- c
  152. case http.StateActive:
  153. s.active <- c
  154. case http.StateIdle:
  155. s.idle <- c
  156. case http.StateHijacked, http.StateClosed:
  157. s.closed <- c
  158. }
  159. }
  160. func (s *server) manage() {
  161. defer func() {
  162. close(s.new)
  163. close(s.active)
  164. close(s.idle)
  165. close(s.closed)
  166. close(s.stop)
  167. close(s.kill)
  168. }()
  169. var stopDone chan struct{}
  170. conns := map[net.Conn]http.ConnState{}
  171. var countNew, countActive, countIdle float64
  172. // decConn decrements the count associated with the current state of the
  173. // given connection.
  174. decConn := func(c net.Conn) {
  175. switch conns[c] {
  176. default:
  177. panic(fmt.Errorf("unknown existing connection: %s", c))
  178. case http.StateNew:
  179. countNew--
  180. case http.StateActive:
  181. countActive--
  182. case http.StateIdle:
  183. countIdle--
  184. }
  185. }
  186. // setup a ticker to report various values every minute. if we don't have a
  187. // Stats implementation provided, we Stop it so it never ticks.
  188. statsTicker := s.clock.Ticker(time.Minute)
  189. if s.stats == nil {
  190. statsTicker.Stop()
  191. }
  192. for {
  193. select {
  194. case <-statsTicker.C:
  195. // we'll only get here when s.stats is not nil
  196. s.stats.BumpAvg("http-state.new", countNew)
  197. s.stats.BumpAvg("http-state.active", countActive)
  198. s.stats.BumpAvg("http-state.idle", countIdle)
  199. s.stats.BumpAvg("http-state.total", countNew+countActive+countIdle)
  200. case c := <-s.new:
  201. conns[c] = http.StateNew
  202. countNew++
  203. case c := <-s.active:
  204. decConn(c)
  205. countActive++
  206. conns[c] = http.StateActive
  207. case c := <-s.idle:
  208. decConn(c)
  209. countIdle++
  210. conns[c] = http.StateIdle
  211. // if we're already stopping, close it
  212. if stopDone != nil {
  213. c.Close()
  214. }
  215. case c := <-s.closed:
  216. stats.BumpSum(s.stats, "conn.closed", 1)
  217. decConn(c)
  218. delete(conns, c)
  219. // if we're waiting to stop and are all empty, we just closed the last
  220. // connection and we're done.
  221. if stopDone != nil && len(conns) == 0 {
  222. close(stopDone)
  223. return
  224. }
  225. case stopDone = <-s.stop:
  226. // if we're already all empty, we're already done
  227. if len(conns) == 0 {
  228. close(stopDone)
  229. return
  230. }
  231. // close current idle connections right away
  232. for c, cs := range conns {
  233. if cs == http.StateIdle {
  234. c.Close()
  235. }
  236. }
  237. // continue the loop and wait for all the ConnState updates which will
  238. // eventually close(stopDone) and return from this goroutine.
  239. case killDone := <-s.kill:
  240. // force close all connections
  241. stats.BumpSum(s.stats, "kill.conn.count", float64(len(conns)))
  242. for c := range conns {
  243. c.Close()
  244. }
  245. // don't block the kill.
  246. close(killDone)
  247. // continue the loop and we wait for all the ConnState updates and will
  248. // return from this goroutine when we're all done. otherwise we'll try to
  249. // send those ConnState updates on closed channels.
  250. }
  251. }
  252. }
  253. func (s *server) serve() {
  254. stats.BumpSum(s.stats, "serve", 1)
  255. if s.certFile == "" && s.keyFile == "" {
  256. s.serveErr <- s.server.Serve(s.listener)
  257. } else {
  258. s.serveErr <- s.server.ServeTLS(s.listener, s.certFile, s.keyFile)
  259. }
  260. close(s.serveDone)
  261. close(s.serveErr)
  262. }
  263. func (s *server) Wait() error {
  264. if err := <-s.serveErr; !isUseOfClosedError(err) {
  265. return err
  266. }
  267. return nil
  268. }
  269. func (s *server) Stop() error {
  270. s.stopOnce.Do(func() {
  271. defer stats.BumpTime(s.stats, "stop.time").End()
  272. stats.BumpSum(s.stats, "stop", 1)
  273. // first disable keep-alive for new connections
  274. s.server.SetKeepAlivesEnabled(false)
  275. // then close the listener so new connections can't connect come thru
  276. closeErr := s.listener.Close()
  277. <-s.serveDone
  278. // then trigger the background goroutine to stop and wait for it
  279. stopDone := make(chan struct{})
  280. s.stop <- stopDone
  281. // wait for stop
  282. select {
  283. case <-stopDone:
  284. case <-s.clock.After(s.stopTimeout):
  285. defer stats.BumpTime(s.stats, "kill.time").End()
  286. stats.BumpSum(s.stats, "kill", 1)
  287. // stop timed out, wait for kill
  288. killDone := make(chan struct{})
  289. s.kill <- killDone
  290. select {
  291. case <-killDone:
  292. case <-s.clock.After(s.killTimeout):
  293. // kill timed out, give up
  294. stats.BumpSum(s.stats, "kill.timeout", 1)
  295. }
  296. }
  297. if closeErr != nil && !isUseOfClosedError(closeErr) {
  298. stats.BumpSum(s.stats, "listener.close.error", 1)
  299. s.stopErr = closeErr
  300. }
  301. })
  302. return s.stopErr
  303. }
  304. func isUseOfClosedError(err error) bool {
  305. if err == nil {
  306. return false
  307. }
  308. if opErr, ok := err.(*net.OpError); ok {
  309. err = opErr.Err
  310. }
  311. return err.Error() == "use of closed network connection"
  312. }
  313. // ListenAndServe is a convenience function to serve and wait for a SIGTERM
  314. // or SIGINT before shutting down.
  315. func ListenAndServe(s *http.Server, hd *HTTP) error {
  316. if hd == nil {
  317. hd = &HTTP{}
  318. }
  319. hs, err := hd.ListenAndServe(s)
  320. if err != nil {
  321. return err
  322. }
  323. waiterr := make(chan error, 1)
  324. go func() {
  325. defer close(waiterr)
  326. waiterr <- hs.Wait()
  327. }()
  328. signals := make(chan os.Signal, 10)
  329. signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
  330. select {
  331. case err := <-waiterr:
  332. if err != nil {
  333. return err
  334. }
  335. case <-signals:
  336. signal.Stop(signals)
  337. if err := hs.Stop(); err != nil {
  338. return err
  339. }
  340. if err := <-waiterr; err != nil {
  341. return err
  342. }
  343. }
  344. return nil
  345. }