aeus/transport/cli/server.go

264 lines
5.6 KiB
Go

package cli
import (
"context"
"fmt"
"math"
"net"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/errors"
"git.nobla.cn/golang/aeus/pkg/logger"
netutil "git.nobla.cn/golang/aeus/pkg/net"
)
type Server struct {
ctx context.Context
router *Router
opts *options
listener net.Listener
sequenceLocker sync.Mutex
sequence int64
ctxMap sync.Map
uri *url.URL
exitFlag int32
middleware []middleware.Middleware
Logger logger.Logger
}
func (svr *Server) Use(middlewares ...middleware.Middleware) {
svr.middleware = append(svr.middleware, middlewares...)
}
func (svr *Server) Handle(pathname string, desc string, cb HandleFunc) {
svr.router.Handle(pathname, svr.wrapCommand(pathname, desc, cb))
}
func (svr *Server) wrapCommand(pathname, desc string, cb HandleFunc) Command {
h := func(ctx *Context) (err error) {
return cb(ctx)
}
if desc == "" {
desc = strings.Join(strings.Split(strings.TrimPrefix(pathname, "/"), "/"), " ")
}
return Command{
Path: pathname,
Handle: h,
Description: desc,
}
}
func (s *Server) createListener() (err error) {
if s.listener != nil {
return
}
if s.listener, err = net.Listen(s.opts.network, s.opts.address); err == nil {
s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener)
}
return
}
func (s *Server) applyContext() *Context {
if v := ctxPool.Get(); v != nil {
if ctx, ok := v.(*Context); ok {
return ctx
}
}
return &Context{}
}
func (s *Server) releaseContext(ctx *Context) {
ctxPool.Put(ctx)
}
func (s *Server) execute(ctx *Context, frame *Frame) (err error) {
var (
params map[string]string
tokens []string
args []string
r *Router
)
cmd := string(frame.Data)
tokens = strings.Fields(cmd)
if frame.Timeout > 0 {
childCtx, cancelFunc := context.WithTimeout(s.ctx, time.Duration(frame.Timeout))
ctx.setContext(childCtx)
defer func() {
cancelFunc()
}()
} else {
ctx.setContext(s.ctx)
}
if r, args, err = s.router.Lookup(tokens); err != nil {
if errors.Is(err, ErrNotFound) {
err = ctx.Error(errors.NotFound, fmt.Sprintf("Command %s not found", cmd))
} else {
err = ctx.Error(errors.Unavailable, err.Error())
}
} else {
if len(r.params) > len(args) {
err = ctx.Error(errors.Unavailable, r.Usage())
return
}
if len(r.params) > 0 {
params = make(map[string]string)
for i, s := range r.params {
params[s] = args[i]
}
}
ctx.setArgs(args)
ctx.setParam(params)
h := func(c context.Context) error {
return r.command.Handle(ctx)
}
next := middleware.Chain(s.middleware...)(h)
md := metadata.FromContext(ctx.ctx)
md.Set(metadata.RequestPathKey, r.command.Path)
md.Set(metadata.RequestProtocolKey, Protocol)
md.TeeReader(&cliMetadataReader{ctx: ctx})
md.TeeWriter(&cliMetadataWriter{ctx: ctx})
ctx.ctx = metadata.NewContext(ctx.ctx, md)
err = next(ctx.ctx)
}
return
}
func (svr *Server) nextSequence() int64 {
svr.sequenceLocker.Lock()
defer svr.sequenceLocker.Unlock()
if svr.sequence == math.MaxInt64 {
svr.sequence = 1
}
svr.sequence++
return svr.sequence
}
func (svr *Server) process(conn net.Conn) {
var (
err error
ctx *Context
frame *Frame
)
ctx = svr.applyContext()
ctx.reset(svr.nextSequence(), conn)
svr.ctxMap.Store(ctx.Id, ctx)
defer func() {
_ = conn.Close()
svr.ctxMap.Delete(ctx.Id)
svr.releaseContext(ctx)
}()
for {
if frame, err = readFrame(conn); err != nil {
break
}
//reset frame
ctx.seq = frame.Seq
switch frame.Type {
case PacketTypeHandshake:
if err = ctx.send(responsePayload{
Type: PacketTypeHandshake,
Data: &handshake{
ID: ctx.Id,
Name: "",
Version: "",
OS: runtime.GOOS,
ServerTime: time.Now(),
RemoteAddr: conn.RemoteAddr().String(),
},
}); err != nil {
break
}
case PacketTypeCompleter:
if err = ctx.send(responsePayload{
Type: PacketTypeCompleter,
Data: svr.router.Completer(strings.Fields(string(frame.Data))...),
}); err != nil {
break
}
case PacketTypeCommand:
if err = svr.execute(ctx, frame); err != nil {
break
}
default:
break
}
}
}
func (s *Server) serve() (err error) {
for {
conn, err := s.listener.Accept()
if err != nil {
if atomic.LoadInt32(&s.exitFlag) == 1 {
return nil
}
return err
}
go s.process(conn)
}
}
func (s *Server) Start(ctx context.Context) (err error) {
s.ctx = ctx
if s.opts.logger != nil {
s.Logger = s.opts.logger
}
if s.Logger == nil {
s.Logger = logger.Default()
}
if err = s.createListener(); err != nil {
return
}
s.Logger.Infof(ctx, "cli server listen on: %s", s.uri.Host)
s.Handle("/help", "Display help information", func(ctx *Context) (err error) {
return ctx.Success(s.router.String())
})
err = s.serve()
return
}
func (s *Server) Stop(ctx context.Context) (err error) {
if !atomic.CompareAndSwapInt32(&s.exitFlag, 0, 1) {
return
}
if s.listener != nil {
if err = s.listener.Close(); err != nil {
s.Logger.Warnf(ctx, "cli listener close error: %v", err)
}
}
s.ctxMap.Range(func(key, value any) bool {
if ctx, ok := value.(*Context); ok {
err = ctx.Close()
}
return true
})
s.Logger.Info(ctx, "cli server stopped")
return
}
func New(cbs ...Option) *Server {
srv := &Server{
opts: &options{
network: "tcp",
address: ":0",
},
uri: &url.URL{Scheme: "cli"},
router: newRouter(""),
}
port, _ := strconv.Atoi(os.Getenv("CLI_PORT"))
srv.opts.address = fmt.Sprintf(":%d", port)
for _, cb := range cbs {
cb(srv.opts)
}
return srv
}