207 lines
3.8 KiB
Go
207 lines
3.8 KiB
Go
package entry
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/sourcegraph/conc"
|
|
)
|
|
|
|
const (
|
|
minFeatureLength = 3
|
|
)
|
|
|
|
var (
|
|
ErrShortFeature = errors.New("short feature")
|
|
ErrInvalidListener = errors.New("invalid listener")
|
|
)
|
|
|
|
type (
|
|
Feature []byte
|
|
|
|
listenerEntity struct {
|
|
feature Feature
|
|
listener *Listener
|
|
}
|
|
|
|
Gateway struct {
|
|
ctx context.Context
|
|
cancelFunc context.CancelCauseFunc
|
|
l net.Listener
|
|
ch chan net.Conn
|
|
address string
|
|
state *State
|
|
waitGroup conc.WaitGroup
|
|
listeners []*listenerEntity
|
|
direct *Listener
|
|
exitFlag int32
|
|
}
|
|
)
|
|
|
|
func (gw *Gateway) handle(conn net.Conn) {
|
|
var (
|
|
n int
|
|
err error
|
|
success int32
|
|
feature = make([]byte, minFeatureLength)
|
|
)
|
|
atomic.AddInt32(&gw.state.Concurrency, 1)
|
|
defer func() {
|
|
if atomic.LoadInt32(&success) != 1 {
|
|
atomic.AddInt32(&gw.state.Concurrency, -1)
|
|
gw.state.IncRequestDiscarded(1)
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
//set deadline
|
|
if err = conn.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil {
|
|
return
|
|
}
|
|
//read feature
|
|
if n, err = io.ReadFull(conn, feature); err != nil {
|
|
return
|
|
}
|
|
//reset deadline
|
|
if err = conn.SetReadDeadline(time.Time{}); err != nil {
|
|
return
|
|
}
|
|
for _, l := range gw.listeners {
|
|
if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
|
|
atomic.StoreInt32(&success, 1)
|
|
l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (gw *Gateway) accept() {
|
|
atomic.StoreInt32(&gw.state.Accepting, 1)
|
|
defer func() {
|
|
atomic.StoreInt32(&gw.state.Accepting, 0)
|
|
}()
|
|
for {
|
|
if conn, err := gw.l.Accept(); err != nil {
|
|
break
|
|
} else {
|
|
//give direct listener
|
|
if gw.direct != nil {
|
|
gw.direct.Receive(conn)
|
|
} else {
|
|
select {
|
|
case gw.ch <- conn:
|
|
gw.state.IncRequest(1)
|
|
case <-gw.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (gw *Gateway) worker() {
|
|
atomic.StoreInt32(&gw.state.Processing, 1)
|
|
defer func() {
|
|
atomic.StoreInt32(&gw.state.Processing, 0)
|
|
}()
|
|
for {
|
|
select {
|
|
case <-gw.ctx.Done():
|
|
return
|
|
case conn, ok := <-gw.ch:
|
|
if ok {
|
|
gw.handle(conn)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (gw *Gateway) Direct(l net.Listener) {
|
|
if ls, ok := l.(*Listener); ok {
|
|
gw.direct = ls
|
|
}
|
|
}
|
|
|
|
func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
|
|
var (
|
|
ok bool
|
|
ls *Listener
|
|
)
|
|
if len(feature) < minFeatureLength {
|
|
return ErrShortFeature
|
|
}
|
|
if ls, ok = listener.(*Listener); !ok {
|
|
return ErrInvalidListener
|
|
}
|
|
for _, l := range gw.listeners {
|
|
if bytes.Compare(l.feature, feature) == 0 {
|
|
l.listener = ls
|
|
return
|
|
}
|
|
}
|
|
gw.listeners = append(gw.listeners, &listenerEntity{
|
|
feature: feature,
|
|
listener: ls,
|
|
})
|
|
return
|
|
}
|
|
|
|
func (gw *Gateway) Apply(feature ...Feature) (listener net.Listener, err error) {
|
|
listener = newListener(gw.l.Addr())
|
|
for _, code := range feature {
|
|
if len(code) < minFeatureLength {
|
|
continue
|
|
}
|
|
err = gw.Bind(code, listener)
|
|
}
|
|
return listener, nil
|
|
}
|
|
|
|
func (gw *Gateway) Release(feature Feature) {
|
|
for i, l := range gw.listeners {
|
|
if bytes.Compare(l.feature, feature) == 0 {
|
|
gw.listeners = append(gw.listeners[:i], gw.listeners[i+1:]...)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (gw *Gateway) State() *State {
|
|
return gw.state
|
|
}
|
|
|
|
func (gw *Gateway) Start(ctx context.Context) (err error) {
|
|
gw.ctx, gw.cancelFunc = context.WithCancelCause(ctx)
|
|
if gw.l, err = net.Listen("tcp", gw.address); err != nil {
|
|
return
|
|
}
|
|
for i := 0; i < 2; i++ {
|
|
gw.waitGroup.Go(gw.worker)
|
|
}
|
|
gw.waitGroup.Go(gw.accept)
|
|
return
|
|
}
|
|
|
|
func (gw *Gateway) Stop() (err error) {
|
|
if !atomic.CompareAndSwapInt32(&gw.exitFlag, 0, 1) {
|
|
return
|
|
}
|
|
gw.cancelFunc(io.ErrClosedPipe)
|
|
err = gw.l.Close()
|
|
gw.waitGroup.Wait()
|
|
close(gw.ch)
|
|
return
|
|
}
|
|
|
|
func New(address string) *Gateway {
|
|
gw := &Gateway{
|
|
address: address,
|
|
state: &State{},
|
|
ch: make(chan net.Conn, 10),
|
|
}
|
|
return gw
|
|
}
|