diff --git a/service/auth.go b/service/auth.go index 4a8bfe8..950f655 100644 --- a/service/auth.go +++ b/service/auth.go @@ -27,6 +27,7 @@ type ( tokenStore TokenStore turnstileValidateUrl string turnstileSiteKey string + enableRefreshToken bool } turnstileRequest struct { @@ -65,6 +66,12 @@ func WithAuthCache(cache cache.Cache) AuthOption { } } +func WithRefreshToken() AuthOption { + return func(o *authOptions) { + o.enableRefreshToken = true + } +} + func WithTokenStore(store TokenStore) AuthOption { return func(o *authOptions) { o.tokenStore = store @@ -152,23 +159,25 @@ func (s *AuthService) Login(ctx context.Context, req *pb.LoginRequest) (res *pb. IssuedAt: time.Now().Unix(), ExpirationAt: time.Now().Add(time.Second * time.Duration(s.opts.ttl)).Unix(), } - refreshClaims := types.Claims{ - Uid: model.Uid, - Role: model.Role, - Admin: model.Admin, - IssuedAt: time.Now().Unix(), - ExpirationAt: time.Now().Add(time.Hour * 48).Unix(), - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) res = &pb.LoginResponse{} if res.Token, err = token.SignedString(s.opts.secret); err == nil { res.Uid = model.Uid res.Username = model.Username res.Expires = s.opts.ttl } - res.RefreshToken, err = refreshToken.SignedString(s.opts.secret) + if s.opts.enableRefreshToken { + refreshClaims := types.Claims{ + Uid: model.Uid, + Role: model.Role, + Admin: model.Admin, + IssuedAt: time.Now().Unix(), + ExpirationAt: time.Now().Add(time.Hour * 48).Unix(), + } + refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims) + res.RefreshToken, err = refreshToken.SignedString(s.opts.secret) + } loginModel := &models.Login{} loginModel.Uid = model.Uid loginModel.AccessToken = res.Token @@ -197,31 +206,34 @@ func (s *AuthService) Logout(ctx context.Context, req *pb.LogoutRequest) (res *p func (s *AuthService) RefreshToken(ctx context.Context, req *pb.RefreshTokenRequest) (res *pb.RefreshTokenResponse, err error) { var ( - token *jwt.Token + refreshToken *jwt.Token ) - if token, err = jwt.ParseWithClaims(req.RefreshToken, &types.Claims{}, func(token *jwt.Token) (interface{}, error) { + if !s.opts.enableRefreshToken { + err = errors.ErrAccessDenied + return + } + if refreshToken, err = jwt.ParseWithClaims(req.RefreshToken, &types.Claims{}, func(token *jwt.Token) (interface{}, error) { return s.opts.secret, nil }); err != nil { return } - if claims, ok := token.Claims.(*types.Claims); ok { - tokenClaims := types.Claims{ - Uid: claims.Uid, - Role: claims.Role, - Admin: claims.Admin, - IssuedAt: time.Now().Unix(), - ExpirationAt: time.Now().Add(time.Second * time.Duration(s.opts.ttl)).Unix(), - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenClaims) - res = &pb.RefreshTokenResponse{} - if res.Token, err = token.SignedString(s.opts.secret); err == nil { - res.Uid = claims.Uid - res.Expires = s.opts.ttl - return - } - } else { + refreshClaims, ok := refreshToken.Claims.(*types.Claims) + if !ok { err = errors.ErrIncompatible } + tokenClaims := types.Claims{ + Uid: refreshClaims.Uid, + Role: refreshClaims.Role, + Admin: refreshClaims.Admin, + IssuedAt: time.Now().Unix(), + ExpirationAt: time.Now().Add(time.Second * time.Duration(s.opts.ttl)).Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenClaims) + res = &pb.RefreshTokenResponse{} + if res.Token, err = token.SignedString(s.opts.secret); err == nil { + res.Uid = refreshClaims.Uid + res.Expires = s.opts.ttl + } return }