requester.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. package main
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "fmt"
  6. "io"
  7. "net"
  8. url2 "net/url"
  9. "os"
  10. "os/signal"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "syscall"
  16. "time"
  17. "github.com/valyala/fasthttp"
  18. "github.com/valyala/fasthttp/fasthttpproxy"
  19. "go.uber.org/automaxprocs/maxprocs"
  20. "golang.org/x/time/rate"
  21. )
  22. var (
  23. startTime = time.Now()
  24. sendOnCloseError interface{}
  25. )
  26. type ReportRecord struct {
  27. cost time.Duration
  28. code int
  29. error string
  30. readBytes int64
  31. writeBytes int64
  32. }
  33. var recordPool = sync.Pool{
  34. New: func() interface{} { return new(ReportRecord) },
  35. }
  36. func init() {
  37. // Honoring env GOMAXPROCS
  38. _, _ = maxprocs.Set()
  39. defer func() {
  40. sendOnCloseError = recover()
  41. }()
  42. func() {
  43. cc := make(chan struct{}, 1)
  44. close(cc)
  45. cc <- struct{}{}
  46. }()
  47. }
  48. type MyConn struct {
  49. net.Conn
  50. r, w *int64
  51. }
  52. func NewMyConn(conn net.Conn, r, w *int64) (*MyConn, error) {
  53. myConn := &MyConn{Conn: conn, r: r, w: w}
  54. return myConn, nil
  55. }
  56. func (c *MyConn) Read(b []byte) (n int, err error) {
  57. sz, err := c.Conn.Read(b)
  58. if err == nil {
  59. atomic.AddInt64(c.r, int64(sz))
  60. }
  61. return sz, err
  62. }
  63. func (c *MyConn) Write(b []byte) (n int, err error) {
  64. sz, err := c.Conn.Write(b)
  65. if err == nil {
  66. atomic.AddInt64(c.w, int64(sz))
  67. }
  68. return sz, err
  69. }
  70. func ThroughputInterceptorDial(dial fasthttp.DialFunc, r *int64, w *int64) fasthttp.DialFunc {
  71. return func(addr string) (net.Conn, error) {
  72. conn, err := dial(addr)
  73. if err != nil {
  74. return nil, err
  75. }
  76. return NewMyConn(conn, r, w)
  77. }
  78. }
  79. type Requester struct {
  80. concurrency int
  81. reqRate *rate.Limit
  82. requests int64
  83. duration time.Duration
  84. clientOpt *ClientOpt
  85. httpClient *fasthttp.HostClient
  86. httpHeader *fasthttp.RequestHeader
  87. errWriter io.Writer
  88. recordChan chan *ReportRecord
  89. closeOnce sync.Once
  90. wg sync.WaitGroup
  91. readBytes int64
  92. writeBytes int64
  93. cancel func()
  94. }
  95. type ClientOpt struct {
  96. url string
  97. method string
  98. headers []string
  99. bodyBytes []byte
  100. bodyFile string
  101. certPath string
  102. keyPath string
  103. insecure bool
  104. maxConns int
  105. doTimeout time.Duration
  106. readTimeout time.Duration
  107. writeTimeout time.Duration
  108. dialTimeout time.Duration
  109. socks5Proxy string
  110. contentType string
  111. host string
  112. }
  113. func NewRequester(concurrency int, requests int64, duration time.Duration, reqRate *rate.Limit, errWriter io.Writer, clientOpt *ClientOpt) (*Requester, error) {
  114. maxResult := concurrency * 100
  115. if maxResult > 8192 {
  116. maxResult = 8192
  117. }
  118. r := &Requester{
  119. concurrency: concurrency,
  120. reqRate: reqRate,
  121. requests: requests,
  122. duration: duration,
  123. errWriter: errWriter,
  124. clientOpt: clientOpt,
  125. recordChan: make(chan *ReportRecord, maxResult),
  126. }
  127. client, header, err := buildRequestClient(clientOpt, &r.readBytes, &r.writeBytes)
  128. if err != nil {
  129. return nil, err
  130. }
  131. r.httpClient = client
  132. r.httpHeader = header
  133. return r, nil
  134. }
  135. func addMissingPort(addr string, isTLS bool) string {
  136. n := strings.Index(addr, ":")
  137. if n >= 0 {
  138. return addr
  139. }
  140. port := 80
  141. if isTLS {
  142. port = 443
  143. }
  144. return net.JoinHostPort(addr, strconv.Itoa(port))
  145. }
  146. func buildTLSConfig(opt *ClientOpt) (*tls.Config, error) {
  147. var certs []tls.Certificate
  148. if opt.certPath != "" && opt.keyPath != "" {
  149. c, err := tls.LoadX509KeyPair(opt.certPath, opt.keyPath)
  150. if err != nil {
  151. return nil, err
  152. }
  153. certs = append(certs, c)
  154. }
  155. return &tls.Config{
  156. InsecureSkipVerify: opt.insecure,
  157. Certificates: certs,
  158. }, nil
  159. }
  160. func buildRequestClient(opt *ClientOpt, r *int64, w *int64) (*fasthttp.HostClient, *fasthttp.RequestHeader, error) {
  161. u, err := url2.Parse(opt.url)
  162. if err != nil {
  163. return nil, nil, err
  164. }
  165. httpClient := &fasthttp.HostClient{
  166. Addr: addMissingPort(u.Host, u.Scheme == "https"),
  167. IsTLS: u.Scheme == "https",
  168. Name: "plow",
  169. MaxConns: opt.maxConns,
  170. ReadTimeout: opt.readTimeout,
  171. WriteTimeout: opt.writeTimeout,
  172. DisableHeaderNamesNormalizing: true,
  173. }
  174. if opt.socks5Proxy != "" {
  175. if !strings.Contains(opt.socks5Proxy, "://") {
  176. opt.socks5Proxy = "socks5://" + opt.socks5Proxy
  177. }
  178. httpClient.Dial = fasthttpproxy.FasthttpSocksDialer(opt.socks5Proxy)
  179. } else {
  180. httpClient.Dial = fasthttpproxy.FasthttpProxyHTTPDialerTimeout(opt.dialTimeout)
  181. }
  182. httpClient.Dial = ThroughputInterceptorDial(httpClient.Dial, r, w)
  183. tlsConfig, err := buildTLSConfig(opt)
  184. if err != nil {
  185. return nil, nil, err
  186. }
  187. httpClient.TLSConfig = tlsConfig
  188. var requestHeader fasthttp.RequestHeader
  189. if opt.contentType != "" {
  190. requestHeader.SetContentType(opt.contentType)
  191. }
  192. if opt.host != "" {
  193. requestHeader.SetHost(opt.host)
  194. } else {
  195. requestHeader.SetHost(u.Host)
  196. }
  197. requestHeader.SetMethod(opt.method)
  198. requestHeader.SetRequestURI(u.RequestURI())
  199. for _, h := range opt.headers {
  200. n := strings.SplitN(h, ":", 2)
  201. if len(n) != 2 {
  202. return nil, nil, fmt.Errorf("invalid header: %s", h)
  203. }
  204. requestHeader.Set(n[0], n[1])
  205. }
  206. return httpClient, &requestHeader, nil
  207. }
  208. func (r *Requester) Cancel() {
  209. r.cancel()
  210. }
  211. func (r *Requester) RecordChan() <-chan *ReportRecord {
  212. return r.recordChan
  213. }
  214. func (r *Requester) closeRecord() {
  215. r.closeOnce.Do(func() {
  216. close(r.recordChan)
  217. })
  218. }
  219. func (r *Requester) DoRequest(req *fasthttp.Request, resp *fasthttp.Response, rr *ReportRecord) {
  220. t1 := time.Since(startTime)
  221. var err error
  222. if r.clientOpt.doTimeout > 0 {
  223. err = r.httpClient.DoTimeout(req, resp, r.clientOpt.doTimeout)
  224. } else {
  225. err = r.httpClient.Do(req, resp)
  226. }
  227. if err != nil {
  228. rr.cost = time.Since(startTime) - t1
  229. rr.error = err.Error()
  230. return
  231. }
  232. writeTo := io.Discard
  233. if resp.StatusCode() >= 500 {
  234. writeTo = r.errWriter
  235. _, _ = r.errWriter.Write([]byte(fmt.Sprintf("\n%d %s\n", resp.StatusCode(), rr.cost)))
  236. _, _ = r.errWriter.Write([]byte(fmt.Sprintf("%s", &resp.Header)))
  237. }
  238. err = resp.BodyWriteTo(writeTo)
  239. if err != nil {
  240. rr.cost = time.Since(startTime) - t1
  241. rr.error = err.Error()
  242. return
  243. }
  244. rr.cost = time.Since(startTime) - t1
  245. rr.code = resp.StatusCode()
  246. rr.error = ""
  247. }
  248. func (r *Requester) Run() {
  249. // handle ctrl-c
  250. sigs := make(chan os.Signal, 1)
  251. signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
  252. defer signal.Stop(sigs)
  253. ctx, cancelFunc := context.WithCancel(context.Background())
  254. r.cancel = cancelFunc
  255. go func() {
  256. <-sigs
  257. r.closeRecord()
  258. cancelFunc()
  259. }()
  260. startTime = time.Now()
  261. if r.duration > 0 {
  262. time.AfterFunc(r.duration, func() {
  263. r.closeRecord()
  264. cancelFunc()
  265. })
  266. }
  267. var limiter *rate.Limiter
  268. if r.reqRate != nil {
  269. limiter = rate.NewLimiter(*r.reqRate, 1)
  270. }
  271. semaphore := r.requests
  272. for i := 0; i < r.concurrency; i++ {
  273. r.wg.Add(1)
  274. go func() {
  275. defer func() {
  276. r.wg.Done()
  277. v := recover()
  278. if v != nil && v != sendOnCloseError {
  279. panic(v)
  280. }
  281. }()
  282. req := &fasthttp.Request{}
  283. resp := &fasthttp.Response{}
  284. r.httpHeader.CopyTo(&req.Header)
  285. if r.httpClient.IsTLS {
  286. req.URI().SetScheme("https")
  287. req.URI().SetHostBytes(req.Header.Host())
  288. }
  289. for {
  290. select {
  291. case <-ctx.Done():
  292. return
  293. default:
  294. }
  295. if limiter != nil {
  296. err := limiter.Wait(ctx)
  297. if err != nil {
  298. continue
  299. }
  300. }
  301. if r.requests > 0 && atomic.AddInt64(&semaphore, -1) < 0 {
  302. cancelFunc()
  303. return
  304. }
  305. if r.clientOpt.bodyFile != "" {
  306. file, err := os.Open(r.clientOpt.bodyFile)
  307. if err != nil {
  308. rr := recordPool.Get().(*ReportRecord)
  309. rr.cost = 0
  310. rr.error = err.Error()
  311. rr.readBytes = atomic.LoadInt64(&r.readBytes)
  312. rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
  313. r.recordChan <- rr
  314. continue
  315. }
  316. req.SetBodyStream(file, -1)
  317. } else {
  318. req.SetBodyRaw(r.clientOpt.bodyBytes)
  319. }
  320. resp.Reset()
  321. rr := recordPool.Get().(*ReportRecord)
  322. r.DoRequest(req, resp, rr)
  323. rr.readBytes = atomic.LoadInt64(&r.readBytes)
  324. rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
  325. r.recordChan <- rr
  326. }
  327. }()
  328. }
  329. r.wg.Wait()
  330. r.closeRecord()
  331. }