aeus/transport/http/server.go

199 lines
5.1 KiB
Go

package http
import (
"context"
"net"
"net/http"
"net/http/pprof"
"net/url"
"sync"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/errors"
"git.nobla.cn/golang/aeus/pkg/logger"
netutil "git.nobla.cn/golang/aeus/pkg/net"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type Server struct {
ctx context.Context
opts *options
uri *url.URL
serve *http.Server
engine *gin.Engine
once sync.Once
listener net.Listener
middlewares []middleware.Middleware
}
func (s *Server) Endpoint(ctx context.Context) (string, error) {
if err := s.createListener(); err != nil {
return "", err
}
return s.uri.String(), nil
}
func (s *Server) GET(pattern string, h HandleFunc) {
s.engine.GET(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) PUT(pattern string, h HandleFunc) {
s.engine.PUT(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) POST(pattern string, h HandleFunc) {
s.engine.POST(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) HEAD(pattern string, h HandleFunc) {
s.engine.HEAD(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) PATCH(pattern string, h HandleFunc) {
s.engine.PATCH(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) DELETE(pattern string, h HandleFunc) {
s.engine.DELETE(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) Use(middlewares ...middleware.Middleware) {
s.middlewares = append(s.middlewares, middlewares...)
}
func (s *Server) requestInterceptor() gin.HandlerFunc {
return func(ginCtx *gin.Context) {
ctx := ginCtx.Request.Context()
next := func(ctx context.Context) error {
ginCtx.Next()
if err := ginCtx.Errors.Last(); err != nil {
return err.Err
}
return nil
}
handler := middleware.Chain(s.middlewares...)(next)
md := metadata.FromContext(ctx)
md.TeeReader(&httpMetadataReader{
hd: ginCtx.Request.Header,
})
md.TeeWriter(&httpMetadataWriter{
w: ginCtx.Writer,
})
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestProtocolKey, Protocol)
md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path)
ctx = metadata.NewContext(ctx, md)
ginCtx.Request = ginCtx.Request.WithContext(ctx)
if err := handler(ctx); err != nil {
if se, ok := err.(*errors.Error); ok {
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil))
} else {
ginCtx.AbortWithError(http.StatusInternalServerError, err)
}
}
}
}
func (s *Server) wrapHandle(f http.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
f(c.Writer, c.Request)
}
}
func (s *Server) createListener() (err error) {
if s.listener == nil {
if s.listener, err = net.Listen(s.opts.network, s.opts.address); err != nil {
return
}
s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener)
}
return
}
func (s *Server) Start(ctx context.Context) (err error) {
s.serve = &http.Server{
Addr: s.opts.address,
Handler: s.engine,
}
s.ctx = ctx
if s.opts.debug {
s.engine.Handle(http.MethodGet, "/debug/pprof/", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/goroutine", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/heap", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/mutex", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/threadcreate", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/cmdline", s.wrapHandle(pprof.Cmdline))
s.engine.Handle(http.MethodGet, "/debug/pprof/profile", s.wrapHandle(pprof.Profile))
s.engine.Handle(http.MethodGet, "/debug/pprof/symbol", s.wrapHandle(pprof.Symbol))
s.engine.Handle(http.MethodGet, "/debug/pprof/trace", s.wrapHandle(pprof.Trace))
}
if err = s.createListener(); err != nil {
return
}
s.opts.logger.Info(ctx, "http server listen on: %s", s.uri.Host)
if s.opts.certFile != "" && s.opts.keyFile != "" {
s.uri.Scheme = "https"
err = s.serve.ServeTLS(s.listener, s.opts.certFile, s.opts.keyFile)
} else {
err = s.serve.Serve(s.listener)
}
if !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
func (s *Server) Stop(ctx context.Context) (err error) {
err = s.serve.Shutdown(ctx)
s.opts.logger.Info(ctx, "http server stopped")
return
}
func New(cbs ...Option) *Server {
svr := &Server{
uri: &url.URL{Scheme: "http"},
opts: &options{
network: "tcp",
logger: logger.Default(),
address: ":0",
},
}
for _, cb := range cbs {
cb(svr.opts)
}
if !svr.opts.debug {
gin.SetMode(gin.ReleaseMode)
}
svr.engine = gin.New(svr.opts.ginOptions...)
svr.engine.Use(svr.requestInterceptor())
return svr
}