package udptransfer import ( "encoding/binary" "fmt" "io" "log" "math/rand" "net" "sort" "sync" "sync/atomic" "time" "github.com/cloudflare/golibs/bytepool" ) const ( _SO_BUF_SIZE = 8 << 20 ) var ( bpool bytepool.BytePool ) type Params struct { LocalAddr string Bandwidth int64 Mtu int IsServ bool FastRetransmit bool FlatTraffic bool EnablePprof bool Stacktrace bool Debug int } type connID struct { lid uint32 rid uint32 } type Endpoint struct { udpconn *net.UDPConn state int32 idSeq uint32 isServ bool listenChan chan *Conn lRegistry map[uint32]*Conn rRegistry map[string][]uint32 mlock sync.RWMutex timeout *time.Timer params Params } func (c *connID) setRid(b []byte) { c.rid = binary.BigEndian.Uint32(b[_MAGIC_SIZE+6:]) } func init() { bpool.Init(0, 2000) rand.Seed(NowNS()) } func NewEndpoint(p *Params) (*Endpoint, error) { set_debug_params(p) if p.Bandwidth <= 0 || p.Bandwidth > 100 { return nil, fmt.Errorf("bw->(0,100]") } conn, err := net.ListenPacket("udp", p.LocalAddr) if err != nil { return nil, err } e := &Endpoint{ udpconn: conn.(*net.UDPConn), idSeq: 1, isServ: p.IsServ, listenChan: make(chan *Conn, 1), lRegistry: make(map[uint32]*Conn), rRegistry: make(map[string][]uint32), timeout: time.NewTimer(0), params: *p, } if e.isServ { e.state = _S_EST0 } else { // client e.state = _S_EST1 e.idSeq = uint32(rand.Int31()) } e.params.Bandwidth = p.Bandwidth << 20 // mbps to bps e.udpconn.SetReadBuffer(_SO_BUF_SIZE) go e.internal_listen() return e, nil } func (e *Endpoint) internal_listen() { const rtmo = time.Duration(30*time.Second) var id connID for { //var buf = make([]byte, 1600) var buf = bpool.Get(1600) e.udpconn.SetReadDeadline(time.Now().Add(rtmo)) n, addr, err := e.udpconn.ReadFromUDP(buf) if err == nil && n >= _AH_SIZE { buf = buf[:n] e.getConnID(&id, buf) switch id.lid { case 0: // new connection if e.isServ { go e.acceptNewConn(id, addr, buf) } else { dumpb("drop", buf) } case _INVALID_SEQ: dumpb("drop invalid", buf) default: // old connection e.mlock.RLock() conn := e.lRegistry[id.lid] e.mlock.RUnlock() if conn != nil { e.dispatch(conn, buf) } else { e.resetPeer(addr, id) dumpb("drop null", buf) } } } else if err != nil { // idle process if nerr, y := err.(net.Error); y && nerr.Timeout() { e.idleProcess() continue } // other errors if atomic.LoadInt32(&e.state) == _S_FIN { return } else { log.Println("Error: read sock", err) } } } } func (e *Endpoint) idleProcess() { // recycle/shrink memory bpool.Drain() e.mlock.Lock() defer e.mlock.Unlock() // reset urgent for _, c := range e.lRegistry { c.outlock.Lock() if c.outQ.size() == 0 && c.urgent != 0 { c.urgent = 0 } c.outlock.Unlock() } } func (e *Endpoint) Dial(addr string) (*Conn, error) { rAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } e.mlock.Lock() e.idSeq++ id := connID{e.idSeq, 0} conn := NewConn(e, rAddr, id) e.lRegistry[id.lid] = conn e.mlock.Unlock() if atomic.LoadInt32(&e.state) != _S_FIN { err = conn.initConnection(nil) return conn, err } return nil, io.EOF } func (e *Endpoint) acceptNewConn(id connID, addr *net.UDPAddr, buf []byte) { rKey := addr.String() e.mlock.Lock() // map: remoteAddr => remoteConnID // filter duplicated syn packets if newArr := insertRid(e.rRegistry[rKey], id.rid); newArr != nil { e.rRegistry[rKey] = newArr } else { e.mlock.Unlock() log.Println("Warn: duplicated connection", addr) return } e.idSeq++ id.lid = e.idSeq conn := NewConn(e, addr, id) e.lRegistry[id.lid] = conn e.mlock.Unlock() err := conn.initConnection(buf) if err == nil { select { case e.listenChan <- conn: case <-time.After(_10ms): log.Println("Warn: no listener") } } else { e.removeConn(id, addr) log.Println("Error: init_connection", addr, err) } } func (e *Endpoint) removeConn(id connID, addr *net.UDPAddr) { e.mlock.Lock() delete(e.lRegistry, id.lid) rKey := addr.String() if newArr := deleteRid(e.rRegistry[rKey], id.rid); newArr != nil { if len(newArr) > 0 { e.rRegistry[rKey] = newArr } else { delete(e.rRegistry, rKey) } } e.mlock.Unlock() } // net.Listener func (e *Endpoint) Close() error { state := atomic.LoadInt32(&e.state) if state > 0 && atomic.CompareAndSwapInt32(&e.state, state, _S_FIN) { err := e.udpconn.Close() e.lRegistry = nil e.rRegistry = nil select { // release listeners case e.listenChan <- nil: default: } return err } return nil } // net.Listener func (e *Endpoint) Addr() net.Addr { return e.udpconn.LocalAddr() } // net.Listener func (e *Endpoint) Accept() (net.Conn, error) { if atomic.LoadInt32(&e.state) == _S_EST0 { return <-e.listenChan, nil } else { return nil, io.EOF } } func (e *Endpoint) Listen() *Conn { if atomic.LoadInt32(&e.state) == _S_EST0 { return <-e.listenChan } else { return nil } } // tmo in MS func (e *Endpoint) ListenTimeout(tmo int64) *Conn { if tmo <= 0 { return e.Listen() } if atomic.LoadInt32(&e.state) == _S_EST0 { select { case c := <-e.listenChan: return c case <-NewTimerChan(tmo): } } return nil } func (e *Endpoint) getConnID(idPtr *connID, buf []byte) { // TODO determine magic header magicAndLen := binary.BigEndian.Uint64(buf) if int(magicAndLen&0xFFff) == len(buf) { id := binary.BigEndian.Uint64(buf[_MAGIC_SIZE+2:]) idPtr.lid = uint32(id >> 32) idPtr.rid = uint32(id) } else { idPtr.lid = _INVALID_SEQ } } func (e *Endpoint) dispatch(c *Conn, buf []byte) { e.timeout.Reset(30*time.Millisecond) select { case c.evRecv <- buf: case <-e.timeout.C: log.Println("Warn: dispatch packet failed") } } func (e *Endpoint) resetPeer(addr *net.UDPAddr, id connID) { pk := &packet{flag: _F_FIN | _F_RESET} buf := nodeOf(pk).marshall(id) e.udpconn.WriteToUDP(buf, addr) } type u32Slice []uint32 func (p u32Slice) Len() int { return len(p) } func (p u32Slice) Less(i, j int) bool { return p[i] < p[j] } func (p u32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // if the rid is not existed in array then insert it return new array func insertRid(array []uint32, rid uint32) []uint32 { if len(array) > 0 { pos := sort.Search(len(array), func(n int) bool { return array[n] >= rid }) if pos < len(array) && array[pos] == rid { return nil } } array = append(array, rid) sort.Sort(u32Slice(array)) return array } // if rid was existed in array then delete it return new array func deleteRid(array []uint32, rid uint32) []uint32 { if len(array) > 0 { pos := sort.Search(len(array), func(n int) bool { return array[n] >= rid }) if pos < len(array) && array[pos] == rid { newArray := make([]uint32, len(array)-1) n := copy(newArray, array[:pos]) copy(newArray[n:], array[pos+1:]) return newArray } } return nil }