add gateway support direct connect

This commit is contained in:
fancl 2023-05-31 11:14:15 +08:00
parent 8d716e837d
commit dc88ceb73d
8 changed files with 121 additions and 38 deletions

View File

@ -29,6 +29,7 @@ func main() {
svr := kos.Init( svr := kos.Init(
kos.WithName("git.nspix.com/golang/test", "0.0.1"), kos.WithName("git.nspix.com/golang/test", "0.0.1"),
kos.WithServer(&subServer{}), kos.WithServer(&subServer{}),
kos.WithDirectHttp(),
) )
svr.Run() svr.Run()
} }

View File

@ -37,20 +37,21 @@ type (
state *State state *State
waitGroup conc.WaitGroup waitGroup conc.WaitGroup
listeners []*listenerEntity listeners []*listenerEntity
direct *Listener
exitFlag int32 exitFlag int32
} }
) )
func (gw *Gateway) handle(conn net.Conn) { func (gw *Gateway) handle(conn net.Conn) {
var ( var (
n int n int
err error err error
successed int32 success int32
feature = make([]byte, minFeatureLength) feature = make([]byte, minFeatureLength)
) )
atomic.AddInt32(&gw.state.Concurrency, 1) atomic.AddInt32(&gw.state.Concurrency, 1)
defer func() { defer func() {
if atomic.LoadInt32(&successed) != 1 { if atomic.LoadInt32(&success) != 1 {
atomic.AddInt32(&gw.state.Concurrency, -1) atomic.AddInt32(&gw.state.Concurrency, -1)
atomic.AddInt64(&gw.state.Request.Discarded, 1) atomic.AddInt64(&gw.state.Request.Discarded, 1)
_ = conn.Close() _ = conn.Close()
@ -70,7 +71,7 @@ func (gw *Gateway) handle(conn net.Conn) {
} }
for _, l := range gw.listeners { for _, l := range gw.listeners {
if bytes.Compare(feature[:n], l.feature[:n]) == 0 { if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
atomic.StoreInt32(&successed, 1) atomic.StoreInt32(&success, 1)
l.listener.Receive(wrapConn(conn, gw.state, feature[:n])) l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
return return
} }
@ -86,11 +87,16 @@ func (gw *Gateway) accept() {
if conn, err := gw.l.Accept(); err != nil { if conn, err := gw.l.Accept(); err != nil {
break break
} else { } else {
select { //give direct listener
case gw.ch <- conn: if gw.direct != nil {
atomic.AddInt64(&gw.state.Request.Total, 1) gw.direct.Receive(conn)
case <-gw.ctx.Done(): } else {
return select {
case gw.ch <- conn:
atomic.AddInt64(&gw.state.Request.Total, 1)
case <-gw.ctx.Done():
return
}
} }
} }
} }
@ -113,6 +119,12 @@ func (gw *Gateway) worker() {
} }
} }
func (gw *Gateway) Direct(l net.Listener) {
if ls, ok := l.(*Listener); ok {
gw.direct = ls
}
}
func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) { func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
var ( var (
ok bool ok bool
@ -165,7 +177,9 @@ func (gw *Gateway) Start(ctx context.Context) (err error) {
if gw.l, err = net.Listen("tcp", gw.address); err != nil { if gw.l, err = net.Listen("tcp", gw.address); err != nil {
return return
} }
gw.waitGroup.Go(gw.worker) for i := 0; i < 2; i++ {
gw.waitGroup.Go(gw.worker)
}
gw.waitGroup.Go(gw.accept) gw.waitGroup.Go(gw.accept)
return return
} }

View File

@ -27,18 +27,18 @@ func (ctx *Context) reset(req *http.Request, res http.ResponseWriter, ps map[str
ctx.req, ctx.res, ctx.params = req, res, ps ctx.req, ctx.res, ctx.params = req, res, ps
} }
func (c *Context) RealIp() string { func (ctx *Context) RealIp() string {
if ip := c.Request().Header.Get("X-Forwarded-For"); ip != "" { if ip := ctx.Request().Header.Get("X-Forwarded-For"); ip != "" {
i := strings.IndexAny(ip, ",") i := strings.IndexAny(ip, ",")
if i > 0 { if i > 0 {
return strings.TrimSpace(ip[:i]) return strings.TrimSpace(ip[:i])
} }
return ip return ip
} }
if ip := c.Request().Header.Get("X-Real-IP"); ip != "" { if ip := ctx.Request().Header.Get("X-Real-IP"); ip != "" {
return ip return ip
} }
ra, _, _ := net.SplitHostPort(c.Request().RemoteAddr) ra, _, _ := net.SplitHostPort(ctx.Request().RemoteAddr)
return ra return ra
} }

View File

@ -16,10 +16,11 @@ var (
) )
type Server struct { type Server struct {
ctx context.Context ctx context.Context
serve *http.Server serve *http.Server
router *router.Router router *router.Router
middleware []Middleware middleware []Middleware
anyRequests map[string]http.Handler
} }
func (svr *Server) applyContext() *Context { func (svr *Server) applyContext() *Context {
@ -62,6 +63,13 @@ func (svr *Server) Use(middleware ...Middleware) {
svr.middleware = append(svr.middleware, middleware...) svr.middleware = append(svr.middleware, middleware...)
} }
func (svr *Server) Any(prefix string, handle http.Handler) {
if !strings.HasSuffix(prefix, "/") {
prefix = "/" + prefix
}
svr.anyRequests[prefix] = handle
}
func (svr *Server) Handle(method string, path string, cb HandleFunc, middleware ...Middleware) { func (svr *Server) Handle(method string, path string, cb HandleFunc, middleware ...Middleware) {
if method == "" { if method == "" {
method = http.MethodPost method = http.MethodPost
@ -149,6 +157,12 @@ func (svr *Server) handleRequest(res http.ResponseWriter, req *http.Request) {
} }
func (svr *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (svr *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
for prefix, handle := range svr.anyRequests {
if strings.HasPrefix(request.URL.Path, prefix) {
handle.ServeHTTP(writer, request)
return
}
}
switch request.Method { switch request.Method {
case http.MethodOptions: case http.MethodOptions:
svr.handleOption(writer, request) svr.handleOption(writer, request)
@ -175,9 +189,10 @@ func (svr *Server) Shutdown() (err error) {
func New(ctx context.Context) *Server { func New(ctx context.Context) *Server {
svr := &Server{ svr := &Server{
ctx: ctx, ctx: ctx,
router: router.New(), router: router.New(),
middleware: make([]Middleware, 0, 10), anyRequests: make(map[string]http.Handler),
middleware: make([]Middleware, 0, 10),
} }
return svr return svr
} }

View File

@ -12,19 +12,21 @@ import (
type ( type (
Options struct { Options struct {
Name string Name string
Version string Version string
Address string Address string
Port int Port int
EnableDebug bool //开启调试模式 EnableDebug bool //开启调试模式
DisableHttp bool //禁用HTTP入口 DisableHttp bool //禁用HTTP入口
DisableCommand bool //禁用命令行入口 EnableDirectHttp bool //启用HTTP直连模式
DisableStateApi bool //禁用系统状态接口 DisableCommand bool //禁用命令行入口
Metadata map[string]string //原数据 EnableDirectCommand bool //启用命令行直连模式
Context context.Context DisableStateApi bool //禁用系统状态接口
Signals []os.Signal Metadata map[string]string //原数据
server Server Context context.Context
shortName string Signals []os.Signal
server Server
shortName string
} }
Option func(o *Options) Option func(o *Options)
@ -67,6 +69,20 @@ func WithDebug() Option {
} }
} }
func WithDirectHttp() Option {
return func(o *Options) {
o.DisableCommand = true
o.EnableDirectHttp = true
}
}
func WithDirectCommand() Option {
return func(o *Options) {
o.DisableHttp = true
o.EnableDirectCommand = true
}
}
func NewOptions() *Options { func NewOptions() *Options {
opts := &Options{ opts := &Options{
Name: env.Get(EnvAppName, sys.Hostname()), Name: env.Get(EnvAppName, sys.Hostname()),
@ -75,7 +91,7 @@ func NewOptions() *Options {
Metadata: make(map[string]string), Metadata: make(map[string]string),
Signals: []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL}, Signals: []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL},
} }
opts.Port = int(env.Integer(EnvAppPort, 80)) opts.Port = int(env.Integer(18080, EnvAppPort, "HTTP_PORT", "KOS_PORT"))
opts.Address = env.Get(EnvAppAddress, ip.Internal()) opts.Address = env.Get(EnvAppAddress, ip.Internal())
return opts return opts
} }

View File

@ -8,6 +8,7 @@ import (
"git.nspix.com/golang/kos/entry" "git.nspix.com/golang/kos/entry"
"git.nspix.com/golang/kos/entry/cli" "git.nspix.com/golang/kos/entry/cli"
"git.nspix.com/golang/kos/entry/http" "git.nspix.com/golang/kos/entry/http"
_ "git.nspix.com/golang/kos/pkg/cache"
"git.nspix.com/golang/kos/pkg/log" "git.nspix.com/golang/kos/pkg/log"
"git.nspix.com/golang/kos/util/env" "git.nspix.com/golang/kos/util/env"
"github.com/sourcegraph/conc" "github.com/sourcegraph/conc"
@ -127,6 +128,9 @@ func (app *application) httpServe() (err error) {
select { select {
case err = <-errChan: case err = <-errChan:
case <-timer.C: case <-timer.C:
if app.opts.EnableDirectHttp {
app.gateway.Direct(l)
}
} }
return return
} }
@ -152,6 +156,9 @@ func (app *application) commandServe() (err error) {
select { select {
case err = <-errChan: case err = <-errChan:
case <-timer.C: case <-timer.C:
if app.opts.EnableDirectCommand {
app.gateway.Direct(l)
}
} }
return return
} }
@ -220,6 +227,7 @@ func (app *application) preStart() (err error) {
return return
} }
} }
app.plugins.Range(func(key, value any) bool { app.plugins.Range(func(key, value any) bool {
if plugin, ok := value.(Plugin); ok { if plugin, ok := value.(Plugin); ok {
if err = plugin.BeforeStart(); err != nil { if err = plugin.BeforeStart(); err != nil {

23
util/env/env.go vendored
View File

@ -15,7 +15,19 @@ func Get(name string, val string) string {
} }
} }
func Integer(name string, val int64) int64 { func Getter(val string, names ...string) string {
var (
value string
)
for _, name := range names {
if value = strings.TrimSpace(os.Getenv(name)); value != "" {
return value
}
}
return val
}
func Int(name string, val int64) int64 {
value := Get(name, "") value := Get(name, "")
if n, err := strconv.ParseInt(value, 10, 64); err == nil { if n, err := strconv.ParseInt(value, 10, 64); err == nil {
return n return n
@ -24,6 +36,15 @@ func Integer(name string, val int64) int64 {
} }
} }
func Integer(val int64, names ...string) int64 {
value := Getter("", names...)
if n, err := strconv.ParseInt(value, 10, 64); err == nil {
return n
} else {
return val
}
}
func Float(name string, val float64) float64 { func Float(name string, val float64) float64 {
value := Get(name, "") value := Get(name, "")
if n, err := strconv.ParseFloat(value, 64); err == nil { if n, err := strconv.ParseFloat(value, 64); err == nil {

View File

@ -192,6 +192,14 @@ func Request(ctx context.Context, urlString string, response any, cbs ...Option)
return return
} }
func Do(ctx context.Context, req *http.Request, cbs ...Option) (res *http.Response, err error) {
opts := newOptions()
for _, cb := range cbs {
cb(opts)
}
return do(ctx, req, opts)
}
func do(ctx context.Context, req *http.Request, opts *Options) (res *http.Response, err error) { func do(ctx context.Context, req *http.Request, opts *Options) (res *http.Response, err error) {
if opts.Human { if opts.Human {
if req.Header.Get("User-Agent") == "" { if req.Header.Get("User-Agent") == "" {