Compare commits

..

8 Commits
v0.0.9 ... main

Author SHA1 Message Date
Yavolte 2812e46444 adjuest logger interface 2025-07-04 14:26:07 +08:00
Yavolte fb585fabe6 use jwt claims 2025-06-30 11:25:55 +08:00
Yavolte 29d609ce0a fix cache bugs 2025-06-30 11:11:34 +08:00
Yavolte b0d6e3423d add gzip support 2025-06-24 15:51:18 +08:00
fcl 4792b21857 updat README 2025-06-23 17:53:41 +08:00
fcl 0d0477d32c add readme 2025-06-23 17:36:28 +08:00
fcl c027ef22bc fix log 2025-06-23 17:20:10 +08:00
Yavolte 908931bc01 add query 2025-06-23 17:00:36 +08:00
13 changed files with 398 additions and 45 deletions

View File

@ -4,5 +4,99 @@
# 环境变量
| 环境变量 | 描述 |
| --- | --- |
| AEUS_DEBUG | 是否开启debug模式 |
| HTTP_PORT | http服务端口 |
| GRPC_PORT | grpc服务端口 |
| CLI_PORT | cli服务端口 |
# 快速开始 # 快速开始
## 创建一个项目
创建项目可以使用`aeus`命令行工具进行生成:
```
aeus new github.com/your-username/your-project-name
```
如果需要创建一个带管理后台的应用, 可以使用`--admin`参数:
```
aeus new github.com/your-username/your-project-name --admin
```
## 生成`Proto`文件
服务使用`proto3`作为通信协议,因此需要生成`Proto`文件。
```
make proto
```
清理生成的文件使用:
```
make proto-clean
```
## 编译项目
编译项目可以使用`make`命令进行编译:
```
make build
```
# 目录结构
```
├── api
│ └── v1
├── cmd
│ ├── main.go
├── config
│ ├── config.go
│ └── config.yaml
├── deploy
│ └── docker
├── go.mod
├── go.sum
├── internal
│ ├── models
│ ├── scope
│ ├── service
├── Makefile
├── README.md
├── third_party
│ ├── aeus
│ ├── errors
│ ├── google
│ ├── openapi
│ ├── README.md
│ └── validate
├── vendor
├── version
│ └── version.go
├── web
└── webhook.yaml
```
| 目录 | 描述 |
| --- | --- |
| api | api定义目录 |
| cmd | 启动命令目录 |
| config | 配置目录 |
| deploy | 部署目录 |
| internal | 内部文件目录 |
| internal.service | 服务定义目录 |
| internal.models | 模型定义目录 |
| internal.scope | 服务scope定义目录,主要有全局的变量(比如DB,Redis等) |
| third_party | 第三方proto文件目录 |
| web | 前端资源目录 |

28
app.go
View File

@ -6,6 +6,7 @@ import (
"os/signal" "os/signal"
"reflect" "reflect"
"runtime" "runtime"
"strconv"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
@ -35,12 +36,19 @@ func (s *Service) Name() string {
} }
func (s *Service) Debug() bool { func (s *Service) Debug() bool {
if s.opts != nil {
return s.opts.debug
}
return false return false
} }
func (s *Service) Version() string { func (s *Service) Version() string {
return s.opts.version if s.opts != nil {
return s.opts.version
}
return ""
} }
func (s *Service) Metadata() map[string]string { func (s *Service) Metadata() map[string]string {
if s.service == nil { if s.service == nil {
return nil return nil
@ -101,10 +109,14 @@ func (s *Service) injectVars(v any) {
continue continue
} }
fieldType := refType.Field(i) fieldType := refType.Field(i)
if fieldType.Type.Kind() != reflect.Ptr { if !(fieldType.Type.Kind() != reflect.Ptr || fieldType.Type.Kind() != reflect.Interface) {
continue continue
} }
for _, rv := range s.refValues { for _, rv := range s.refValues {
if fieldType.Type.Kind() == reflect.Interface && rv.Type().Implements(fieldType.Type) {
refValue.Field(i).Set(rv)
break
}
if fieldType.Type == rv.Type() { if fieldType.Type == rv.Type() {
refValue.Field(i).Set(rv) refValue.Field(i).Set(rv)
break break
@ -114,8 +126,11 @@ func (s *Service) injectVars(v any) {
} }
func (s *Service) preStart(ctx context.Context) (err error) { func (s *Service) preStart(ctx context.Context) (err error) {
s.Logger().Info(s.ctx, "starting") s.Logger().Info(ctx, "starting")
s.refValues = append(s.refValues, s.opts.injectVars...)
s.refValues = append(s.refValues, reflect.ValueOf(s.Logger()))
for _, ptr := range s.opts.servers { for _, ptr := range s.opts.servers {
s.injectVars(ptr)
s.refValues = append(s.refValues, reflect.ValueOf(ptr)) s.refValues = append(s.refValues, reflect.ValueOf(ptr))
} }
if s.opts.registry != nil { if s.opts.registry != nil {
@ -177,7 +192,7 @@ func (s *Service) preStart(ctx context.Context) (err error) {
o.Context = ctx o.Context = ctx
o.TTL = s.opts.registrarTimeout o.TTL = s.opts.registrarTimeout
}); err != nil { }); err != nil {
s.Logger().Warn(ctx, "service register error: %v", err) s.Logger().Warnf(ctx, "service register error: %v", err)
} }
} }
} }
@ -198,14 +213,14 @@ func (s *Service) preStop() (err error) {
}() }()
for _, srv := range s.opts.servers { for _, srv := range s.opts.servers {
if err = srv.Stop(ctx); err != nil { if err = srv.Stop(ctx); err != nil {
s.Logger().Warn(ctx, "server stop error: %v", err) s.Logger().Warnf(ctx, "server stop error: %v", err)
} }
} }
if s.opts.registry != nil { if s.opts.registry != nil {
if err = s.opts.registry.Deregister(s.service, func(o *registry.DeregisterOptions) { if err = s.opts.registry.Deregister(s.service, func(o *registry.DeregisterOptions) {
o.Context = ctx o.Context = ctx
}); err != nil { }); err != nil {
s.Logger().Warn(ctx, "server deregister error: %v", err) s.Logger().Warnf(ctx, "server deregister error: %v", err)
} }
} }
s.Logger().Info(ctx, "stopped") s.Logger().Info(ctx, "stopped")
@ -255,6 +270,7 @@ func New(cbs ...Option) *Service {
registrarTimeout: time.Second * 30, registrarTimeout: time.Second * 30,
}, },
} }
s.opts.debug, _ = strconv.ParseBool("AEUS_DEBUG")
s.opts.metadata = make(map[string]string) s.opts.metadata = make(map[string]string)
for _, cb := range cbs { for _, cb := range cbs {
cb(s.opts) cb(s.opts)

View File

@ -47,13 +47,26 @@ func WithAllow(paths ...string) Option {
if o.allows == nil { if o.allows == nil {
o.allows = make([]string, 0, 16) o.allows = make([]string, 0, 16)
} }
o.allows = append(o.allows, paths...) for _, s := range paths {
s = strings.TrimSpace(s)
if len(s) == 0 {
continue
}
o.allows = append(o.allows, s)
}
} }
} }
func WithClaims(claims reflect.Type) Option { func WithClaims(claims any) Option {
return func(o *options) { return func(o *options) {
o.claims = claims if tv, ok := claims.(reflect.Type); ok {
o.claims = tv
} else {
o.claims = reflect.TypeOf(claims)
if o.claims.Kind() == reflect.Ptr {
o.claims = o.claims.Elem()
}
}
} }
} }
@ -65,19 +78,19 @@ func WithValidate(fn Validate) Option {
// isAllowed check if the path is allowed // isAllowed check if the path is allowed
func isAllowed(uripath string, allows []string) bool { func isAllowed(uripath string, allows []string) bool {
for _, str := range allows { for _, pattern := range allows {
n := len(str) n := len(pattern)
if n == 0 { if pattern == uripath {
continue return true
} }
if n > 1 && str[n-1] == '*' { if pattern == "*" {
if strings.HasPrefix(uripath, str[:n-1]) { return true
}
if n > 1 && pattern[n-1] == '*' {
if strings.HasPrefix(uripath, pattern[:n-1]) {
return true return true
} }
} }
if str == uripath {
return true
}
} }
return false return false
} }
@ -108,9 +121,7 @@ func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware {
return err return err
} }
} }
if strings.HasPrefix(token, bearerWord) { token, _ = strings.CutPrefix(token, bearerWord)
token = strings.TrimPrefix(token, bearerWord)
}
var ( var (
ti *jwt.Token ti *jwt.Token
) )

View File

@ -2,6 +2,8 @@ package aeus
import ( import (
"context" "context"
"maps"
"reflect"
"time" "time"
"git.nobla.cn/golang/aeus/pkg/logger" "git.nobla.cn/golang/aeus/pkg/logger"
@ -18,10 +20,12 @@ type options struct {
servers []Server servers []Server
endpoints []string endpoints []string
scope Scope scope Scope
debug bool
registrarTimeout time.Duration registrarTimeout time.Duration
registry registry.Registry registry registry.Registry
serviceLoader ServiceLoader serviceLoader ServiceLoader
stopTimeout time.Duration stopTimeout time.Duration
injectVars []reflect.Value
} }
func WithName(name string) Option { func WithName(name string) Option {
@ -41,9 +45,7 @@ func WithMetadata(metadata map[string]string) Option {
if o.metadata == nil { if o.metadata == nil {
o.metadata = make(map[string]string) o.metadata = make(map[string]string)
} }
for k, v := range metadata { maps.Copy(o.metadata, metadata)
o.metadata[k] = v
}
} }
} }
@ -71,6 +73,20 @@ func WithScope(scope Scope) Option {
} }
} }
func WithDebug(debug bool) Option {
return func(o *options) {
o.debug = debug
}
}
func WithInjectVars(vars ...any) Option {
return func(o *options) {
for _, v := range vars {
o.injectVars = append(o.injectVars, reflect.ValueOf(v))
}
}
}
func WithServiceLoader(loader ServiceLoader) Option { func WithServiceLoader(loader ServiceLoader) Option {
return func(o *options) { return func(o *options) {
o.serviceLoader = loader o.serviceLoader = loader

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"time" "time"
"git.nobla.cn/golang/aeus"
) )
type redisCache struct { type redisCache struct {
@ -57,7 +59,12 @@ func (c *redisCache) String() string {
} }
func NewCache(opts ...Option) *redisCache { func NewCache(opts ...Option) *redisCache {
return &redisCache{ cache := &redisCache{
opts: newOptions(opts...), opts: newOptions(opts...),
} }
app := aeus.FromContext(cache.opts.context)
if app != nil {
cache.opts.prefix = app.Name() + ":" + cache.opts.prefix
}
return cache
} }

View File

@ -1,13 +1,16 @@
package redis package redis
import ( import (
"context"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
type ( type (
options struct { options struct {
client *redis.Client context context.Context
prefix string client *redis.Client
prefix string
} }
Option func(*options) Option func(*options)
@ -19,6 +22,12 @@ func WithClient(client *redis.Client) Option {
} }
} }
func WithContext(ctx context.Context) Option {
return func(o *options) {
o.context = ctx
}
}
func WithPrefix(prefix string) Option { func WithPrefix(prefix string) Option {
return func(o *options) { return func(o *options) {
o.prefix = prefix o.prefix = prefix

View File

@ -3,6 +3,7 @@ package logger
import ( import (
"context" "context"
"log/slog" "log/slog"
"os"
) )
var ( var (
@ -10,32 +11,56 @@ var (
) )
func init() { func init() {
log = NewLogger(slog.Default()) log = NewLogger(slog.New(
slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
}),
))
} }
type Logger interface { type Logger interface {
Debug(ctx context.Context, format string, args ...any) Debug(ctx context.Context, format string, args ...any) //Structured context as loosely typed key-value pairs.
Debugf(ctx context.Context, format string, args ...any)
Info(ctx context.Context, format string, args ...any) Info(ctx context.Context, format string, args ...any)
Infof(ctx context.Context, format string, args ...any)
Warn(ctx context.Context, format string, args ...any) Warn(ctx context.Context, format string, args ...any)
Warnf(ctx context.Context, format string, args ...any)
Error(ctx context.Context, format string, args ...any) Error(ctx context.Context, format string, args ...any)
Errorf(ctx context.Context, format string, args ...any)
} }
func Debug(ctx context.Context, format string, args ...any) { func Debug(ctx context.Context, format string, args ...any) {
log.Debug(ctx, format, args...) log.Debug(ctx, format, args...)
} }
func Debugf(ctx context.Context, format string, args ...any) {
log.Debugf(ctx, format, args...)
}
func Info(ctx context.Context, format string, args ...any) { func Info(ctx context.Context, format string, args ...any) {
log.Debug(ctx, format, args...) log.Info(ctx, format, args...)
}
func Infof(ctx context.Context, format string, args ...any) {
log.Infof(ctx, format, args...)
} }
func Warn(ctx context.Context, format string, args ...any) { func Warn(ctx context.Context, format string, args ...any) {
log.Debug(ctx, format, args...) log.Warn(ctx, format, args...)
}
func Warnf(ctx context.Context, format string, args ...any) {
log.Warnf(ctx, format, args...)
} }
func Error(ctx context.Context, format string, args ...any) { func Error(ctx context.Context, format string, args ...any) {
log.Debug(ctx, format, args...) log.Debug(ctx, format, args...)
} }
func Errorf(ctx context.Context, format string, args ...any) {
log.Errorf(ctx, format, args...)
}
func Default() Logger { func Default() Logger {
return log return log
} }

View File

@ -11,18 +11,34 @@ type logger struct {
} }
func (l *logger) Debug(ctx context.Context, msg string, args ...any) { func (l *logger) Debug(ctx context.Context, msg string, args ...any) {
l.log.DebugContext(ctx, msg, args...)
}
func (l *logger) Debugf(ctx context.Context, msg string, args ...any) {
l.log.DebugContext(ctx, fmt.Sprintf(msg, args...)) l.log.DebugContext(ctx, fmt.Sprintf(msg, args...))
} }
func (l *logger) Info(ctx context.Context, msg string, args ...any) { func (l *logger) Info(ctx context.Context, msg string, args ...any) {
l.log.InfoContext(ctx, msg, args...)
}
func (l *logger) Infof(ctx context.Context, msg string, args ...any) {
l.log.InfoContext(ctx, fmt.Sprintf(msg, args...)) l.log.InfoContext(ctx, fmt.Sprintf(msg, args...))
} }
func (l *logger) Warn(ctx context.Context, msg string, args ...any) { func (l *logger) Warn(ctx context.Context, msg string, args ...any) {
l.log.WarnContext(ctx, msg, args...)
}
func (l *logger) Warnf(ctx context.Context, msg string, args ...any) {
l.log.WarnContext(ctx, fmt.Sprintf(msg, args...)) l.log.WarnContext(ctx, fmt.Sprintf(msg, args...))
} }
func (l *logger) Error(ctx context.Context, msg string, args ...any) { func (l *logger) Error(ctx context.Context, msg string, args ...any) {
l.log.ErrorContext(ctx, msg, args...)
}
func (l *logger) Errorf(ctx context.Context, msg string, args ...any) {
l.log.ErrorContext(ctx, fmt.Sprintf(msg, args...)) l.log.ErrorContext(ctx, fmt.Sprintf(msg, args...))
} }

View File

@ -32,6 +32,7 @@ type Server struct {
uri *url.URL uri *url.URL
exitFlag int32 exitFlag int32
middleware []middleware.Middleware middleware []middleware.Middleware
Logger logger.Logger
} }
func (svr *Server) Use(middlewares ...middleware.Middleware) { func (svr *Server) Use(middlewares ...middleware.Middleware) {
@ -134,7 +135,7 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) {
func (svr *Server) nextSequence() int64 { func (svr *Server) nextSequence() int64 {
svr.sequenceLocker.Lock() svr.sequenceLocker.Lock()
defer svr.sequenceLocker.Unlock() defer svr.sequenceLocker.Unlock()
if svr.sequence >= math.MaxInt64 { if svr.sequence == math.MaxInt64 {
svr.sequence = 1 svr.sequence = 1
} }
svr.sequence++ svr.sequence++
@ -208,10 +209,16 @@ func (s *Server) serve() (err error) {
func (s *Server) Start(ctx context.Context) (err error) { func (s *Server) Start(ctx context.Context) (err error) {
s.ctx = ctx s.ctx = ctx
if s.opts.logger != nil {
s.Logger = s.opts.logger
}
if s.Logger == nil {
s.Logger = logger.Default()
}
if err = s.createListener(); err != nil { if err = s.createListener(); err != nil {
return return
} }
s.opts.logger.Info(ctx, "cli server listen on: %s", s.uri.Host) s.Logger.Infof(ctx, "cli server listen on: %s", s.uri.Host)
s.Handle("/help", "Display help information", func(ctx *Context) (err error) { s.Handle("/help", "Display help information", func(ctx *Context) (err error) {
return ctx.Success(s.router.String()) return ctx.Success(s.router.String())
}) })
@ -224,7 +231,9 @@ func (s *Server) Stop(ctx context.Context) (err error) {
return return
} }
if s.listener != nil { if s.listener != nil {
err = s.listener.Close() if err = s.listener.Close(); err != nil {
s.Logger.Warnf(ctx, "cli listener close error: %v", err)
}
} }
s.ctxMap.Range(func(key, value any) bool { s.ctxMap.Range(func(key, value any) bool {
if ctx, ok := value.(*Context); ok { if ctx, ok := value.(*Context); ok {
@ -232,6 +241,7 @@ func (s *Server) Stop(ctx context.Context) (err error) {
} }
return true return true
}) })
s.Logger.Info(ctx, "cli server stopped")
return return
} }
@ -240,7 +250,6 @@ func New(cbs ...Option) *Server {
opts: &options{ opts: &options{
network: "tcp", network: "tcp",
address: ":0", address: ":0",
logger: logger.Default(),
}, },
uri: &url.URL{Scheme: "cli"}, uri: &url.URL{Scheme: "cli"},
router: newRouter(""), router: newRouter(""),

View File

@ -25,6 +25,7 @@ type Server struct {
serve *grpc.Server serve *grpc.Server
listener net.Listener listener net.Listener
middlewares []middleware.Middleware middlewares []middleware.Middleware
Logger logger.Logger
} }
func (s *Server) createListener() (err error) { func (s *Server) createListener() (err error) {
@ -107,11 +108,16 @@ func (s *Server) Use(middlewares ...middleware.Middleware) {
func (s *Server) Start(ctx context.Context) (err error) { func (s *Server) Start(ctx context.Context) (err error) {
s.ctx = ctx s.ctx = ctx
if s.opts.logger != nil {
s.Logger = s.opts.logger
}
if s.Logger == nil {
s.Logger = logger.Default()
}
if err = s.createListener(); err != nil { if err = s.createListener(); err != nil {
return return
} }
s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host) s.Logger.Infof(ctx, "grpc server listen on: %s", s.uri.Host)
reflection.Register(s.serve) reflection.Register(s.serve)
s.serve.Serve(s.listener) s.serve.Serve(s.listener)
return return
@ -130,7 +136,7 @@ func (s *Server) RegisterService(sd *grpc.ServiceDesc, ss any) {
func (s *Server) Stop(ctx context.Context) (err error) { func (s *Server) Stop(ctx context.Context) (err error) {
s.serve.GracefulStop() s.serve.GracefulStop()
s.opts.logger.Info(s.ctx, "grpc server stopped") s.Logger.Infof(s.ctx, "grpc server stopped")
return return
} }
@ -138,7 +144,6 @@ func New(cbs ...Option) *Server {
svr := &Server{ svr := &Server{
opts: &options{ opts: &options{
network: "tcp", network: "tcp",
logger: logger.Default(),
grpcOpts: make([]grpc.ServerOption, 0, 10), grpcOpts: make([]grpc.ServerOption, 0, 10),
}, },
uri: &url.URL{ uri: &url.URL{

View File

@ -22,6 +22,10 @@ func (c *Context) Context() context.Context {
return c.ctx.Request.Context() return c.ctx.Request.Context()
} }
func (c *Context) Gin() *gin.Context {
return c.ctx
}
func (c *Context) Request() *http.Request { func (c *Context) Request() *http.Request {
return c.ctx.Request return c.ctx.Request
} }
@ -38,6 +42,14 @@ func (c *Context) Param(key string) string {
return c.ctx.Param(key) return c.ctx.Param(key)
} }
func (c *Context) Query(key string) string {
qs := c.ctx.Request.URL.Query()
if qs != nil {
return qs.Get(key)
}
return ""
}
func (c *Context) Bind(val any) (err error) { func (c *Context) Bind(val any) (err error) {
// if params exists, try bind params first // if params exists, try bind params first
if len(c.ctx.Params) > 0 { if len(c.ctx.Params) > 0 {

View File

@ -3,13 +3,17 @@ package http
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
"net/url" "net/url"
"os" "os"
"path" "path"
"path/filepath"
"slices"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -32,6 +36,7 @@ type Server struct {
listener net.Listener listener net.Listener
fs *filesystem fs *filesystem
middlewares []middleware.Middleware middlewares []middleware.Middleware
Logger logger.Logger
} }
func (s *Server) Endpoint(ctx context.Context) (string, error) { func (s *Server) Endpoint(ctx context.Context) (string, error) {
@ -106,14 +111,60 @@ func (s *Server) Webroot(prefix string, fs http.FileSystem) {
s.fs.SetIndexFile("/index.html") s.fs.SetIndexFile("/index.html")
} }
func (s *Server) shouldCompress(req *http.Request) bool {
if !strings.Contains(req.Header.Get(headerAcceptEncoding), "gzip") ||
strings.Contains(req.Header.Get("Connection"), "Upgrade") {
return false
}
// Check if the request path is excluded from compression
extension := filepath.Ext(req.URL.Path)
if slices.Contains(assetsExtensions, extension) {
return true
}
return false
}
func (s *Server) staticHandle(ctx *gin.Context, fp http.File) {
uri := path.Clean(ctx.Request.URL.Path)
fi, err := fp.Stat()
if err != nil {
return
}
if !fi.IsDir() {
//https://github.com/gin-contrib/gzip
if s.shouldCompress(ctx.Request) && fi.Size() > 8192 {
gzWriter := newGzipWriter()
gzWriter.Reset(ctx.Writer)
ctx.Header(headerContentEncoding, "gzip")
ctx.Writer.Header().Add(headerVary, headerAcceptEncoding)
originalEtag := ctx.GetHeader("ETag")
if originalEtag != "" && !strings.HasPrefix(originalEtag, "W/") {
ctx.Header("ETag", "W/"+originalEtag)
}
ctx.Writer = &gzipWriter{ctx.Writer, gzWriter}
defer func() {
if ctx.Writer.Size() < 0 {
gzWriter.Reset(io.Discard)
}
gzWriter.Close()
if ctx.Writer.Size() > -1 {
ctx.Header("Content-Length", strconv.Itoa(ctx.Writer.Size()))
}
putGzipWriter(gzWriter)
}()
}
}
http.ServeContent(ctx.Writer, ctx.Request, path.Base(uri), s.fs.modtime, fp)
ctx.Abort()
}
func (s *Server) notFoundHandle(ctx *gin.Context) { func (s *Server) notFoundHandle(ctx *gin.Context) {
if s.fs != nil && ctx.Request.Method == http.MethodGet { if s.fs != nil && ctx.Request.Method == http.MethodGet {
uri := path.Clean(ctx.Request.URL.Path) uri := path.Clean(ctx.Request.URL.Path)
if fp, err := s.fs.Open(uri); err == nil { if fp, err := s.fs.Open(uri); err == nil {
http.ServeContent(ctx.Writer, ctx.Request, path.Base(uri), s.fs.modtime, fp) s.staticHandle(ctx, fp)
fp.Close() fp.Close()
ctx.Abort()
return
} }
} }
ctx.JSON(http.StatusNotFound, newResponse(errors.NotFound, "Not Found", nil)) ctx.JSON(http.StatusNotFound, newResponse(errors.NotFound, "Not Found", nil))
@ -204,6 +255,12 @@ func (s *Server) Start(ctx context.Context) (err error) {
Addr: s.opts.address, Addr: s.opts.address,
Handler: s.engine, Handler: s.engine,
} }
if s.opts.logger != nil {
s.Logger = s.opts.logger
}
if s.Logger == nil {
s.Logger = logger.Default()
}
s.ctx = ctx s.ctx = ctx
if s.opts.debug { if s.opts.debug {
s.engine.Handle(http.MethodGet, "/debug/pprof/", s.wrapHandle(pprof.Index)) s.engine.Handle(http.MethodGet, "/debug/pprof/", s.wrapHandle(pprof.Index))
@ -220,7 +277,7 @@ func (s *Server) Start(ctx context.Context) (err error) {
return return
} }
s.engine.NoRoute(s.notFoundHandle) s.engine.NoRoute(s.notFoundHandle)
s.opts.logger.Info(ctx, "http server listen on: %s", s.uri.Host) s.Logger.Infof(ctx, "http server listen on: %s", s.uri.Host)
if s.opts.certFile != "" && s.opts.keyFile != "" { if s.opts.certFile != "" && s.opts.keyFile != "" {
s.uri.Scheme = "https" s.uri.Scheme = "https"
err = s.serve.ServeTLS(s.listener, s.opts.certFile, s.opts.keyFile) err = s.serve.ServeTLS(s.listener, s.opts.certFile, s.opts.keyFile)
@ -235,7 +292,7 @@ func (s *Server) Start(ctx context.Context) (err error) {
func (s *Server) Stop(ctx context.Context) (err error) { func (s *Server) Stop(ctx context.Context) (err error) {
err = s.serve.Shutdown(ctx) err = s.serve.Shutdown(ctx)
s.opts.logger.Info(ctx, "http server stopped") s.Logger.Infof(ctx, "http server stopped")
return return
} }
@ -244,7 +301,6 @@ func New(cbs ...Option) *Server {
uri: &url.URL{Scheme: "http"}, uri: &url.URL{Scheme: "http"},
opts: &options{ opts: &options{
network: "tcp", network: "tcp",
logger: logger.Default(),
}, },
} }
port, _ := strconv.Atoi(os.Getenv("HTTP_PORT")) port, _ := strconv.Atoi(os.Getenv("HTTP_PORT"))

View File

@ -1,8 +1,14 @@
package http package http
import ( import (
"bufio"
"compress/gzip"
"context" "context"
"errors"
"io"
"net"
"net/http" "net/http"
"sync"
"git.nobla.cn/golang/aeus/pkg/logger" "git.nobla.cn/golang/aeus/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -41,6 +47,77 @@ type (
} }
) )
const (
headerAcceptEncoding = "Accept-Encoding"
headerContentEncoding = "Content-Encoding"
headerVary = "Vary"
)
var (
gzPool sync.Pool
assetsExtensions = []string{".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".ttf", ".eot", ".otf"}
)
type gzipWriter struct {
gin.ResponseWriter
writer *gzip.Writer
}
func (g *gzipWriter) WriteString(s string) (int, error) {
g.Header().Del("Content-Length")
return g.writer.Write([]byte(s))
}
func (g *gzipWriter) Write(data []byte) (int, error) {
g.Header().Del("Content-Length")
return g.writer.Write(data)
}
func (g *gzipWriter) Flush() {
_ = g.writer.Flush()
g.ResponseWriter.Flush()
}
// Fix: https://github.com/mholt/caddy/issues/38
func (g *gzipWriter) WriteHeader(code int) {
g.Header().Del("Content-Length")
g.ResponseWriter.WriteHeader(code)
}
var _ http.Hijacker = (*gzipWriter)(nil)
// Hijack allows the caller to take over the connection from the HTTP server.
// After a call to Hijack, the HTTP server library will not do anything else with the connection.
// It becomes the caller's responsibility to manage and close the connection.
//
// It returns the underlying net.Conn, a buffered reader/writer for the connection, and an error
// if the ResponseWriter does not support the Hijacker interface.
func (g *gzipWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := g.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("the ResponseWriter doesn't support the Hijacker interface")
}
return hijacker.Hijack()
}
func newGzipWriter() (writer *gzip.Writer) {
v := gzPool.Get()
if v == nil {
writer, _ = gzip.NewWriterLevel(io.Discard, gzip.DefaultCompression)
} else {
if w, ok := v.(*gzip.Writer); ok {
return w
} else {
writer, _ = gzip.NewWriterLevel(io.Discard, gzip.DefaultCompression)
}
}
return
}
func putGzipWriter(writer *gzip.Writer) {
gzPool.Put(writer)
}
func WithNetwork(network string) Option { func WithNetwork(network string) Option {
return func(o *options) { return func(o *options) {
o.network = network o.network = network