Browse Source

feat: handle aborted connection (#95)

* Handle aborted connection

* Handle when writing as well

* return bytes written

* optimize return

* remove goroutine

* fix style

* Add tests

* add missing newline
Rob Landers 2 years ago
parent
commit
3abda4fbb6
4 changed files with 144 additions and 8 deletions
  1. 8 2
      frankenphp.c
  2. 21 6
      frankenphp.go
  3. 98 0
      frankenphp_test.go
  4. 17 0
      testdata/connectionStatusLog.php

+ 8 - 2
frankenphp.c

@@ -405,7 +405,13 @@ static size_t frankenphp_ub_write(const char *str, size_t str_length)
 		return 0;
 	}
 
-	return go_ub_write(ctx->current_request ? ctx->current_request : ctx->main_request, (char *) str, str_length);
+	struct go_ub_write_return result = go_ub_write(ctx->current_request ? ctx->current_request : ctx->main_request, (char *) str, str_length);
+
+	if (result.r1) {
+		php_handle_aborted_connection();
+	}
+
+	return result.r0;
 }
 
 static int frankenphp_send_headers(sapi_headers_struct *sapi_headers)
@@ -445,7 +451,7 @@ static void frankenphp_sapi_flush(void *server_context)
 
 	if (!ctx || ctx->current_request == 0) return;
 
-	go_sapi_flush(ctx->current_request);
+	if (go_sapi_flush(ctx->current_request)) php_handle_aborted_connection();
 }
 
 static size_t frankenphp_read_post(char *buffer, size_t count_bytes)

+ 21 - 6
frankenphp.go

@@ -124,13 +124,22 @@ type FrankenPHPContext struct {
 	populated    bool
 	authPassword string
 
-	// Whether the request is already closed
+	// Whether the request is already closed by us
 	closed sync.Once
 
 	responseWriter http.ResponseWriter
 	done           chan interface{}
 }
 
+func clientHasClosed(r *http.Request) bool {
+	select {
+	case <-r.Context().Done():
+		return true
+	default:
+		return false
+	}
+}
+
 // NewRequestWithContext creates a new FrankenPHP request context.
 func NewRequestWithContext(r *http.Request, documentRoot string, l *zap.Logger) *http.Request {
 	if l == nil {
@@ -407,7 +416,7 @@ func go_execute_script(rh unsafe.Pointer) {
 }
 
 //export go_ub_write
-func go_ub_write(rh C.uintptr_t, cString *C.char, length C.int) C.size_t {
+func go_ub_write(rh C.uintptr_t, cString *C.char, length C.int) (C.size_t, C.bool) {
 	r := cgo.Handle(rh).Value().(*http.Request)
 	fc, _ := FromContext(r.Context())
 
@@ -426,7 +435,7 @@ func go_ub_write(rh C.uintptr_t, cString *C.char, length C.int) C.size_t {
 		fc.Logger.Info(writer.(*bytes.Buffer).String())
 	}
 
-	return C.size_t(i)
+	return C.size_t(i), C.bool(clientHasClosed(r))
 }
 
 //export go_register_variables
@@ -486,20 +495,26 @@ func go_write_header(rh C.uintptr_t, status C.int) {
 }
 
 //export go_sapi_flush
-func go_sapi_flush(rh C.uintptr_t) {
+func go_sapi_flush(rh C.uintptr_t) bool {
 	r := cgo.Handle(rh).Value().(*http.Request)
 	fc := r.Context().Value(contextKey).(*FrankenPHPContext)
 
 	if fc.responseWriter == nil {
-		return
+		return true
 	}
 
 	flusher, ok := fc.responseWriter.(http.Flusher)
 	if !ok {
-		return
+		return true
+	}
+
+	if clientHasClosed(r) {
+		return true
 	}
 
 	flusher.Flush()
+
+	return false
 }
 
 //export go_read_post

+ 98 - 0
frankenphp_test.go

@@ -1,6 +1,7 @@
 package frankenphp_test
 
 import (
+	"context"
 	"fmt"
 	"io"
 	"log"
@@ -15,6 +16,7 @@ import (
 	"strings"
 	"sync"
 	"testing"
+	"time"
 
 	"github.com/dunglas/frankenphp"
 	"github.com/stretchr/testify/assert"
@@ -388,6 +390,102 @@ func testLog(t *testing.T, opts *testOptions) {
 	}, opts)
 }
 
+func TestConnectionAbortNormal_module(t *testing.T) { testConnectionAbortNormal(t, &testOptions{}) }
+func TestConnectionAbortNormal_worker(t *testing.T) {
+	testConnectionAbortNormal(t, &testOptions{workerScript: "connectionStatusLog.php"})
+}
+func testConnectionAbortNormal(t *testing.T, opts *testOptions) {
+	logger, logs := observer.New(zap.InfoLevel)
+	opts.logger = zap.New(logger)
+
+	runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) {
+		req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/connectionStatusLog.php?i=%d", i), nil)
+		w := httptest.NewRecorder()
+
+		ctx, cancel := context.WithCancel(req.Context())
+		req = req.WithContext(ctx)
+		cancel()
+		handler(w, req)
+
+		// todo: remove conditions on wall clock to avoid race conditions/flakiness
+		time.Sleep(1000 * time.Microsecond)
+		var found bool
+		searched := fmt.Sprintf("request %d: 1", i)
+		for _, entry := range logs.All() {
+			if entry.Message == searched {
+				found = true
+				break
+			}
+		}
+
+		assert.True(t, found)
+	}, opts)
+}
+
+func TestConnectionAbortFlush_module(t *testing.T) { testConnectionAbortFlush(t, &testOptions{}) }
+func TestConnectionAbortFlush_worker(t *testing.T) {
+	testConnectionAbortFlush(t, &testOptions{workerScript: "connectionStatusLog.php"})
+}
+func testConnectionAbortFlush(t *testing.T, opts *testOptions) {
+	logger, logs := observer.New(zap.InfoLevel)
+	opts.logger = zap.New(logger)
+
+	runTest(t, func(handler func(w http.ResponseWriter, response *http.Request), _ *httptest.Server, i int) {
+		req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/connectionStatusLog.php?i=%d&flush", i), nil)
+		w := httptest.NewRecorder()
+
+		ctx, cancel := context.WithCancel(req.Context())
+		req = req.WithContext(ctx)
+		cancel()
+		handler(w, req)
+
+		// todo: remove conditions on wall clock to avoid race conditions/flakiness
+		time.Sleep(1000 * time.Microsecond)
+		var found bool
+		searched := fmt.Sprintf("request %d: 1", i)
+		for _, entry := range logs.All() {
+			if entry.Message == searched {
+				found = true
+				break
+			}
+		}
+
+		assert.True(t, found)
+	}, opts)
+}
+
+func TestConnectionAbortFinish_module(t *testing.T) { testConnectionAbortFinish(t, &testOptions{}) }
+func TestConnectionAbortFinish_worker(t *testing.T) {
+	testConnectionAbortFinish(t, &testOptions{workerScript: "connectionStatusLog.php"})
+}
+func testConnectionAbortFinish(t *testing.T, opts *testOptions) {
+	logger, logs := observer.New(zap.InfoLevel)
+	opts.logger = zap.New(logger)
+
+	runTest(t, func(handler func(w http.ResponseWriter, response *http.Request), _ *httptest.Server, i int) {
+		req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/connectionStatusLog.php?i=%d&finish", i), nil)
+		w := httptest.NewRecorder()
+
+		ctx, cancel := context.WithCancel(req.Context())
+		req = req.WithContext(ctx)
+		cancel()
+		handler(w, req)
+
+		// todo: remove conditions on wall clock to avoid race conditions/flakiness
+		time.Sleep(1000 * time.Microsecond)
+		var found bool
+		searched := fmt.Sprintf("request %d: 0", i)
+		for _, entry := range logs.All() {
+			if entry.Message == searched {
+				found = true
+				break
+			}
+		}
+
+		assert.True(t, found)
+	}, opts)
+}
+
 func TestException_module(t *testing.T) { testException(t, &testOptions{}) }
 func TestException_worker(t *testing.T) {
 	testException(t, &testOptions{workerScript: "exception.php"})

+ 17 - 0
testdata/connectionStatusLog.php

@@ -0,0 +1,17 @@
+<?php
+
+ignore_user_abort(true);
+
+require_once __DIR__.'/_executor.php';
+
+return function () {
+	if(isset($_GET['finish'])) {
+		frankenphp_finish_request();
+	}
+	echo 'hi';
+	if(isset($_GET['flush'])) {
+		flush();
+	}
+	$status = (string) connection_status();
+	error_log("request {$_GET['i']}: " . $status);
+};