rest/plugins/validate/validation.go

276 lines
6.9 KiB
Go

package validate
import (
"context"
"fmt"
"git.nobla.cn/golang/rest"
"git.nobla.cn/golang/rest/types"
validator "github.com/go-playground/validator/v10"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"reflect"
"strconv"
"strings"
)
type Validate struct {
validator *validator.Validate
}
func (validate *Validate) Name() string {
return "gorm:validate"
}
func (validate *Validate) telephoneValidate(ctx context.Context, fl validator.FieldLevel) bool {
val := fmt.Sprint(fl.Field().Interface())
return telephoneRegex.MatchString(val)
}
func (validate *Validate) uniqueValidate(ctx context.Context, fl validator.FieldLevel) bool {
var (
scope *validateScope
ok bool
count int64
field *schema.Field
primaryKeyValue reflect.Value
)
val := fl.Field().Interface()
if scope, ok = ctx.Value(scopeCtxKey).(*validateScope); !ok {
return true
}
if len(scope.DB.Statement.Schema.PrimaryFields) > 0 {
field = scope.DB.Statement.Schema.PrimaryFields[0]
primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model))
for _, n := range field.BindNames {
primaryKeyValue = primaryKeyValue.FieldByName(n)
}
}
sess := scope.DB.Session(&gorm.Session{NewDB: true})
if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil {
//多域校验
if scope.MultiDomain && scope.Domain != "" {
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ? AND domain = ?", val, primaryKeyValue.Interface(), scope.Domain).Count(&count)
} else {
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count)
}
} else {
if scope.MultiDomain && scope.Domain != "" {
sess.Model(scope.Model).Where(scope.Column+"=? AND domain = ?", val, scope.Domain).Count(&count)
} else {
sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count)
}
}
if count > 0 {
return false
}
return true
}
func (validate *Validate) grantRules(scm *types.Schema, scenario string, rule types.Rule) []*validateRule {
rules := make([]*validateRule, 0, 5)
if rule.Min != 0 {
rules = append(rules, newRule("min", strconv.Itoa(rule.Min)))
}
if rule.Max != 0 {
rules = append(rules, newRule("max", strconv.Itoa(rule.Max)))
}
//主键不做唯一判断
if rule.Unique && !scm.Attribute.PrimaryKey {
rules = append(rules, newRule("db_unique"))
}
if rule.Type != "" {
rules = append(rules, newRule(rule.Type))
}
if rule.Required != nil && len(rule.Required) > 0 {
for _, v := range rule.Required {
if v == scenario {
rules = append(rules, newRule("required"))
break
}
}
}
return rules
}
func (validate *Validate) buildRules(rs []*validateRule) string {
var sb strings.Builder
for _, r := range rs {
if !r.Valid {
continue
}
if sb.Len() > 0 {
sb.WriteString(",")
}
if r.Value == "" {
sb.WriteString(r.Rule)
} else {
sb.WriteString(r.Rule + "=" + r.Value)
}
}
return sb.String()
}
func (validate *Validate) findRule(name string, rules []*validateRule) *validateRule {
for _, r := range rules {
if r.Rule == name {
return r
}
}
return nil
}
func (validate *Validate) Initialize(db *gorm.DB) (err error) {
validate.validator = validator.New()
if err = db.Callback().Create().Before("gorm:before_create").Register("rest_validate_create", validate.Validate); err != nil {
return
}
if err = db.Callback().Create().Before("gorm:before_update").Register("rest_validate_update", validate.Validate); err != nil {
return
}
if err = validate.validator.RegisterValidationCtx("telephone", validate.telephoneValidate); err != nil {
return
}
if err = validate.validator.RegisterValidationCtx("db_unique", validate.uniqueValidate); err != nil {
return
}
return
}
func (validate *Validate) inArray(v any, vs []any) bool {
sv := fmt.Sprint(v)
for _, s := range vs {
if fmt.Sprint(s) == sv {
return true
}
}
return false
}
// isVisible 判断字段是否需要显示
func (validate *Validate) isVisible(stmt *gorm.Statement, scm *types.Schema) bool {
if len(scm.Attribute.Visible) <= 0 {
return true
}
for _, row := range scm.Attribute.Visible {
if len(row.Values) == 0 {
continue
}
targetField := stmt.Schema.LookUpField(row.Column)
if targetField == nil {
continue
}
targetValue := stmt.ReflectValue.FieldByName(targetField.Name)
if !targetValue.IsValid() {
return false
}
if !validate.inArray(targetValue.Interface(), row.Values) {
return false
}
}
return true
}
// Validate 校验字段
func (validate *Validate) Validate(db *gorm.DB) {
var (
ok bool
err error
rules []*validateRule
stmt *gorm.Statement
skipValidate bool
multiDomain bool
value reflect.Value
runtimeScope *types.RuntimeScope
)
if result, ok := db.Get(SkipValidations); ok && result.(bool) {
return
}
stmt = db.Statement
if stmt.Model == nil {
return
}
if db.Statement.Context == nil {
return
}
if val := db.Statement.Context.Value(rest.RuntimeScopeKey); val != nil {
if runtimeScope, ok = val.(*types.RuntimeScope); !ok {
return
}
} else {
return
}
if runtimeScope.Schemas == nil {
return
}
if stmt.Schema.LookUpField("domain") != nil {
multiDomain = true
}
for _, row := range runtimeScope.Schemas {
//如果字段隐藏,那么就不进行校验
if !validate.isVisible(stmt, row) {
continue
}
if rules = validate.grantRules(row, runtimeScope.Scenario, row.Rule); len(rules) <= 0 {
continue
}
field := stmt.Schema.LookUpField(row.Column)
if field == nil {
continue
}
value = stmt.ReflectValue.FieldByName(field.Name)
if !value.IsValid() {
continue
}
skipValidate = false
if r := validate.findRule("required", rules); r != nil {
if value.Interface() != nil {
vType := reflect.ValueOf(value.Interface())
switch vType.Kind() {
case reflect.Bool:
skipValidate = true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
skipValidate = true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
skipValidate = true
case reflect.Float32, reflect.Float64:
skipValidate = true
default:
skipValidate = false
}
}
if skipValidate {
r.Valid = false
}
} else {
if isEmpty(value.Interface()) {
continue
}
}
ctx := context.WithValue(db.Statement.Context, scopeCtxKey, &validateScope{
DB: db,
Column: row.Column,
Model: stmt.Model,
Domain: runtimeScope.Domain,
MultiDomain: multiDomain,
})
if err = validate.validator.VarCtx(ctx, value.Interface(), validate.buildRules(rules)); err != nil {
if errs, ok := err.(validator.ValidationErrors); ok {
for _, e := range errs {
_ = db.AddError(&StructError{
Tag: e.Tag(),
Column: row.Column,
Message: formatError(row.Rule, row, e.Tag()),
})
}
} else {
_ = db.AddError(err)
}
break
}
}
}
func New() *Validate {
return &Validate{}
}