package http import ( "context" "fmt" "io" "net" "net/http" "net/http/pprof" "net/url" "os" "path" "path/filepath" "slices" "strconv" "strings" "sync" "time" "git.nobla.cn/golang/aeus/metadata" "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/errors" "git.nobla.cn/golang/aeus/pkg/logger" netutil "git.nobla.cn/golang/aeus/pkg/net" "github.com/gin-gonic/gin" "github.com/google/uuid" ) type Server struct { ctx context.Context opts *options uri *url.URL serve *http.Server engine *gin.Engine once sync.Once listener net.Listener fs *filesystem middlewares []middleware.Middleware } func (s *Server) Endpoint(ctx context.Context) (string, error) { if err := s.createListener(); err != nil { return "", err } return s.uri.String(), nil } func (s *Server) GET(pattern string, h HandleFunc) { s.engine.GET(pattern, func(ctx *gin.Context) { childCtx := newContext(ctx) h(childCtx) putContext(childCtx) }) } func (s *Server) PUT(pattern string, h HandleFunc) { s.engine.PUT(pattern, func(ctx *gin.Context) { childCtx := newContext(ctx) h(childCtx) putContext(childCtx) }) } func (s *Server) POST(pattern string, h HandleFunc) { s.engine.POST(pattern, func(ctx *gin.Context) { childCtx := newContext(ctx) h(childCtx) putContext(childCtx) }) } func (s *Server) HEAD(pattern string, h HandleFunc) { s.engine.HEAD(pattern, func(ctx *gin.Context) { childCtx := newContext(ctx) h(childCtx) putContext(childCtx) }) } func (s *Server) PATCH(pattern string, h HandleFunc) { s.engine.PATCH(pattern, func(ctx *gin.Context) { childCtx := newContext(ctx) h(childCtx) putContext(childCtx) }) } func (s *Server) DELETE(pattern string, h HandleFunc) { s.engine.DELETE(pattern, func(ctx *gin.Context) { childCtx := newContext(ctx) h(childCtx) putContext(childCtx) }) } func (s *Server) Use(middlewares ...middleware.Middleware) { s.middlewares = append(s.middlewares, middlewares...) } func (s *Server) Handle(method string, uri string, handler http.HandlerFunc) { s.engine.Handle(method, uri, func(ctx *gin.Context) { handler(ctx.Writer, ctx.Request) }) } func (s *Server) Webroot(prefix string, fs http.FileSystem) { s.fs = newFS(time.Now(), fs) s.fs.SetPrefix(prefix) s.fs.DenyAccessDirectory() s.fs.SetIndexFile("/index.html") } func (s *Server) shouldCompress(req *http.Request) bool { if !strings.Contains(req.Header.Get(headerAcceptEncoding), "gzip") || strings.Contains(req.Header.Get("Connection"), "Upgrade") { return false } // Check if the request path is excluded from compression extension := filepath.Ext(req.URL.Path) if slices.Contains(assetsExtensions, extension) { return true } return false } func (s *Server) staticHandle(ctx *gin.Context, fp http.File) { uri := path.Clean(ctx.Request.URL.Path) fi, err := fp.Stat() if err != nil { return } if !fi.IsDir() { //https://github.com/gin-contrib/gzip if s.shouldCompress(ctx.Request) && fi.Size() > 8192 { gzWriter := newGzipWriter() gzWriter.Reset(ctx.Writer) ctx.Header(headerContentEncoding, "gzip") ctx.Writer.Header().Add(headerVary, headerAcceptEncoding) originalEtag := ctx.GetHeader("ETag") if originalEtag != "" && !strings.HasPrefix(originalEtag, "W/") { ctx.Header("ETag", "W/"+originalEtag) } ctx.Writer = &gzipWriter{ctx.Writer, gzWriter} defer func() { if ctx.Writer.Size() < 0 { gzWriter.Reset(io.Discard) } gzWriter.Close() if ctx.Writer.Size() > -1 { ctx.Header("Content-Length", strconv.Itoa(ctx.Writer.Size())) } putGzipWriter(gzWriter) }() } } http.ServeContent(ctx.Writer, ctx.Request, path.Base(uri), s.fs.modtime, fp) ctx.Abort() return } func (s *Server) notFoundHandle(ctx *gin.Context) { if s.fs != nil && ctx.Request.Method == http.MethodGet { uri := path.Clean(ctx.Request.URL.Path) if fp, err := s.fs.Open(uri); err == nil { s.staticHandle(ctx, fp) fp.Close() } } ctx.JSON(http.StatusNotFound, newResponse(errors.NotFound, "Not Found", nil)) } func (s *Server) CORSInterceptor() gin.HandlerFunc { return func(c *gin.Context) { if c.Request.Method == "OPTIONS" { c.Writer.Header().Add("Vary", "Origin") c.Writer.Header().Add("Vary", "Access-Control-Request-Method") c.Writer.Header().Add("Vary", "Access-Control-Request-Headers") c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,HEAD,PUT,PATCH,POST,DELETE") h := c.Request.Header.Get("Access-Control-Request-Headers") if h != "" { c.Writer.Header().Set("Access-Control-Allow-Headers", h) } c.AbortWithStatus(204) return } else { c.Writer.Header().Add("Vary", "Origin") c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") h := c.Request.Header.Get("Access-Control-Request-Headers") if h != "" { c.Writer.Header().Set("Access-Control-Allow-Headers", h) } } c.Next() } } func (s *Server) requestInterceptor() gin.HandlerFunc { return func(ginCtx *gin.Context) { ctx := ginCtx.Request.Context() next := func(ctx context.Context) error { ginCtx.Request = ginCtx.Request.WithContext(ctx) ginCtx.Next() if err := ginCtx.Errors.Last(); err != nil { return err.Err } return nil } handler := middleware.Chain(s.middlewares...)(next) md := metadata.FromContext(ctx) md.TeeReader(&httpMetadataReader{ hd: ginCtx.Request.Header, }) md.TeeWriter(&httpMetadataWriter{ w: ginCtx.Writer, }) if !md.Has(metadata.RequestIDKey) { md.Set(metadata.RequestIDKey, uuid.New().String()) } md.Set(metadata.RequestProtocolKey, Protocol) md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path) md.Set("method", ginCtx.Request.Method) ctx = metadata.NewContext(ctx, md) if err := handler(ctx); err != nil { if se, ok := err.(*errors.Error); ok { ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil)) } else { ginCtx.AbortWithError(http.StatusInternalServerError, err) } } } } func (s *Server) wrapHandle(f http.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { f(c.Writer, c.Request) } } func (s *Server) createListener() (err error) { if s.listener == nil { if s.listener, err = net.Listen(s.opts.network, s.opts.address); err != nil { return } s.uri.Host = netutil.TrulyAddr(s.opts.address, s.listener) } return } func (s *Server) Start(ctx context.Context) (err error) { s.serve = &http.Server{ Addr: s.opts.address, Handler: s.engine, } s.ctx = ctx if s.opts.debug { s.engine.Handle(http.MethodGet, "/debug/pprof/", s.wrapHandle(pprof.Index)) s.engine.Handle(http.MethodGet, "/debug/pprof/goroutine", s.wrapHandle(pprof.Index)) s.engine.Handle(http.MethodGet, "/debug/pprof/heap", s.wrapHandle(pprof.Index)) s.engine.Handle(http.MethodGet, "/debug/pprof/mutex", s.wrapHandle(pprof.Index)) s.engine.Handle(http.MethodGet, "/debug/pprof/threadcreate", s.wrapHandle(pprof.Index)) s.engine.Handle(http.MethodGet, "/debug/pprof/cmdline", s.wrapHandle(pprof.Cmdline)) s.engine.Handle(http.MethodGet, "/debug/pprof/profile", s.wrapHandle(pprof.Profile)) s.engine.Handle(http.MethodGet, "/debug/pprof/symbol", s.wrapHandle(pprof.Symbol)) s.engine.Handle(http.MethodGet, "/debug/pprof/trace", s.wrapHandle(pprof.Trace)) } if err = s.createListener(); err != nil { return } s.engine.NoRoute(s.notFoundHandle) s.opts.logger.Info(ctx, "http server listen on: %s", s.uri.Host) if s.opts.certFile != "" && s.opts.keyFile != "" { s.uri.Scheme = "https" err = s.serve.ServeTLS(s.listener, s.opts.certFile, s.opts.keyFile) } else { err = s.serve.Serve(s.listener) } if !errors.Is(err, http.ErrServerClosed) { return err } return nil } func (s *Server) Stop(ctx context.Context) (err error) { err = s.serve.Shutdown(ctx) s.opts.logger.Info(ctx, "http server stopped") return } func New(cbs ...Option) *Server { svr := &Server{ uri: &url.URL{Scheme: "http"}, opts: &options{ network: "tcp", logger: logger.Default(), }, } port, _ := strconv.Atoi(os.Getenv("HTTP_PORT")) svr.opts.address = fmt.Sprintf(":%d", port) for _, cb := range cbs { cb(svr.opts) } if !svr.opts.debug { gin.SetMode(gin.ReleaseMode) } svr.engine = gin.New(svr.opts.ginOptions...) if svr.opts.enableCORS { svr.engine.Use(svr.CORSInterceptor()) } svr.engine.Use(svr.requestInterceptor()) return svr }