diff --git a/cmd/main.go b/cmd/main.go index 6f4a353..f6a6b12 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -29,6 +29,7 @@ func main() { svr := kos.Init( kos.WithName("git.nspix.com/golang/test", "0.0.1"), kos.WithServer(&subServer{}), + kos.WithDirectHttp(), ) svr.Run() } diff --git a/entry/gateway.go b/entry/gateway.go index d8ef4c9..002ae73 100644 --- a/entry/gateway.go +++ b/entry/gateway.go @@ -37,20 +37,21 @@ type ( state *State waitGroup conc.WaitGroup listeners []*listenerEntity + direct *Listener exitFlag int32 } ) func (gw *Gateway) handle(conn net.Conn) { var ( - n int - err error - successed int32 - feature = make([]byte, minFeatureLength) + n int + err error + success int32 + feature = make([]byte, minFeatureLength) ) atomic.AddInt32(&gw.state.Concurrency, 1) defer func() { - if atomic.LoadInt32(&successed) != 1 { + if atomic.LoadInt32(&success) != 1 { atomic.AddInt32(&gw.state.Concurrency, -1) atomic.AddInt64(&gw.state.Request.Discarded, 1) _ = conn.Close() @@ -70,7 +71,7 @@ func (gw *Gateway) handle(conn net.Conn) { } for _, l := range gw.listeners { if bytes.Compare(feature[:n], l.feature[:n]) == 0 { - atomic.StoreInt32(&successed, 1) + atomic.StoreInt32(&success, 1) l.listener.Receive(wrapConn(conn, gw.state, feature[:n])) return } @@ -86,11 +87,16 @@ func (gw *Gateway) accept() { if conn, err := gw.l.Accept(); err != nil { break } else { - select { - case gw.ch <- conn: - atomic.AddInt64(&gw.state.Request.Total, 1) - case <-gw.ctx.Done(): - return + //give direct listener + if gw.direct != nil { + gw.direct.Receive(conn) + } else { + select { + case gw.ch <- conn: + atomic.AddInt64(&gw.state.Request.Total, 1) + case <-gw.ctx.Done(): + return + } } } } @@ -113,6 +119,12 @@ func (gw *Gateway) worker() { } } +func (gw *Gateway) Direct(l net.Listener) { + if ls, ok := l.(*Listener); ok { + gw.direct = ls + } +} + func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) { var ( ok bool @@ -165,7 +177,9 @@ func (gw *Gateway) Start(ctx context.Context) (err error) { if gw.l, err = net.Listen("tcp", gw.address); err != nil { return } - gw.waitGroup.Go(gw.worker) + for i := 0; i < 2; i++ { + gw.waitGroup.Go(gw.worker) + } gw.waitGroup.Go(gw.accept) return } diff --git a/entry/http/context.go b/entry/http/context.go index 72bc018..18cb232 100644 --- a/entry/http/context.go +++ b/entry/http/context.go @@ -27,18 +27,18 @@ func (ctx *Context) reset(req *http.Request, res http.ResponseWriter, ps map[str ctx.req, ctx.res, ctx.params = req, res, ps } -func (c *Context) RealIp() string { - if ip := c.Request().Header.Get("X-Forwarded-For"); ip != "" { +func (ctx *Context) RealIp() string { + if ip := ctx.Request().Header.Get("X-Forwarded-For"); ip != "" { i := strings.IndexAny(ip, ",") if i > 0 { return strings.TrimSpace(ip[:i]) } return ip } - if ip := c.Request().Header.Get("X-Real-IP"); ip != "" { + if ip := ctx.Request().Header.Get("X-Real-IP"); ip != "" { return ip } - ra, _, _ := net.SplitHostPort(c.Request().RemoteAddr) + ra, _, _ := net.SplitHostPort(ctx.Request().RemoteAddr) return ra } diff --git a/entry/http/server.go b/entry/http/server.go index e6ec8c3..bc8752e 100644 --- a/entry/http/server.go +++ b/entry/http/server.go @@ -16,10 +16,11 @@ var ( ) type Server struct { - ctx context.Context - serve *http.Server - router *router.Router - middleware []Middleware + ctx context.Context + serve *http.Server + router *router.Router + middleware []Middleware + anyRequests map[string]http.Handler } func (svr *Server) applyContext() *Context { @@ -62,6 +63,13 @@ func (svr *Server) Use(middleware ...Middleware) { svr.middleware = append(svr.middleware, middleware...) } +func (svr *Server) Any(prefix string, handle http.Handler) { + if !strings.HasSuffix(prefix, "/") { + prefix = "/" + prefix + } + svr.anyRequests[prefix] = handle +} + func (svr *Server) Handle(method string, path string, cb HandleFunc, middleware ...Middleware) { if method == "" { method = http.MethodPost @@ -149,6 +157,12 @@ func (svr *Server) handleRequest(res http.ResponseWriter, req *http.Request) { } func (svr *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + for prefix, handle := range svr.anyRequests { + if strings.HasPrefix(request.URL.Path, prefix) { + handle.ServeHTTP(writer, request) + return + } + } switch request.Method { case http.MethodOptions: svr.handleOption(writer, request) @@ -175,9 +189,10 @@ func (svr *Server) Shutdown() (err error) { func New(ctx context.Context) *Server { svr := &Server{ - ctx: ctx, - router: router.New(), - middleware: make([]Middleware, 0, 10), + ctx: ctx, + router: router.New(), + anyRequests: make(map[string]http.Handler), + middleware: make([]Middleware, 0, 10), } return svr } diff --git a/options.go b/options.go index f3ba187..8fe679f 100644 --- a/options.go +++ b/options.go @@ -12,19 +12,21 @@ import ( type ( Options struct { - Name string - Version string - Address string - Port int - EnableDebug bool //开启调试模式 - DisableHttp bool //禁用HTTP入口 - DisableCommand bool //禁用命令行入口 - DisableStateApi bool //禁用系统状态接口 - Metadata map[string]string //原数据 - Context context.Context - Signals []os.Signal - server Server - shortName string + Name string + Version string + Address string + Port int + EnableDebug bool //开启调试模式 + DisableHttp bool //禁用HTTP入口 + EnableDirectHttp bool //启用HTTP直连模式 + DisableCommand bool //禁用命令行入口 + EnableDirectCommand bool //启用命令行直连模式 + DisableStateApi bool //禁用系统状态接口 + Metadata map[string]string //原数据 + Context context.Context + Signals []os.Signal + server Server + shortName string } Option func(o *Options) @@ -67,6 +69,20 @@ func WithDebug() Option { } } +func WithDirectHttp() Option { + return func(o *Options) { + o.DisableCommand = true + o.EnableDirectHttp = true + } +} + +func WithDirectCommand() Option { + return func(o *Options) { + o.DisableHttp = true + o.EnableDirectCommand = true + } +} + func NewOptions() *Options { opts := &Options{ Name: env.Get(EnvAppName, sys.Hostname()), @@ -75,7 +91,7 @@ func NewOptions() *Options { Metadata: make(map[string]string), Signals: []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL}, } - opts.Port = int(env.Integer(EnvAppPort, 80)) + opts.Port = int(env.Integer(18080, EnvAppPort, "HTTP_PORT", "KOS_PORT")) opts.Address = env.Get(EnvAppAddress, ip.Internal()) return opts } diff --git a/service.go b/service.go index 0fabe36..2cf95a6 100644 --- a/service.go +++ b/service.go @@ -8,6 +8,7 @@ import ( "git.nspix.com/golang/kos/entry" "git.nspix.com/golang/kos/entry/cli" "git.nspix.com/golang/kos/entry/http" + _ "git.nspix.com/golang/kos/pkg/cache" "git.nspix.com/golang/kos/pkg/log" "git.nspix.com/golang/kos/util/env" "github.com/sourcegraph/conc" @@ -127,6 +128,9 @@ func (app *application) httpServe() (err error) { select { case err = <-errChan: case <-timer.C: + if app.opts.EnableDirectHttp { + app.gateway.Direct(l) + } } return } @@ -152,6 +156,9 @@ func (app *application) commandServe() (err error) { select { case err = <-errChan: case <-timer.C: + if app.opts.EnableDirectCommand { + app.gateway.Direct(l) + } } return } @@ -220,6 +227,7 @@ func (app *application) preStart() (err error) { return } } + app.plugins.Range(func(key, value any) bool { if plugin, ok := value.(Plugin); ok { if err = plugin.BeforeStart(); err != nil { diff --git a/util/env/env.go b/util/env/env.go index 303c3a5..f4e303f 100644 --- a/util/env/env.go +++ b/util/env/env.go @@ -15,7 +15,19 @@ func Get(name string, val string) string { } } -func Integer(name string, val int64) int64 { +func Getter(val string, names ...string) string { + var ( + value string + ) + for _, name := range names { + if value = strings.TrimSpace(os.Getenv(name)); value != "" { + return value + } + } + return val +} + +func Int(name string, val int64) int64 { value := Get(name, "") if n, err := strconv.ParseInt(value, 10, 64); err == nil { return n @@ -24,6 +36,15 @@ func Integer(name string, val int64) int64 { } } +func Integer(val int64, names ...string) int64 { + value := Getter("", names...) + if n, err := strconv.ParseInt(value, 10, 64); err == nil { + return n + } else { + return val + } +} + func Float(name string, val float64) float64 { value := Get(name, "") if n, err := strconv.ParseFloat(value, 64); err == nil { diff --git a/util/fetch/fetch.go b/util/fetch/fetch.go index b48bda3..32b6fde 100644 --- a/util/fetch/fetch.go +++ b/util/fetch/fetch.go @@ -192,6 +192,14 @@ func Request(ctx context.Context, urlString string, response any, cbs ...Option) return } +func Do(ctx context.Context, req *http.Request, cbs ...Option) (res *http.Response, err error) { + opts := newOptions() + for _, cb := range cbs { + cb(opts) + } + return do(ctx, req, opts) +} + func do(ctx context.Context, req *http.Request, opts *Options) (res *http.Response, err error) { if opts.Human { if req.Header.Get("User-Agent") == "" {