aeus/middleware/auth/jwt.go

137 lines
3.3 KiB
Go

package auth
import (
"context"
"strings"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/errors"
jwt "github.com/golang-jwt/jwt/v5"
)
type authKey struct{}
const (
// bearerWord the bearer key word for authorization
bearerWord string = "Bearer"
// bearerFormat authorization token format
bearerFormat string = "Bearer %s"
// authorizationKey holds the key used to store the JWT Token in the request tokenHeader.
authorizationKey string = "Authorization"
// reason holds the error reason.
reason string = "UNAUTHORIZED"
)
type Option func(*options)
// Parser is a jwt parser
type options struct {
allows []string
claims func() jwt.Claims
}
// WithClaims with customer claim
// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems
// If you use it in Client, f only needs to return a single object to provide performance
func WithClaims(f func() jwt.Claims) Option {
return func(o *options) {
o.claims = f
}
}
// WithAllow with allow path
func WithAllow(paths ...string) Option {
return func(o *options) {
if o.allows == nil {
o.allows = make([]string, 0, 16)
}
o.allows = append(o.allows, paths...)
}
}
// isAllowed check if the path is allowed
func isAllowed(uripath string, allows []string) bool {
for _, str := range allows {
n := len(str)
if n == 0 {
continue
}
if n > 1 && str[n-1] == '*' {
if strings.HasPrefix(uripath, str[:n-1]) {
return true
}
}
if str == uripath {
return true
}
}
return true
}
// JWT auth middleware
func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware {
opts := options{}
for _, cb := range cbs {
cb(&opts)
}
return func(next middleware.Handler) middleware.Handler {
return func(ctx context.Context) (err error) {
md := metadata.FromContext(ctx)
if len(opts.allows) > 0 {
requestPath, ok := md.Get(metadata.RequestPathKey)
if ok {
if isAllowed(requestPath, opts.allows) {
return next(ctx)
}
}
}
authorizationValue, ok := md.Get(authorizationKey)
if !ok {
return errors.ErrAccessDenied
}
if !strings.HasPrefix(authorizationValue, bearerWord) {
return errors.ErrAccessDenied
}
var (
ti *jwt.Token
)
authorizationToken := strings.TrimSpace(strings.TrimPrefix(authorizationValue, bearerWord))
if opts.claims != nil {
ti, err = jwt.ParseWithClaims(authorizationToken, opts.claims(), keyFunc)
} else {
ti, err = jwt.Parse(authorizationToken, keyFunc)
}
if err != nil {
if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) {
return errors.ErrAccessDenied
}
if errors.Is(err, jwt.ErrTokenNotValidYet) || errors.Is(err, jwt.ErrTokenExpired) {
return errors.ErrTokenExpired
}
return errors.ErrPermissionDenied
}
if !ti.Valid {
return errors.ErrPermissionDenied
}
ctx = NewContext(ctx, ti.Claims)
return next(ctx)
}
}
}
// NewContext put auth info into context
func NewContext(ctx context.Context, info jwt.Claims) context.Context {
return context.WithValue(ctx, authKey{}, info)
}
// FromContext extract auth info from context
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) {
token, ok = ctx.Value(authKey{}).(jwt.Claims)
return
}