aeus/transport/http/server.go

321 lines
8.5 KiB
Go

package http
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/pprof"
"net/url"
"os"
"path"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"time"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/errors"
"git.nobla.cn/golang/aeus/pkg/logger"
netutil "git.nobla.cn/golang/aeus/pkg/net"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type Server struct {
ctx context.Context
opts *options
uri *url.URL
serve *http.Server
engine *gin.Engine
once sync.Once
listener net.Listener
fs *filesystem
middlewares []middleware.Middleware
Logger logger.Logger
}
func (s *Server) Endpoint(ctx context.Context) (string, error) {
if err := s.createListener(); err != nil {
return "", err
}
return s.uri.String(), nil
}
func (s *Server) GET(pattern string, h HandleFunc) {
s.engine.GET(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) PUT(pattern string, h HandleFunc) {
s.engine.PUT(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) POST(pattern string, h HandleFunc) {
s.engine.POST(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) HEAD(pattern string, h HandleFunc) {
s.engine.HEAD(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) PATCH(pattern string, h HandleFunc) {
s.engine.PATCH(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) DELETE(pattern string, h HandleFunc) {
s.engine.DELETE(pattern, func(ctx *gin.Context) {
childCtx := newContext(ctx)
h(childCtx)
putContext(childCtx)
})
}
func (s *Server) Use(middlewares ...middleware.Middleware) {
s.middlewares = append(s.middlewares, middlewares...)
}
func (s *Server) Handle(method string, uri string, handler http.HandlerFunc) {
s.engine.Handle(method, uri, func(ctx *gin.Context) {
handler(ctx.Writer, ctx.Request)
})
}
func (s *Server) Webroot(prefix string, fs http.FileSystem) {
s.fs = newFS(time.Now(), fs)
s.fs.SetPrefix(prefix)
s.fs.DenyAccessDirectory()
s.fs.SetIndexFile("/index.html")
}
func (s *Server) shouldCompress(req *http.Request) bool {
if !strings.Contains(req.Header.Get(headerAcceptEncoding), "gzip") ||
strings.Contains(req.Header.Get("Connection"), "Upgrade") {
return false
}
// Check if the request path is excluded from compression
extension := filepath.Ext(req.URL.Path)
if slices.Contains(assetsExtensions, extension) {
return true
}
return false
}
func (s *Server) staticHandle(ctx *gin.Context, fp http.File) {
uri := path.Clean(ctx.Request.URL.Path)
fi, err := fp.Stat()
if err != nil {
return
}
if !fi.IsDir() {
//https://github.com/gin-contrib/gzip
if s.shouldCompress(ctx.Request) && fi.Size() > 8192 {
gzWriter := newGzipWriter()
gzWriter.Reset(ctx.Writer)
ctx.Header(headerContentEncoding, "gzip")
ctx.Writer.Header().Add(headerVary, headerAcceptEncoding)
originalEtag := ctx.GetHeader("ETag")
if originalEtag != "" && !strings.HasPrefix(originalEtag, "W/") {
ctx.Header("ETag", "W/"+originalEtag)
}
ctx.Writer = &gzipWriter{ctx.Writer, gzWriter}
defer func() {
if ctx.Writer.Size() < 0 {
gzWriter.Reset(io.Discard)
}
gzWriter.Close()
if ctx.Writer.Size() > -1 {
ctx.Header("Content-Length", strconv.Itoa(ctx.Writer.Size()))
}
putGzipWriter(gzWriter)
}()
}
}
http.ServeContent(ctx.Writer, ctx.Request, path.Base(uri), s.fs.modtime, fp)
ctx.Abort()
}
func (s *Server) notFoundHandle(ctx *gin.Context) {
if s.fs != nil && ctx.Request.Method == http.MethodGet {
uri := path.Clean(ctx.Request.URL.Path)
if fp, err := s.fs.Open(uri); err == nil {
s.staticHandle(ctx, fp)
fp.Close()
}
}
ctx.JSON(http.StatusNotFound, newResponse(errors.NotFound, "Not Found", nil))
}
func (s *Server) CORSInterceptor() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Writer.Header().Add("Vary", "Origin")
c.Writer.Header().Add("Vary", "Access-Control-Request-Method")
c.Writer.Header().Add("Vary", "Access-Control-Request-Headers")
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,HEAD,PUT,PATCH,POST,DELETE")
h := c.Request.Header.Get("Access-Control-Request-Headers")
if h != "" {
c.Writer.Header().Set("Access-Control-Allow-Headers", h)
}
c.AbortWithStatus(204)
return
} else {
c.Writer.Header().Add("Vary", "Origin")
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
h := c.Request.Header.Get("Access-Control-Request-Headers")
if h != "" {
c.Writer.Header().Set("Access-Control-Allow-Headers", h)
}
}
c.Next()
}
}
func (s *Server) requestInterceptor() gin.HandlerFunc {
return func(ginCtx *gin.Context) {
ctx := ginCtx.Request.Context()
next := func(ctx context.Context) error {
ginCtx.Request = ginCtx.Request.WithContext(ctx)
ginCtx.Next()
if err := ginCtx.Errors.Last(); err != nil {
return err.Err
}
return nil
}
handler := middleware.Chain(s.middlewares...)(next)
md := metadata.FromContext(ctx)
md.TeeReader(&httpMetadataReader{
hd: ginCtx.Request.Header,
})
md.TeeWriter(&httpMetadataWriter{
w: ginCtx.Writer,
})
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestProtocolKey, Protocol)
md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path)
md.Set("method", ginCtx.Request.Method)
ctx = metadata.NewContext(ctx, md)
if err := handler(ctx); err != nil {
if se, ok := err.(*errors.Error); ok {
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil))
} else {
ginCtx.AbortWithError(http.StatusInternalServerError, err)
}
}
}
}
func (s *Server) wrapHandle(f http.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
f(c.Writer, c.Request)
}
}
func (s *Server) createListener() (err error) {
if s.listener == nil {
if s.listener, err = net.Listen(s.opts.network, s.opts.address); err != nil {
return
}
s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener)
}
return
}
func (s *Server) Start(ctx context.Context) (err error) {
s.serve = &http.Server{
Addr: s.opts.address,
Handler: s.engine,
}
if s.opts.logger != nil {
s.Logger = s.opts.logger
}
if s.Logger == nil {
s.Logger = logger.Default()
}
s.ctx = ctx
if s.opts.debug {
s.engine.Handle(http.MethodGet, "/debug/pprof/", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/goroutine", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/heap", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/mutex", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/threadcreate", s.wrapHandle(pprof.Index))
s.engine.Handle(http.MethodGet, "/debug/pprof/cmdline", s.wrapHandle(pprof.Cmdline))
s.engine.Handle(http.MethodGet, "/debug/pprof/profile", s.wrapHandle(pprof.Profile))
s.engine.Handle(http.MethodGet, "/debug/pprof/symbol", s.wrapHandle(pprof.Symbol))
s.engine.Handle(http.MethodGet, "/debug/pprof/trace", s.wrapHandle(pprof.Trace))
}
if err = s.createListener(); err != nil {
return
}
s.engine.NoRoute(s.notFoundHandle)
s.Logger.Infof(ctx, "http server listen on: %s", s.uri.Host)
if s.opts.certFile != "" && s.opts.keyFile != "" {
s.uri.Scheme = "https"
err = s.serve.ServeTLS(s.listener, s.opts.certFile, s.opts.keyFile)
} else {
err = s.serve.Serve(s.listener)
}
if !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
func (s *Server) Stop(ctx context.Context) (err error) {
err = s.serve.Shutdown(ctx)
s.Logger.Infof(ctx, "http server stopped")
return
}
func New(cbs ...Option) *Server {
svr := &Server{
uri: &url.URL{Scheme: "http"},
opts: &options{
network: "tcp",
},
}
port, _ := strconv.Atoi(os.Getenv("HTTP_PORT"))
svr.opts.address = fmt.Sprintf(":%d", port)
for _, cb := range cbs {
cb(svr.opts)
}
if !svr.opts.debug {
gin.SetMode(gin.ReleaseMode)
}
svr.engine = gin.New(svr.opts.ginOptions...)
if svr.opts.enableCORS {
svr.engine.Use(svr.CORSInterceptor())
}
svr.engine.Use(svr.requestInterceptor())
return svr
}