add gateway support direct connect

This commit is contained in:
fancl 2023-05-31 11:14:15 +08:00
parent 8d716e837d
commit dc88ceb73d
8 changed files with 121 additions and 38 deletions

View File

@ -29,6 +29,7 @@ func main() {
svr := kos.Init(
kos.WithName("git.nspix.com/golang/test", "0.0.1"),
kos.WithServer(&subServer{}),
kos.WithDirectHttp(),
)
svr.Run()
}

View File

@ -37,6 +37,7 @@ type (
state *State
waitGroup conc.WaitGroup
listeners []*listenerEntity
direct *Listener
exitFlag int32
}
)
@ -45,12 +46,12 @@ func (gw *Gateway) handle(conn net.Conn) {
var (
n int
err error
successed int32
success int32
feature = make([]byte, minFeatureLength)
)
atomic.AddInt32(&gw.state.Concurrency, 1)
defer func() {
if atomic.LoadInt32(&successed) != 1 {
if atomic.LoadInt32(&success) != 1 {
atomic.AddInt32(&gw.state.Concurrency, -1)
atomic.AddInt64(&gw.state.Request.Discarded, 1)
_ = conn.Close()
@ -70,7 +71,7 @@ func (gw *Gateway) handle(conn net.Conn) {
}
for _, l := range gw.listeners {
if bytes.Compare(feature[:n], l.feature[:n]) == 0 {
atomic.StoreInt32(&successed, 1)
atomic.StoreInt32(&success, 1)
l.listener.Receive(wrapConn(conn, gw.state, feature[:n]))
return
}
@ -85,6 +86,10 @@ func (gw *Gateway) accept() {
for {
if conn, err := gw.l.Accept(); err != nil {
break
} else {
//give direct listener
if gw.direct != nil {
gw.direct.Receive(conn)
} else {
select {
case gw.ch <- conn:
@ -94,6 +99,7 @@ func (gw *Gateway) accept() {
}
}
}
}
}
func (gw *Gateway) worker() {
@ -113,6 +119,12 @@ func (gw *Gateway) worker() {
}
}
func (gw *Gateway) Direct(l net.Listener) {
if ls, ok := l.(*Listener); ok {
gw.direct = ls
}
}
func (gw *Gateway) Bind(feature Feature, listener net.Listener) (err error) {
var (
ok bool
@ -165,7 +177,9 @@ func (gw *Gateway) Start(ctx context.Context) (err error) {
if gw.l, err = net.Listen("tcp", gw.address); err != nil {
return
}
for i := 0; i < 2; i++ {
gw.waitGroup.Go(gw.worker)
}
gw.waitGroup.Go(gw.accept)
return
}

View File

@ -27,18 +27,18 @@ func (ctx *Context) reset(req *http.Request, res http.ResponseWriter, ps map[str
ctx.req, ctx.res, ctx.params = req, res, ps
}
func (c *Context) RealIp() string {
if ip := c.Request().Header.Get("X-Forwarded-For"); ip != "" {
func (ctx *Context) RealIp() string {
if ip := ctx.Request().Header.Get("X-Forwarded-For"); ip != "" {
i := strings.IndexAny(ip, ",")
if i > 0 {
return strings.TrimSpace(ip[:i])
}
return ip
}
if ip := c.Request().Header.Get("X-Real-IP"); ip != "" {
if ip := ctx.Request().Header.Get("X-Real-IP"); ip != "" {
return ip
}
ra, _, _ := net.SplitHostPort(c.Request().RemoteAddr)
ra, _, _ := net.SplitHostPort(ctx.Request().RemoteAddr)
return ra
}

View File

@ -20,6 +20,7 @@ type Server struct {
serve *http.Server
router *router.Router
middleware []Middleware
anyRequests map[string]http.Handler
}
func (svr *Server) applyContext() *Context {
@ -62,6 +63,13 @@ func (svr *Server) Use(middleware ...Middleware) {
svr.middleware = append(svr.middleware, middleware...)
}
func (svr *Server) Any(prefix string, handle http.Handler) {
if !strings.HasSuffix(prefix, "/") {
prefix = "/" + prefix
}
svr.anyRequests[prefix] = handle
}
func (svr *Server) Handle(method string, path string, cb HandleFunc, middleware ...Middleware) {
if method == "" {
method = http.MethodPost
@ -149,6 +157,12 @@ func (svr *Server) handleRequest(res http.ResponseWriter, req *http.Request) {
}
func (svr *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
for prefix, handle := range svr.anyRequests {
if strings.HasPrefix(request.URL.Path, prefix) {
handle.ServeHTTP(writer, request)
return
}
}
switch request.Method {
case http.MethodOptions:
svr.handleOption(writer, request)
@ -177,6 +191,7 @@ func New(ctx context.Context) *Server {
svr := &Server{
ctx: ctx,
router: router.New(),
anyRequests: make(map[string]http.Handler),
middleware: make([]Middleware, 0, 10),
}
return svr

View File

@ -18,7 +18,9 @@ type (
Port int
EnableDebug bool //开启调试模式
DisableHttp bool //禁用HTTP入口
EnableDirectHttp bool //启用HTTP直连模式
DisableCommand bool //禁用命令行入口
EnableDirectCommand bool //启用命令行直连模式
DisableStateApi bool //禁用系统状态接口
Metadata map[string]string //原数据
Context context.Context
@ -67,6 +69,20 @@ func WithDebug() Option {
}
}
func WithDirectHttp() Option {
return func(o *Options) {
o.DisableCommand = true
o.EnableDirectHttp = true
}
}
func WithDirectCommand() Option {
return func(o *Options) {
o.DisableHttp = true
o.EnableDirectCommand = true
}
}
func NewOptions() *Options {
opts := &Options{
Name: env.Get(EnvAppName, sys.Hostname()),
@ -75,7 +91,7 @@ func NewOptions() *Options {
Metadata: make(map[string]string),
Signals: []os.Signal{syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL},
}
opts.Port = int(env.Integer(EnvAppPort, 80))
opts.Port = int(env.Integer(18080, EnvAppPort, "HTTP_PORT", "KOS_PORT"))
opts.Address = env.Get(EnvAppAddress, ip.Internal())
return opts
}

View File

@ -8,6 +8,7 @@ import (
"git.nspix.com/golang/kos/entry"
"git.nspix.com/golang/kos/entry/cli"
"git.nspix.com/golang/kos/entry/http"
_ "git.nspix.com/golang/kos/pkg/cache"
"git.nspix.com/golang/kos/pkg/log"
"git.nspix.com/golang/kos/util/env"
"github.com/sourcegraph/conc"
@ -127,6 +128,9 @@ func (app *application) httpServe() (err error) {
select {
case err = <-errChan:
case <-timer.C:
if app.opts.EnableDirectHttp {
app.gateway.Direct(l)
}
}
return
}
@ -152,6 +156,9 @@ func (app *application) commandServe() (err error) {
select {
case err = <-errChan:
case <-timer.C:
if app.opts.EnableDirectCommand {
app.gateway.Direct(l)
}
}
return
}
@ -220,6 +227,7 @@ func (app *application) preStart() (err error) {
return
}
}
app.plugins.Range(func(key, value any) bool {
if plugin, ok := value.(Plugin); ok {
if err = plugin.BeforeStart(); err != nil {

23
util/env/env.go vendored
View File

@ -15,7 +15,19 @@ func Get(name string, val string) string {
}
}
func Integer(name string, val int64) int64 {
func Getter(val string, names ...string) string {
var (
value string
)
for _, name := range names {
if value = strings.TrimSpace(os.Getenv(name)); value != "" {
return value
}
}
return val
}
func Int(name string, val int64) int64 {
value := Get(name, "")
if n, err := strconv.ParseInt(value, 10, 64); err == nil {
return n
@ -24,6 +36,15 @@ func Integer(name string, val int64) int64 {
}
}
func Integer(val int64, names ...string) int64 {
value := Getter("", names...)
if n, err := strconv.ParseInt(value, 10, 64); err == nil {
return n
} else {
return val
}
}
func Float(name string, val float64) float64 {
value := Get(name, "")
if n, err := strconv.ParseFloat(value, 64); err == nil {

View File

@ -192,6 +192,14 @@ func Request(ctx context.Context, urlString string, response any, cbs ...Option)
return
}
func Do(ctx context.Context, req *http.Request, cbs ...Option) (res *http.Response, err error) {
opts := newOptions()
for _, cb := range cbs {
cb(opts)
}
return do(ctx, req, opts)
}
func do(ctx context.Context, req *http.Request, opts *Options) (res *http.Response, err error) {
if opts.Human {
if req.Header.Get("User-Agent") == "" {