123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- package main
- import (
- "context"
- "crypto/tls"
- "fmt"
- "github.com/valyala/fasthttp"
- "github.com/valyala/fasthttp/fasthttpproxy"
- "go.uber.org/automaxprocs/maxprocs"
- "golang.org/x/time/rate"
- "io/ioutil"
- "net"
- url2 "net/url"
- "os"
- "os/signal"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "syscall"
- "time"
- )
- var (
- startTime = time.Now()
- sendOnCloseError interface{}
- )
- type ReportRecord struct {
- cost time.Duration
- code int
- error string
- readBytes int64
- writeBytes int64
- }
- var recordPool = sync.Pool{
- New: func() interface{} { return new(ReportRecord) },
- }
- func init() {
- // Honoring env GOMAXPROCS
- _, _ = maxprocs.Set()
- defer func() {
- sendOnCloseError = recover()
- }()
- func() {
- cc := make(chan struct{}, 1)
- close(cc)
- cc <- struct{}{}
- }()
- }
- type MyConn struct {
- net.Conn
- r, w *int64
- }
- func NewMyConn(conn net.Conn, r, w *int64) (*MyConn, error) {
- myConn := &MyConn{Conn: conn, r: r, w: w}
- return myConn, nil
- }
- func (c *MyConn) Read(b []byte) (n int, err error) {
- sz, err := c.Conn.Read(b)
- if err == nil {
- atomic.AddInt64(c.r, int64(sz))
- }
- return sz, err
- }
- func (c *MyConn) Write(b []byte) (n int, err error) {
- sz, err := c.Conn.Write(b)
- if err == nil {
- atomic.AddInt64(c.w, int64(sz))
- }
- return sz, err
- }
- func ThroughputInterceptorDial(dial fasthttp.DialFunc, r *int64, w *int64) fasthttp.DialFunc {
- return func(addr string) (net.Conn, error) {
- conn, err := dial(addr)
- if err != nil {
- return nil, err
- }
- return NewMyConn(conn, r, w)
- }
- }
- type Requester struct {
- concurrency int
- reqRate *rate.Limit
- requests int64
- duration time.Duration
- clientOpt *ClientOpt
- httpClient *fasthttp.HostClient
- httpHeader *fasthttp.RequestHeader
- recordChan chan *ReportRecord
- closeOnce sync.Once
- wg sync.WaitGroup
- readBytes int64
- writeBytes int64
- cancel func()
- }
- type ClientOpt struct {
- url string
- method string
- headers []string
- bodyBytes []byte
- bodyFile string
- certPath string
- keyPath string
- insecure bool
- maxConns int
- doTimeout time.Duration
- readTimeout time.Duration
- writeTimeout time.Duration
- dialTimeout time.Duration
- socks5Proxy string
- contentType string
- host string
- }
- func NewRequester(concurrency int, requests int64, duration time.Duration, reqRate *rate.Limit, clientOpt *ClientOpt) (*Requester, error) {
- maxResult := concurrency * 100
- if maxResult > 8192 {
- maxResult = 8192
- }
- r := &Requester{
- concurrency: concurrency,
- reqRate: reqRate,
- requests: requests,
- duration: duration,
- clientOpt: clientOpt,
- recordChan: make(chan *ReportRecord, maxResult),
- }
- client, header, err := buildRequestClient(clientOpt, &r.readBytes, &r.writeBytes)
- if err != nil {
- return nil, err
- }
- r.httpClient = client
- r.httpHeader = header
- return r, nil
- }
- func addMissingPort(addr string, isTLS bool) string {
- n := strings.Index(addr, ":")
- if n >= 0 {
- return addr
- }
- port := 80
- if isTLS {
- port = 443
- }
- return net.JoinHostPort(addr, strconv.Itoa(port))
- }
- func buildTLSConfig(opt *ClientOpt) (*tls.Config, error) {
- var certs []tls.Certificate
- if opt.certPath != "" && opt.keyPath != "" {
- c, err := tls.LoadX509KeyPair(opt.certPath, opt.keyPath)
- if err != nil {
- return nil, err
- }
- certs = append(certs, c)
- }
- return &tls.Config{
- InsecureSkipVerify: opt.insecure,
- Certificates: certs,
- }, nil
- }
- func buildRequestClient(opt *ClientOpt, r *int64, w *int64) (*fasthttp.HostClient, *fasthttp.RequestHeader, error) {
- u, err := url2.Parse(opt.url)
- if err != nil {
- return nil, nil, err
- }
- httpClient := &fasthttp.HostClient{
- Addr: addMissingPort(u.Host, u.Scheme == "https"),
- IsTLS: u.Scheme == "https",
- Name: "plow",
- MaxConns: opt.maxConns,
- ReadTimeout: opt.readTimeout,
- WriteTimeout: opt.writeTimeout,
- DisableHeaderNamesNormalizing: true,
- }
- if opt.socks5Proxy != "" {
- if !strings.Contains(opt.socks5Proxy, "://") {
- opt.socks5Proxy = "socks5://" + opt.socks5Proxy
- }
- httpClient.Dial = fasthttpproxy.FasthttpSocksDialer(opt.socks5Proxy)
- } else {
- httpClient.Dial = fasthttpproxy.FasthttpProxyHTTPDialerTimeout(opt.dialTimeout)
- }
- httpClient.Dial = ThroughputInterceptorDial(httpClient.Dial, r, w)
- tlsConfig, err := buildTLSConfig(opt)
- if err != nil {
- return nil, nil, err
- }
- httpClient.TLSConfig = tlsConfig
- var requestHeader fasthttp.RequestHeader
- if opt.contentType != "" {
- requestHeader.SetContentType(opt.contentType)
- }
- if opt.host != "" {
- requestHeader.SetHost(opt.host)
- } else {
- requestHeader.SetHost(u.Host)
- }
- requestHeader.SetMethod(opt.method)
- requestHeader.SetRequestURI(u.RequestURI())
- for _, h := range opt.headers {
- n := strings.SplitN(h, ":", 2)
- if len(n) != 2 {
- return nil, nil, fmt.Errorf("invalid header: %s", h)
- }
- requestHeader.Set(n[0], n[1])
- }
- return httpClient, &requestHeader, nil
- }
- func (r *Requester) Cancel() {
- r.cancel()
- }
- func (r *Requester) RecordChan() <-chan *ReportRecord {
- return r.recordChan
- }
- func (r *Requester) closeRecord() {
- r.closeOnce.Do(func() {
- close(r.recordChan)
- })
- }
- func (r *Requester) DoRequest(req *fasthttp.Request, resp *fasthttp.Response, rr *ReportRecord) {
- t1 := time.Since(startTime)
- var err error
- if r.clientOpt.doTimeout > 0 {
- err = r.httpClient.DoTimeout(req, resp, r.clientOpt.doTimeout)
- } else {
- err = r.httpClient.Do(req, resp)
- }
- if err != nil {
- rr.cost = time.Since(startTime) - t1
- rr.error = err.Error()
- return
- }
- err = resp.BodyWriteTo(ioutil.Discard)
- if err != nil {
- rr.cost = time.Since(startTime) - t1
- rr.error = err.Error()
- return
- }
- rr.cost = time.Since(startTime) - t1
- rr.code = resp.StatusCode()
- rr.error = ""
- }
- func (r *Requester) Run() {
- // handle ctrl-c
- sigs := make(chan os.Signal, 1)
- signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
- defer signal.Stop(sigs)
- ctx, cancelFunc := context.WithCancel(context.Background())
- r.cancel = cancelFunc
- go func() {
- <-sigs
- r.closeRecord()
- cancelFunc()
- }()
- startTime = time.Now()
- if r.duration > 0 {
- time.AfterFunc(r.duration, func() {
- r.closeRecord()
- cancelFunc()
- })
- }
- var limiter *rate.Limiter
- if r.reqRate != nil {
- limiter = rate.NewLimiter(*r.reqRate, 1)
- }
- semaphore := r.requests
- for i := 0; i < r.concurrency; i++ {
- r.wg.Add(1)
- go func() {
- defer func() {
- r.wg.Done()
- v := recover()
- if v != nil && v != sendOnCloseError {
- panic(v)
- }
- }()
- req := &fasthttp.Request{}
- resp := &fasthttp.Response{}
- r.httpHeader.CopyTo(&req.Header)
- if r.httpClient.IsTLS {
- req.URI().SetScheme("https")
- req.URI().SetHostBytes(req.Header.Host())
- }
- for {
- select {
- case <-ctx.Done():
- return
- default:
- }
- if limiter != nil {
- err := limiter.Wait(ctx)
- if err != nil {
- continue
- }
- }
- if r.requests > 0 && atomic.AddInt64(&semaphore, -1) < 0 {
- cancelFunc()
- return
- }
- if r.clientOpt.bodyFile != "" {
- file, err := os.Open(r.clientOpt.bodyFile)
- if err != nil {
- rr := recordPool.Get().(*ReportRecord)
- rr.cost = 0
- rr.error = err.Error()
- rr.readBytes = atomic.LoadInt64(&r.readBytes)
- rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
- r.recordChan <- rr
- continue
- }
- req.SetBodyStream(file, -1)
- } else {
- req.SetBodyRaw(r.clientOpt.bodyBytes)
- }
- resp.Reset()
- rr := recordPool.Get().(*ReportRecord)
- r.DoRequest(req, resp, rr)
- rr.readBytes = atomic.LoadInt64(&r.readBytes)
- rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
- r.recordChan <- rr
- }
- }()
- }
- r.wg.Wait()
- r.closeRecord()
- }
|