1
0

requester.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. package main
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "fmt"
  6. "github.com/valyala/fasthttp"
  7. "github.com/valyala/fasthttp/fasthttpproxy"
  8. "go.uber.org/automaxprocs/maxprocs"
  9. "golang.org/x/time/rate"
  10. "io/ioutil"
  11. "net"
  12. url2 "net/url"
  13. "os"
  14. "os/signal"
  15. "strconv"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. "syscall"
  20. "time"
  21. )
  22. var (
  23. startTime = time.Now()
  24. sendOnCloseError interface{}
  25. )
  26. type ReportRecord struct {
  27. cost time.Duration
  28. code int
  29. error string
  30. readBytes int64
  31. writeBytes int64
  32. }
  33. var recordPool = sync.Pool{
  34. New: func() interface{} { return new(ReportRecord) },
  35. }
  36. func init() {
  37. // Honoring env GOMAXPROCS
  38. _, _ = maxprocs.Set()
  39. defer func() {
  40. sendOnCloseError = recover()
  41. }()
  42. func() {
  43. cc := make(chan struct{}, 1)
  44. close(cc)
  45. cc <- struct{}{}
  46. }()
  47. }
  48. type MyConn struct {
  49. net.Conn
  50. r, w *int64
  51. }
  52. func NewMyConn(conn net.Conn, r, w *int64) (*MyConn, error) {
  53. myConn := &MyConn{Conn: conn, r: r, w: w}
  54. return myConn, nil
  55. }
  56. func (c *MyConn) Read(b []byte) (n int, err error) {
  57. sz, err := c.Conn.Read(b)
  58. if err == nil {
  59. atomic.AddInt64(c.r, int64(sz))
  60. }
  61. return sz, err
  62. }
  63. func (c *MyConn) Write(b []byte) (n int, err error) {
  64. sz, err := c.Conn.Write(b)
  65. if err == nil {
  66. atomic.AddInt64(c.w, int64(sz))
  67. }
  68. return sz, err
  69. }
  70. func ThroughputInterceptorDial(dial fasthttp.DialFunc, r *int64, w *int64) fasthttp.DialFunc {
  71. return func(addr string) (net.Conn, error) {
  72. conn, err := dial(addr)
  73. if err != nil {
  74. return nil, err
  75. }
  76. return NewMyConn(conn, r, w)
  77. }
  78. }
  79. type Requester struct {
  80. concurrency int
  81. reqRate *rate.Limit
  82. requests int64
  83. duration time.Duration
  84. clientOpt *ClientOpt
  85. httpClient *fasthttp.HostClient
  86. httpHeader *fasthttp.RequestHeader
  87. recordChan chan *ReportRecord
  88. closeOnce sync.Once
  89. wg sync.WaitGroup
  90. readBytes int64
  91. writeBytes int64
  92. cancel func()
  93. }
  94. type ClientOpt struct {
  95. url string
  96. method string
  97. headers []string
  98. bodyBytes []byte
  99. bodyFile string
  100. certPath string
  101. keyPath string
  102. insecure bool
  103. maxConns int
  104. doTimeout time.Duration
  105. readTimeout time.Duration
  106. writeTimeout time.Duration
  107. dialTimeout time.Duration
  108. socks5Proxy string
  109. contentType string
  110. host string
  111. }
  112. func NewRequester(concurrency int, requests int64, duration time.Duration, reqRate *rate.Limit, clientOpt *ClientOpt) (*Requester, error) {
  113. maxResult := concurrency * 100
  114. if maxResult > 8192 {
  115. maxResult = 8192
  116. }
  117. r := &Requester{
  118. concurrency: concurrency,
  119. reqRate: reqRate,
  120. requests: requests,
  121. duration: duration,
  122. clientOpt: clientOpt,
  123. recordChan: make(chan *ReportRecord, maxResult),
  124. }
  125. client, header, err := buildRequestClient(clientOpt, &r.readBytes, &r.writeBytes)
  126. if err != nil {
  127. return nil, err
  128. }
  129. r.httpClient = client
  130. r.httpHeader = header
  131. return r, nil
  132. }
  133. func addMissingPort(addr string, isTLS bool) string {
  134. n := strings.Index(addr, ":")
  135. if n >= 0 {
  136. return addr
  137. }
  138. port := 80
  139. if isTLS {
  140. port = 443
  141. }
  142. return net.JoinHostPort(addr, strconv.Itoa(port))
  143. }
  144. func buildTLSConfig(opt *ClientOpt) (*tls.Config, error) {
  145. var certs []tls.Certificate
  146. if opt.certPath != "" && opt.keyPath != "" {
  147. c, err := tls.LoadX509KeyPair(opt.certPath, opt.keyPath)
  148. if err != nil {
  149. return nil, err
  150. }
  151. certs = append(certs, c)
  152. }
  153. return &tls.Config{
  154. InsecureSkipVerify: opt.insecure,
  155. Certificates: certs,
  156. }, nil
  157. }
  158. func buildRequestClient(opt *ClientOpt, r *int64, w *int64) (*fasthttp.HostClient, *fasthttp.RequestHeader, error) {
  159. u, err := url2.Parse(opt.url)
  160. if err != nil {
  161. return nil, nil, err
  162. }
  163. httpClient := &fasthttp.HostClient{
  164. Addr: addMissingPort(u.Host, u.Scheme == "https"),
  165. IsTLS: u.Scheme == "https",
  166. Name: "plow",
  167. MaxConns: opt.maxConns,
  168. ReadTimeout: opt.readTimeout,
  169. WriteTimeout: opt.writeTimeout,
  170. DisableHeaderNamesNormalizing: true,
  171. }
  172. if opt.socks5Proxy != "" {
  173. if !strings.Contains(opt.socks5Proxy, "://") {
  174. opt.socks5Proxy = "socks5://" + opt.socks5Proxy
  175. }
  176. httpClient.Dial = fasthttpproxy.FasthttpSocksDialer(opt.socks5Proxy)
  177. } else {
  178. httpClient.Dial = fasthttpproxy.FasthttpProxyHTTPDialerTimeout(opt.dialTimeout)
  179. }
  180. httpClient.Dial = ThroughputInterceptorDial(httpClient.Dial, r, w)
  181. tlsConfig, err := buildTLSConfig(opt)
  182. if err != nil {
  183. return nil, nil, err
  184. }
  185. httpClient.TLSConfig = tlsConfig
  186. var requestHeader fasthttp.RequestHeader
  187. if opt.contentType != "" {
  188. requestHeader.SetContentType(opt.contentType)
  189. }
  190. if opt.host != "" {
  191. requestHeader.SetHost(opt.host)
  192. } else {
  193. requestHeader.SetHost(u.Host)
  194. }
  195. requestHeader.SetMethod(opt.method)
  196. requestHeader.SetRequestURI(u.RequestURI())
  197. for _, h := range opt.headers {
  198. n := strings.SplitN(h, ":", 2)
  199. if len(n) != 2 {
  200. return nil, nil, fmt.Errorf("invalid header: %s", h)
  201. }
  202. requestHeader.Set(n[0], n[1])
  203. }
  204. return httpClient, &requestHeader, nil
  205. }
  206. func (r *Requester) Cancel() {
  207. r.cancel()
  208. }
  209. func (r *Requester) RecordChan() <-chan *ReportRecord {
  210. return r.recordChan
  211. }
  212. func (r *Requester) closeRecord() {
  213. r.closeOnce.Do(func() {
  214. close(r.recordChan)
  215. })
  216. }
  217. func (r *Requester) DoRequest(req *fasthttp.Request, resp *fasthttp.Response, rr *ReportRecord) {
  218. t1 := time.Since(startTime)
  219. var err error
  220. if r.clientOpt.doTimeout > 0 {
  221. err = r.httpClient.DoTimeout(req, resp, r.clientOpt.doTimeout)
  222. } else {
  223. err = r.httpClient.Do(req, resp)
  224. }
  225. if err != nil {
  226. rr.cost = time.Since(startTime) - t1
  227. rr.error = err.Error()
  228. return
  229. }
  230. err = resp.BodyWriteTo(ioutil.Discard)
  231. if err != nil {
  232. rr.cost = time.Since(startTime) - t1
  233. rr.error = err.Error()
  234. return
  235. }
  236. rr.cost = time.Since(startTime) - t1
  237. rr.code = resp.StatusCode()
  238. rr.error = ""
  239. }
  240. func (r *Requester) Run() {
  241. // handle ctrl-c
  242. sigs := make(chan os.Signal, 1)
  243. signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
  244. defer signal.Stop(sigs)
  245. ctx, cancelFunc := context.WithCancel(context.Background())
  246. r.cancel = cancelFunc
  247. go func() {
  248. <-sigs
  249. r.closeRecord()
  250. cancelFunc()
  251. }()
  252. startTime = time.Now()
  253. if r.duration > 0 {
  254. time.AfterFunc(r.duration, func() {
  255. r.closeRecord()
  256. cancelFunc()
  257. })
  258. }
  259. var limiter *rate.Limiter
  260. if r.reqRate != nil {
  261. limiter = rate.NewLimiter(*r.reqRate, 1)
  262. }
  263. semaphore := r.requests
  264. for i := 0; i < r.concurrency; i++ {
  265. r.wg.Add(1)
  266. go func() {
  267. defer func() {
  268. r.wg.Done()
  269. v := recover()
  270. if v != nil && v != sendOnCloseError {
  271. panic(v)
  272. }
  273. }()
  274. req := &fasthttp.Request{}
  275. resp := &fasthttp.Response{}
  276. r.httpHeader.CopyTo(&req.Header)
  277. if r.httpClient.IsTLS {
  278. req.URI().SetScheme("https")
  279. req.URI().SetHostBytes(req.Header.Host())
  280. }
  281. for {
  282. select {
  283. case <-ctx.Done():
  284. return
  285. default:
  286. }
  287. if limiter != nil {
  288. err := limiter.Wait(ctx)
  289. if err != nil {
  290. continue
  291. }
  292. }
  293. if r.requests > 0 && atomic.AddInt64(&semaphore, -1) < 0 {
  294. cancelFunc()
  295. return
  296. }
  297. if r.clientOpt.bodyFile != "" {
  298. file, err := os.Open(r.clientOpt.bodyFile)
  299. if err != nil {
  300. rr := recordPool.Get().(*ReportRecord)
  301. rr.cost = 0
  302. rr.error = err.Error()
  303. rr.readBytes = atomic.LoadInt64(&r.readBytes)
  304. rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
  305. r.recordChan <- rr
  306. continue
  307. }
  308. req.SetBodyStream(file, -1)
  309. } else {
  310. req.SetBodyRaw(r.clientOpt.bodyBytes)
  311. }
  312. resp.Reset()
  313. rr := recordPool.Get().(*ReportRecord)
  314. r.DoRequest(req, resp, rr)
  315. rr.readBytes = atomic.LoadInt64(&r.readBytes)
  316. rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
  317. r.recordChan <- rr
  318. }
  319. }()
  320. }
  321. r.wg.Wait()
  322. r.closeRecord()
  323. }