add metadata middleware
This commit is contained in:
parent
3abcc1ceed
commit
5403f14ccd
4
app.go
4
app.go
|
@ -144,8 +144,10 @@ func (s *Service) preStart(ctx context.Context) (err error) {
|
||||||
}
|
}
|
||||||
if s.opts.registry != nil {
|
if s.opts.registry != nil {
|
||||||
s.errGroup.Go(func() error {
|
s.errGroup.Go(func() error {
|
||||||
|
childCtx, cancel := context.WithTimeout(ctx, s.opts.registrarTimeout)
|
||||||
|
defer cancel()
|
||||||
opts := func(o *registry.RegisterOptions) {
|
opts := func(o *registry.RegisterOptions) {
|
||||||
o.Context = ctx
|
o.Context = childCtx
|
||||||
}
|
}
|
||||||
if err = s.opts.registry.Register(s.service, opts); err != nil {
|
if err = s.opts.registry.Register(s.service, opts); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"maps"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type metadataKey struct{}
|
||||||
|
|
||||||
|
// Metadata is our way of representing request headers internally.
|
||||||
|
// They're used at the RPC level and translate back and forth
|
||||||
|
// from Transport headers.
|
||||||
|
type Metadata map[string]string
|
||||||
|
|
||||||
|
func canonicalMetadataKey(key string) string {
|
||||||
|
return strings.ToLower(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md Metadata) Has(key string) bool {
|
||||||
|
_, ok := md[canonicalMetadataKey(key)]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md Metadata) Get(key string) (string, bool) {
|
||||||
|
val, ok := md[canonicalMetadataKey(key)]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md Metadata) Set(key, val string) {
|
||||||
|
md[canonicalMetadataKey(key)] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md Metadata) Delete(key string) {
|
||||||
|
delete(md, canonicalMetadataKey(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy makes a copy of the metadata.
|
||||||
|
func Copy(md Metadata) Metadata {
|
||||||
|
cmd := make(Metadata, len(md))
|
||||||
|
maps.Copy(cmd, md)
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete key from metadata.
|
||||||
|
func Delete(ctx context.Context, k string) context.Context {
|
||||||
|
return Set(ctx, k, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set add key with val to metadata.
|
||||||
|
func Set(ctx context.Context, k, v string) context.Context {
|
||||||
|
md, ok := FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
md = make(Metadata)
|
||||||
|
}
|
||||||
|
k = canonicalMetadataKey(k)
|
||||||
|
if v == "" {
|
||||||
|
delete(md, k)
|
||||||
|
} else {
|
||||||
|
md[k] = v
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, metadataKey{}, md)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a single value from metadata in the context.
|
||||||
|
func Get(ctx context.Context, key string) (string, bool) {
|
||||||
|
md, ok := FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", ok
|
||||||
|
}
|
||||||
|
key = canonicalMetadataKey(key)
|
||||||
|
val, ok := md[canonicalMetadataKey(key)]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext returns metadata from the given context.
|
||||||
|
func FromContext(ctx context.Context) (Metadata, bool) {
|
||||||
|
md, ok := ctx.Value(metadataKey{}).(Metadata)
|
||||||
|
if !ok {
|
||||||
|
return nil, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// capitalise all values
|
||||||
|
newMD := make(Metadata, len(md))
|
||||||
|
for k, v := range md {
|
||||||
|
newMD[canonicalMetadataKey(k)] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return newMD, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContext creates a new context with the given metadata.
|
||||||
|
func NewContext(ctx context.Context, md Metadata) context.Context {
|
||||||
|
return context.WithValue(ctx, metadataKey{}, md)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeContext merges metadata to existing metadata, overwriting if specified.
|
||||||
|
func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
md, _ := ctx.Value(metadataKey{}).(Metadata)
|
||||||
|
cmd := make(Metadata, len(md))
|
||||||
|
maps.Copy(cmd, md)
|
||||||
|
for k, v := range patchMd {
|
||||||
|
if _, ok := cmd[k]; ok && !overwrite {
|
||||||
|
// skip
|
||||||
|
} else if v != "" {
|
||||||
|
cmd[k] = v
|
||||||
|
} else {
|
||||||
|
delete(cmd, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, metadataKey{}, cmd)
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
const (
|
||||||
|
RequestIDKey = "X-AEUS-Request-ID"
|
||||||
|
RequestPathKey = "X-AEUS-Request-Path"
|
||||||
|
RequestProtocolKey = "X-AEUS-Request-Protocol"
|
||||||
|
)
|
|
@ -0,0 +1,5 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
type Transporter interface {
|
||||||
|
|
||||||
|
}
|
|
@ -5,18 +5,23 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"git.nobla.cn/golang/aeus/metadata"
|
||||||
|
"git.nobla.cn/golang/aeus/middleware"
|
||||||
"git.nobla.cn/golang/aeus/pkg/logger"
|
"git.nobla.cn/golang/aeus/pkg/logger"
|
||||||
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
||||||
|
"github.com/google/uuid"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
grpcmd "google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/reflection"
|
"google.golang.org/grpc/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
opts *options
|
opts *options
|
||||||
uri *url.URL
|
uri *url.URL
|
||||||
serve *grpc.Server
|
serve *grpc.Server
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
middlewares []middleware.Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createListener() (err error) {
|
func (s *Server) createListener() (err error) {
|
||||||
|
@ -29,12 +34,44 @@ func (s *Server) createListener() (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||||
|
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
|
||||||
|
next := func(ctx context.Context) (err error) {
|
||||||
|
resp, err = handler(ctx, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h := middleware.Chain(s.middlewares...)(next)
|
||||||
|
md := make(metadata.Metadata)
|
||||||
|
if !md.Has(metadata.RequestIDKey) {
|
||||||
|
md.Set(metadata.RequestIDKey, uuid.New().String())
|
||||||
|
}
|
||||||
|
md.Set(metadata.RequestPathKey, info.FullMethod)
|
||||||
|
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||||
|
if gmd, ok := grpcmd.FromIncomingContext(ctx); ok {
|
||||||
|
for k, v := range gmd {
|
||||||
|
if len(v) > 0 {
|
||||||
|
md.Set(k, v[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx = metadata.MergeContext(ctx, md, true)
|
||||||
|
ctx = context.WithValue(ctx, requestValueContextKey{}, req)
|
||||||
|
err = h(ctx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Use(middlewares ...middleware.Middleware) {
|
||||||
|
s.middlewares = append(s.middlewares, middlewares...)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) Start(ctx context.Context) (err error) {
|
func (s *Server) Start(ctx context.Context) (err error) {
|
||||||
s.ctx = ctx
|
s.ctx = ctx
|
||||||
if err = s.createListener(); err != nil {
|
if err = s.createListener(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host)
|
s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host)
|
||||||
|
|
||||||
reflection.Register(s.serve)
|
reflection.Register(s.serve)
|
||||||
s.serve.Serve(s.listener)
|
s.serve.Serve(s.listener)
|
||||||
return
|
return
|
||||||
|
@ -60,18 +97,20 @@ func (s *Server) Stop(ctx context.Context) (err error) {
|
||||||
func New(cbs ...Option) *Server {
|
func New(cbs ...Option) *Server {
|
||||||
svr := &Server{
|
svr := &Server{
|
||||||
opts: &options{
|
opts: &options{
|
||||||
network: "tcp",
|
network: "tcp",
|
||||||
logger: logger.Default(),
|
logger: logger.Default(),
|
||||||
address: ":0",
|
address: ":0",
|
||||||
|
grpcOpts: make([]grpc.ServerOption, 0, 10),
|
||||||
},
|
},
|
||||||
uri: &url.URL{
|
uri: &url.URL{
|
||||||
Scheme: "grpc",
|
Scheme: "grpc",
|
||||||
},
|
},
|
||||||
|
middlewares: make([]middleware.Middleware, 0, 10),
|
||||||
}
|
}
|
||||||
for _, cb := range cbs {
|
for _, cb := range cbs {
|
||||||
cb(svr.opts)
|
cb(svr.opts)
|
||||||
}
|
}
|
||||||
gopts := []grpc.ServerOption{}
|
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor()))
|
||||||
svr.serve = grpc.NewServer(gopts...)
|
svr.serve = grpc.NewServer(svr.opts.grpcOpts...)
|
||||||
return svr
|
return svr
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,14 +9,19 @@ import (
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Protocol = "grpc"
|
||||||
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Option func(*options)
|
Option func(*options)
|
||||||
|
|
||||||
options struct {
|
options struct {
|
||||||
network string
|
network string
|
||||||
address string
|
address string
|
||||||
logger logger.Logger
|
grpcOpts []grpc.ServerOption
|
||||||
context context.Context
|
logger logger.Logger
|
||||||
|
context context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
clientOptions struct {
|
clientOptions struct {
|
||||||
|
@ -26,6 +31,8 @@ type (
|
||||||
}
|
}
|
||||||
|
|
||||||
ClientOption func(*clientOptions)
|
ClientOption func(*clientOptions)
|
||||||
|
|
||||||
|
requestValueContextKey struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithNetwork(network string) Option {
|
func WithNetwork(network string) Option {
|
||||||
|
@ -69,3 +76,10 @@ func WithGrpcDialOptions(opts ...grpc.DialOption) ClientOption {
|
||||||
o.dialOptions = opts
|
o.dialOptions = opts
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRequestValueFromContext(ctx context.Context) any {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ctx.Value(requestValueContextKey{})
|
||||||
|
}
|
||||||
|
|
|
@ -8,21 +8,24 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"git.nobla.cn/golang/aeus/metadata"
|
||||||
"git.nobla.cn/golang/aeus/middleware"
|
"git.nobla.cn/golang/aeus/middleware"
|
||||||
"git.nobla.cn/golang/aeus/pkg/errors"
|
"git.nobla.cn/golang/aeus/pkg/errors"
|
||||||
"git.nobla.cn/golang/aeus/pkg/logger"
|
"git.nobla.cn/golang/aeus/pkg/logger"
|
||||||
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
opts *options
|
opts *options
|
||||||
uri *url.URL
|
uri *url.URL
|
||||||
serve *http.Server
|
serve *http.Server
|
||||||
engine *gin.Engine
|
engine *gin.Engine
|
||||||
once sync.Once
|
once sync.Once
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
middlewares []middleware.Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Endpoint(ctx context.Context) (string, error) {
|
func (s *Server) Endpoint(ctx context.Context) (string, error) {
|
||||||
|
@ -81,25 +84,38 @@ func (s *Server) DELETE(pattern string, h HandleFunc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Use(middlewares ...middleware.Middleware) {
|
func (s *Server) Use(middlewares ...middleware.Middleware) {
|
||||||
for _, m := range middlewares {
|
s.middlewares = append(s.middlewares, middlewares...)
|
||||||
s.engine.Use(s.warpMiddleware(m))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) warpMiddleware(m middleware.Middleware) gin.HandlerFunc {
|
func (s *Server) requestInterceptor() gin.HandlerFunc {
|
||||||
return func(ginCtx *gin.Context) {
|
return func(ginCtx *gin.Context) {
|
||||||
ctx := ginCtx.Request.Context()
|
ctx := ginCtx.Request.Context()
|
||||||
handler := func(ctx context.Context) error {
|
next := func(ctx context.Context) error {
|
||||||
ginCtx.Next()
|
ginCtx.Next()
|
||||||
if err := ginCtx.Errors.Last(); err != nil {
|
if err := ginCtx.Errors.Last(); err != nil {
|
||||||
return err.Err
|
return err.Err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
wrappedHandler := m(handler)
|
handler := middleware.Chain(s.middlewares...)(next)
|
||||||
if err := wrappedHandler(ctx); err != nil {
|
md := make(metadata.Metadata)
|
||||||
ginCtx.AbortWithError(http.StatusServiceUnavailable, err)
|
for k, v := range ginCtx.Request.Header {
|
||||||
return
|
if len(v) > 0 {
|
||||||
|
md.Set(k, v[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !md.Has(metadata.RequestIDKey) {
|
||||||
|
md.Set(metadata.RequestIDKey, uuid.New().String())
|
||||||
|
}
|
||||||
|
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||||
|
md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path)
|
||||||
|
ctx = metadata.MergeContext(ctx, md, true)
|
||||||
|
if err := handler(ctx); err != nil {
|
||||||
|
if se, ok := err.(*errors.Error); ok {
|
||||||
|
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil))
|
||||||
|
} else {
|
||||||
|
ginCtx.AbortWithError(http.StatusInternalServerError, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -175,5 +191,6 @@ func New(cbs ...Option) *Server {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
svr.engine = gin.New(svr.opts.ginOptions...)
|
svr.engine = gin.New(svr.opts.ginOptions...)
|
||||||
|
svr.engine.Use(svr.requestInterceptor())
|
||||||
return svr
|
return svr
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,10 @@ import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Protocol = "http"
|
||||||
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Option func(*options)
|
Option func(*options)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue