add metadata middleware

This commit is contained in:
Yavolte 2025-06-06 22:26:41 +08:00
parent 3abcc1ceed
commit 5403f14ccd
8 changed files with 234 additions and 31 deletions

4
app.go
View File

@ -144,8 +144,10 @@ func (s *Service) preStart(ctx context.Context) (err error) {
} }
if s.opts.registry != nil { if s.opts.registry != nil {
s.errGroup.Go(func() error { s.errGroup.Go(func() error {
childCtx, cancel := context.WithTimeout(ctx, s.opts.registrarTimeout)
defer cancel()
opts := func(o *registry.RegisterOptions) { opts := func(o *registry.RegisterOptions) {
o.Context = ctx o.Context = childCtx
} }
if err = s.opts.registry.Register(s.service, opts); err != nil { if err = s.opts.registry.Register(s.service, opts); err != nil {
return err return err

View File

@ -0,0 +1,115 @@
package metadata
import (
"context"
"maps"
"strings"
)
type metadataKey struct{}
// Metadata is our way of representing request headers internally.
// They're used at the RPC level and translate back and forth
// from Transport headers.
type Metadata map[string]string
func canonicalMetadataKey(key string) string {
return strings.ToLower(key)
}
func (md Metadata) Has(key string) bool {
_, ok := md[canonicalMetadataKey(key)]
return ok
}
func (md Metadata) Get(key string) (string, bool) {
val, ok := md[canonicalMetadataKey(key)]
return val, ok
}
func (md Metadata) Set(key, val string) {
md[canonicalMetadataKey(key)] = val
}
func (md Metadata) Delete(key string) {
delete(md, canonicalMetadataKey(key))
}
// Copy makes a copy of the metadata.
func Copy(md Metadata) Metadata {
cmd := make(Metadata, len(md))
maps.Copy(cmd, md)
return cmd
}
// Delete key from metadata.
func Delete(ctx context.Context, k string) context.Context {
return Set(ctx, k, "")
}
// Set add key with val to metadata.
func Set(ctx context.Context, k, v string) context.Context {
md, ok := FromContext(ctx)
if !ok {
md = make(Metadata)
}
k = canonicalMetadataKey(k)
if v == "" {
delete(md, k)
} else {
md[k] = v
}
return context.WithValue(ctx, metadataKey{}, md)
}
// Get returns a single value from metadata in the context.
func Get(ctx context.Context, key string) (string, bool) {
md, ok := FromContext(ctx)
if !ok {
return "", ok
}
key = canonicalMetadataKey(key)
val, ok := md[canonicalMetadataKey(key)]
return val, ok
}
// FromContext returns metadata from the given context.
func FromContext(ctx context.Context) (Metadata, bool) {
md, ok := ctx.Value(metadataKey{}).(Metadata)
if !ok {
return nil, ok
}
// capitalise all values
newMD := make(Metadata, len(md))
for k, v := range md {
newMD[canonicalMetadataKey(k)] = v
}
return newMD, ok
}
// NewContext creates a new context with the given metadata.
func NewContext(ctx context.Context, md Metadata) context.Context {
return context.WithValue(ctx, metadataKey{}, md)
}
// MergeContext merges metadata to existing metadata, overwriting if specified.
func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context {
if ctx == nil {
ctx = context.Background()
}
md, _ := ctx.Value(metadataKey{}).(Metadata)
cmd := make(Metadata, len(md))
maps.Copy(cmd, md)
for k, v := range patchMd {
if _, ok := cmd[k]; ok && !overwrite {
// skip
} else if v != "" {
cmd[k] = v
} else {
delete(cmd, k)
}
}
return context.WithValue(ctx, metadataKey{}, cmd)
}

View File

@ -0,0 +1,7 @@
package metadata
const (
RequestIDKey = "X-AEUS-Request-ID"
RequestPathKey = "X-AEUS-Request-Path"
RequestProtocolKey = "X-AEUS-Request-Protocol"
)

View File

@ -0,0 +1,5 @@
package middleware
type Transporter interface {
}

View File

@ -5,9 +5,13 @@ import (
"net" "net"
"net/url" "net/url"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/logger" "git.nobla.cn/golang/aeus/pkg/logger"
netutil "git.nobla.cn/golang/aeus/pkg/net" netutil "git.nobla.cn/golang/aeus/pkg/net"
"github.com/google/uuid"
"google.golang.org/grpc" "google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"
"google.golang.org/grpc/reflection" "google.golang.org/grpc/reflection"
) )
@ -17,6 +21,7 @@ type Server struct {
uri *url.URL uri *url.URL
serve *grpc.Server serve *grpc.Server
listener net.Listener listener net.Listener
middlewares []middleware.Middleware
} }
func (s *Server) createListener() (err error) { func (s *Server) createListener() (err error) {
@ -29,12 +34,44 @@ func (s *Server) createListener() (err error) {
return return
} }
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
next := func(ctx context.Context) (err error) {
resp, err = handler(ctx, req)
return
}
h := middleware.Chain(s.middlewares...)(next)
md := make(metadata.Metadata)
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestPathKey, info.FullMethod)
md.Set(metadata.RequestProtocolKey, Protocol)
if gmd, ok := grpcmd.FromIncomingContext(ctx); ok {
for k, v := range gmd {
if len(v) > 0 {
md.Set(k, v[0])
}
}
}
ctx = metadata.MergeContext(ctx, md, true)
ctx = context.WithValue(ctx, requestValueContextKey{}, req)
err = h(ctx)
return
}
}
func (s *Server) Use(middlewares ...middleware.Middleware) {
s.middlewares = append(s.middlewares, middlewares...)
}
func (s *Server) Start(ctx context.Context) (err error) { func (s *Server) Start(ctx context.Context) (err error) {
s.ctx = ctx s.ctx = ctx
if err = s.createListener(); err != nil { if err = s.createListener(); err != nil {
return return
} }
s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host) s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host)
reflection.Register(s.serve) reflection.Register(s.serve)
s.serve.Serve(s.listener) s.serve.Serve(s.listener)
return return
@ -63,15 +100,17 @@ func New(cbs ...Option) *Server {
network: "tcp", network: "tcp",
logger: logger.Default(), logger: logger.Default(),
address: ":0", address: ":0",
grpcOpts: make([]grpc.ServerOption, 0, 10),
}, },
uri: &url.URL{ uri: &url.URL{
Scheme: "grpc", Scheme: "grpc",
}, },
middlewares: make([]middleware.Middleware, 0, 10),
} }
for _, cb := range cbs { for _, cb := range cbs {
cb(svr.opts) cb(svr.opts)
} }
gopts := []grpc.ServerOption{} svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor()))
svr.serve = grpc.NewServer(gopts...) svr.serve = grpc.NewServer(svr.opts.grpcOpts...)
return svr return svr
} }

View File

@ -9,12 +9,17 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
const (
Protocol = "grpc"
)
type ( type (
Option func(*options) Option func(*options)
options struct { options struct {
network string network string
address string address string
grpcOpts []grpc.ServerOption
logger logger.Logger logger logger.Logger
context context.Context context context.Context
} }
@ -26,6 +31,8 @@ type (
} }
ClientOption func(*clientOptions) ClientOption func(*clientOptions)
requestValueContextKey struct{}
) )
func WithNetwork(network string) Option { func WithNetwork(network string) Option {
@ -69,3 +76,10 @@ func WithGrpcDialOptions(opts ...grpc.DialOption) ClientOption {
o.dialOptions = opts o.dialOptions = opts
} }
} }
func GetRequestValueFromContext(ctx context.Context) any {
if ctx == nil {
return nil
}
return ctx.Value(requestValueContextKey{})
}

View File

@ -8,11 +8,13 @@ import (
"net/url" "net/url"
"sync" "sync"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/errors" "git.nobla.cn/golang/aeus/pkg/errors"
"git.nobla.cn/golang/aeus/pkg/logger" "git.nobla.cn/golang/aeus/pkg/logger"
netutil "git.nobla.cn/golang/aeus/pkg/net" netutil "git.nobla.cn/golang/aeus/pkg/net"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid"
) )
type Server struct { type Server struct {
@ -23,6 +25,7 @@ type Server struct {
engine *gin.Engine engine *gin.Engine
once sync.Once once sync.Once
listener net.Listener listener net.Listener
middlewares []middleware.Middleware
} }
func (s *Server) Endpoint(ctx context.Context) (string, error) { func (s *Server) Endpoint(ctx context.Context) (string, error) {
@ -81,25 +84,38 @@ func (s *Server) DELETE(pattern string, h HandleFunc) {
} }
func (s *Server) Use(middlewares ...middleware.Middleware) { func (s *Server) Use(middlewares ...middleware.Middleware) {
for _, m := range middlewares { s.middlewares = append(s.middlewares, middlewares...)
s.engine.Use(s.warpMiddleware(m))
}
} }
func (s *Server) warpMiddleware(m middleware.Middleware) gin.HandlerFunc { func (s *Server) requestInterceptor() gin.HandlerFunc {
return func(ginCtx *gin.Context) { return func(ginCtx *gin.Context) {
ctx := ginCtx.Request.Context() ctx := ginCtx.Request.Context()
handler := func(ctx context.Context) error { next := func(ctx context.Context) error {
ginCtx.Next() ginCtx.Next()
if err := ginCtx.Errors.Last(); err != nil { if err := ginCtx.Errors.Last(); err != nil {
return err.Err return err.Err
} }
return nil return nil
} }
wrappedHandler := m(handler) handler := middleware.Chain(s.middlewares...)(next)
if err := wrappedHandler(ctx); err != nil { md := make(metadata.Metadata)
ginCtx.AbortWithError(http.StatusServiceUnavailable, err) for k, v := range ginCtx.Request.Header {
return if len(v) > 0 {
md.Set(k, v[0])
}
}
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestProtocolKey, Protocol)
md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path)
ctx = metadata.MergeContext(ctx, md, true)
if err := handler(ctx); err != nil {
if se, ok := err.(*errors.Error); ok {
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil))
} else {
ginCtx.AbortWithError(http.StatusInternalServerError, err)
}
} }
} }
} }
@ -175,5 +191,6 @@ func New(cbs ...Option) *Server {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
svr.engine = gin.New(svr.opts.ginOptions...) svr.engine = gin.New(svr.opts.ginOptions...)
svr.engine.Use(svr.requestInterceptor())
return svr return svr
} }

View File

@ -8,6 +8,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const (
Protocol = "http"
)
type ( type (
Option func(*options) Option func(*options)