package cli import ( "context" "fmt" "math" "net" "net/url" "os" "runtime" "strconv" "strings" "sync" "sync/atomic" "time" "git.nobla.cn/golang/aeus/metadata" "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/errors" "git.nobla.cn/golang/aeus/pkg/logger" netutil "git.nobla.cn/golang/aeus/pkg/net" ) type Server struct { ctx context.Context router *Router opts *options listener net.Listener sequenceLocker sync.Mutex sequence int64 ctxMap sync.Map uri *url.URL exitFlag int32 middleware []middleware.Middleware Logger logger.Logger } func (svr *Server) Use(middlewares ...middleware.Middleware) { svr.middleware = append(svr.middleware, middlewares...) } func (svr *Server) Handle(pathname string, desc string, cb HandleFunc) { svr.router.Handle(pathname, svr.wrapCommand(pathname, desc, cb)) } func (svr *Server) wrapCommand(pathname, desc string, cb HandleFunc) Command { h := func(ctx *Context) (err error) { return cb(ctx) } if desc == "" { desc = strings.Join(strings.Split(strings.TrimPrefix(pathname, "/"), "/"), " ") } return Command{ Path: pathname, Handle: h, Description: desc, } } func (s *Server) createListener() (err error) { if s.listener != nil { return } if s.listener, err = net.Listen(s.opts.network, s.opts.address); err == nil { s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener) } return } func (s *Server) applyContext() *Context { if v := ctxPool.Get(); v != nil { if ctx, ok := v.(*Context); ok { return ctx } } return &Context{} } func (s *Server) releaseContext(ctx *Context) { ctxPool.Put(ctx) } func (s *Server) execute(ctx *Context, frame *Frame) (err error) { var ( params map[string]string tokens []string args []string r *Router ) cmd := string(frame.Data) tokens = strings.Fields(cmd) if frame.Timeout > 0 { childCtx, cancelFunc := context.WithTimeout(s.ctx, time.Duration(frame.Timeout)) ctx.setContext(childCtx) defer func() { cancelFunc() }() } else { ctx.setContext(s.ctx) } if r, args, err = s.router.Lookup(tokens); err != nil { if errors.Is(err, ErrNotFound) { err = ctx.Error(errors.NotFound, fmt.Sprintf("Command %s not found", cmd)) } else { err = ctx.Error(errors.Unavailable, err.Error()) } } else { if len(r.params) > len(args) { err = ctx.Error(errors.Unavailable, r.Usage()) return } if len(r.params) > 0 { params = make(map[string]string) for i, s := range r.params { params[s] = args[i] } } ctx.setArgs(args) ctx.setParam(params) h := func(c context.Context) error { return r.command.Handle(ctx) } next := middleware.Chain(s.middleware...)(h) md := metadata.FromContext(ctx.ctx) md.Set(metadata.RequestPathKey, r.command.Path) md.Set(metadata.RequestProtocolKey, Protocol) md.TeeReader(&cliMetadataReader{ctx: ctx}) md.TeeWriter(&cliMetadataWriter{ctx: ctx}) ctx.ctx = metadata.NewContext(ctx.ctx, md) err = next(ctx.ctx) } return } func (svr *Server) nextSequence() int64 { svr.sequenceLocker.Lock() defer svr.sequenceLocker.Unlock() if svr.sequence == math.MaxInt64 { svr.sequence = 1 } svr.sequence++ return svr.sequence } func (svr *Server) process(conn net.Conn) { var ( err error ctx *Context frame *Frame ) ctx = svr.applyContext() ctx.reset(svr.nextSequence(), conn) svr.ctxMap.Store(ctx.Id, ctx) defer func() { _ = conn.Close() svr.ctxMap.Delete(ctx.Id) svr.releaseContext(ctx) }() for { if frame, err = readFrame(conn); err != nil { break } //reset frame ctx.seq = frame.Seq switch frame.Type { case PacketTypeHandshake: if err = ctx.send(responsePayload{ Type: PacketTypeHandshake, Data: &handshake{ ID: ctx.Id, Name: "", Version: "", OS: runtime.GOOS, ServerTime: time.Now(), RemoteAddr: conn.RemoteAddr().String(), }, }); err != nil { break } case PacketTypeCompleter: if err = ctx.send(responsePayload{ Type: PacketTypeCompleter, Data: svr.router.Completer(strings.Fields(string(frame.Data))...), }); err != nil { break } case PacketTypeCommand: if err = svr.execute(ctx, frame); err != nil { break } default: break } } } func (s *Server) serve() (err error) { for { conn, err := s.listener.Accept() if err != nil { if atomic.LoadInt32(&s.exitFlag) == 1 { return nil } return err } go s.process(conn) } } func (s *Server) Start(ctx context.Context) (err error) { s.ctx = ctx if s.opts.logger != nil { s.Logger = s.opts.logger } if s.Logger == nil { s.Logger = logger.Default() } if err = s.createListener(); err != nil { return } s.Logger.Infof(ctx, "cli server listen on: %s", s.uri.Host) s.Handle("/help", "Display help information", func(ctx *Context) (err error) { return ctx.Success(s.router.String()) }) err = s.serve() return } func (s *Server) Stop(ctx context.Context) (err error) { if !atomic.CompareAndSwapInt32(&s.exitFlag, 0, 1) { return } if s.listener != nil { if err = s.listener.Close(); err != nil { s.Logger.Warnf(ctx, "cli listener close error: %v", err) } } s.ctxMap.Range(func(key, value any) bool { if ctx, ok := value.(*Context); ok { err = ctx.Close() } return true }) s.Logger.Info(ctx, "cli server stopped") return } func New(cbs ...Option) *Server { srv := &Server{ opts: &options{ network: "tcp", address: ":0", }, uri: &url.URL{Scheme: "cli"}, router: newRouter(""), } port, _ := strconv.Atoi(os.Getenv("CLI_PORT")) srv.opts.address = fmt.Sprintf(":%d", port) for _, cb := range cbs { cb(srv.opts) } return srv }