164 lines
4.4 KiB
Go
164 lines
4.4 KiB
Go
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
|
|
"git.nobla.cn/golang/aeus/metadata"
|
|
"git.nobla.cn/golang/aeus/middleware"
|
|
"git.nobla.cn/golang/aeus/pkg/logger"
|
|
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
|
"github.com/google/uuid"
|
|
"google.golang.org/grpc"
|
|
grpcmd "google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/reflection"
|
|
)
|
|
|
|
type Server struct {
|
|
ctx context.Context
|
|
opts *options
|
|
uri *url.URL
|
|
serve *grpc.Server
|
|
listener net.Listener
|
|
middlewares []middleware.Middleware
|
|
Logger logger.Logger
|
|
}
|
|
|
|
func (s *Server) createListener() (err error) {
|
|
if s.listener == nil {
|
|
if s.listener, err = net.Listen(s.opts.network, s.opts.address); err != nil {
|
|
return
|
|
}
|
|
s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
|
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
|
|
next := func(ctx context.Context) (err error) {
|
|
resp, err = handler(ctx, req)
|
|
return
|
|
}
|
|
h := middleware.Chain(s.middlewares...)(next)
|
|
md := metadata.FromContext(ctx)
|
|
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
|
|
if ok {
|
|
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
|
|
}
|
|
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
|
|
if !ok {
|
|
grpcOutgoingMetadata = make(grpcmd.MD)
|
|
}
|
|
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
|
|
if !md.Has(metadata.RequestIDKey) {
|
|
md.Set(metadata.RequestIDKey, uuid.New().String())
|
|
}
|
|
md.Set(metadata.RequestPathKey, info.FullMethod)
|
|
md.Set(metadata.RequestProtocolKey, Protocol)
|
|
ctx = metadata.NewContext(ctx, md)
|
|
ctx = context.WithValue(ctx, requestValueContextKey{}, req)
|
|
err = h(ctx)
|
|
if grpcOutgoingMetadata.Len() > 0 {
|
|
grpc.SetHeader(ctx, grpcOutgoingMetadata)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
|
|
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
|
|
ctx := ss.Context()
|
|
next := func(ctx context.Context) (err error) {
|
|
err = handler(srv, ss)
|
|
return
|
|
}
|
|
h := middleware.Chain(s.middlewares...)(next)
|
|
md := metadata.FromContext(ctx)
|
|
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
|
|
if ok {
|
|
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
|
|
}
|
|
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
|
|
if !ok {
|
|
grpcOutgoingMetadata = make(grpcmd.MD)
|
|
}
|
|
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
|
|
if !md.Has(metadata.RequestIDKey) {
|
|
md.Set(metadata.RequestIDKey, uuid.New().String())
|
|
}
|
|
md.Set(metadata.RequestPathKey, info.FullMethod)
|
|
md.Set(metadata.RequestProtocolKey, Protocol)
|
|
ctx = metadata.NewContext(ctx, md)
|
|
err = h(ctx)
|
|
if grpcOutgoingMetadata.Len() > 0 {
|
|
grpc.SetHeader(ctx, grpcOutgoingMetadata)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
func (s *Server) Use(middlewares ...middleware.Middleware) {
|
|
s.middlewares = append(s.middlewares, middlewares...)
|
|
}
|
|
|
|
func (s *Server) Start(ctx context.Context) (err error) {
|
|
s.ctx = ctx
|
|
if s.opts.logger != nil {
|
|
s.Logger = s.opts.logger
|
|
}
|
|
if s.Logger == nil {
|
|
s.Logger = logger.Default()
|
|
}
|
|
if err = s.createListener(); err != nil {
|
|
return
|
|
}
|
|
s.Logger.Infof(ctx, "grpc server listen on: %s", s.uri.Host)
|
|
reflection.Register(s.serve)
|
|
s.serve.Serve(s.listener)
|
|
return
|
|
}
|
|
|
|
func (s *Server) Endpoint(ctx context.Context) (string, error) {
|
|
if err := s.createListener(); err != nil {
|
|
return "", err
|
|
}
|
|
return s.uri.String(), nil
|
|
}
|
|
|
|
func (s *Server) RegisterService(sd *grpc.ServiceDesc, ss any) {
|
|
s.serve.RegisterService(sd, ss)
|
|
}
|
|
|
|
func (s *Server) Stop(ctx context.Context) (err error) {
|
|
s.serve.GracefulStop()
|
|
s.Logger.Infof(s.ctx, "grpc server stopped")
|
|
return
|
|
}
|
|
|
|
func New(cbs ...Option) *Server {
|
|
svr := &Server{
|
|
opts: &options{
|
|
network: "tcp",
|
|
grpcOpts: make([]grpc.ServerOption, 0, 10),
|
|
},
|
|
uri: &url.URL{
|
|
Scheme: "grpc",
|
|
},
|
|
middlewares: make([]middleware.Middleware, 0, 10),
|
|
}
|
|
port, _ := strconv.Atoi(os.Getenv("GRPC_PORT"))
|
|
svr.opts.address = fmt.Sprintf(":%d", port)
|
|
for _, cb := range cbs {
|
|
cb(svr.opts)
|
|
}
|
|
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor()))
|
|
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainStreamInterceptor(svr.streamServerInterceptor()))
|
|
svr.serve = grpc.NewServer(svr.opts.grpcOpts...)
|
|
return svr
|
|
}
|