http_util.go 10 KB


  1. /*
  2. *
  3. * Copyright 2014 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package transport
  19. import (
  20. "bufio"
  21. "encoding/base64"
  22. "errors"
  23. "fmt"
  24. "io"
  25. "math"
  26. "net"
  27. "net/http"
  28. "net/url"
  29. "strconv"
  30. "strings"
  31. "time"
  32. "unicode/utf8"
  33. "github.com/golang/protobuf/proto"
  34. "golang.org/x/net/http2"
  35. "golang.org/x/net/http2/hpack"
  36. spb "google.golang.org/genproto/googleapis/rpc/status"
  37. "google.golang.org/grpc/codes"
  38. "google.golang.org/grpc/status"
  39. )
  40. const (
  41. // http2MaxFrameLen specifies the max length of a HTTP2 frame.
  42. http2MaxFrameLen = 16384 // 16KB frame
  43. // https://httpwg.org/specs/rfc7540.html#SettingValues
  44. http2InitHeaderTableSize = 4096
  45. )
  46. var (
  47. clientPreface = []byte(http2.ClientPreface)
  48. http2ErrConvTab = map[http2.ErrCode]codes.Code{
  49. http2.ErrCodeNo: codes.Internal,
  50. http2.ErrCodeProtocol: codes.Internal,
  51. http2.ErrCodeInternal: codes.Internal,
  52. http2.ErrCodeFlowControl: codes.ResourceExhausted,
  53. http2.ErrCodeSettingsTimeout: codes.Internal,
  54. http2.ErrCodeStreamClosed: codes.Internal,
  55. http2.ErrCodeFrameSize: codes.Internal,
  56. http2.ErrCodeRefusedStream: codes.Unavailable,
  57. http2.ErrCodeCancel: codes.Canceled,
  58. http2.ErrCodeCompression: codes.Internal,
  59. http2.ErrCodeConnect: codes.Internal,
  60. http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
  61. http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
  62. http2.ErrCodeHTTP11Required: codes.Internal,
  63. }
  64. // HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
  65. HTTPStatusConvTab = map[int]codes.Code{
  66. // 400 Bad Request - INTERNAL.
  67. http.StatusBadRequest: codes.Internal,
  68. // 401 Unauthorized - UNAUTHENTICATED.
  69. http.StatusUnauthorized: codes.Unauthenticated,
  70. // 403 Forbidden - PERMISSION_DENIED.
  71. http.StatusForbidden: codes.PermissionDenied,
  72. // 404 Not Found - UNIMPLEMENTED.
  73. http.StatusNotFound: codes.Unimplemented,
  74. // 429 Too Many Requests - UNAVAILABLE.
  75. http.StatusTooManyRequests: codes.Unavailable,
  76. // 502 Bad Gateway - UNAVAILABLE.
  77. http.StatusBadGateway: codes.Unavailable,
  78. // 503 Service Unavailable - UNAVAILABLE.
  79. http.StatusServiceUnavailable: codes.Unavailable,
  80. // 504 Gateway timeout - UNAVAILABLE.
  81. http.StatusGatewayTimeout: codes.Unavailable,
  82. }
  83. )
  84. // isReservedHeader checks whether hdr belongs to HTTP2 headers
  85. // reserved by gRPC protocol. Any other headers are classified as the
  86. // user-specified metadata.
  87. func isReservedHeader(hdr string) bool {
  88. if hdr != "" && hdr[0] == ':' {
  89. return true
  90. }
  91. switch hdr {
  92. case "content-type",
  93. "user-agent",
  94. "grpc-message-type",
  95. "grpc-encoding",
  96. "grpc-message",
  97. "grpc-status",
  98. "grpc-timeout",
  99. "grpc-status-details-bin",
  100. // Intentionally exclude grpc-previous-rpc-attempts and
  101. // grpc-retry-pushback-ms, which are "reserved", but their API
  102. // intentionally works via metadata.
  103. "te":
  104. return true
  105. default:
  106. return false
  107. }
  108. }
  109. // isWhitelistedHeader checks whether hdr should be propagated into metadata
  110. // visible to users, even though it is classified as "reserved", above.
  111. func isWhitelistedHeader(hdr string) bool {
  112. switch hdr {
  113. case ":authority", "user-agent":
  114. return true
  115. default:
  116. return false
  117. }
  118. }
  119. const binHdrSuffix = "-bin"
  120. func encodeBinHeader(v []byte) string {
  121. return base64.RawStdEncoding.EncodeToString(v)
  122. }
  123. func decodeBinHeader(v string) ([]byte, error) {
  124. if len(v)%4 == 0 {
  125. // Input was padded, or padding was not necessary.
  126. return base64.StdEncoding.DecodeString(v)
  127. }
  128. return base64.RawStdEncoding.DecodeString(v)
  129. }
  130. func encodeMetadataHeader(k, v string) string {
  131. if strings.HasSuffix(k, binHdrSuffix) {
  132. return encodeBinHeader(([]byte)(v))
  133. }
  134. return v
  135. }
  136. func decodeMetadataHeader(k, v string) (string, error) {
  137. if strings.HasSuffix(k, binHdrSuffix) {
  138. b, err := decodeBinHeader(v)
  139. return string(b), err
  140. }
  141. return v, nil
  142. }
  143. func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) {
  144. v, err := decodeBinHeader(rawDetails)
  145. if err != nil {
  146. return nil, err
  147. }
  148. st := &spb.Status{}
  149. if err = proto.Unmarshal(v, st); err != nil {
  150. return nil, err
  151. }
  152. return status.FromProto(st), nil
  153. }
  154. type timeoutUnit uint8
  155. const (
  156. hour timeoutUnit = 'H'
  157. minute timeoutUnit = 'M'
  158. second timeoutUnit = 'S'
  159. millisecond timeoutUnit = 'm'
  160. microsecond timeoutUnit = 'u'
  161. nanosecond timeoutUnit = 'n'
  162. )
  163. func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
  164. switch u {
  165. case hour:
  166. return time.Hour, true
  167. case minute:
  168. return time.Minute, true
  169. case second:
  170. return time.Second, true
  171. case millisecond:
  172. return time.Millisecond, true
  173. case microsecond:
  174. return time.Microsecond, true
  175. case nanosecond:
  176. return time.Nanosecond, true
  177. default:
  178. }
  179. return
  180. }
  181. func decodeTimeout(s string) (time.Duration, error) {
  182. size := len(s)
  183. if size < 2 {
  184. return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
  185. }
  186. if size > 9 {
  187. // Spec allows for 8 digits plus the unit.
  188. return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
  189. }
  190. unit := timeoutUnit(s[size-1])
  191. d, ok := timeoutUnitToDuration(unit)
  192. if !ok {
  193. return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
  194. }
  195. t, err := strconv.ParseInt(s[:size-1], 10, 64)
  196. if err != nil {
  197. return 0, err
  198. }
  199. const maxHours = math.MaxInt64 / int64(time.Hour)
  200. if d == time.Hour && t > maxHours {
  201. // This timeout would overflow math.MaxInt64; clamp it.
  202. return time.Duration(math.MaxInt64), nil
  203. }
  204. return d * time.Duration(t), nil
  205. }
  206. const (
  207. spaceByte = ' '
  208. tildeByte = '~'
  209. percentByte = '%'
  210. )
  211. // encodeGrpcMessage is used to encode status code in header field
  212. // "grpc-message". It does percent encoding and also replaces invalid utf-8
  213. // characters with Unicode replacement character.
  214. //
  215. // It checks to see if each individual byte in msg is an allowable byte, and
  216. // then either percent encoding or passing it through. When percent encoding,
  217. // the byte is converted into hexadecimal notation with a '%' prepended.
  218. func encodeGrpcMessage(msg string) string {
  219. if msg == "" {
  220. return ""
  221. }
  222. lenMsg := len(msg)
  223. for i := 0; i < lenMsg; i++ {
  224. c := msg[i]
  225. if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
  226. return encodeGrpcMessageUnchecked(msg)
  227. }
  228. }
  229. return msg
  230. }
  231. func encodeGrpcMessageUnchecked(msg string) string {
  232. var sb strings.Builder
  233. for len(msg) > 0 {
  234. r, size := utf8.DecodeRuneInString(msg)
  235. for _, b := range []byte(string(r)) {
  236. if size > 1 {
  237. // If size > 1, r is not ascii. Always do percent encoding.
  238. fmt.Fprintf(&sb, "%%%02X", b)
  239. continue
  240. }
  241. // The for loop is necessary even if size == 1. r could be
  242. // utf8.RuneError.
  243. //
  244. // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
  245. if b >= spaceByte && b <= tildeByte && b != percentByte {
  246. sb.WriteByte(b)
  247. } else {
  248. fmt.Fprintf(&sb, "%%%02X", b)
  249. }
  250. }
  251. msg = msg[size:]
  252. }
  253. return sb.String()
  254. }
  255. // decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
  256. func decodeGrpcMessage(msg string) string {
  257. if msg == "" {
  258. return ""
  259. }
  260. lenMsg := len(msg)
  261. for i := 0; i < lenMsg; i++ {
  262. if msg[i] == percentByte && i+2 < lenMsg {
  263. return decodeGrpcMessageUnchecked(msg)
  264. }
  265. }
  266. return msg
  267. }
  268. func decodeGrpcMessageUnchecked(msg string) string {
  269. var sb strings.Builder
  270. lenMsg := len(msg)
  271. for i := 0; i < lenMsg; i++ {
  272. c := msg[i]
  273. if c == percentByte && i+2 < lenMsg {
  274. parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
  275. if err != nil {
  276. sb.WriteByte(c)
  277. } else {
  278. sb.WriteByte(byte(parsed))
  279. i += 2
  280. }
  281. } else {
  282. sb.WriteByte(c)
  283. }
  284. }
  285. return sb.String()
  286. }
  287. type bufWriter struct {
  288. buf []byte
  289. offset int
  290. batchSize int
  291. conn net.Conn
  292. err error
  293. }
  294. func newBufWriter(conn net.Conn, batchSize int) *bufWriter {
  295. return &bufWriter{
  296. buf: make([]byte, batchSize*2),
  297. batchSize: batchSize,
  298. conn: conn,
  299. }
  300. }
  301. func (w *bufWriter) Write(b []byte) (n int, err error) {
  302. if w.err != nil {
  303. return 0, w.err
  304. }
  305. if w.batchSize == 0 { // Buffer has been disabled.
  306. n, err = w.conn.Write(b)
  307. return n, toIOError(err)
  308. }
  309. for len(b) > 0 {
  310. nn := copy(w.buf[w.offset:], b)
  311. b = b[nn:]
  312. w.offset += nn
  313. n += nn
  314. if w.offset >= w.batchSize {
  315. err = w.Flush()
  316. }
  317. }
  318. return n, err
  319. }
  320. func (w *bufWriter) Flush() error {
  321. if w.err != nil {
  322. return w.err
  323. }
  324. if w.offset == 0 {
  325. return nil
  326. }
  327. _, w.err = w.conn.Write(w.buf[:w.offset])
  328. w.err = toIOError(w.err)
  329. w.offset = 0
  330. return w.err
  331. }
  332. type ioError struct {
  333. error
  334. }
  335. func (i ioError) Unwrap() error {
  336. return i.error
  337. }
  338. func isIOError(err error) bool {
  339. return errors.As(err, &ioError{})
  340. }
  341. func toIOError(err error) error {
  342. if err == nil {
  343. return nil
  344. }
  345. return ioError{error: err}
  346. }
  347. type framer struct {
  348. writer *bufWriter
  349. fr *http2.Framer
  350. }
  351. func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderListSize uint32) *framer {
  352. if writeBufferSize < 0 {
  353. writeBufferSize = 0
  354. }
  355. var r io.Reader = conn
  356. if readBufferSize > 0 {
  357. r = bufio.NewReaderSize(r, readBufferSize)
  358. }
  359. w := newBufWriter(conn, writeBufferSize)
  360. f := &framer{
  361. writer: w,
  362. fr: http2.NewFramer(w, r),
  363. }
  364. f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
  365. // Opt-in to Frame reuse API on framer to reduce garbage.
  366. // Frames aren't safe to read from after a subsequent call to ReadFrame.
  367. f.fr.SetReuseFrames()
  368. f.fr.MaxHeaderListSize = maxHeaderListSize
  369. f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
  370. return f
  371. }
  372. // parseDialTarget returns the network and address to pass to dialer.
  373. func parseDialTarget(target string) (string, string) {
  374. net := "tcp"
  375. m1 := strings.Index(target, ":")
  376. m2 := strings.Index(target, ":/")
  377. // handle unix:addr which will fail with url.Parse
  378. if m1 >= 0 && m2 < 0 {
  379. if n := target[0:m1]; n == "unix" {
  380. return n, target[m1+1:]
  381. }
  382. }
  383. if m2 >= 0 {
  384. t, err := url.Parse(target)
  385. if err != nil {
  386. return net, target
  387. }
  388. scheme := t.Scheme
  389. addr := t.Path
  390. if scheme == "unix" {
  391. if addr == "" {
  392. addr = t.Host
  393. }
  394. return scheme, addr
  395. }
  396. }
  397. return net, target
  398. }