aeus/transport/grpc/server.go

159 lines
4.3 KiB
Go

package grpc
import (
"context"
"fmt"
"net"
"net/url"
"os"
"strconv"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/logger"
netutil "git.nobla.cn/golang/aeus/pkg/net"
"github.com/google/uuid"
"google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"
"google.golang.org/grpc/reflection"
)
type Server struct {
ctx context.Context
opts *options
uri *url.URL
serve *grpc.Server
listener net.Listener
middlewares []middleware.Middleware
}
func (s *Server) createListener() (err error) {
if s.listener == nil {
if s.listener, err = net.Listen(s.opts.network, s.opts.address); err != nil {
return
}
s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener)
}
return
}
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
next := func(ctx context.Context) (err error) {
resp, err = handler(ctx, req)
return
}
h := middleware.Chain(s.middlewares...)(next)
md := metadata.FromContext(ctx)
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
if ok {
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
}
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
if !ok {
grpcOutgoingMetadata = make(grpcmd.MD)
}
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestPathKey, info.FullMethod)
md.Set(metadata.RequestProtocolKey, Protocol)
ctx = metadata.NewContext(ctx, md)
ctx = context.WithValue(ctx, requestValueContextKey{}, req)
err = h(ctx)
if grpcOutgoingMetadata.Len() > 0 {
grpc.SetHeader(ctx, grpcOutgoingMetadata)
}
return
}
}
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
ctx := ss.Context()
next := func(ctx context.Context) (err error) {
err = handler(srv, ss)
return
}
h := middleware.Chain(s.middlewares...)(next)
md := metadata.FromContext(ctx)
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
if ok {
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
}
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
if !ok {
grpcOutgoingMetadata = make(grpcmd.MD)
}
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestPathKey, info.FullMethod)
md.Set(metadata.RequestProtocolKey, Protocol)
ctx = metadata.NewContext(ctx, md)
err = h(ctx)
if grpcOutgoingMetadata.Len() > 0 {
grpc.SetHeader(ctx, grpcOutgoingMetadata)
}
return
}
}
func (s *Server) Use(middlewares ...middleware.Middleware) {
s.middlewares = append(s.middlewares, middlewares...)
}
func (s *Server) Start(ctx context.Context) (err error) {
s.ctx = ctx
if err = s.createListener(); err != nil {
return
}
s.opts.logger.Info(ctx, "grpc server listen on: %s", s.uri.Host)
reflection.Register(s.serve)
s.serve.Serve(s.listener)
return
}
func (s *Server) Endpoint(ctx context.Context) (string, error) {
if err := s.createListener(); err != nil {
return "", err
}
return s.uri.String(), nil
}
func (s *Server) RegisterService(sd *grpc.ServiceDesc, ss any) {
s.serve.RegisterService(sd, ss)
}
func (s *Server) Stop(ctx context.Context) (err error) {
s.serve.GracefulStop()
s.opts.logger.Info(s.ctx, "grpc server stopped")
return
}
func New(cbs ...Option) *Server {
svr := &Server{
opts: &options{
network: "tcp",
logger: logger.Default(),
grpcOpts: make([]grpc.ServerOption, 0, 10),
},
uri: &url.URL{
Scheme: "grpc",
},
middlewares: make([]middleware.Middleware, 0, 10),
}
port, _ := strconv.Atoi(os.Getenv("GRPC_PORT"))
svr.opts.address = fmt.Sprintf(":%d", port)
for _, cb := range cbs {
cb(svr.opts)
}
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor()))
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainStreamInterceptor(svr.streamServerInterceptor()))
svr.serve = grpc.NewServer(svr.opts.grpcOpts...)
return svr
}