package entry import ( "bytes" "context" "errors" "io" "net" "sync/atomic" "time" "github.com/sourcegraph/conc" ) const ( minFeatureLength = 3 ) var ( ErrShortFeature = errors.New("short feature") ErrInvalidListener = errors.New("invalid listener") ) type ( Feature []byte listenerEntity struct { feature Feature listener *Listener } Gateway struct { ctx context.Context cancelFunc context.CancelCauseFunc l net.Listener ch chan net.Conn address string state *State waitGroup conc.WaitGroup listeners []*listenerEntity direct *Listener exitFlag int32 } ) func (gw *Gateway) handle(conn net.Conn) { var ( n int err error success int32 feature = make([]byte, minFeatureLength) ) atomic.AddInt32(&gw.state.Concurrency, 1) defer func() { if atomic.LoadInt32(&success) != 1 { atomic.AddInt32(&gw.state.Concurrency, -1) gw.state.IncRequestDiscarded(1) _ = conn.Close() } }() //set deadline if err = conn.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil { return } //read feature if n, err = io.ReadFull(conn, feature); err != nil { return } //reset deadline if err = conn.SetReadDeadline(time.Time{}); err != nil { return } for _, l := range gw.listeners { if bytes.Compare(feature[:n], l.feature[:n]) == 0 { atomic.StoreInt32(&success, 1) l.listener.Receive(wrapConn(conn, gw.state, feature[:n])) return } } } func (gw *Gateway) accept() { atomic.StoreInt32(&gw.state.Accepting, 1) defer func() { atomic.StoreInt32(&gw.state.Accepting, 0) }() for { if conn, err := gw.l.Accept(); err != nil { break } else { //give direct listener if gw.direct != nil { gw.direct.Receive(conn) } else { select { case gw.ch <- conn: gw.state.IncRequest(1) case <-gw.ctx.Done(): return } } } } } func (gw *Gateway) worker() { atomic.StoreInt32(&gw.state.Processing, 1) defer func() { atomic.StoreInt32(&gw.state.Processing, 0) }() for { select { case <-gw.ctx.Done(): return case conn, ok := <-gw.ch: if ok { gw.handle(conn) } } } } func (gw *Gateway) Direct(l net.Listener) { if ls, ok := l.(*Listener); ok { gw.direct = ls } } func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) { var ( ok bool ls *Listener ) if len(feature) < minFeatureLength { return ErrShortFeature } if ls, ok = listener.(*Listener); !ok { return ErrInvalidListener } for _, l := range gw.listeners { if bytes.Compare(l.feature, feature) == 0 { l.listener = ls return } } gw.listeners = append(gw.listeners, &listenerEntity{ feature: feature, listener: ls, }) return } func (gw *Gateway) Apply(feature ...Feature) (listener net.Listener, err error) { listener = newListener(gw.l.Addr()) for _, code := range feature { if len(code) < minFeatureLength { continue } err = gw.Bind(code, listener) } return listener, nil } func (gw *Gateway) Release(feature Feature) { for i, l := range gw.listeners { if bytes.Compare(l.feature, feature) == 0 { gw.listeners = append(gw.listeners[:i], gw.listeners[i+1:]...) } } } func (gw *Gateway) State() *State { return gw.state } func (gw *Gateway) Start(ctx context.Context) (err error) { gw.ctx, gw.cancelFunc = context.WithCancelCause(ctx) if gw.l, err = net.Listen("tcp", gw.address); err != nil { return } for i := 0; i < 2; i++ { gw.waitGroup.Go(gw.worker) } gw.waitGroup.Go(gw.accept) return } func (gw *Gateway) Stop() (err error) { if !atomic.CompareAndSwapInt32(&gw.exitFlag, 0, 1) { return } gw.cancelFunc(io.ErrClosedPipe) err = gw.l.Close() gw.waitGroup.Wait() close(gw.ch) return } func New(address string) *Gateway { gw := &Gateway{ address: address, state: &State{}, ch: make(chan net.Conn, 10), } return gw }