251 lines
5.3 KiB
Go
251 lines
5.3 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"net/url"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.nobla.cn/golang/aeus/metadata"
|
|
"git.nobla.cn/golang/aeus/middleware"
|
|
"git.nobla.cn/golang/aeus/pkg/errors"
|
|
"git.nobla.cn/golang/aeus/pkg/logger"
|
|
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
|
)
|
|
|
|
type Server struct {
|
|
ctx context.Context
|
|
router *Router
|
|
opts *options
|
|
listener net.Listener
|
|
sequenceLocker sync.Mutex
|
|
sequence int64
|
|
ctxMap sync.Map
|
|
uri *url.URL
|
|
exitFlag int32
|
|
middleware []middleware.Middleware
|
|
}
|
|
|
|
func (svr *Server) Use(middlewares ...middleware.Middleware) {
|
|
svr.middleware = append(svr.middleware, middlewares...)
|
|
}
|
|
|
|
func (svr *Server) Handle(pathname string, desc string, cb HandleFunc) {
|
|
svr.router.Handle(pathname, svr.wrapCommand(pathname, desc, cb))
|
|
}
|
|
|
|
func (svr *Server) wrapCommand(pathname, desc string, cb HandleFunc) Command {
|
|
h := func(ctx *Context) (err error) {
|
|
return cb(ctx)
|
|
}
|
|
if desc == "" {
|
|
desc = strings.Join(strings.Split(strings.TrimPrefix(pathname, "/"), "/"), " ")
|
|
}
|
|
return Command{
|
|
Path: pathname,
|
|
Handle: h,
|
|
Description: desc,
|
|
}
|
|
}
|
|
|
|
func (s *Server) createListener() (err error) {
|
|
if s.listener != nil {
|
|
return
|
|
}
|
|
if s.listener, err = net.Listen(s.opts.network, s.opts.address); err == nil {
|
|
s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *Server) applyContext() *Context {
|
|
if v := ctxPool.Get(); v != nil {
|
|
if ctx, ok := v.(*Context); ok {
|
|
return ctx
|
|
}
|
|
}
|
|
return &Context{}
|
|
}
|
|
|
|
func (s *Server) releaseContext(ctx *Context) {
|
|
ctxPool.Put(ctx)
|
|
}
|
|
|
|
func (s *Server) execute(ctx *Context, frame *Frame) (err error) {
|
|
var (
|
|
params map[string]string
|
|
tokens []string
|
|
args []string
|
|
r *Router
|
|
)
|
|
cmd := string(frame.Data)
|
|
tokens = strings.Fields(cmd)
|
|
if frame.Timeout > 0 {
|
|
childCtx, cancelFunc := context.WithTimeout(s.ctx, time.Duration(frame.Timeout))
|
|
ctx.setContext(childCtx)
|
|
defer func() {
|
|
cancelFunc()
|
|
}()
|
|
} else {
|
|
ctx.setContext(s.ctx)
|
|
}
|
|
if r, args, err = s.router.Lookup(tokens); err != nil {
|
|
if errors.Is(err, ErrNotFound) {
|
|
err = ctx.Error(errors.NotFound, fmt.Sprintf("Command %s not found", cmd))
|
|
} else {
|
|
err = ctx.Error(errors.Unavailable, err.Error())
|
|
}
|
|
} else {
|
|
if len(r.params) > len(args) {
|
|
err = ctx.Error(errors.Unavailable, r.Usage())
|
|
return
|
|
}
|
|
if len(r.params) > 0 {
|
|
params = make(map[string]string)
|
|
for i, s := range r.params {
|
|
params[s] = args[i]
|
|
}
|
|
}
|
|
ctx.setArgs(args)
|
|
ctx.setParam(params)
|
|
h := func(c context.Context) error {
|
|
return r.command.Handle(ctx)
|
|
}
|
|
next := middleware.Chain(s.middleware...)(h)
|
|
md := metadata.FromContext(ctx.ctx)
|
|
md.Set(metadata.RequestPathKey, r.command.Path)
|
|
md.Set(metadata.RequestProtocolKey, Protocol)
|
|
md.TeeReader(&cliMetadataReader{ctx: ctx})
|
|
md.TeeWriter(&cliMetadataWriter{ctx: ctx})
|
|
ctx.ctx = metadata.NewContext(ctx.ctx, md)
|
|
err = next(ctx.ctx)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (svr *Server) nextSequence() int64 {
|
|
svr.sequenceLocker.Lock()
|
|
defer svr.sequenceLocker.Unlock()
|
|
if svr.sequence >= math.MaxInt64 {
|
|
svr.sequence = 1
|
|
}
|
|
svr.sequence++
|
|
return svr.sequence
|
|
}
|
|
|
|
func (svr *Server) process(conn net.Conn) {
|
|
var (
|
|
err error
|
|
ctx *Context
|
|
frame *Frame
|
|
)
|
|
ctx = svr.applyContext()
|
|
ctx.reset(svr.nextSequence(), conn)
|
|
svr.ctxMap.Store(ctx.Id, ctx)
|
|
defer func() {
|
|
_ = conn.Close()
|
|
svr.ctxMap.Delete(ctx.Id)
|
|
svr.releaseContext(ctx)
|
|
}()
|
|
for {
|
|
if frame, err = readFrame(conn); err != nil {
|
|
break
|
|
}
|
|
//reset frame
|
|
ctx.seq = frame.Seq
|
|
switch frame.Type {
|
|
case PacketTypeHandshake:
|
|
if err = ctx.send(responsePayload{
|
|
Type: PacketTypeHandshake,
|
|
Data: &handshake{
|
|
ID: ctx.Id,
|
|
Name: "",
|
|
Version: "",
|
|
OS: runtime.GOOS,
|
|
ServerTime: time.Now(),
|
|
RemoteAddr: conn.RemoteAddr().String(),
|
|
},
|
|
}); err != nil {
|
|
break
|
|
}
|
|
case PacketTypeCompleter:
|
|
if err = ctx.send(responsePayload{
|
|
Type: PacketTypeCompleter,
|
|
Data: svr.router.Completer(strings.Fields(string(frame.Data))...),
|
|
}); err != nil {
|
|
break
|
|
}
|
|
case PacketTypeCommand:
|
|
if err = svr.execute(ctx, frame); err != nil {
|
|
break
|
|
}
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) serve() (err error) {
|
|
for {
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
if atomic.LoadInt32(&s.exitFlag) == 1 {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
go s.process(conn)
|
|
}
|
|
}
|
|
|
|
func (s *Server) Start(ctx context.Context) (err error) {
|
|
s.ctx = ctx
|
|
if err = s.createListener(); err != nil {
|
|
return
|
|
}
|
|
s.opts.logger.Info(ctx, "cli server listen on: %s", s.uri.Host)
|
|
s.Handle("/help", "Display help information", func(ctx *Context) (err error) {
|
|
return ctx.Success(s.router.String())
|
|
})
|
|
err = s.serve()
|
|
return
|
|
}
|
|
|
|
func (s *Server) Stop(ctx context.Context) (err error) {
|
|
if !atomic.CompareAndSwapInt32(&s.exitFlag, 0, 1) {
|
|
return
|
|
}
|
|
if s.listener != nil {
|
|
err = s.listener.Close()
|
|
}
|
|
s.ctxMap.Range(func(key, value any) bool {
|
|
if ctx, ok := value.(*Context); ok {
|
|
err = ctx.Close()
|
|
}
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
func New(cbs ...Option) *Server {
|
|
srv := &Server{
|
|
opts: &options{
|
|
network: "tcp",
|
|
address: ":0",
|
|
logger: logger.Default(),
|
|
},
|
|
uri: &url.URL{Scheme: "cli"},
|
|
router: newRouter(""),
|
|
}
|
|
for _, cb := range cbs {
|
|
cb(srv.opts)
|
|
}
|
|
return srv
|
|
}
|