state.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package frankenphp
  2. import (
  3. "slices"
  4. "strconv"
  5. "sync"
  6. )
  7. type stateID uint8
  8. const (
  9. // lifecycle states of a thread
  10. stateBooting stateID = iota
  11. stateShuttingDown
  12. stateDone
  13. // these states are safe to transition from at any time
  14. stateInactive
  15. stateReady
  16. // states necessary for restarting workers
  17. stateRestarting
  18. stateYielding
  19. // states necessary for transitioning between different handlers
  20. stateTransitionRequested
  21. stateTransitionInProgress
  22. stateTransitionComplete
  23. )
  24. type threadState struct {
  25. currentState stateID
  26. mu sync.RWMutex
  27. subscribers []stateSubscriber
  28. }
  29. type stateSubscriber struct {
  30. states []stateID
  31. ch chan struct{}
  32. }
  33. func newThreadState() *threadState {
  34. return &threadState{
  35. currentState: stateBooting,
  36. subscribers: []stateSubscriber{},
  37. mu: sync.RWMutex{},
  38. }
  39. }
  40. func (ts *threadState) is(state stateID) bool {
  41. ts.mu.RLock()
  42. ok := ts.currentState == state
  43. ts.mu.RUnlock()
  44. return ok
  45. }
  46. func (ts *threadState) compareAndSwap(compareTo stateID, swapTo stateID) bool {
  47. ts.mu.Lock()
  48. ok := ts.currentState == compareTo
  49. if ok {
  50. ts.currentState = swapTo
  51. ts.notifySubscribers(swapTo)
  52. }
  53. ts.mu.Unlock()
  54. return ok
  55. }
  56. func (ts *threadState) name() string {
  57. // TODO: return the actual name for logging/metrics
  58. return "state:" + strconv.Itoa(int(ts.get()))
  59. }
  60. func (ts *threadState) get() stateID {
  61. ts.mu.RLock()
  62. id := ts.currentState
  63. ts.mu.RUnlock()
  64. return id
  65. }
  66. func (ts *threadState) set(nextState stateID) {
  67. ts.mu.Lock()
  68. ts.currentState = nextState
  69. ts.notifySubscribers(nextState)
  70. ts.mu.Unlock()
  71. }
  72. func (ts *threadState) notifySubscribers(nextState stateID) {
  73. if len(ts.subscribers) == 0 {
  74. return
  75. }
  76. newSubscribers := []stateSubscriber{}
  77. // notify subscribers to the state change
  78. for _, sub := range ts.subscribers {
  79. if !slices.Contains(sub.states, nextState) {
  80. newSubscribers = append(newSubscribers, sub)
  81. continue
  82. }
  83. close(sub.ch)
  84. }
  85. ts.subscribers = newSubscribers
  86. }
  87. // block until the thread reaches a certain state
  88. func (ts *threadState) waitFor(states ...stateID) {
  89. ts.mu.Lock()
  90. if slices.Contains(states, ts.currentState) {
  91. ts.mu.Unlock()
  92. return
  93. }
  94. sub := stateSubscriber{
  95. states: states,
  96. ch: make(chan struct{}),
  97. }
  98. ts.subscribers = append(ts.subscribers, sub)
  99. ts.mu.Unlock()
  100. <-sub.ch
  101. }
  102. // safely request a state change from a different goroutine
  103. func (ts *threadState) requestSafeStateChange(nextState stateID) bool {
  104. ts.mu.Lock()
  105. switch ts.currentState {
  106. // disallow state changes if shutting down
  107. case stateShuttingDown, stateDone:
  108. ts.mu.Unlock()
  109. return false
  110. // ready and inactive are safe states to transition from
  111. case stateReady, stateInactive:
  112. ts.currentState = nextState
  113. ts.notifySubscribers(nextState)
  114. ts.mu.Unlock()
  115. return true
  116. }
  117. ts.mu.Unlock()
  118. // wait for the state to change to a safe state
  119. ts.waitFor(stateReady, stateInactive, stateShuttingDown)
  120. return ts.requestSafeStateChange(nextState)
  121. }