add metadata middleware
This commit is contained in:
parent
3abcc1ceed
commit
5403f14ccd
4
app.go
4
app.go
|
@ -144,8 +144,10 @@ func (s *Service) preStart(ctx context.Context) (err error) {
|
|||
}
|
||||
if s.opts.registry != nil {
|
||||
s.errGroup.Go(func() error {
|
||||
childCtx, cancel := context.WithTimeout(ctx, s.opts.registrarTimeout)
|
||||
defer cancel()
|
||||
opts := func(o *registry.RegisterOptions) {
|
||||
o.Context = ctx
|
||||
o.Context = childCtx
|
||||
}
|
||||
if err = s.opts.registry.Register(s.service, opts); err != nil {
|
||||
return err
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"maps"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type metadataKey struct{}
|
||||
|
||||
// Metadata is our way of representing request headers internally.
|
||||
// They're used at the RPC level and translate back and forth
|
||||
// from Transport headers.
|
||||
type Metadata map[string]string
|
||||
|
||||
func canonicalMetadataKey(key string) string {
|
||||
return strings.ToLower(key)
|
||||
}
|
||||
|
||||
func (md Metadata) Has(key string) bool {
|
||||
_, ok := md[canonicalMetadataKey(key)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (md Metadata) Get(key string) (string, bool) {
|
||||
val, ok := md[canonicalMetadataKey(key)]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (md Metadata) Set(key, val string) {
|
||||
md[canonicalMetadataKey(key)] = val
|
||||
}
|
||||
|
||||
func (md Metadata) Delete(key string) {
|
||||
delete(md, canonicalMetadataKey(key))
|
||||
}
|
||||
|
||||
// Copy makes a copy of the metadata.
|
||||
func Copy(md Metadata) Metadata {
|
||||
cmd := make(Metadata, len(md))
|
||||
maps.Copy(cmd, md)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Delete key from metadata.
|
||||
func Delete(ctx context.Context, k string) context.Context {
|
||||
return Set(ctx, k, "")
|
||||
}
|
||||
|
||||
// Set add key with val to metadata.
|
||||
func Set(ctx context.Context, k, v string) context.Context {
|
||||
md, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
md = make(Metadata)
|
||||
}
|
||||
k = canonicalMetadataKey(k)
|
||||
if v == "" {
|
||||
delete(md, k)
|
||||
} else {
|
||||
md[k] = v
|
||||
}
|
||||
return context.WithValue(ctx, metadataKey{}, md)
|
||||
}
|
||||
|
||||
// Get returns a single value from metadata in the context.
|
||||
func Get(ctx context.Context, key string) (string, bool) {
|
||||
md, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
return "", ok
|
||||
}
|
||||
key = canonicalMetadataKey(key)
|
||||
val, ok := md[canonicalMetadataKey(key)]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// FromContext returns metadata from the given context.
|
||||
func FromContext(ctx context.Context) (Metadata, bool) {
|
||||
md, ok := ctx.Value(metadataKey{}).(Metadata)
|
||||
if !ok {
|
||||
return nil, ok
|
||||
}
|
||||
|
||||
// capitalise all values
|
||||
newMD := make(Metadata, len(md))
|
||||
for k, v := range md {
|
||||
newMD[canonicalMetadataKey(k)] = v
|
||||
}
|
||||
|
||||
return newMD, ok
|
||||
}
|
||||
|
||||
// NewContext creates a new context with the given metadata.
|
||||
func NewContext(ctx context.Context, md Metadata) context.Context {
|
||||
return context.WithValue(ctx, metadataKey{}, md)
|
||||
}
|
||||
|
||||
// MergeContext merges metadata to existing metadata, overwriting if specified.
|
||||
func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
md, _ := ctx.Value(metadataKey{}).(Metadata)
|
||||
cmd := make(Metadata, len(md))
|
||||
maps.Copy(cmd, md)
|
||||
for k, v := range patchMd {
|
||||
if _, ok := cmd[k]; ok && !overwrite {
|
||||
// skip
|
||||
} else if v != "" {
|
||||
cmd[k] = v
|
||||
} else {
|
||||
delete(cmd, k)
|
||||
}
|
||||
}
|
||||
return context.WithValue(ctx, metadataKey{}, cmd)
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
package metadata
|
||||
|
||||
const (
|
||||
RequestIDKey = "X-AEUS-Request-ID"
|
||||
RequestPathKey = "X-AEUS-Request-Path"
|
||||
RequestProtocolKey = "X-AEUS-Request-Protocol"
|
||||
)
|
|
@ -0,0 +1,5 @@
|
|||
package middleware
|
||||
|
||||
type Transporter interface {
|
||||
|
||||
}
|
|
@ -5,18 +5,23 @@ import (
|
|||
"net"
|
||||
"net/url"
|
||||
|
||||
"git.nobla.cn/golang/aeus/metadata"
|
||||
"git.nobla.cn/golang/aeus/middleware"
|
||||
"git.nobla.cn/golang/aeus/pkg/logger"
|
||||
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc"
|
||||
grpcmd "google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
opts *options
|
||||
uri *url.URL
|
||||
serve *grpc.Server
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
opts *options
|
||||
uri *url.URL
|
||||
serve *grpc.Server
|
||||
listener net.Listener
|
||||
middlewares []middleware.Middleware
|
||||
}
|
||||
|
||||
func (s *Server) createListener() (err error) {
|
||||
|
@ -29,12 +34,44 @@ func (s *Server) createListener() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
|
||||
next := func(ctx context.Context) (err error) {
|
||||
resp, err = handler(ctx, req)
|
||||
return
|
||||
}
|
||||
h := middleware.Chain(s.middlewares...)(next)
|
||||
md := make(metadata.Metadata)
|
||||
if !md.Has(metadata.RequestIDKey) {
|
||||
md.Set(metadata.RequestIDKey, uuid.New().String())
|
||||
}
|
||||
md.Set(metadata.RequestPathKey, info.FullMethod)
|
||||
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||
if gmd, ok := grpcmd.FromIncomingContext(ctx); ok {
|
||||
for k, v := range gmd {
|
||||
if len(v) > 0 {
|
||||
md.Set(k, v[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx = metadata.MergeContext(ctx, md, true)
|
||||
ctx = context.WithValue(ctx, requestValueContextKey{}, req)
|
||||
err = h(ctx)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Use(middlewares ...middleware.Middleware) {
|
||||
s.middlewares = append(s.middlewares, middlewares...)
|
||||
}
|
||||
|
||||
func (s *Server) Start(ctx context.Context) (err error) {
|
||||
s.ctx = ctx
|
||||
if err = s.createListener(); err != nil {
|
||||
return
|
||||
}
|
||||
s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host)
|
||||
|
||||
reflection.Register(s.serve)
|
||||
s.serve.Serve(s.listener)
|
||||
return
|
||||
|
@ -60,18 +97,20 @@ func (s *Server) Stop(ctx context.Context) (err error) {
|
|||
func New(cbs ...Option) *Server {
|
||||
svr := &Server{
|
||||
opts: &options{
|
||||
network: "tcp",
|
||||
logger: logger.Default(),
|
||||
address: ":0",
|
||||
network: "tcp",
|
||||
logger: logger.Default(),
|
||||
address: ":0",
|
||||
grpcOpts: make([]grpc.ServerOption, 0, 10),
|
||||
},
|
||||
uri: &url.URL{
|
||||
Scheme: "grpc",
|
||||
},
|
||||
middlewares: make([]middleware.Middleware, 0, 10),
|
||||
}
|
||||
for _, cb := range cbs {
|
||||
cb(svr.opts)
|
||||
}
|
||||
gopts := []grpc.ServerOption{}
|
||||
svr.serve = grpc.NewServer(gopts...)
|
||||
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor()))
|
||||
svr.serve = grpc.NewServer(svr.opts.grpcOpts...)
|
||||
return svr
|
||||
}
|
||||
|
|
|
@ -9,14 +9,19 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
Protocol = "grpc"
|
||||
)
|
||||
|
||||
type (
|
||||
Option func(*options)
|
||||
|
||||
options struct {
|
||||
network string
|
||||
address string
|
||||
logger logger.Logger
|
||||
context context.Context
|
||||
network string
|
||||
address string
|
||||
grpcOpts []grpc.ServerOption
|
||||
logger logger.Logger
|
||||
context context.Context
|
||||
}
|
||||
|
||||
clientOptions struct {
|
||||
|
@ -26,6 +31,8 @@ type (
|
|||
}
|
||||
|
||||
ClientOption func(*clientOptions)
|
||||
|
||||
requestValueContextKey struct{}
|
||||
)
|
||||
|
||||
func WithNetwork(network string) Option {
|
||||
|
@ -69,3 +76,10 @@ func WithGrpcDialOptions(opts ...grpc.DialOption) ClientOption {
|
|||
o.dialOptions = opts
|
||||
}
|
||||
}
|
||||
|
||||
func GetRequestValueFromContext(ctx context.Context) any {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
return ctx.Value(requestValueContextKey{})
|
||||
}
|
||||
|
|
|
@ -8,21 +8,24 @@ import (
|
|||
"net/url"
|
||||
"sync"
|
||||
|
||||
"git.nobla.cn/golang/aeus/metadata"
|
||||
"git.nobla.cn/golang/aeus/middleware"
|
||||
"git.nobla.cn/golang/aeus/pkg/errors"
|
||||
"git.nobla.cn/golang/aeus/pkg/logger"
|
||||
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
opts *options
|
||||
uri *url.URL
|
||||
serve *http.Server
|
||||
engine *gin.Engine
|
||||
once sync.Once
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
opts *options
|
||||
uri *url.URL
|
||||
serve *http.Server
|
||||
engine *gin.Engine
|
||||
once sync.Once
|
||||
listener net.Listener
|
||||
middlewares []middleware.Middleware
|
||||
}
|
||||
|
||||
func (s *Server) Endpoint(ctx context.Context) (string, error) {
|
||||
|
@ -81,25 +84,38 @@ func (s *Server) DELETE(pattern string, h HandleFunc) {
|
|||
}
|
||||
|
||||
func (s *Server) Use(middlewares ...middleware.Middleware) {
|
||||
for _, m := range middlewares {
|
||||
s.engine.Use(s.warpMiddleware(m))
|
||||
}
|
||||
s.middlewares = append(s.middlewares, middlewares...)
|
||||
}
|
||||
|
||||
func (s *Server) warpMiddleware(m middleware.Middleware) gin.HandlerFunc {
|
||||
func (s *Server) requestInterceptor() gin.HandlerFunc {
|
||||
return func(ginCtx *gin.Context) {
|
||||
ctx := ginCtx.Request.Context()
|
||||
handler := func(ctx context.Context) error {
|
||||
next := func(ctx context.Context) error {
|
||||
ginCtx.Next()
|
||||
if err := ginCtx.Errors.Last(); err != nil {
|
||||
return err.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
wrappedHandler := m(handler)
|
||||
if err := wrappedHandler(ctx); err != nil {
|
||||
ginCtx.AbortWithError(http.StatusServiceUnavailable, err)
|
||||
return
|
||||
handler := middleware.Chain(s.middlewares...)(next)
|
||||
md := make(metadata.Metadata)
|
||||
for k, v := range ginCtx.Request.Header {
|
||||
if len(v) > 0 {
|
||||
md.Set(k, v[0])
|
||||
}
|
||||
}
|
||||
if !md.Has(metadata.RequestIDKey) {
|
||||
md.Set(metadata.RequestIDKey, uuid.New().String())
|
||||
}
|
||||
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||
md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path)
|
||||
ctx = metadata.MergeContext(ctx, md, true)
|
||||
if err := handler(ctx); err != nil {
|
||||
if se, ok := err.(*errors.Error); ok {
|
||||
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil))
|
||||
} else {
|
||||
ginCtx.AbortWithError(http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -175,5 +191,6 @@ func New(cbs ...Option) *Server {
|
|||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
svr.engine = gin.New(svr.opts.ginOptions...)
|
||||
svr.engine.Use(svr.requestInterceptor())
|
||||
return svr
|
||||
}
|
||||
|
|
|
@ -8,6 +8,10 @@ import (
|
|||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
Protocol = "http"
|
||||
)
|
||||
|
||||
type (
|
||||
Option func(*options)
|
||||
|
||||
|
|
Loading…
Reference in New Issue