kos/entry/gateway.go

192 lines
3.6 KiB
Go
Raw Normal View History

2023-04-23 17:57:36 +08:00
package entry
import (
"bytes"
"context"
"errors"
"github.com/sourcegraph/conc"
"io"
"net"
"sync/atomic"
"time"
)
const (
minFeatureLength = 3
)
var (
ErrShortFeature = errors.New("short feature")
ErrInvalidListener = errors.New("invalid listener")
)
type (
Feature []byte
listenerEntity struct {
feature Feature
listener *Listener
}
Gateway struct {
ctx context.Context
cancelFunc context.CancelCauseFunc
l net.Listener
ch chan net.Conn
address string
state *State
waitGroup conc.WaitGroup
listeners []*listenerEntity
exitFlag int32
}
)
func (gw *Gateway) handle(conn net.Conn) {
var (
n int
err error
successed int32
feature = make([]byte, minFeatureLength)
)
atomic.AddInt32(&gw.state.Concurrency, 1)
defer func() {
if atomic.LoadInt32(&successed) != 1 {
atomic.AddInt32(&gw.state.Concurrency, -1)
atomic.AddInt64(&gw.state.Request.Discarded, 1)
_ = conn.Close()
}
}()
//set deadline
if err = conn.SetReadDeadline(time.Now().Add(time.Second * 30)); err != nil {
return
}
//read feature
if n, err = io.ReadFull(conn, feature); err != nil {
return
}
//reset deadline
if err = conn.SetReadDeadline(time.Time{}); err != nil {
return
}
for _, l := range gw.listeners {
if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
atomic.StoreInt32(&successed, 1)
l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
return
}
}
}
func (gw *Gateway) accept() {
atomic.StoreInt32(&gw.state.Accepting, 1)
defer func() {
atomic.StoreInt32(&gw.state.Accepting, 0)
}()
for {
if conn, err := gw.l.Accept(); err != nil {
break
} else {
select {
case gw.ch <- conn:
atomic.AddInt64(&gw.state.Request.Total, 1)
case <-gw.ctx.Done():
return
}
}
}
}
func (gw *Gateway) worker() {
atomic.StoreInt32(&gw.state.Processing, 1)
defer func() {
atomic.StoreInt32(&gw.state.Processing, 0)
}()
for {
select {
case <-gw.ctx.Done():
return
case conn, ok := <-gw.ch:
if ok {
gw.handle(conn)
}
}
}
}
func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
var (
ok bool
ls *Listener
)
if len(feature) < minFeatureLength {
return ErrShortFeature
}
if ls, ok = listener.(*Listener); !ok {
return ErrInvalidListener
}
for _, l := range gw.listeners {
if bytes.Compare(l.feature, feature) == 0 {
l.listener = ls
return
}
}
gw.listeners = append(gw.listeners, &listenerEntity{
feature: feature,
listener: ls,
})
return
}
func (gw *Gateway) Apply(feature ...Feature) (listener net.Listener, err error) {
listener = newListener(gw.l.Addr())
for _, code := range feature {
if len(code) < minFeatureLength {
continue
}
err = gw.Bind(code, listener)
}
return listener, nil
}
func (gw *Gateway) Release(feature Feature) {
for i, l := range gw.listeners {
if bytes.Compare(l.feature, feature) == 0 {
gw.listeners = append(gw.listeners[:i], gw.listeners[i+1:]...)
}
}
}
func (gw *Gateway) State() *State {
return gw.state
}
func (gw *Gateway) Start(ctx context.Context) (err error) {
gw.ctx, gw.cancelFunc = context.WithCancelCause(ctx)
if gw.l, err = net.Listen("tcp", gw.address); err != nil {
return
}
gw.waitGroup.Go(gw.worker)
gw.waitGroup.Go(gw.accept)
return
}
func (gw *Gateway) Stop() (err error) {
if !atomic.CompareAndSwapInt32(&gw.exitFlag, 0, 1) {
return
}
gw.cancelFunc(io.ErrClosedPipe)
err = gw.l.Close()
gw.waitGroup.Wait()
close(gw.ch)
return
}
func New(address string) *Gateway {
gw := &Gateway{
address: address,
state: &State{},
ch: make(chan net.Conn, 10),
}
return gw
}