package auth import ( "context" "strings" "git.nobla.cn/golang/aeus/metadata" "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/errors" jwt "github.com/golang-jwt/jwt/v5" ) type authKey struct{} const ( // bearerWord the bearer key word for authorization bearerWord string = "Bearer" // bearerFormat authorization token format bearerFormat string = "Bearer %s" // authorizationKey holds the key used to store the JWT Token in the request tokenHeader. authorizationKey string = "Authorization" // reason holds the error reason. reason string = "UNAUTHORIZED" ) type Option func(*options) // Parser is a jwt parser type options struct { allows []string claims func() jwt.Claims } // WithClaims with customer claim // If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems // If you use it in Client, f only needs to return a single object to provide performance func WithClaims(f func() jwt.Claims) Option { return func(o *options) { o.claims = f } } // WithAllow with allow path func WithAllow(paths ...string) Option { return func(o *options) { if o.allows == nil { o.allows = make([]string, 0, 16) } o.allows = append(o.allows, paths...) } } // isAllowed check if the path is allowed func isAllowed(uripath string, allows []string) bool { for _, str := range allows { n := len(str) if n == 0 { continue } if n > 1 && str[n-1] == '*' { if strings.HasPrefix(uripath, str[:n-1]) { return true } } if str == uripath { return true } } return true } // JWT auth middleware func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware { opts := options{} for _, cb := range cbs { cb(&opts) } return func(next middleware.Handler) middleware.Handler { return func(ctx context.Context) (err error) { md := metadata.FromContext(ctx) if len(opts.allows) > 0 { requestPath, ok := md.Get(metadata.RequestPathKey) if ok { if isAllowed(requestPath, opts.allows) { return next(ctx) } } } authorizationValue, ok := md.Get(authorizationKey) if !ok { return errors.ErrAccessDenied } if !strings.HasPrefix(authorizationValue, bearerWord) { return errors.ErrAccessDenied } var ( ti *jwt.Token ) authorizationToken := strings.TrimSpace(strings.TrimPrefix(authorizationValue, bearerWord)) if opts.claims != nil { ti, err = jwt.ParseWithClaims(authorizationToken, opts.claims(), keyFunc) } else { ti, err = jwt.Parse(authorizationToken, keyFunc) } if err != nil { if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) { return errors.ErrAccessDenied } if errors.Is(err, jwt.ErrTokenNotValidYet) || errors.Is(err, jwt.ErrTokenExpired) { return errors.ErrTokenExpired } return errors.ErrPermissionDenied } if !ti.Valid { return errors.ErrPermissionDenied } ctx = NewContext(ctx, ti.Claims) return next(ctx) } } } // NewContext put auth info into context func NewContext(ctx context.Context, info jwt.Claims) context.Context { return context.WithValue(ctx, authKey{}, info) } // FromContext extract auth info from context func FromContext(ctx context.Context) (token jwt.Claims, ok bool) { token, ok = ctx.Value(authKey{}).(jwt.Claims) return }