Просмотр исходного кода

refactor: decouple worker threads from non-worker threads (#1137)

* Decouple workers.

* Moves code to separate file.

* Cleans up the exponential backoff.

* Initial working implementation.

* Refactors php threads to take callbacks.

* Cleanup.

* Cleanup.

* Cleanup.

* Cleanup.

* Adjusts watcher logic.

* Adjusts the watcher logic.

* Fix opcache_reset race condition.

* Fixing merge conflicts and formatting.

* Prevents overlapping of TSRM reservation and script execution.

* Adjustments as suggested by @dunglas.

* Adds error assertions.

* Adds comments.

* Removes logs and explicitly compares to C.false.

* Resets check.

* Adds cast for safety.

* Fixes waitgroup overflow.

* Resolves waitgroup race condition on startup.

* Moves worker request logic to worker.go.

* Removes defer.

* Removes call from go to c.

* Fixes merge conflict.

* Adds fibers test back in.

* Refactors new thread loop approach.

* Removes redundant check.

* Adds compareAndSwap.

* Refactor: removes global waitgroups and uses a 'thread state' abstraction instead.

* Removes unnecessary method.

* Updates comment.

* Removes unnecessary booleans.

* test

* First state machine steps.

* Splits threads.

* Minimal working implementation with broken tests.

* Fixes tests.

* Refactoring.

* Fixes merge conflicts.

* Formatting

* C formatting.

* More cleanup.

* Allows for clean state transitions.

* Adds state tests.

* Adds support for thread transitioning.

* Fixes the testdata path.

* Formatting.

* Allows transitioning back to inactive state.

* Fixes go linting.

* Formatting.

* Removes duplication.

* Applies suggestions by @dunglas

* Removes redundant check.

* Locks the handler on restart.

* Removes unnecessary log.

* Changes Unpin() logic as suggested by @withinboredom

* Adds suggestions by @dunglas and resolves TODO.

* Makes restarts fully safe.

* Will make the initial startup fail even if the watcher is enabled (as is currently the case)

* Also adds compareAndSwap to the test.

* Adds comment.

* Prevents panic on initial watcher startup.
Alliballibaba2 2 месяцев назад
Родитель
Сommit
f592e0f47b
10 измененных файлов с 430 добавлено и 223 удалено
  1. 0 2
      cgi.go
  2. 25 41
      frankenphp.c
  3. 24 104
      frankenphp.go
  4. 2 1
      frankenphp.h
  5. 13 18
      frankenphp_arginfo.h
  6. 17 0
      frankenphp_test.go
  7. 108 0
      phpmainthread.go
  8. 161 0
      phpmainthread_test.go
  9. 80 17
      phpthread.go
  10. 0 40
      phpthread_test.go

+ 0 - 2
cgi.go

@@ -227,8 +227,6 @@ func go_frankenphp_release_known_variable_keys(threadIndex C.uintptr_t) {
 	for _, v := range thread.knownVariableKeys {
 		C.frankenphp_release_zend_string(v)
 	}
-	// release everything that might still be pinned to the thread
-	thread.Unpin()
 	thread.knownVariableKeys = nil
 }
 

+ 25 - 41
frankenphp.c

@@ -89,7 +89,7 @@ static void frankenphp_free_request_context() {
   free(ctx->cookie_data);
   ctx->cookie_data = NULL;
 
-  /* Is freed via thread.Unpin() at the end of each request */
+  /* Is freed via thread.Unpin() */
   SG(request_info).auth_password = NULL;
   SG(request_info).auth_user = NULL;
   SG(request_info).request_method = NULL;
@@ -243,7 +243,7 @@ PHP_FUNCTION(frankenphp_finish_request) { /* {{{ */
   php_header();
 
   if (ctx->has_active_request) {
-    go_frankenphp_finish_request(thread_index, false);
+    go_frankenphp_finish_php_request(thread_index);
   }
 
   ctx->finished = true;
@@ -443,7 +443,7 @@ PHP_FUNCTION(frankenphp_handle_request) {
 
   frankenphp_worker_request_shutdown();
   ctx->has_active_request = false;
-  go_frankenphp_finish_request(thread_index, true);
+  go_frankenphp_finish_worker_request(thread_index);
 
   RETURN_TRUE;
 }
@@ -811,9 +811,9 @@ static void set_thread_name(char *thread_name) {
 }
 
 static void *php_thread(void *arg) {
-  char thread_name[16] = {0};
-  snprintf(thread_name, 16, "php-%" PRIxPTR, (uintptr_t)arg);
   thread_index = (uintptr_t)arg;
+  char thread_name[16] = {0};
+  snprintf(thread_name, 16, "php-%" PRIxPTR, thread_index);
   set_thread_name(thread_name);
 
 #ifdef ZTS
@@ -832,7 +832,11 @@ static void *php_thread(void *arg) {
   cfg_get_string("filter.default", &default_filter);
   should_filter_var = default_filter != NULL;
 
-  while (go_handle_request(thread_index)) {
+  // loop until Go signals to stop
+  char *scriptName = NULL;
+  while ((scriptName = go_frankenphp_before_script_execution(thread_index))) {
+    go_frankenphp_after_script_execution(thread_index,
+                                         frankenphp_execute_script(scriptName));
   }
 
   go_frankenphp_release_known_variable_keys(thread_index);
@@ -841,6 +845,8 @@ static void *php_thread(void *arg) {
   ts_free_thread();
 #endif
 
+  go_frankenphp_on_thread_shutdown(thread_index);
+
   return NULL;
 }
 
@@ -858,13 +864,11 @@ static void *php_main(void *arg) {
     exit(EXIT_FAILURE);
   }
 
-  intptr_t num_threads = (intptr_t)arg;
-
   set_thread_name("php-main");
 
 #ifdef ZTS
 #if (PHP_VERSION_ID >= 80300)
-  php_tsrm_startup_ex(num_threads);
+  php_tsrm_startup_ex((intptr_t)arg);
 #else
   php_tsrm_startup();
 #endif
@@ -892,28 +896,7 @@ static void *php_main(void *arg) {
 
   frankenphp_sapi_module.startup(&frankenphp_sapi_module);
 
-  pthread_t *threads = malloc(num_threads * sizeof(pthread_t));
-  if (threads == NULL) {
-    perror("malloc failed");
-    exit(EXIT_FAILURE);
-  }
-
-  for (uintptr_t i = 0; i < num_threads; i++) {
-    if (pthread_create(&(*(threads + i)), NULL, &php_thread, (void *)i) != 0) {
-      perror("failed to create PHP thread");
-      free(threads);
-      exit(EXIT_FAILURE);
-    }
-  }
-
-  for (int i = 0; i < num_threads; i++) {
-    if (pthread_join((*(threads + i)), NULL) != 0) {
-      perror("failed to join PHP thread");
-      free(threads);
-      exit(EXIT_FAILURE);
-    }
-  }
-  free(threads);
+  go_frankenphp_main_thread_is_ready();
 
   /* channel closed, shutdown gracefully */
   frankenphp_sapi_module.shutdown(&frankenphp_sapi_module);
@@ -929,25 +912,30 @@ static void *php_main(void *arg) {
     frankenphp_sapi_module.ini_entries = NULL;
   }
 #endif
-
-  go_shutdown();
-
+  go_frankenphp_shutdown_main_thread();
   return NULL;
 }
 
-int frankenphp_init(int num_threads) {
+int frankenphp_new_main_thread(int num_threads) {
   pthread_t thread;
 
   if (pthread_create(&thread, NULL, &php_main, (void *)(intptr_t)num_threads) !=
       0) {
-    go_shutdown();
-
     return -1;
   }
 
   return pthread_detach(thread);
 }
 
+bool frankenphp_new_php_thread(uintptr_t thread_index) {
+  pthread_t thread;
+  if (pthread_create(&thread, NULL, &php_thread, (void *)thread_index) != 0) {
+    return false;
+  }
+  pthread_detach(thread);
+  return true;
+}
+
 int frankenphp_request_startup() {
   if (php_request_startup() == SUCCESS) {
     return SUCCESS;
@@ -960,8 +948,6 @@ int frankenphp_request_startup() {
 
 int frankenphp_execute_script(char *file_name) {
   if (frankenphp_request_startup() == FAILURE) {
-    free(file_name);
-    file_name = NULL;
 
     return FAILURE;
   }
@@ -970,8 +956,6 @@ int frankenphp_execute_script(char *file_name) {
 
   zend_file_handle file_handle;
   zend_stream_init_filename(&file_handle, file_name);
-  free(file_name);
-  file_name = NULL;
 
   file_handle.primary_script = 1;
 

+ 24 - 104
frankenphp.go

@@ -64,8 +64,6 @@ var (
 	ScriptExecutionError        = errors.New("error during PHP script execution")
 
 	requestChan chan *http.Request
-	done        chan struct{}
-	shutdownWG  sync.WaitGroup
 
 	loggerMu sync.RWMutex
 	logger   *zap.Logger
@@ -123,7 +121,7 @@ type FrankenPHPContext struct {
 	closed sync.Once
 
 	responseWriter http.ResponseWriter
-	exitStatus     C.int
+	exitStatus     int
 
 	done      chan interface{}
 	startedAt time.Time
@@ -244,7 +242,7 @@ func Config() PHPConfig {
 // MaxThreads is internally used during tests. It is written to, but never read and may go away in the future.
 var MaxThreads int
 
-func calculateMaxThreads(opt *opt) error {
+func calculateMaxThreads(opt *opt) (int, int, error) {
 	maxProcs := runtime.GOMAXPROCS(0) * 2
 
 	var numWorkers int
@@ -266,13 +264,13 @@ func calculateMaxThreads(opt *opt) error {
 			opt.numThreads = maxProcs
 		}
 	} else if opt.numThreads <= numWorkers {
-		return NotEnoughThreads
+		return opt.numThreads, numWorkers, NotEnoughThreads
 	}
 
 	metrics.TotalThreads(opt.numThreads)
 	MaxThreads = opt.numThreads
 
-	return nil
+	return opt.numThreads, numWorkers, nil
 }
 
 // Init starts the PHP runtime and the configured workers.
@@ -311,7 +309,7 @@ func Init(options ...Option) error {
 		metrics = opt.metrics
 	}
 
-	err := calculateMaxThreads(opt)
+	totalThreadCount, workerThreadCount, err := calculateMaxThreads(opt)
 	if err != nil {
 		return err
 	}
@@ -327,29 +325,26 @@ func Init(options ...Option) error {
 			logger.Warn(`Zend Max Execution Timers are not enabled, timeouts (e.g. "max_execution_time") are disabled, recompile PHP with the "--enable-zend-max-execution-timers" configuration option to fix this issue`)
 		}
 	} else {
-		opt.numThreads = 1
+		totalThreadCount = 1
 		logger.Warn(`ZTS is not enabled, only 1 thread will be available, recompile PHP using the "--enable-zts" configuration option or performance will be degraded`)
 	}
 
-	shutdownWG.Add(1)
-	done = make(chan struct{})
 	requestChan = make(chan *http.Request, opt.numThreads)
-	initPHPThreads(opt.numThreads)
-
-	if C.frankenphp_init(C.int(opt.numThreads)) != 0 {
-		return MainThreadCreationError
+	if err := initPHPThreads(totalThreadCount); err != nil {
+		return err
 	}
 
-	if err := initWorkers(opt.workers); err != nil {
-		return err
+	for i := 0; i < totalThreadCount-workerThreadCount; i++ {
+		thread := getInactivePHPThread()
+		convertToRegularThread(thread)
 	}
 
-	if err := restartWorkersOnFileChanges(opt.workers); err != nil {
+	if err := initWorkers(opt.workers); err != nil {
 		return err
 	}
 
 	if c := logger.Check(zapcore.InfoLevel, "FrankenPHP started 🐘"); c != nil {
-		c.Write(zap.String("php_version", Version().Version), zap.Int("num_threads", opt.numThreads))
+		c.Write(zap.String("php_version", Version().Version), zap.Int("num_threads", totalThreadCount))
 	}
 	if EmbeddedAppPath != "" {
 		if c := logger.Check(zapcore.InfoLevel, "embedded PHP app 📦"); c != nil {
@@ -363,7 +358,7 @@ func Init(options ...Option) error {
 // Shutdown stops the workers and the PHP runtime.
 func Shutdown() {
 	drainWorkers()
-	drainThreads()
+	drainPHPThreads()
 	metrics.Shutdown()
 	requestChan = nil
 
@@ -375,17 +370,6 @@ func Shutdown() {
 	logger.Debug("FrankenPHP shut down")
 }
 
-//export go_shutdown
-func go_shutdown() {
-	shutdownWG.Done()
-}
-
-func drainThreads() {
-	close(done)
-	shutdownWG.Wait()
-	phpThreads = nil
-}
-
 func getLogger() *zap.Logger {
 	loggerMu.RLock()
 	defer loggerMu.RUnlock()
@@ -466,9 +450,6 @@ func ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) error
 		return nil
 	}
 
-	shutdownWG.Add(1)
-	defer shutdownWG.Done()
-
 	fc, ok := FromContext(request.Context())
 	if !ok {
 		return InvalidRequestError
@@ -477,76 +458,25 @@ func ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) error
 	fc.responseWriter = responseWriter
 	fc.startedAt = time.Now()
 
-	isWorker := fc.responseWriter == nil
-
 	// Detect if a worker is available to handle this request
-	if !isWorker {
-		if worker, ok := workers[fc.scriptFilename]; ok {
-			metrics.StartWorkerRequest(fc.scriptFilename)
-			worker.handleRequest(request)
-			<-fc.done
-			metrics.StopWorkerRequest(fc.scriptFilename, time.Since(fc.startedAt))
-			return nil
-		} else {
-			metrics.StartRequest()
-		}
+	if worker, ok := workers[fc.scriptFilename]; ok {
+		worker.handleRequest(request, fc)
+		return nil
 	}
 
+	metrics.StartRequest()
+
 	select {
-	case <-done:
+	case <-mainThread.done:
 	case requestChan <- request:
 		<-fc.done
 	}
 
-	if !isWorker {
-		metrics.StopRequest()
-	}
+	metrics.StopRequest()
 
 	return nil
 }
 
-//export go_handle_request
-func go_handle_request(threadIndex C.uintptr_t) bool {
-	select {
-	case <-done:
-		return false
-
-	case r := <-requestChan:
-		thread := phpThreads[threadIndex]
-		thread.mainRequest = r
-
-		fc, ok := FromContext(r.Context())
-		if !ok {
-			panic(InvalidRequestError)
-		}
-		defer func() {
-			maybeCloseContext(fc)
-			thread.mainRequest = nil
-			thread.Unpin()
-		}()
-
-		if err := updateServerContext(thread, r, true, false); err != nil {
-			rejectRequest(fc.responseWriter, err.Error())
-			return true
-		}
-
-		// scriptFilename is freed in frankenphp_execute_script()
-		fc.exitStatus = C.frankenphp_execute_script(C.CString(fc.scriptFilename))
-		if fc.exitStatus < 0 {
-			panic(ScriptExecutionError)
-		}
-
-		// if the script has errored or timed out, make sure any pending worker requests are closed
-		if fc.exitStatus > 0 && thread.workerRequest != nil {
-			fc := thread.workerRequest.Context().Value(contextKey).(*FrankenPHPContext)
-			maybeCloseContext(fc)
-			thread.workerRequest = nil
-		}
-
-		return true
-	}
-}
-
 func maybeCloseContext(fc *FrankenPHPContext) {
 	fc.closed.Do(func() {
 		close(fc.done)
@@ -598,7 +528,7 @@ func go_apache_request_headers(threadIndex C.uintptr_t, hasActiveRequest bool) (
 
 	if !hasActiveRequest {
 		// worker mode, not handling a request
-		mfc := thread.mainRequest.Context().Value(contextKey).(*FrankenPHPContext)
+		mfc := thread.getActiveRequest().Context().Value(contextKey).(*FrankenPHPContext)
 
 		if c := mfc.logger.Check(zapcore.DebugLevel, "apache_request_headers() called in non-HTTP context"); c != nil {
 			c.Write(zap.String("worker", mfc.scriptFilename))
@@ -784,21 +714,11 @@ func freeArgs(argv []*C.char) {
 	}
 }
 
-func executePHPFunction(functionName string) {
+func executePHPFunction(functionName string) bool {
 	cFunctionName := C.CString(functionName)
 	defer C.free(unsafe.Pointer(cFunctionName))
 
-	success := C.frankenphp_execute_php_function(cFunctionName)
-
-	if success == 1 {
-		if c := logger.Check(zapcore.DebugLevel, "php function call successful"); c != nil {
-			c.Write(zap.String("function", functionName))
-		}
-	} else {
-		if c := logger.Check(zapcore.ErrorLevel, "php function call failed"); c != nil {
-			c.Write(zap.String("function", functionName))
-		}
-	}
+	return C.frankenphp_execute_php_function(cFunctionName) == 1
 }
 
 // Ensure that the request path does not contain null bytes

+ 2 - 1
frankenphp.h

@@ -40,7 +40,8 @@ typedef struct frankenphp_config {
 } frankenphp_config;
 frankenphp_config frankenphp_get_config();
 
-int frankenphp_init(int num_threads);
+int frankenphp_new_main_thread(int num_threads);
+bool frankenphp_new_php_thread(uintptr_t thread_index);
 
 int frankenphp_update_server_context(
     bool create, bool has_main_request, bool has_active_request,

+ 13 - 18
frankenphp_arginfo.h

@@ -36,22 +36,17 @@ ZEND_FUNCTION(frankenphp_finish_request);
 ZEND_FUNCTION(frankenphp_request_headers);
 ZEND_FUNCTION(frankenphp_response_headers);
 
+// clang-format off
 static const zend_function_entry ext_functions[] = {
-    ZEND_FE(frankenphp_handle_request, arginfo_frankenphp_handle_request)
-        ZEND_FE(headers_send, arginfo_headers_send) ZEND_FE(
-            frankenphp_finish_request, arginfo_frankenphp_finish_request)
-            ZEND_FALIAS(fastcgi_finish_request, frankenphp_finish_request,
-                        arginfo_fastcgi_finish_request)
-                ZEND_FE(frankenphp_request_headers,
-                        arginfo_frankenphp_request_headers)
-                    ZEND_FALIAS(apache_request_headers,
-                                frankenphp_request_headers,
-                                arginfo_apache_request_headers)
-                        ZEND_FALIAS(getallheaders, frankenphp_request_headers,
-                                    arginfo_getallheaders)
-                            ZEND_FE(frankenphp_response_headers,
-                                    arginfo_frankenphp_response_headers)
-                                ZEND_FALIAS(apache_response_headers,
-                                            frankenphp_response_headers,
-                                            arginfo_apache_response_headers)
-                                    ZEND_FE_END};
+  ZEND_FE(frankenphp_handle_request, arginfo_frankenphp_handle_request)
+  ZEND_FE(headers_send, arginfo_headers_send)
+  ZEND_FE(frankenphp_finish_request, arginfo_frankenphp_finish_request)
+  ZEND_FALIAS(fastcgi_finish_request, frankenphp_finish_request, arginfo_fastcgi_finish_request)
+  ZEND_FE(frankenphp_request_headers, arginfo_frankenphp_request_headers)
+  ZEND_FALIAS(apache_request_headers, frankenphp_request_headers, arginfo_apache_request_headers)
+  ZEND_FALIAS(getallheaders, frankenphp_request_headers, arginfo_getallheaders)
+  ZEND_FE(frankenphp_response_headers, arginfo_frankenphp_response_headers)
+  ZEND_FALIAS(apache_response_headers, frankenphp_response_headers, arginfo_apache_response_headers)
+  ZEND_FE_END
+};
+// clang-format on

+ 17 - 0
frankenphp_test.go

@@ -592,6 +592,23 @@ func testFiberNoCgo(t *testing.T, opts *testOptions) {
 	}, opts)
 }
 
+func TestFiberBasic_module(t *testing.T) { testFiberBasic(t, &testOptions{}) }
+func TestFiberBasic_worker(t *testing.T) {
+	testFiberBasic(t, &testOptions{workerScript: "fiber-basic.php"})
+}
+func testFiberBasic(t *testing.T, opts *testOptions) {
+	runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) {
+		req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/fiber-basic.php?i=%d", i), nil)
+		w := httptest.NewRecorder()
+		handler(w, req)
+
+		resp := w.Result()
+		body, _ := io.ReadAll(resp.Body)
+
+		assert.Equal(t, string(body), fmt.Sprintf("Fiber %d", i))
+	}, opts)
+}
+
 func TestRequestHeaders_module(t *testing.T) { testRequestHeaders(t, &testOptions{}) }
 func TestRequestHeaders_worker(t *testing.T) {
 	testRequestHeaders(t, &testOptions{workerScript: "request-headers.php"})

+ 108 - 0
phpmainthread.go

@@ -0,0 +1,108 @@
+package frankenphp
+
+// #include "frankenphp.h"
+import "C"
+import (
+	"sync"
+
+	"go.uber.org/zap"
+)
+
+// represents the main PHP thread
+// the thread needs to keep running as long as all other threads are running
+type phpMainThread struct {
+	state      *threadState
+	done       chan struct{}
+	numThreads int
+}
+
+var (
+	phpThreads []*phpThread
+	mainThread *phpMainThread
+)
+
+// reserve a fixed number of PHP threads on the Go side
+func initPHPThreads(numThreads int) error {
+	mainThread = &phpMainThread{
+		state:      newThreadState(),
+		done:       make(chan struct{}),
+		numThreads: numThreads,
+	}
+	phpThreads = make([]*phpThread, numThreads)
+
+	if err := mainThread.start(); err != nil {
+		return err
+	}
+
+	// initialize all threads as inactive
+	for i := 0; i < numThreads; i++ {
+		phpThreads[i] = newPHPThread(i)
+		convertToInactiveThread(phpThreads[i])
+	}
+
+	// start the underlying C threads
+	ready := sync.WaitGroup{}
+	ready.Add(numThreads)
+	for _, thread := range phpThreads {
+		go func() {
+			if !C.frankenphp_new_php_thread(C.uintptr_t(thread.threadIndex)) {
+				logger.Panic("unable to create thread", zap.Int("threadIndex", thread.threadIndex))
+			}
+			thread.state.waitFor(stateInactive)
+			ready.Done()
+		}()
+	}
+	ready.Wait()
+
+	return nil
+}
+
+func drainPHPThreads() {
+	doneWG := sync.WaitGroup{}
+	doneWG.Add(len(phpThreads))
+	for _, thread := range phpThreads {
+		thread.handlerMu.Lock()
+		_ = thread.state.requestSafeStateChange(stateShuttingDown)
+		close(thread.drainChan)
+	}
+	close(mainThread.done)
+	for _, thread := range phpThreads {
+		go func(thread *phpThread) {
+			thread.state.waitFor(stateDone)
+			thread.handlerMu.Unlock()
+			doneWG.Done()
+		}(thread)
+	}
+	doneWG.Wait()
+	mainThread.state.set(stateShuttingDown)
+	mainThread.state.waitFor(stateDone)
+	phpThreads = nil
+}
+
+func (mainThread *phpMainThread) start() error {
+	if C.frankenphp_new_main_thread(C.int(mainThread.numThreads)) != 0 {
+		return MainThreadCreationError
+	}
+	mainThread.state.waitFor(stateReady)
+	return nil
+}
+
+func getInactivePHPThread() *phpThread {
+	for _, thread := range phpThreads {
+		if thread.state.is(stateInactive) {
+			return thread
+		}
+	}
+	panic("not enough threads reserved")
+}
+
+//export go_frankenphp_main_thread_is_ready
+func go_frankenphp_main_thread_is_ready() {
+	mainThread.state.set(stateReady)
+	mainThread.state.waitFor(stateShuttingDown)
+}
+
+//export go_frankenphp_shutdown_main_thread
+func go_frankenphp_shutdown_main_thread() {
+	mainThread.state.set(stateDone)
+}

+ 161 - 0
phpmainthread_test.go

@@ -0,0 +1,161 @@
+package frankenphp
+
+import (
+	"io"
+	"math/rand/v2"
+	"net/http/httptest"
+	"path/filepath"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"go.uber.org/zap"
+)
+
+var testDataPath, _ = filepath.Abs("./testdata")
+
+func TestStartAndStopTheMainThreadWithOneInactiveThread(t *testing.T) {
+	logger = zap.NewNop()                // the logger needs to not be nil
+	assert.NoError(t, initPHPThreads(1)) // reserve 1 thread
+
+	assert.Len(t, phpThreads, 1)
+	assert.Equal(t, 0, phpThreads[0].threadIndex)
+	assert.True(t, phpThreads[0].state.is(stateInactive))
+
+	drainPHPThreads()
+	assert.Nil(t, phpThreads)
+}
+
+func TestTransitionRegularThreadToWorkerThread(t *testing.T) {
+	logger = zap.NewNop()
+	assert.NoError(t, initPHPThreads(1))
+
+	// transition to regular thread
+	convertToRegularThread(phpThreads[0])
+	assert.IsType(t, &regularThread{}, phpThreads[0].handler)
+
+	// transition to worker thread
+	worker := getDummyWorker("transition-worker-1.php")
+	convertToWorkerThread(phpThreads[0], worker)
+	assert.IsType(t, &workerThread{}, phpThreads[0].handler)
+	assert.Len(t, worker.threads, 1)
+
+	// transition back to inactive thread
+	convertToInactiveThread(phpThreads[0])
+	assert.IsType(t, &inactiveThread{}, phpThreads[0].handler)
+	assert.Len(t, worker.threads, 0)
+
+	drainPHPThreads()
+	assert.Nil(t, phpThreads)
+}
+
+func TestTransitionAThreadBetween2DifferentWorkers(t *testing.T) {
+	logger = zap.NewNop()
+	assert.NoError(t, initPHPThreads(1))
+	firstWorker := getDummyWorker("transition-worker-1.php")
+	secondWorker := getDummyWorker("transition-worker-2.php")
+
+	// convert to first worker thread
+	convertToWorkerThread(phpThreads[0], firstWorker)
+	firstHandler := phpThreads[0].handler.(*workerThread)
+	assert.Same(t, firstWorker, firstHandler.worker)
+	assert.Len(t, firstWorker.threads, 1)
+	assert.Len(t, secondWorker.threads, 0)
+
+	// convert to second worker thread
+	convertToWorkerThread(phpThreads[0], secondWorker)
+	secondHandler := phpThreads[0].handler.(*workerThread)
+	assert.Same(t, secondWorker, secondHandler.worker)
+	assert.Len(t, firstWorker.threads, 0)
+	assert.Len(t, secondWorker.threads, 1)
+
+	drainPHPThreads()
+	assert.Nil(t, phpThreads)
+}
+
+func TestTransitionThreadsWhileDoingRequests(t *testing.T) {
+	numThreads := 10
+	numRequestsPerThread := 100
+	isRunning := atomic.Bool{}
+	isRunning.Store(true)
+	wg := sync.WaitGroup{}
+	worker1Path := testDataPath + "/transition-worker-1.php"
+	worker2Path := testDataPath + "/transition-worker-2.php"
+
+	assert.NoError(t, Init(
+		WithNumThreads(numThreads),
+		WithWorkers(worker1Path, 1, map[string]string{"ENV1": "foo"}, []string{}),
+		WithWorkers(worker2Path, 1, map[string]string{"ENV1": "foo"}, []string{}),
+		WithLogger(zap.NewNop()),
+	))
+
+	// randomly transition threads between regular, inactive and 2 worker threads
+	go func() {
+		for {
+			for i := 0; i < numThreads; i++ {
+				switch rand.IntN(4) {
+				case 0:
+					convertToRegularThread(phpThreads[i])
+				case 1:
+					convertToWorkerThread(phpThreads[i], workers[worker1Path])
+				case 2:
+					convertToWorkerThread(phpThreads[i], workers[worker2Path])
+				case 3:
+					convertToInactiveThread(phpThreads[i])
+				}
+				time.Sleep(time.Millisecond)
+				if !isRunning.Load() {
+					return
+				}
+			}
+		}
+	}()
+
+	// randomly do requests to the 3 endpoints
+	wg.Add(numThreads)
+	for i := 0; i < numThreads; i++ {
+		go func(i int) {
+			for j := 0; j < numRequestsPerThread; j++ {
+				switch rand.IntN(3) {
+				case 0:
+					assertRequestBody(t, "http://localhost/transition-worker-1.php", "Hello from worker 1")
+				case 1:
+					assertRequestBody(t, "http://localhost/transition-worker-2.php", "Hello from worker 2")
+				case 2:
+					assertRequestBody(t, "http://localhost/transition-regular.php", "Hello from regular thread")
+				}
+			}
+			wg.Done()
+		}(i)
+	}
+
+	wg.Wait()
+	isRunning.Store(false)
+	Shutdown()
+}
+
+func getDummyWorker(fileName string) *worker {
+	if workers == nil {
+		workers = make(map[string]*worker)
+	}
+	worker, _ := newWorker(workerOpt{
+		fileName: testDataPath + "/" + fileName,
+		num:      1,
+	})
+	return worker
+}
+
+func assertRequestBody(t *testing.T, url string, expected string) {
+	r := httptest.NewRequest("GET", url, nil)
+	w := httptest.NewRecorder()
+
+	req, err := NewRequestWithContext(r, WithRequestDocumentRoot(testDataPath, false))
+	assert.NoError(t, err)
+	err = ServeHTTP(w, req)
+	assert.NoError(t, err)
+	resp := w.Result()
+	body, _ := io.ReadAll(resp.Body)
+	assert.Equal(t, expected, string(body))
+}

+ 80 - 17
phpthread.go

@@ -1,7 +1,6 @@
 package frankenphp
 
-// #include <stdint.h>
-// #include <php_variables.h>
+// #include "frankenphp.h"
 import "C"
 import (
 	"net/http"
@@ -10,32 +9,65 @@ import (
 	"unsafe"
 )
 
-var phpThreads []*phpThread
-
+// representation of the actual underlying PHP thread
+// identified by the index in the phpThreads slice
 type phpThread struct {
 	runtime.Pinner
 
-	mainRequest       *http.Request
-	workerRequest     *http.Request
-	worker            *worker
-	requestChan       chan *http.Request
+	threadIndex       int
 	knownVariableKeys map[string]*C.zend_string
-	readiedOnce       sync.Once
+	requestChan       chan *http.Request
+	drainChan         chan struct{}
+	handlerMu         *sync.Mutex
+	handler           threadHandler
+	state             *threadState
+}
+
+// interface that defines how the callbacks from the C thread should be handled
+type threadHandler interface {
+	beforeScriptExecution() string
+	afterScriptExecution(exitStatus int)
+	getActiveRequest() *http.Request
 }
 
-func initPHPThreads(numThreads int) {
-	phpThreads = make([]*phpThread, 0, numThreads)
-	for i := 0; i < numThreads; i++ {
-		phpThreads = append(phpThreads, &phpThread{})
+func newPHPThread(threadIndex int) *phpThread {
+	return &phpThread{
+		threadIndex: threadIndex,
+		drainChan:   make(chan struct{}),
+		requestChan: make(chan *http.Request),
+		handlerMu:   &sync.Mutex{},
+		state:       newThreadState(),
 	}
 }
 
-func (thread *phpThread) getActiveRequest() *http.Request {
-	if thread.workerRequest != nil {
-		return thread.workerRequest
+// change the thread handler safely
+// must be called from outside of the PHP thread
+func (thread *phpThread) setHandler(handler threadHandler) {
+	logger.Debug("setHandler")
+	thread.handlerMu.Lock()
+	defer thread.handlerMu.Unlock()
+	if !thread.state.requestSafeStateChange(stateTransitionRequested) {
+		// no state change allowed == shutdown
+		return
 	}
+	close(thread.drainChan)
+	thread.state.waitFor(stateTransitionInProgress)
+	thread.handler = handler
+	thread.drainChan = make(chan struct{})
+	thread.state.set(stateTransitionComplete)
+}
 
-	return thread.mainRequest
+// transition to a new handler safely
+// is triggered by setHandler and executed on the PHP thread
+func (thread *phpThread) transitionToNewHandler() string {
+	thread.state.set(stateTransitionInProgress)
+	thread.state.waitFor(stateTransitionComplete)
+	// execute beforeScriptExecution of the new handler
+	return thread.handler.beforeScriptExecution()
+}
+
+func (thread *phpThread) getActiveRequest() *http.Request {
+	return thread.handler.getActiveRequest()
 }
 
 // Pin a string that is not null-terminated
@@ -50,3 +82,34 @@ func (thread *phpThread) pinString(s string) *C.char {
 func (thread *phpThread) pinCString(s string) *C.char {
 	return thread.pinString(s + "\x00")
 }
+
+//export go_frankenphp_before_script_execution
+func go_frankenphp_before_script_execution(threadIndex C.uintptr_t) *C.char {
+	thread := phpThreads[threadIndex]
+	scriptName := thread.handler.beforeScriptExecution()
+
+	// if no scriptName is passed, shut down
+	if scriptName == "" {
+		return nil
+	}
+	// return the name of the PHP script that should be executed
+	return thread.pinCString(scriptName)
+}
+
+//export go_frankenphp_after_script_execution
+func go_frankenphp_after_script_execution(threadIndex C.uintptr_t, exitStatus C.int) {
+	thread := phpThreads[threadIndex]
+	if exitStatus < 0 {
+		panic(ScriptExecutionError)
+	}
+	thread.handler.afterScriptExecution(int(exitStatus))
+
+	// unpin all memory used during script execution
+	thread.Unpin()
+}
+
+//export go_frankenphp_on_thread_shutdown
+func go_frankenphp_on_thread_shutdown(threadIndex C.uintptr_t) {
+	phpThreads[threadIndex].Unpin()
+	phpThreads[threadIndex].state.set(stateDone)
+}

+ 0 - 40
phpthread_test.go

@@ -1,40 +0,0 @@
-package frankenphp
-
-import (
-	"net/http"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestInitializeTwoPhpThreadsWithoutRequests(t *testing.T) {
-	initPHPThreads(2)
-
-	assert.Len(t, phpThreads, 2)
-	assert.NotNil(t, phpThreads[0])
-	assert.NotNil(t, phpThreads[1])
-	assert.Nil(t, phpThreads[0].mainRequest)
-	assert.Nil(t, phpThreads[0].workerRequest)
-}
-
-func TestMainRequestIsActiveRequest(t *testing.T) {
-	mainRequest := &http.Request{}
-	initPHPThreads(1)
-	thread := phpThreads[0]
-
-	thread.mainRequest = mainRequest
-
-	assert.Equal(t, mainRequest, thread.getActiveRequest())
-}
-
-func TestWorkerRequestIsActiveRequest(t *testing.T) {
-	mainRequest := &http.Request{}
-	workerRequest := &http.Request{}
-	initPHPThreads(1)
-	thread := phpThreads[0]
-
-	thread.mainRequest = mainRequest
-	thread.workerRequest = workerRequest
-
-	assert.Equal(t, workerRequest, thread.getActiveRequest())
-}

Некоторые файлы не были показаны из-за большого количества измененных файлов