Browse Source

chore: improve tests and add missing file (#13)

Kévin Dunglas 2 years ago
parent
commit
5012ac30cd
2 changed files with 288 additions and 0 deletions
  1. 275 0
      recorder_test.go
  2. 13 0
      testdata/early-hints.php

+ 275 - 0
recorder_test.go

@@ -0,0 +1,275 @@
+// Remove me when https://github.com/golang/go/pull/56151 will be merged
+
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package frankenphp_test
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"net/http"
+	"net/http/httptrace"
+	"net/textproto"
+	"strconv"
+	"strings"
+
+	"golang.org/x/net/http/httpguts"
+)
+
+// ResponseRecorder is an implementation of http.ResponseWriter that
+// records its mutations for later inspection in tests.
+type ResponseRecorder struct {
+	// Code is the HTTP response code set by WriteHeader.
+	//
+	// Note that if a Handler never calls WriteHeader or Write,
+	// this might end up being 0, rather than the implicit
+	// http.StatusOK. To get the implicit value, use the Result
+	// method.
+	Code int
+
+	// HeaderMap contains the headers explicitly set by the Handler.
+	// It is an internal detail.
+	//
+	// Deprecated: HeaderMap exists for historical compatibility
+	// and should not be used. To access the headers returned by a handler,
+	// use the Response.Header map as returned by the Result method.
+	HeaderMap http.Header
+
+	// Body is the buffer to which the Handler's Write calls are sent.
+	// If nil, the Writes are silently discarded.
+	Body *bytes.Buffer
+
+	// Flushed is whether the Handler called Flush.
+	Flushed bool
+
+	// ClientTrace is used to trace 1XX responses
+	ClientTrace *httptrace.ClientTrace
+
+	result      *http.Response // cache of Result's return value
+	snapHeader  http.Header    // snapshot of HeaderMap at first Write
+	wroteHeader bool
+}
+
+// NewRecorder returns an initialized ResponseRecorder.
+func NewRecorder() *ResponseRecorder {
+	return &ResponseRecorder{
+		HeaderMap: make(http.Header),
+		Body:      new(bytes.Buffer),
+		Code:      200,
+	}
+}
+
+// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
+// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+const DefaultRemoteAddr = "1.2.3.4"
+
+// Header implements http.ResponseWriter. It returns the response
+// headers to mutate within a handler. To test the headers that were
+// written after a handler completes, use the Result method and see
+// the returned Response value's Header.
+func (rw *ResponseRecorder) Header() http.Header {
+	m := rw.HeaderMap
+	if m == nil {
+		m = make(http.Header)
+		rw.HeaderMap = m
+	}
+	return m
+}
+
+// writeHeader writes a header if it was not written yet and
+// detects Content-Type if needed.
+//
+// bytes or str are the beginning of the response body.
+// We pass both to avoid unnecessarily generate garbage
+// in rw.WriteString which was created for performance reasons.
+// Non-nil bytes win.
+func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
+	if rw.wroteHeader {
+		return
+	}
+	if len(str) > 512 {
+		str = str[:512]
+	}
+
+	m := rw.Header()
+
+	_, hasType := m["Content-Type"]
+	hasTE := m.Get("Transfer-Encoding") != ""
+	if !hasType && !hasTE {
+		if b == nil {
+			b = []byte(str)
+		}
+		m.Set("Content-Type", http.DetectContentType(b))
+	}
+
+	rw.WriteHeader(200)
+}
+
+// Write implements http.ResponseWriter. The data in buf is written to
+// rw.Body, if not nil.
+func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+	rw.writeHeader(buf, "")
+	if rw.Body != nil {
+		rw.Body.Write(buf)
+	}
+	return len(buf), nil
+}
+
+// WriteString implements io.StringWriter. The data in str is written
+// to rw.Body, if not nil.
+func (rw *ResponseRecorder) WriteString(str string) (int, error) {
+	rw.writeHeader(nil, str)
+	if rw.Body != nil {
+		rw.Body.WriteString(str)
+	}
+	return len(str), nil
+}
+
+func checkWriteHeaderCode(code int) {
+	// Issue 22880: require valid WriteHeader status codes.
+	// For now we only enforce that it's three digits.
+	// In the future we might block things over 599 (600 and above aren't defined
+	// at https://httpwg.org/specs/rfc7231.html#status.codes)
+	// and we might block under 200 (once we have more mature 1xx support).
+	// But for now any three digits.
+	//
+	// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
+	// no equivalent bogus thing we can realistically send in HTTP/2,
+	// so we'll consistently panic instead and help people find their bugs
+	// early. (We can't return an error from WriteHeader even if we wanted to.)
+	if code < 100 || code > 999 {
+		panic(fmt.Sprintf("invalid WriteHeader code %v", code))
+	}
+}
+
+// WriteHeader implements http.ResponseWriter.
+func (rw *ResponseRecorder) WriteHeader(code int) {
+	if rw.wroteHeader {
+		return
+	}
+
+	checkWriteHeaderCode(code)
+
+	if rw.ClientTrace != nil && code >= 100 && code < 200 {
+		if code == 100 {
+			rw.ClientTrace.Got100Continue()
+		}
+		// treat 101 as a terminal status, see issue 26161
+		if code != http.StatusSwitchingProtocols {
+			if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil {
+				panic(err)
+			}
+			return
+		}
+	}
+
+	rw.Code = code
+	rw.wroteHeader = true
+	if rw.HeaderMap == nil {
+		rw.HeaderMap = make(http.Header)
+	}
+	rw.snapHeader = rw.HeaderMap.Clone()
+}
+
+// Flush implements http.Flusher. To test whether Flush was
+// called, see rw.Flushed.
+func (rw *ResponseRecorder) Flush() {
+	if !rw.wroteHeader {
+		rw.WriteHeader(200)
+	}
+	rw.Flushed = true
+}
+
+// Result returns the response generated by the handler.
+//
+// The returned Response will have at least its StatusCode,
+// Header, Body, and optionally Trailer populated.
+// More fields may be populated in the future, so callers should
+// not DeepEqual the result in tests.
+//
+// The Response.Header is a snapshot of the headers at the time of the
+// first write call, or at the time of this call, if the handler never
+// did a write.
+//
+// The Response.Body is guaranteed to be non-nil and Body.Read call is
+// guaranteed to not return any error other than io.EOF.
+//
+// Result must only be called after the handler has finished running.
+func (rw *ResponseRecorder) Result() *http.Response {
+	if rw.result != nil {
+		return rw.result
+	}
+	if rw.snapHeader == nil {
+		rw.snapHeader = rw.HeaderMap.Clone()
+	}
+	res := &http.Response{
+		Proto:      "HTTP/1.1",
+		ProtoMajor: 1,
+		ProtoMinor: 1,
+		StatusCode: rw.Code,
+		Header:     rw.snapHeader,
+	}
+	rw.result = res
+	if res.StatusCode == 0 {
+		res.StatusCode = 200
+	}
+	res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
+	if rw.Body != nil {
+		res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
+	} else {
+		res.Body = http.NoBody
+	}
+	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
+
+	if trailers, ok := rw.snapHeader["Trailer"]; ok {
+		res.Trailer = make(http.Header, len(trailers))
+		for _, k := range trailers {
+			for _, k := range strings.Split(k, ",") {
+				k = http.CanonicalHeaderKey(textproto.TrimString(k))
+				if !httpguts.ValidTrailerHeader(k) {
+					// Ignore since forbidden by RFC 7230, section 4.1.2.
+					continue
+				}
+				vv, ok := rw.HeaderMap[k]
+				if !ok {
+					continue
+				}
+				vv2 := make([]string, len(vv))
+				copy(vv2, vv)
+				res.Trailer[k] = vv2
+			}
+		}
+	}
+	for k, vv := range rw.HeaderMap {
+		if !strings.HasPrefix(k, http.TrailerPrefix) {
+			continue
+		}
+		if res.Trailer == nil {
+			res.Trailer = make(http.Header)
+		}
+		for _, v := range vv {
+			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
+		}
+	}
+	return res
+}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+//
+// This a modified version of same function found in net/http/transfer.go. This
+// one just ignores an invalid header.
+func parseContentLength(cl string) int64 {
+	cl = textproto.TrimString(cl)
+	if cl == "" {
+		return -1
+	}
+	n, err := strconv.ParseUint(cl, 10, 63)
+	if err != nil {
+		return -1
+	}
+	return int64(n)
+}

+ 13 - 0
testdata/early-hints.php

@@ -0,0 +1,13 @@
+<?php
+
+require_once __DIR__.'/_executor.php';
+
+return function () {
+    header('Link: </style.css>; rel=preload; as=style');
+    header("Request: {$_GET['i']}");
+    headers_send(103);
+
+    header_remove('Link');
+
+    echo 'Hello';
+};