123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- package frankenphp_test
- import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "net/http/httptrace"
- "net/textproto"
- "strconv"
- "strings"
- "golang.org/x/net/http/httpguts"
- )
- type ResponseRecorder struct {
-
-
-
-
-
-
- Code int
-
-
-
-
-
-
- HeaderMap http.Header
-
-
- Body *bytes.Buffer
-
- Flushed bool
-
- ClientTrace *httptrace.ClientTrace
- result *http.Response
- snapHeader http.Header
- wroteHeader bool
- }
- func NewRecorder() *ResponseRecorder {
- return &ResponseRecorder{
- HeaderMap: make(http.Header),
- Body: new(bytes.Buffer),
- Code: 200,
- }
- }
- const DefaultRemoteAddr = "1.2.3.4"
- func (rw *ResponseRecorder) Header() http.Header {
- m := rw.HeaderMap
- if m == nil {
- m = make(http.Header)
- rw.HeaderMap = m
- }
- return m
- }
- func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
- if rw.wroteHeader {
- return
- }
- if len(str) > 512 {
- str = str[:512]
- }
- m := rw.Header()
- _, hasType := m["Content-Type"]
- hasTE := m.Get("Transfer-Encoding") != ""
- if !hasType && !hasTE {
- if b == nil {
- b = []byte(str)
- }
- m.Set("Content-Type", http.DetectContentType(b))
- }
- rw.WriteHeader(200)
- }
- func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
- rw.writeHeader(buf, "")
- if rw.Body != nil {
- rw.Body.Write(buf)
- }
- return len(buf), nil
- }
- func (rw *ResponseRecorder) WriteString(str string) (int, error) {
- rw.writeHeader(nil, str)
- if rw.Body != nil {
- rw.Body.WriteString(str)
- }
- return len(str), nil
- }
- func checkWriteHeaderCode(code int) {
-
-
-
-
-
-
-
-
-
-
-
- if code < 100 || code > 999 {
- panic(fmt.Sprintf("invalid WriteHeader code %v", code))
- }
- }
- func (rw *ResponseRecorder) WriteHeader(code int) {
- if rw.wroteHeader {
- return
- }
- checkWriteHeaderCode(code)
- if rw.ClientTrace != nil && code >= 100 && code < 200 {
- if code == 100 {
- rw.ClientTrace.Got100Continue()
- }
-
- if code != http.StatusSwitchingProtocols {
- if err := rw.ClientTrace.Got1xxResponse(code, textproto.MIMEHeader(rw.HeaderMap)); err != nil {
- panic(err)
- }
- return
- }
- }
- rw.Code = code
- rw.wroteHeader = true
- if rw.HeaderMap == nil {
- rw.HeaderMap = make(http.Header)
- }
- rw.snapHeader = rw.HeaderMap.Clone()
- }
- func (rw *ResponseRecorder) Flush() {
- if !rw.wroteHeader {
- rw.WriteHeader(200)
- }
- rw.Flushed = true
- }
- func (rw *ResponseRecorder) Result() *http.Response {
- if rw.result != nil {
- return rw.result
- }
- if rw.snapHeader == nil {
- rw.snapHeader = rw.HeaderMap.Clone()
- }
- res := &http.Response{
- Proto: "HTTP/1.1",
- ProtoMajor: 1,
- ProtoMinor: 1,
- StatusCode: rw.Code,
- Header: rw.snapHeader,
- }
- rw.result = res
- if res.StatusCode == 0 {
- res.StatusCode = 200
- }
- res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
- if rw.Body != nil {
- res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
- } else {
- res.Body = http.NoBody
- }
- res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
- if trailers, ok := rw.snapHeader["Trailer"]; ok {
- res.Trailer = make(http.Header, len(trailers))
- for _, k := range trailers {
- for _, k := range strings.Split(k, ",") {
- k = http.CanonicalHeaderKey(textproto.TrimString(k))
- if !httpguts.ValidTrailerHeader(k) {
-
- continue
- }
- vv, ok := rw.HeaderMap[k]
- if !ok {
- continue
- }
- vv2 := make([]string, len(vv))
- copy(vv2, vv)
- res.Trailer[k] = vv2
- }
- }
- }
- for k, vv := range rw.HeaderMap {
- if !strings.HasPrefix(k, http.TrailerPrefix) {
- continue
- }
- if res.Trailer == nil {
- res.Trailer = make(http.Header)
- }
- for _, v := range vv {
- res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
- }
- }
- return res
- }
- func parseContentLength(cl string) int64 {
- cl = textproto.TrimString(cl)
- if cl == "" {
- return -1
- }
- n, err := strconv.ParseUint(cl, 10, 63)
- if err != nil {
- return -1
- }
- return int64(n)
- }
|