package validate import ( "context" "fmt" "git.nobla.cn/golang/rest" "git.nobla.cn/golang/rest/types" validator "github.com/go-playground/validator/v10" "gorm.io/gorm" "gorm.io/gorm/schema" "reflect" "strconv" "strings" ) type Validate struct { validator *validator.Validate } func (validate *Validate) Name() string { return "gorm:validate" } func (validate *Validate) telephoneValidate(ctx context.Context, fl validator.FieldLevel) bool { val := fmt.Sprint(fl.Field().Interface()) return telephoneRegex.MatchString(val) } func (validate *Validate) uniqueValidate(ctx context.Context, fl validator.FieldLevel) bool { var ( scope *validateScope ok bool count int64 field *schema.Field primaryKeyValue reflect.Value ) val := fl.Field().Interface() if scope, ok = ctx.Value(scopeCtxKey).(*validateScope); !ok { return true } if len(scope.DB.Statement.Schema.PrimaryFields) > 0 { field = scope.DB.Statement.Schema.PrimaryFields[0] primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model)) for _, n := range field.BindNames { primaryKeyValue = primaryKeyValue.FieldByName(n) } } sess := scope.DB.Session(&gorm.Session{NewDB: true}) if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil { //多域校验 if scope.MultiDomain && scope.Domain != "" { sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ? AND domain = ?", val, primaryKeyValue.Interface(), scope.Domain).Count(&count) } else { sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count) } } else { if scope.MultiDomain && scope.Domain != "" { sess.Model(scope.Model).Where(scope.Column+"=? AND domain = ?", val, scope.Domain).Count(&count) } else { sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count) } } if count > 0 { return false } return true } func (validate *Validate) grantRules(scm *types.Schema, scenario string, rule types.Rule) []*validateRule { rules := make([]*validateRule, 0, 5) if rule.Min != 0 { rules = append(rules, newRule("min", strconv.Itoa(rule.Min))) } if rule.Max != 0 { rules = append(rules, newRule("max", strconv.Itoa(rule.Max))) } //主键不做唯一判断 if rule.Unique && !scm.Attribute.PrimaryKey { rules = append(rules, newRule("db_unique")) } if rule.Type != "" { rules = append(rules, newRule(rule.Type)) } if rule.Required != nil && len(rule.Required) > 0 { for _, v := range rule.Required { if v == scenario { rules = append(rules, newRule("required")) break } } } return rules } func (validate *Validate) buildRules(rs []*validateRule) string { var sb strings.Builder for _, r := range rs { if !r.Valid { continue } if sb.Len() > 0 { sb.WriteString(",") } if r.Value == "" { sb.WriteString(r.Rule) } else { sb.WriteString(r.Rule + "=" + r.Value) } } return sb.String() } func (validate *Validate) findRule(name string, rules []*validateRule) *validateRule { for _, r := range rules { if r.Rule == name { return r } } return nil } func (validate *Validate) Initialize(db *gorm.DB) (err error) { validate.validator = validator.New() if err = db.Callback().Create().Before("gorm:before_create").Register("rest_validate_create", validate.Validate); err != nil { return } if err = db.Callback().Create().Before("gorm:before_update").Register("rest_validate_update", validate.Validate); err != nil { return } if err = validate.validator.RegisterValidationCtx("telephone", validate.telephoneValidate); err != nil { return } if err = validate.validator.RegisterValidationCtx("db_unique", validate.uniqueValidate); err != nil { return } return } func (validate *Validate) inArray(v any, vs []any) bool { sv := fmt.Sprint(v) for _, s := range vs { if fmt.Sprint(s) == sv { return true } } return false } // isVisible 判断字段是否需要显示 func (validate *Validate) isVisible(stmt *gorm.Statement, scm *types.Schema) bool { if len(scm.Attribute.Visible) <= 0 { return true } for _, row := range scm.Attribute.Visible { if len(row.Values) == 0 { continue } targetField := stmt.Schema.LookUpField(row.Column) if targetField == nil { continue } targetValue := stmt.ReflectValue.FieldByName(targetField.Name) if !targetValue.IsValid() { return false } if !validate.inArray(targetValue.Interface(), row.Values) { return false } } return true } // Validate 校验字段 func (validate *Validate) Validate(db *gorm.DB) { var ( ok bool err error rules []*validateRule stmt *gorm.Statement skipValidate bool multiDomain bool value reflect.Value runtimeScope *types.RuntimeScope ) if result, ok := db.Get(SkipValidations); ok && result.(bool) { return } stmt = db.Statement if stmt.Model == nil { return } if db.Statement.Context == nil { return } if val := db.Statement.Context.Value(rest.RuntimeScopeKey); val != nil { if runtimeScope, ok = val.(*types.RuntimeScope); !ok { return } } else { return } if runtimeScope.Schemas == nil { return } if stmt.Schema.LookUpField("domain") != nil { multiDomain = true } for _, row := range runtimeScope.Schemas { //如果字段隐藏,那么就不进行校验 if !validate.isVisible(stmt, row) { continue } if rules = validate.grantRules(row, runtimeScope.Scenario, row.Rule); len(rules) <= 0 { continue } field := stmt.Schema.LookUpField(row.Column) if field == nil { continue } value = stmt.ReflectValue.FieldByName(field.Name) if !value.IsValid() { continue } skipValidate = false if r := validate.findRule("required", rules); r != nil { if value.Interface() != nil { vType := reflect.ValueOf(value.Interface()) switch vType.Kind() { case reflect.Bool: skipValidate = true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: skipValidate = true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: skipValidate = true case reflect.Float32, reflect.Float64: skipValidate = true default: skipValidate = false } } if skipValidate { r.Valid = false } } else { if isEmpty(value.Interface()) { continue } } ctx := context.WithValue(db.Statement.Context, scopeCtxKey, &validateScope{ DB: db, Column: row.Column, Model: stmt.Model, Domain: runtimeScope.Domain, MultiDomain: multiDomain, }) if err = validate.validator.VarCtx(ctx, value.Interface(), validate.buildRules(rules)); err != nil { if errs, ok := err.(validator.ValidationErrors); ok { for _, e := range errs { _ = db.AddError(&StructError{ Tag: e.Tag(), Column: row.Column, Message: formatError(row.Rule, row, e.Tag()), }) } } else { _ = db.AddError(err) } break } } } func New() *Validate { return &Validate{} }