diff --git a/app.go b/app.go index c82a87b..ae6e1c6 100644 --- a/app.go +++ b/app.go @@ -109,10 +109,14 @@ func (s *Service) injectVars(v any) { continue } fieldType := refType.Field(i) - if fieldType.Type.Kind() != reflect.Ptr { + if !(fieldType.Type.Kind() != reflect.Ptr || fieldType.Type.Kind() != reflect.Interface) { continue } for _, rv := range s.refValues { + if fieldType.Type.Kind() == reflect.Interface && rv.Type().Implements(fieldType.Type) { + refValue.Field(i).Set(rv) + break + } if fieldType.Type == rv.Type() { refValue.Field(i).Set(rv) break @@ -123,7 +127,10 @@ func (s *Service) injectVars(v any) { func (s *Service) preStart(ctx context.Context) (err error) { s.Logger().Info(ctx, "starting") + s.refValues = append(s.refValues, s.opts.injectVars...) + s.refValues = append(s.refValues, reflect.ValueOf(s.Logger())) for _, ptr := range s.opts.servers { + s.injectVars(ptr) s.refValues = append(s.refValues, reflect.ValueOf(ptr)) } if s.opts.registry != nil { @@ -185,7 +192,7 @@ func (s *Service) preStart(ctx context.Context) (err error) { o.Context = ctx o.TTL = s.opts.registrarTimeout }); err != nil { - s.Logger().Warn(ctx, "service register error: %v", err) + s.Logger().Warnf(ctx, "service register error: %v", err) } } } @@ -206,14 +213,14 @@ func (s *Service) preStop() (err error) { }() for _, srv := range s.opts.servers { if err = srv.Stop(ctx); err != nil { - s.Logger().Warn(ctx, "server stop error: %v", err) + s.Logger().Warnf(ctx, "server stop error: %v", err) } } if s.opts.registry != nil { if err = s.opts.registry.Deregister(s.service, func(o *registry.DeregisterOptions) { o.Context = ctx }); err != nil { - s.Logger().Warn(ctx, "server deregister error: %v", err) + s.Logger().Warnf(ctx, "server deregister error: %v", err) } } s.Logger().Info(ctx, "stopped") diff --git a/options.go b/options.go index bbd1103..14284bf 100644 --- a/options.go +++ b/options.go @@ -3,6 +3,7 @@ package aeus import ( "context" "maps" + "reflect" "time" "git.nobla.cn/golang/aeus/pkg/logger" @@ -24,6 +25,7 @@ type options struct { registry registry.Registry serviceLoader ServiceLoader stopTimeout time.Duration + injectVars []reflect.Value } func WithName(name string) Option { @@ -77,6 +79,14 @@ func WithDebug(debug bool) Option { } } +func WithInjectVars(vars ...any) Option { + return func(o *options) { + for _, v := range vars { + o.injectVars = append(o.injectVars, reflect.ValueOf(v)) + } + } +} + func WithServiceLoader(loader ServiceLoader) Option { return func(o *options) { o.serviceLoader = loader diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 6ffeb33..3b7f8c7 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -3,6 +3,7 @@ package logger import ( "context" "log/slog" + "os" ) var ( @@ -10,32 +11,56 @@ var ( ) func init() { - log = NewLogger(slog.Default()) + log = NewLogger(slog.New( + slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }), + )) } type Logger interface { - Debug(ctx context.Context, format string, args ...any) + Debug(ctx context.Context, format string, args ...any) //Structured context as loosely typed key-value pairs. + Debugf(ctx context.Context, format string, args ...any) Info(ctx context.Context, format string, args ...any) + Infof(ctx context.Context, format string, args ...any) Warn(ctx context.Context, format string, args ...any) + Warnf(ctx context.Context, format string, args ...any) Error(ctx context.Context, format string, args ...any) + Errorf(ctx context.Context, format string, args ...any) } func Debug(ctx context.Context, format string, args ...any) { log.Debug(ctx, format, args...) } +func Debugf(ctx context.Context, format string, args ...any) { + log.Debugf(ctx, format, args...) +} + func Info(ctx context.Context, format string, args ...any) { - log.Debug(ctx, format, args...) + log.Info(ctx, format, args...) +} + +func Infof(ctx context.Context, format string, args ...any) { + log.Infof(ctx, format, args...) } func Warn(ctx context.Context, format string, args ...any) { - log.Debug(ctx, format, args...) + log.Warn(ctx, format, args...) +} + +func Warnf(ctx context.Context, format string, args ...any) { + log.Warnf(ctx, format, args...) } func Error(ctx context.Context, format string, args ...any) { log.Debug(ctx, format, args...) } +func Errorf(ctx context.Context, format string, args ...any) { + log.Errorf(ctx, format, args...) +} + func Default() Logger { return log } diff --git a/pkg/logger/slog.go b/pkg/logger/slog.go index 59b57a0..ab55420 100644 --- a/pkg/logger/slog.go +++ b/pkg/logger/slog.go @@ -11,18 +11,34 @@ type logger struct { } func (l *logger) Debug(ctx context.Context, msg string, args ...any) { + l.log.DebugContext(ctx, msg, args...) +} + +func (l *logger) Debugf(ctx context.Context, msg string, args ...any) { l.log.DebugContext(ctx, fmt.Sprintf(msg, args...)) } func (l *logger) Info(ctx context.Context, msg string, args ...any) { + l.log.InfoContext(ctx, msg, args...) +} + +func (l *logger) Infof(ctx context.Context, msg string, args ...any) { l.log.InfoContext(ctx, fmt.Sprintf(msg, args...)) } func (l *logger) Warn(ctx context.Context, msg string, args ...any) { + l.log.WarnContext(ctx, msg, args...) +} + +func (l *logger) Warnf(ctx context.Context, msg string, args ...any) { l.log.WarnContext(ctx, fmt.Sprintf(msg, args...)) } func (l *logger) Error(ctx context.Context, msg string, args ...any) { + l.log.ErrorContext(ctx, msg, args...) +} + +func (l *logger) Errorf(ctx context.Context, msg string, args ...any) { l.log.ErrorContext(ctx, fmt.Sprintf(msg, args...)) } diff --git a/transport/cli/server.go b/transport/cli/server.go index 87f734f..8dd1fa8 100644 --- a/transport/cli/server.go +++ b/transport/cli/server.go @@ -32,6 +32,7 @@ type Server struct { uri *url.URL exitFlag int32 middleware []middleware.Middleware + Logger logger.Logger } func (svr *Server) Use(middlewares ...middleware.Middleware) { @@ -134,7 +135,7 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) { func (svr *Server) nextSequence() int64 { svr.sequenceLocker.Lock() defer svr.sequenceLocker.Unlock() - if svr.sequence >= math.MaxInt64 { + if svr.sequence == math.MaxInt64 { svr.sequence = 1 } svr.sequence++ @@ -208,10 +209,16 @@ func (s *Server) serve() (err error) { func (s *Server) Start(ctx context.Context) (err error) { s.ctx = ctx + if s.opts.logger != nil { + s.Logger = s.opts.logger + } + if s.Logger == nil { + s.Logger = logger.Default() + } if err = s.createListener(); err != nil { return } - s.opts.logger.Info(ctx, "cli server listen on: %s", s.uri.Host) + s.Logger.Infof(ctx, "cli server listen on: %s", s.uri.Host) s.Handle("/help", "Display help information", func(ctx *Context) (err error) { return ctx.Success(s.router.String()) }) @@ -224,7 +231,9 @@ func (s *Server) Stop(ctx context.Context) (err error) { return } if s.listener != nil { - err = s.listener.Close() + if err = s.listener.Close(); err != nil { + s.Logger.Warnf(ctx, "cli listener close error: %v", err) + } } s.ctxMap.Range(func(key, value any) bool { if ctx, ok := value.(*Context); ok { @@ -232,6 +241,7 @@ func (s *Server) Stop(ctx context.Context) (err error) { } return true }) + s.Logger.Info(ctx, "cli server stopped") return } @@ -240,7 +250,6 @@ func New(cbs ...Option) *Server { opts: &options{ network: "tcp", address: ":0", - logger: logger.Default(), }, uri: &url.URL{Scheme: "cli"}, router: newRouter(""), diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 2468526..8775cf7 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -25,6 +25,7 @@ type Server struct { serve *grpc.Server listener net.Listener middlewares []middleware.Middleware + Logger logger.Logger } func (s *Server) createListener() (err error) { @@ -107,11 +108,16 @@ func (s *Server) Use(middlewares ...middleware.Middleware) { func (s *Server) Start(ctx context.Context) (err error) { s.ctx = ctx + if s.opts.logger != nil { + s.Logger = s.opts.logger + } + if s.Logger == nil { + s.Logger = logger.Default() + } if err = s.createListener(); err != nil { return } - s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host) - + s.Logger.Infof(ctx, "grpc server listen on: %s", s.uri.Host) reflection.Register(s.serve) s.serve.Serve(s.listener) return @@ -130,7 +136,7 @@ func (s *Server) RegisterService(sd *grpc.ServiceDesc, ss any) { func (s *Server) Stop(ctx context.Context) (err error) { s.serve.GracefulStop() - s.opts.logger.Info(s.ctx, "grpc server stopped") + s.Logger.Infof(s.ctx, "grpc server stopped") return } @@ -138,7 +144,6 @@ func New(cbs ...Option) *Server { svr := &Server{ opts: &options{ network: "tcp", - logger: logger.Default(), grpcOpts: make([]grpc.ServerOption, 0, 10), }, uri: &url.URL{ diff --git a/transport/http/server.go b/transport/http/server.go index 3f3e8c0..8d2391e 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -36,6 +36,7 @@ type Server struct { listener net.Listener fs *filesystem middlewares []middleware.Middleware + Logger logger.Logger } func (s *Server) Endpoint(ctx context.Context) (string, error) { @@ -156,7 +157,6 @@ func (s *Server) staticHandle(ctx *gin.Context, fp http.File) { } http.ServeContent(ctx.Writer, ctx.Request, path.Base(uri), s.fs.modtime, fp) ctx.Abort() - return } func (s *Server) notFoundHandle(ctx *gin.Context) { @@ -255,6 +255,12 @@ func (s *Server) Start(ctx context.Context) (err error) { Addr: s.opts.address, Handler: s.engine, } + if s.opts.logger != nil { + s.Logger = s.opts.logger + } + if s.Logger == nil { + s.Logger = logger.Default() + } s.ctx = ctx if s.opts.debug { s.engine.Handle(http.MethodGet, "/debug/pprof/", s.wrapHandle(pprof.Index)) @@ -271,7 +277,7 @@ func (s *Server) Start(ctx context.Context) (err error) { return } s.engine.NoRoute(s.notFoundHandle) - s.opts.logger.Info(ctx, "http server listen on: %s", s.uri.Host) + s.Logger.Infof(ctx, "http server listen on: %s", s.uri.Host) if s.opts.certFile != "" && s.opts.keyFile != "" { s.uri.Scheme = "https" err = s.serve.ServeTLS(s.listener, s.opts.certFile, s.opts.keyFile) @@ -286,7 +292,7 @@ func (s *Server) Start(ctx context.Context) (err error) { func (s *Server) Stop(ctx context.Context) (err error) { err = s.serve.Shutdown(ctx) - s.opts.logger.Info(ctx, "http server stopped") + s.Logger.Infof(ctx, "http server stopped") return } @@ -295,7 +301,6 @@ func New(cbs ...Option) *Server { uri: &url.URL{Scheme: "http"}, opts: &options{ network: "tcp", - logger: logger.Default(), }, } port, _ := strconv.Atoi(os.Getenv("HTTP_PORT"))