kos/entry/cli/server.go

225 lines
4.6 KiB
Go

package cli
import (
"context"
"errors"
"fmt"
"math"
"net"
"path"
"runtime"
"strings"
"sync"
"time"
"git.nspix.com/golang/kos/util/env"
"github.com/sourcegraph/conc"
)
var (
ctxPool sync.Pool
)
type Server struct {
ctx context.Context
sequenceLocker sync.Mutex
sequence int64
ctxMap sync.Map
waitGroup conc.WaitGroup
middleware []Middleware
router *Router
l net.Listener
}
func (svr *Server) applyContext() *Context {
if v := ctxPool.Get(); v != nil {
if ctx, ok := v.(*Context); ok {
return ctx
}
}
return &Context{}
}
func (svr *Server) releaseContext(ctx *Context) {
ctxPool.Put(ctx)
}
func (svr *Server) execute(ctx *Context, frame *Frame) (err error) {
var (
params map[string]string
tokens []string
args []string
r *Router
)
cmd := string(frame.Data)
tokens = strings.Fields(cmd)
if frame.Timeout > 0 {
childCtx, cancelFunc := context.WithTimeout(svr.ctx, time.Duration(frame.Timeout))
ctx.setContext(childCtx)
defer func() {
cancelFunc()
}()
} else {
ctx.setContext(svr.ctx)
}
if r, args, err = svr.router.Lookup(tokens); err != nil {
if errors.Is(err, ErrNotFound) {
err = ctx.Error(errNotFound, fmt.Sprintf("Command %s not found", cmd))
} else {
err = ctx.Error(errExecuteFailed, err.Error())
}
} else {
if len(r.params) > len(args) {
err = ctx.Error(errExecuteFailed, r.Usage())
return
}
if len(r.params) > 0 {
params = make(map[string]string)
for i, s := range r.params {
params[s] = args[i]
}
}
ctx.setArgs(args)
ctx.setParam(params)
err = r.command.Handle(ctx)
}
return
}
func (svr *Server) nextSequence() int64 {
svr.sequenceLocker.Lock()
defer svr.sequenceLocker.Unlock()
if svr.sequence >= math.MaxInt64 {
svr.sequence = 1
}
svr.sequence++
return svr.sequence
}
func (svr *Server) process(conn net.Conn) {
var (
err error
ctx *Context
frame *Frame
)
ctx = svr.applyContext()
ctx.reset(svr.nextSequence(), conn)
svr.ctxMap.Store(ctx.Id, ctx)
defer func() {
_ = conn.Close()
svr.ctxMap.Delete(ctx.Id)
svr.releaseContext(ctx)
}()
for {
if frame, err = readFrame(conn); err != nil {
break
}
//reset frame
ctx.seq = frame.Seq
switch frame.Type {
case PacketTypeHandshake:
if err = ctx.send(responsePayload{
Type: PacketTypeHandshake,
Data: &Info{
ID: ctx.Id,
Name: env.Get("VOX_NAME", ""),
Version: env.Get("VOX_VERSION", ""),
OS: runtime.GOOS,
ServerTime: time.Now(),
RemoteAddr: conn.RemoteAddr().String(),
},
}); err != nil {
break
}
case PacketTypeCompleter:
if err = ctx.send(responsePayload{
Type: PacketTypeCompleter,
Data: svr.router.Completer(strings.Fields(string(frame.Data))...),
}); err != nil {
break
}
case PacketTypeCommand:
if err = svr.execute(ctx, frame); err != nil {
break
}
default:
break
}
}
}
func (svr *Server) serve() {
for {
conn, err := svr.l.Accept()
if err != nil {
break
}
svr.waitGroup.Go(func() {
svr.process(conn)
})
}
}
func (svr *Server) wrapHandle(pathname, desc string, cb HandleFunc, middleware ...Middleware) Command {
h := func(ctx *Context) (err error) {
for i := len(svr.middleware) - 1; i >= 0; i-- {
cb = svr.middleware[i](cb)
}
for i := len(middleware) - 1; i >= 0; i-- {
cb = middleware[i](cb)
}
return cb(ctx)
}
if desc == "" {
desc = strings.Join(strings.Split(strings.TrimPrefix(pathname, "/"), "/"), " ")
}
return Command{
Path: pathname,
Handle: h,
Description: desc,
}
}
func (svr *Server) Use(middleware ...Middleware) {
svr.middleware = append(svr.middleware, middleware...)
}
func (svr *Server) Group(prefix string, commands []Command, middleware ...Middleware) {
for _, cmd := range commands {
svr.Handle(path.Join(prefix, cmd.Path), cmd.Description, cmd.Handle, middleware...)
}
}
func (svr *Server) Handle(pathname string, desc string, cb HandleFunc, middleware ...Middleware) {
svr.router.Handle(pathname, svr.wrapHandle(pathname, desc, cb, middleware...))
}
func (svr *Server) Serve(l net.Listener) (err error) {
svr.l = l
svr.Handle("/help", "Display help information", func(ctx *Context) (err error) {
return ctx.Success(svr.router.String())
})
svr.serve()
return
}
func (svr *Server) Shutdown() (err error) {
err = svr.l.Close()
svr.ctxMap.Range(func(key, value any) bool {
if ctx, ok := value.(*Context); ok {
err = ctx.Close()
}
return true
})
svr.waitGroup.Wait()
return
}
func New(ctx context.Context) *Server {
return &Server{
ctx: ctx,
router: newRouter(""),
middleware: make([]Middleware, 0, 10),
}
}