package grpc import ( "context" "fmt" "net" "net/url" "os" "strconv" "git.nobla.cn/golang/aeus/metadata" "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/logger" netutil "git.nobla.cn/golang/aeus/pkg/net" "github.com/google/uuid" "google.golang.org/grpc" grpcmd "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection" ) type Server struct { ctx context.Context opts *options uri *url.URL serve *grpc.Server listener net.Listener middlewares []middleware.Middleware Logger logger.Logger } func (s *Server) createListener() (err error) { if s.listener == nil { if s.listener, err = net.Listen(s.opts.network, s.opts.address); err != nil { return } s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener) } return } func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { next := func(ctx context.Context) (err error) { resp, err = handler(ctx, req) return } h := middleware.Chain(s.middlewares...)(next) md := metadata.FromContext(ctx) grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx) if ok { md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata}) } grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx) if !ok { grpcOutgoingMetadata = make(grpcmd.MD) } md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata}) if !md.Has(metadata.RequestIDKey) { md.Set(metadata.RequestIDKey, uuid.New().String()) } md.Set(metadata.RequestPathKey, info.FullMethod) md.Set(metadata.RequestProtocolKey, Protocol) ctx = metadata.NewContext(ctx, md) ctx = context.WithValue(ctx, requestValueContextKey{}, req) err = h(ctx) if grpcOutgoingMetadata.Len() > 0 { grpc.SetHeader(ctx, grpcOutgoingMetadata) } return } } func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { ctx := ss.Context() next := func(ctx context.Context) (err error) { err = handler(srv, ss) return } h := middleware.Chain(s.middlewares...)(next) md := metadata.FromContext(ctx) grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx) if ok { md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata}) } grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx) if !ok { grpcOutgoingMetadata = make(grpcmd.MD) } md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata}) if !md.Has(metadata.RequestIDKey) { md.Set(metadata.RequestIDKey, uuid.New().String()) } md.Set(metadata.RequestPathKey, info.FullMethod) md.Set(metadata.RequestProtocolKey, Protocol) ctx = metadata.NewContext(ctx, md) err = h(ctx) if grpcOutgoingMetadata.Len() > 0 { grpc.SetHeader(ctx, grpcOutgoingMetadata) } return } } func (s *Server) Use(middlewares ...middleware.Middleware) { s.middlewares = append(s.middlewares, middlewares...) } func (s *Server) Start(ctx context.Context) (err error) { s.ctx = ctx if s.opts.logger != nil { s.Logger = s.opts.logger } if s.Logger == nil { s.Logger = logger.Default() } if err = s.createListener(); err != nil { return } s.Logger.Infof(ctx, "grpc server listen on: %s", s.uri.Host) reflection.Register(s.serve) s.serve.Serve(s.listener) return } func (s *Server) Endpoint(ctx context.Context) (string, error) { if err := s.createListener(); err != nil { return "", err } return s.uri.String(), nil } func (s *Server) RegisterService(sd *grpc.ServiceDesc, ss any) { s.serve.RegisterService(sd, ss) } func (s *Server) Stop(ctx context.Context) (err error) { s.serve.GracefulStop() s.Logger.Infof(s.ctx, "grpc server stopped") return } func New(cbs ...Option) *Server { svr := &Server{ opts: &options{ network: "tcp", grpcOpts: make([]grpc.ServerOption, 0, 10), }, uri: &url.URL{ Scheme: "grpc", }, middlewares: make([]middleware.Middleware, 0, 10), } port, _ := strconv.Atoi(os.Getenv("GRPC_PORT")) svr.opts.address = fmt.Sprintf(":%d", port) for _, cb := range cbs { cb(svr.opts) } svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor())) svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainStreamInterceptor(svr.streamServerInterceptor())) svr.serve = grpc.NewServer(svr.opts.grpcOpts...) return svr }