diff --git a/app.go b/app.go index 279d532..7500ea2 100644 --- a/app.go +++ b/app.go @@ -144,8 +144,10 @@ func (s *Service) preStart(ctx context.Context) (err error) { } if s.opts.registry != nil { s.errGroup.Go(func() error { + childCtx, cancel := context.WithTimeout(ctx, s.opts.registrarTimeout) + defer cancel() opts := func(o *registry.RegisterOptions) { - o.Context = ctx + o.Context = childCtx } if err = s.opts.registry.Register(s.service, opts); err != nil { return err diff --git a/metadata/metadata.go b/metadata/metadata.go new file mode 100644 index 0000000..10d8d51 --- /dev/null +++ b/metadata/metadata.go @@ -0,0 +1,115 @@ +package metadata + +import ( + "context" + "maps" + "strings" +) + +type metadataKey struct{} + +// Metadata is our way of representing request headers internally. +// They're used at the RPC level and translate back and forth +// from Transport headers. +type Metadata map[string]string + +func canonicalMetadataKey(key string) string { + return strings.ToLower(key) +} + +func (md Metadata) Has(key string) bool { + _, ok := md[canonicalMetadataKey(key)] + return ok +} + +func (md Metadata) Get(key string) (string, bool) { + val, ok := md[canonicalMetadataKey(key)] + return val, ok +} + +func (md Metadata) Set(key, val string) { + md[canonicalMetadataKey(key)] = val +} + +func (md Metadata) Delete(key string) { + delete(md, canonicalMetadataKey(key)) +} + +// Copy makes a copy of the metadata. +func Copy(md Metadata) Metadata { + cmd := make(Metadata, len(md)) + maps.Copy(cmd, md) + return cmd +} + +// Delete key from metadata. +func Delete(ctx context.Context, k string) context.Context { + return Set(ctx, k, "") +} + +// Set add key with val to metadata. +func Set(ctx context.Context, k, v string) context.Context { + md, ok := FromContext(ctx) + if !ok { + md = make(Metadata) + } + k = canonicalMetadataKey(k) + if v == "" { + delete(md, k) + } else { + md[k] = v + } + return context.WithValue(ctx, metadataKey{}, md) +} + +// Get returns a single value from metadata in the context. +func Get(ctx context.Context, key string) (string, bool) { + md, ok := FromContext(ctx) + if !ok { + return "", ok + } + key = canonicalMetadataKey(key) + val, ok := md[canonicalMetadataKey(key)] + return val, ok +} + +// FromContext returns metadata from the given context. +func FromContext(ctx context.Context) (Metadata, bool) { + md, ok := ctx.Value(metadataKey{}).(Metadata) + if !ok { + return nil, ok + } + + // capitalise all values + newMD := make(Metadata, len(md)) + for k, v := range md { + newMD[canonicalMetadataKey(k)] = v + } + + return newMD, ok +} + +// NewContext creates a new context with the given metadata. +func NewContext(ctx context.Context, md Metadata) context.Context { + return context.WithValue(ctx, metadataKey{}, md) +} + +// MergeContext merges metadata to existing metadata, overwriting if specified. +func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context { + if ctx == nil { + ctx = context.Background() + } + md, _ := ctx.Value(metadataKey{}).(Metadata) + cmd := make(Metadata, len(md)) + maps.Copy(cmd, md) + for k, v := range patchMd { + if _, ok := cmd[k]; ok && !overwrite { + // skip + } else if v != "" { + cmd[k] = v + } else { + delete(cmd, k) + } + } + return context.WithValue(ctx, metadataKey{}, cmd) +} diff --git a/metadata/types.go b/metadata/types.go new file mode 100644 index 0000000..a0015c2 --- /dev/null +++ b/metadata/types.go @@ -0,0 +1,7 @@ +package metadata + +const ( + RequestIDKey = "X-AEUS-Request-ID" + RequestPathKey = "X-AEUS-Request-Path" + RequestProtocolKey = "X-AEUS-Request-Protocol" +) diff --git a/middleware/transporter.go b/middleware/transporter.go new file mode 100644 index 0000000..24f2ba1 --- /dev/null +++ b/middleware/transporter.go @@ -0,0 +1,5 @@ +package middleware + +type Transporter interface { + +} diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 901fcf4..b1f12f5 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -5,18 +5,23 @@ import ( "net" "net/url" + "git.nobla.cn/golang/aeus/metadata" + "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/logger" netutil "git.nobla.cn/golang/aeus/pkg/net" + "github.com/google/uuid" "google.golang.org/grpc" + grpcmd "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection" ) type Server struct { - ctx context.Context - opts *options - uri *url.URL - serve *grpc.Server - listener net.Listener + ctx context.Context + opts *options + uri *url.URL + serve *grpc.Server + listener net.Listener + middlewares []middleware.Middleware } func (s *Server) createListener() (err error) { @@ -29,12 +34,44 @@ func (s *Server) createListener() (err error) { return } +func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + next := func(ctx context.Context) (err error) { + resp, err = handler(ctx, req) + return + } + h := middleware.Chain(s.middlewares...)(next) + md := make(metadata.Metadata) + if !md.Has(metadata.RequestIDKey) { + md.Set(metadata.RequestIDKey, uuid.New().String()) + } + md.Set(metadata.RequestPathKey, info.FullMethod) + md.Set(metadata.RequestProtocolKey, Protocol) + if gmd, ok := grpcmd.FromIncomingContext(ctx); ok { + for k, v := range gmd { + if len(v) > 0 { + md.Set(k, v[0]) + } + } + } + ctx = metadata.MergeContext(ctx, md, true) + ctx = context.WithValue(ctx, requestValueContextKey{}, req) + err = h(ctx) + return + } +} + +func (s *Server) Use(middlewares ...middleware.Middleware) { + s.middlewares = append(s.middlewares, middlewares...) +} + func (s *Server) Start(ctx context.Context) (err error) { s.ctx = ctx if err = s.createListener(); err != nil { return } s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host) + reflection.Register(s.serve) s.serve.Serve(s.listener) return @@ -60,18 +97,20 @@ func (s *Server) Stop(ctx context.Context) (err error) { func New(cbs ...Option) *Server { svr := &Server{ opts: &options{ - network: "tcp", - logger: logger.Default(), - address: ":0", + network: "tcp", + logger: logger.Default(), + address: ":0", + grpcOpts: make([]grpc.ServerOption, 0, 10), }, uri: &url.URL{ Scheme: "grpc", }, + middlewares: make([]middleware.Middleware, 0, 10), } for _, cb := range cbs { cb(svr.opts) } - gopts := []grpc.ServerOption{} - svr.serve = grpc.NewServer(gopts...) + svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor())) + svr.serve = grpc.NewServer(svr.opts.grpcOpts...) return svr } diff --git a/transport/grpc/types.go b/transport/grpc/types.go index e373540..c3ee296 100644 --- a/transport/grpc/types.go +++ b/transport/grpc/types.go @@ -9,14 +9,19 @@ import ( "google.golang.org/grpc" ) +const ( + Protocol = "grpc" +) + type ( Option func(*options) options struct { - network string - address string - logger logger.Logger - context context.Context + network string + address string + grpcOpts []grpc.ServerOption + logger logger.Logger + context context.Context } clientOptions struct { @@ -26,6 +31,8 @@ type ( } ClientOption func(*clientOptions) + + requestValueContextKey struct{} ) func WithNetwork(network string) Option { @@ -69,3 +76,10 @@ func WithGrpcDialOptions(opts ...grpc.DialOption) ClientOption { o.dialOptions = opts } } + +func GetRequestValueFromContext(ctx context.Context) any { + if ctx == nil { + return nil + } + return ctx.Value(requestValueContextKey{}) +} diff --git a/transport/http/server.go b/transport/http/server.go index a0d1da2..b8b8be7 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -8,21 +8,24 @@ import ( "net/url" "sync" + "git.nobla.cn/golang/aeus/metadata" "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/errors" "git.nobla.cn/golang/aeus/pkg/logger" netutil "git.nobla.cn/golang/aeus/pkg/net" "github.com/gin-gonic/gin" + "github.com/google/uuid" ) type Server struct { - ctx context.Context - opts *options - uri *url.URL - serve *http.Server - engine *gin.Engine - once sync.Once - listener net.Listener + ctx context.Context + opts *options + uri *url.URL + serve *http.Server + engine *gin.Engine + once sync.Once + listener net.Listener + middlewares []middleware.Middleware } func (s *Server) Endpoint(ctx context.Context) (string, error) { @@ -81,25 +84,38 @@ func (s *Server) DELETE(pattern string, h HandleFunc) { } func (s *Server) Use(middlewares ...middleware.Middleware) { - for _, m := range middlewares { - s.engine.Use(s.warpMiddleware(m)) - } + s.middlewares = append(s.middlewares, middlewares...) } -func (s *Server) warpMiddleware(m middleware.Middleware) gin.HandlerFunc { +func (s *Server) requestInterceptor() gin.HandlerFunc { return func(ginCtx *gin.Context) { ctx := ginCtx.Request.Context() - handler := func(ctx context.Context) error { + next := func(ctx context.Context) error { ginCtx.Next() if err := ginCtx.Errors.Last(); err != nil { return err.Err } return nil } - wrappedHandler := m(handler) - if err := wrappedHandler(ctx); err != nil { - ginCtx.AbortWithError(http.StatusServiceUnavailable, err) - return + handler := middleware.Chain(s.middlewares...)(next) + md := make(metadata.Metadata) + for k, v := range ginCtx.Request.Header { + if len(v) > 0 { + md.Set(k, v[0]) + } + } + if !md.Has(metadata.RequestIDKey) { + md.Set(metadata.RequestIDKey, uuid.New().String()) + } + md.Set(metadata.RequestProtocolKey, Protocol) + md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path) + ctx = metadata.MergeContext(ctx, md, true) + if err := handler(ctx); err != nil { + if se, ok := err.(*errors.Error); ok { + ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil)) + } else { + ginCtx.AbortWithError(http.StatusInternalServerError, err) + } } } } @@ -175,5 +191,6 @@ func New(cbs ...Option) *Server { gin.SetMode(gin.ReleaseMode) } svr.engine = gin.New(svr.opts.ginOptions...) + svr.engine.Use(svr.requestInterceptor()) return svr } diff --git a/transport/http/types.go b/transport/http/types.go index 31ccbba..8e4fd8f 100644 --- a/transport/http/types.go +++ b/transport/http/types.go @@ -8,6 +8,10 @@ import ( "github.com/gin-gonic/gin" ) +const ( + Protocol = "http" +) + type ( Option func(*options)