package cli import ( "context" "errors" "fmt" "math" "net" "path" "runtime" "strings" "sync" "sync/atomic" "time" "git.nobla.cn/golang/kos/util/env" "github.com/sourcegraph/conc" ) var ( ctxPool sync.Pool ) type Server struct { ctx context.Context sequenceLocker sync.Mutex sequence int64 ctxMap sync.Map waitGroup conc.WaitGroup middleware []Middleware router *Router l net.Listener exitFlag int32 } func (svr *Server) applyContext() *Context { if v := ctxPool.Get(); v != nil { if ctx, ok := v.(*Context); ok { return ctx } } return &Context{} } func (svr *Server) releaseContext(ctx *Context) { ctxPool.Put(ctx) } func (svr *Server) execute(ctx *Context, frame *Frame) (err error) { var ( params map[string]string tokens []string args []string r *Router ) cmd := string(frame.Data) tokens = strings.Fields(cmd) if frame.Timeout > 0 { childCtx, cancelFunc := context.WithTimeout(svr.ctx, time.Duration(frame.Timeout)) ctx.setContext(childCtx) defer func() { cancelFunc() }() } else { ctx.setContext(svr.ctx) } if r, args, err = svr.router.Lookup(tokens); err != nil { if errors.Is(err, ErrNotFound) { err = ctx.Error(errNotFound, fmt.Sprintf("Command %s not found", cmd)) } else { err = ctx.Error(errExecuteFailed, err.Error()) } } else { if len(r.params) > len(args) { err = ctx.Error(errExecuteFailed, r.Usage()) return } if len(r.params) > 0 { params = make(map[string]string) for i, s := range r.params { params[s] = args[i] } } ctx.setArgs(args) ctx.setParam(params) err = r.command.Handle(ctx) } return } func (svr *Server) nextSequence() int64 { svr.sequenceLocker.Lock() defer svr.sequenceLocker.Unlock() if svr.sequence >= math.MaxInt64 { svr.sequence = 1 } svr.sequence++ return svr.sequence } func (svr *Server) process(conn net.Conn) { var ( err error ctx *Context frame *Frame ) ctx = svr.applyContext() ctx.reset(svr.nextSequence(), conn) svr.ctxMap.Store(ctx.Id, ctx) defer func() { _ = conn.Close() svr.ctxMap.Delete(ctx.Id) svr.releaseContext(ctx) }() for { if frame, err = readFrame(conn); err != nil { break } //reset frame ctx.seq = frame.Seq switch frame.Type { case PacketTypeHandshake: if err = ctx.send(responsePayload{ Type: PacketTypeHandshake, Data: &Info{ ID: ctx.Id, Name: env.Get("VOX_NAME", ""), Version: env.Get("VOX_VERSION", ""), OS: runtime.GOOS, ServerTime: time.Now(), RemoteAddr: conn.RemoteAddr().String(), }, }); err != nil { break } case PacketTypeCompleter: if err = ctx.send(responsePayload{ Type: PacketTypeCompleter, Data: svr.router.Completer(strings.Fields(string(frame.Data))...), }); err != nil { break } case PacketTypeCommand: if err = svr.execute(ctx, frame); err != nil { break } default: break } } } func (svr *Server) serve() { for { conn, err := svr.l.Accept() if err != nil { break } svr.waitGroup.Go(func() { svr.process(conn) }) } } func (svr *Server) wrapHandle(pathname, desc string, cb HandleFunc, middleware ...Middleware) Command { h := func(ctx *Context) (err error) { for i := len(svr.middleware) - 1; i >= 0; i-- { cb = svr.middleware[i](cb) } for i := len(middleware) - 1; i >= 0; i-- { cb = middleware[i](cb) } return cb(ctx) } if desc == "" { desc = strings.Join(strings.Split(strings.TrimPrefix(pathname, "/"), "/"), " ") } return Command{ Path: pathname, Handle: h, Description: desc, } } func (svr *Server) Use(middleware ...Middleware) { svr.middleware = append(svr.middleware, middleware...) } func (svr *Server) Group(prefix string, commands []Command, middleware ...Middleware) { for _, cmd := range commands { svr.Handle(path.Join(prefix, cmd.Path), cmd.Description, cmd.Handle, middleware...) } } func (svr *Server) Handle(pathname string, desc string, cb HandleFunc, middleware ...Middleware) { svr.router.Handle(pathname, svr.wrapHandle(pathname, desc, cb, middleware...)) } func (svr *Server) Serve(l net.Listener) (err error) { svr.l = l svr.Handle("/help", "Display help information", func(ctx *Context) (err error) { return ctx.Success(svr.router.String()) }) svr.serve() atomic.StoreInt32(&svr.exitFlag, 0) return } func (svr *Server) Shutdown() (err error) { if !atomic.CompareAndSwapInt32(&svr.exitFlag, 0, 1) { return } if svr.l != nil { err = svr.l.Close() } svr.ctxMap.Range(func(key, value any) bool { if ctx, ok := value.(*Context); ok { err = ctx.Close() } return true }) svr.waitGroup.Wait() return } func New(ctx context.Context) *Server { return &Server{ ctx: ctx, router: newRouter(""), middleware: make([]Middleware, 0, 10), } }