165 lines
3.6 KiB
Go
165 lines
3.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"git.nobla.cn/golang/aeus/metadata"
|
|
"git.nobla.cn/golang/aeus/middleware"
|
|
"git.nobla.cn/golang/aeus/pkg/errors"
|
|
jwt "github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
type authKey struct{}
|
|
|
|
const (
|
|
|
|
// bearerWord the bearer key word for authorization
|
|
bearerWord string = "Bearer"
|
|
|
|
// bearerFormat authorization token format
|
|
bearerFormat string = "Bearer %s"
|
|
|
|
// authorizationKey holds the key used to store the JWT Token in the request tokenHeader.
|
|
authorizationKey string = "Authorization"
|
|
|
|
// reason holds the error reason.
|
|
reason string = "UNAUTHORIZED"
|
|
)
|
|
|
|
type Option func(*options)
|
|
|
|
type Validate interface {
|
|
Validate(ctx context.Context, token string) error
|
|
}
|
|
|
|
// Parser is a jwt parser
|
|
type options struct {
|
|
allows []string
|
|
claims reflect.Type
|
|
validate Validate
|
|
}
|
|
|
|
// WithAllow with allow path
|
|
func WithAllow(paths ...string) Option {
|
|
return func(o *options) {
|
|
if o.allows == nil {
|
|
o.allows = make([]string, 0, 16)
|
|
}
|
|
for _, s := range paths {
|
|
s = strings.TrimSpace(s)
|
|
if len(s) == 0 {
|
|
continue
|
|
}
|
|
o.allows = append(o.allows, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithClaims(claims any) Option {
|
|
return func(o *options) {
|
|
if tv, ok := claims.(reflect.Type); ok {
|
|
o.claims = tv
|
|
} else {
|
|
o.claims = reflect.TypeOf(claims)
|
|
if o.claims.Kind() == reflect.Ptr {
|
|
o.claims = o.claims.Elem()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithValidate(fn Validate) Option {
|
|
return func(o *options) {
|
|
o.validate = fn
|
|
}
|
|
}
|
|
|
|
// isAllowed check if the path is allowed
|
|
func isAllowed(uripath string, allows []string) bool {
|
|
for _, pattern := range allows {
|
|
n := len(pattern)
|
|
if pattern == uripath {
|
|
return true
|
|
}
|
|
if pattern == "*" {
|
|
return true
|
|
}
|
|
if n > 1 && pattern[n-1] == '*' {
|
|
if strings.HasPrefix(uripath, pattern[:n-1]) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// JWT auth middleware
|
|
func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware {
|
|
opts := options{}
|
|
for _, cb := range cbs {
|
|
cb(&opts)
|
|
}
|
|
return func(next middleware.Handler) middleware.Handler {
|
|
return func(ctx context.Context) (err error) {
|
|
md := metadata.FromContext(ctx)
|
|
if len(opts.allows) > 0 {
|
|
requestPath, ok := md.Get(metadata.RequestPathKey)
|
|
if ok {
|
|
if isAllowed(requestPath, opts.allows) {
|
|
return next(ctx)
|
|
}
|
|
}
|
|
}
|
|
token, ok := md.Get(authorizationKey)
|
|
if !ok {
|
|
return errors.ErrAccessDenied
|
|
}
|
|
if opts.validate != nil {
|
|
if err = opts.validate.Validate(ctx, token); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
token, _ = strings.CutPrefix(token, bearerWord)
|
|
var (
|
|
ti *jwt.Token
|
|
)
|
|
token = strings.TrimSpace(token)
|
|
if opts.claims != nil {
|
|
if claims, ok := reflect.New(opts.claims).Interface().(jwt.Claims); ok {
|
|
ti, err = jwt.ParseWithClaims(token, claims, keyFunc)
|
|
}
|
|
}
|
|
if ti == nil {
|
|
ti, err = jwt.Parse(token, keyFunc)
|
|
}
|
|
if err != nil {
|
|
if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) {
|
|
return errors.ErrAccessDenied
|
|
}
|
|
if errors.Is(err, jwt.ErrTokenNotValidYet) || errors.Is(err, jwt.ErrTokenExpired) {
|
|
return errors.ErrTokenExpired
|
|
}
|
|
return errors.ErrPermissionDenied
|
|
}
|
|
if !ti.Valid {
|
|
return errors.ErrPermissionDenied
|
|
}
|
|
ctx = NewContext(ctx, ti.Claims)
|
|
return next(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
// NewContext put auth info into context
|
|
func NewContext(ctx context.Context, info jwt.Claims) context.Context {
|
|
return context.WithValue(ctx, authKey{}, info)
|
|
}
|
|
|
|
// FromContext extract auth info from context
|
|
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) {
|
|
token, ok = ctx.Value(authKey{}).(jwt.Claims)
|
|
return
|
|
}
|