233 lines
4.8 KiB
Go
233 lines
4.8 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"path"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.nspix.com/golang/kos/util/env"
|
|
"github.com/sourcegraph/conc"
|
|
)
|
|
|
|
var (
|
|
ctxPool sync.Pool
|
|
)
|
|
|
|
type Server struct {
|
|
ctx context.Context
|
|
sequenceLocker sync.Mutex
|
|
sequence int64
|
|
ctxMap sync.Map
|
|
waitGroup conc.WaitGroup
|
|
middleware []Middleware
|
|
router *Router
|
|
l net.Listener
|
|
exitFlag int32
|
|
}
|
|
|
|
func (svr *Server) applyContext() *Context {
|
|
if v := ctxPool.Get(); v != nil {
|
|
if ctx, ok := v.(*Context); ok {
|
|
return ctx
|
|
}
|
|
}
|
|
return &Context{}
|
|
}
|
|
|
|
func (svr *Server) releaseContext(ctx *Context) {
|
|
ctxPool.Put(ctx)
|
|
}
|
|
|
|
func (svr *Server) execute(ctx *Context, frame *Frame) (err error) {
|
|
var (
|
|
params map[string]string
|
|
tokens []string
|
|
args []string
|
|
r *Router
|
|
)
|
|
cmd := string(frame.Data)
|
|
tokens = strings.Fields(cmd)
|
|
if frame.Timeout > 0 {
|
|
childCtx, cancelFunc := context.WithTimeout(svr.ctx, time.Duration(frame.Timeout))
|
|
ctx.setContext(childCtx)
|
|
defer func() {
|
|
cancelFunc()
|
|
}()
|
|
} else {
|
|
ctx.setContext(svr.ctx)
|
|
}
|
|
if r, args, err = svr.router.Lookup(tokens); err != nil {
|
|
if errors.Is(err, ErrNotFound) {
|
|
err = ctx.Error(errNotFound, fmt.Sprintf("Command %s not found", cmd))
|
|
} else {
|
|
err = ctx.Error(errExecuteFailed, err.Error())
|
|
}
|
|
} else {
|
|
if len(r.params) > len(args) {
|
|
err = ctx.Error(errExecuteFailed, r.Usage())
|
|
return
|
|
}
|
|
if len(r.params) > 0 {
|
|
params = make(map[string]string)
|
|
for i, s := range r.params {
|
|
params[s] = args[i]
|
|
}
|
|
}
|
|
ctx.setArgs(args)
|
|
ctx.setParam(params)
|
|
err = r.command.Handle(ctx)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (svr *Server) nextSequence() int64 {
|
|
svr.sequenceLocker.Lock()
|
|
defer svr.sequenceLocker.Unlock()
|
|
if svr.sequence >= math.MaxInt64 {
|
|
svr.sequence = 1
|
|
}
|
|
svr.sequence++
|
|
return svr.sequence
|
|
}
|
|
|
|
func (svr *Server) process(conn net.Conn) {
|
|
var (
|
|
err error
|
|
ctx *Context
|
|
frame *Frame
|
|
)
|
|
ctx = svr.applyContext()
|
|
ctx.reset(svr.nextSequence(), conn)
|
|
svr.ctxMap.Store(ctx.Id, ctx)
|
|
defer func() {
|
|
_ = conn.Close()
|
|
svr.ctxMap.Delete(ctx.Id)
|
|
svr.releaseContext(ctx)
|
|
}()
|
|
for {
|
|
if frame, err = readFrame(conn); err != nil {
|
|
break
|
|
}
|
|
//reset frame
|
|
ctx.seq = frame.Seq
|
|
switch frame.Type {
|
|
case PacketTypeHandshake:
|
|
if err = ctx.send(responsePayload{
|
|
Type: PacketTypeHandshake,
|
|
Data: &Info{
|
|
ID: ctx.Id,
|
|
Name: env.Get("VOX_NAME", ""),
|
|
Version: env.Get("VOX_VERSION", ""),
|
|
OS: runtime.GOOS,
|
|
ServerTime: time.Now(),
|
|
RemoteAddr: conn.RemoteAddr().String(),
|
|
},
|
|
}); err != nil {
|
|
break
|
|
}
|
|
case PacketTypeCompleter:
|
|
if err = ctx.send(responsePayload{
|
|
Type: PacketTypeCompleter,
|
|
Data: svr.router.Completer(strings.Fields(string(frame.Data))...),
|
|
}); err != nil {
|
|
break
|
|
}
|
|
case PacketTypeCommand:
|
|
if err = svr.execute(ctx, frame); err != nil {
|
|
break
|
|
}
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (svr *Server) serve() {
|
|
for {
|
|
conn, err := svr.l.Accept()
|
|
if err != nil {
|
|
break
|
|
}
|
|
svr.waitGroup.Go(func() {
|
|
svr.process(conn)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (svr *Server) wrapHandle(pathname, desc string, cb HandleFunc, middleware ...Middleware) Command {
|
|
h := func(ctx *Context) (err error) {
|
|
for i := len(svr.middleware) - 1; i >= 0; i-- {
|
|
cb = svr.middleware[i](cb)
|
|
}
|
|
for i := len(middleware) - 1; i >= 0; i-- {
|
|
cb = middleware[i](cb)
|
|
}
|
|
return cb(ctx)
|
|
}
|
|
if desc == "" {
|
|
desc = strings.Join(strings.Split(strings.TrimPrefix(pathname, "/"), "/"), " ")
|
|
}
|
|
return Command{
|
|
Path: pathname,
|
|
Handle: h,
|
|
Description: desc,
|
|
}
|
|
}
|
|
|
|
func (svr *Server) Use(middleware ...Middleware) {
|
|
svr.middleware = append(svr.middleware, middleware...)
|
|
}
|
|
|
|
func (svr *Server) Group(prefix string, commands []Command, middleware ...Middleware) {
|
|
for _, cmd := range commands {
|
|
svr.Handle(path.Join(prefix, cmd.Path), cmd.Description, cmd.Handle, middleware...)
|
|
}
|
|
}
|
|
|
|
func (svr *Server) Handle(pathname string, desc string, cb HandleFunc, middleware ...Middleware) {
|
|
svr.router.Handle(pathname, svr.wrapHandle(pathname, desc, cb, middleware...))
|
|
}
|
|
|
|
func (svr *Server) Serve(l net.Listener) (err error) {
|
|
svr.l = l
|
|
svr.Handle("/help", "Display help information", func(ctx *Context) (err error) {
|
|
return ctx.Success(svr.router.String())
|
|
})
|
|
svr.serve()
|
|
atomic.StoreInt32(&svr.exitFlag, 0)
|
|
return
|
|
}
|
|
|
|
func (svr *Server) Shutdown() (err error) {
|
|
if !atomic.CompareAndSwapInt32(&svr.exitFlag, 0, 1) {
|
|
return
|
|
}
|
|
if svr.l != nil {
|
|
err = svr.l.Close()
|
|
}
|
|
svr.ctxMap.Range(func(key, value any) bool {
|
|
if ctx, ok := value.(*Context); ok {
|
|
err = ctx.Close()
|
|
}
|
|
return true
|
|
})
|
|
svr.waitGroup.Wait()
|
|
return
|
|
}
|
|
|
|
func New(ctx context.Context) *Server {
|
|
return &Server{
|
|
ctx: ctx,
|
|
router: newRouter(""),
|
|
middleware: make([]Middleware, 0, 10),
|
|
}
|
|
}
|