package auth import ( "context" "reflect" "strings" "git.nobla.cn/golang/aeus/metadata" "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/errors" jwt "github.com/golang-jwt/jwt/v5" ) type authKey struct{} const ( // bearerWord the bearer key word for authorization bearerWord string = "Bearer" // bearerFormat authorization token format bearerFormat string = "Bearer %s" // authorizationKey holds the key used to store the JWT Token in the request tokenHeader. authorizationKey string = "Authorization" // reason holds the error reason. reason string = "UNAUTHORIZED" ) type Option func(*options) type Validate interface { Validate(ctx context.Context, token string) error } // Parser is a jwt parser type options struct { allows []string claims reflect.Type validate Validate } // WithAllow with allow path func WithAllow(paths ...string) Option { return func(o *options) { if o.allows == nil { o.allows = make([]string, 0, 16) } o.allows = append(o.allows, paths...) } } func WithClaims(claims reflect.Type) Option { return func(o *options) { o.claims = claims } } func WithValidate(fn Validate) Option { return func(o *options) { o.validate = fn } } // isAllowed check if the path is allowed func isAllowed(uripath string, allows []string) bool { for _, str := range allows { n := len(str) if n == 0 { continue } if n > 1 && str[n-1] == '*' { if strings.HasPrefix(uripath, str[:n-1]) { return true } } if str == uripath { return true } } return false } // JWT auth middleware func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware { opts := options{} for _, cb := range cbs { cb(&opts) } return func(next middleware.Handler) middleware.Handler { return func(ctx context.Context) (err error) { md := metadata.FromContext(ctx) if len(opts.allows) > 0 { requestPath, ok := md.Get(metadata.RequestPathKey) if ok { if isAllowed(requestPath, opts.allows) { return next(ctx) } } } token, ok := md.Get(authorizationKey) if !ok { return errors.ErrAccessDenied } if opts.validate != nil { if err = opts.validate.Validate(ctx, token); err != nil { return err } } if strings.HasPrefix(token, bearerWord) { token = strings.TrimPrefix(token, bearerWord) } var ( ti *jwt.Token ) token = strings.TrimSpace(token) if opts.claims != nil { if claims, ok := reflect.New(opts.claims).Interface().(jwt.Claims); ok { ti, err = jwt.ParseWithClaims(token, claims, keyFunc) } } if ti == nil { ti, err = jwt.Parse(token, keyFunc) } if err != nil { if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) { return errors.ErrAccessDenied } if errors.Is(err, jwt.ErrTokenNotValidYet) || errors.Is(err, jwt.ErrTokenExpired) { return errors.ErrTokenExpired } return errors.ErrPermissionDenied } if !ti.Valid { return errors.ErrPermissionDenied } ctx = NewContext(ctx, ti.Claims) return next(ctx) } } } // NewContext put auth info into context func NewContext(ctx context.Context, info jwt.Claims) context.Context { return context.WithValue(ctx, authKey{}, info) } // FromContext extract auth info from context func FromContext(ctx context.Context) (token jwt.Claims, ok bool) { token, ok = ctx.Value(authKey{}).(jwt.Claims) return }