package rest import ( "context" "errors" "fmt" "git.nobla.cn/golang/kos/util/arrays" "git.nobla.cn/golang/kos/util/reflection" "git.nobla.cn/golang/rest/inflector" "git.nobla.cn/golang/rest/types" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "net/http" "reflect" "strconv" "strings" "time" ) var ( modelEntities []*Model httpRouter types.HttpRouter hookMgr *hookManager timeKind = reflect.TypeOf(time.Time{}).Kind() timePtrKind = reflect.TypeOf(&time.Time{}).Kind() matchEnums = []string{types.MatchExactly, types.MatchFuzzy} ) var ( allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} ) func init() { hookMgr = &hookManager{} modelEntities = make([]*Model, 0) } // cloneStmt 从指定的db克隆一个 Statement 对象 func cloneStmt(db *gorm.DB) *gorm.Statement { return &gorm.Statement{ DB: db, ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, } } // dataTypeOf 推断数据的类型 func dataTypeOf(field *schema.Field) string { var dataType string reflectType := field.FieldType for reflectType.Kind() == reflect.Ptr { reflectType = reflectType.Elem() } if dataType = field.Tag.Get("type"); dataType != "" { return dataType } dataValue := reflect.Indirect(reflect.New(reflectType)) switch dataValue.Kind() { case reflect.Bool: dataType = types.TypeBoolean case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: dataType = types.TypeInteger case reflect.Float32, reflect.Float64: dataType = types.TypeFloat default: dataType = types.TypeString } return dataType } // dataFormatOf 推断数据的格式 func dataFormatOf(field *schema.Field) string { var format string format = field.Tag.Get("format") if format != "" { return format } //如果有枚举值,直接设置为下拉类型 enum := field.Tag.Get("enum") if enum != "" { return types.FormatDropdown } reflectType := field.FieldType for reflectType.Kind() == reflect.Ptr { reflectType = reflectType.Elem() } //时间处理 dataValue := reflect.Indirect(reflect.New(reflectType)) if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" { switch dataValue.Kind() { case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return types.FormatTimestamp default: return types.FormatDatetime } } if strings.Contains(strings.ToLower(field.Name), "pass") { return types.FormatPassword } switch dataValue.Kind() { case timeKind, timePtrKind: format = types.FormatDatetime case reflect.Bool: format = types.FormatBoolean case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: format = types.FormatInteger case reflect.Float32, reflect.Float64: format = types.FormatFloat case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { format = types.FormatDatetime } default: if field.Size >= 1024 { format = types.FormatText } else { format = types.FormatString } } return format } // fieldName 生成字段名称 func fieldName(name string) string { tokens := strings.Split(name, "_") for i, s := range tokens { tokens[i] = strings.Title(s) } return strings.Join(tokens, " ") } // fieldNative 判断是否为原始字段 func fieldNative(field *schema.Field) uint8 { if _, ok := field.Tag.Lookup("virtual"); ok { return 0 } return 1 } // fieldRule 返回字段规则 func fieldRule(field *schema.Field) types.Rule { r := types.Rule{ Required: []string{}, } if field.GORMDataType == schema.String { r.Max = field.Size } if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint { r.Max = field.Scale } rs := field.Tag.Get("rule") if rs != "" { ss := strings.Split(rs, ";") for _, s := range ss { vs := strings.SplitN(s, ":", 2) ls := len(vs) if ls == 0 { continue } switch vs[0] { case "required", "require": if ls > 1 { bs := strings.Split(vs[1], ",") for _, i := range bs { if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) { r.Required = append(r.Required, i) } } } else { r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate} } case "unique": r.Unique = true case "regexp": if ls > 1 { r.Regular = vs[1] } } } } if field.PrimaryKey { r.Unique = true } return r } // fieldScenario 字段Scenarios func fieldScenario(index int, field *schema.Field) types.Scenarios { var ss types.Scenarios if v, ok := field.Tag.Lookup("scenarios"); ok { v = strings.TrimSpace(v) if v != "" { ss = strings.Split(v, ";") } } else { if field.PrimaryKey { ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport} } else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" { ss = []string{types.ScenarioList} } else if field.Name == "DeletedAt" || field.Name == "Namespace" { //不添加任何显示场景 ss = []string{} } else { if index < 10 { //高级字段只配置一些简单的场景 ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} } else { //高级字段只配置一些简单的场景 ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} } } } return ss } // fieldPosition 字段的排序位置 func fieldPosition(field *schema.Field, i int) int { s := field.Tag.Get("position") n, _ := strconv.Atoi(s) if n > 0 { return n } return i + 100 } // fieldAttribute 字段属性 func fieldAttribute(field *schema.Field) types.Attribute { attr := types.Attribute{ Match: types.MatchFuzzy, PrimaryKey: field.PrimaryKey, DefaultValue: field.DefaultValue, Readonly: []string{}, Disable: []string{}, Visible: make([]types.VisibleCondition, 0), Values: make([]types.EnumValue, 0), Live: types.LiveValue{}, } if field.Name == "CreatedAt" || field.Name == "UpdatedAt" { attr.EndOfNow = true } //赋值属性 props := field.Tag.Get("props") if props != "" { vs := strings.Split(props, ";") for _, str := range vs { kv := strings.SplitN(str, ":", 2) if len(kv) != 2 { continue } sv := strings.TrimSpace(kv[1]) switch strings.ToLower(strings.TrimSpace(kv[0])) { case "icon": attr.Icon = sv case "match": if arrays.Exists(sv, matchEnums) { attr.Match = sv } case "endofnow", "end_of_now": if ok, _ := strconv.ParseBool(sv); ok { attr.EndOfNow = true } case "invisible": if ok, _ := strconv.ParseBool(sv); ok { attr.Invisible = true } case "suffix": attr.Suffix = sv case "tag": attr.Tag = sv case "tooltip": attr.Tooltip = sv case "uploadurl", "uploaduri", "upload_url", "upload_uri": attr.UploadUrl = sv case "description": attr.Description = sv case "readonly": bs := strings.Split(sv, ",") for _, i := range bs { if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) { attr.Readonly = append(attr.Readonly, i) } } } } } //live的赋值 live := field.Tag.Get("live") if live != "" { attr.Live.Enable = true vs := strings.Split(live, ";") for _, str := range vs { kv := strings.SplitN(str, ":", 2) if len(kv) != 2 { continue } switch kv[0] { case "method": attr.Live.Method = kv[1] case "type": if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader { attr.Live.Type = kv[1] } else { attr.Live.Type = types.LiveTypeDropdown } case "url", "uri": attr.Live.Url = kv[1] case "columns": attr.Live.Columns = strings.Split(kv[1], ",") } } } dropdown := field.Tag.Get("dropdown") if dropdown != "" { attr.DropdownOption = &types.DropdownOption{} vs := strings.Split(dropdown, ";") for _, str := range vs { kv := strings.SplitN(str, ":", 2) if len(kv) == 0 { continue } switch kv[0] { case "created": attr.DropdownOption.Created = true case "filterable": attr.DropdownOption.Filterable = true case "autocomplete": attr.DropdownOption.Autocomplete = true case "default_first": attr.DropdownOption.DefaultFirst = true } } } //显示条件 conditions := field.Tag.Get("condition") if conditions != "" { vs := strings.Split(conditions, ";") for _, str := range vs { kv := strings.SplitN(str, ":", 2) if len(kv) != 2 { continue } cond := types.VisibleCondition{ Column: kv[0], Values: make([]any, 0), } vv := strings.Split(kv[1], ",") for _, x := range vv { x = strings.TrimSpace(x) if x == "" { continue } cond.Values = append(cond.Values, x) } attr.Visible = append(attr.Visible, cond) } } //赋值枚举值 enumns := field.Tag.Get("enum") if enumns != "" { vs := strings.Split(enumns, ";") for _, str := range vs { kv := strings.SplitN(str, ":", 2) if len(kv) != 2 { continue } fv := types.EnumValue{Value: kv[0]} //颜色分隔符 if pos := strings.IndexByte(kv[1], '#'); pos > -1 { fv.Label = kv[1][:pos] fv.Color = kv[1][pos:] } else { fv.Label = kv[1] } attr.Values = append(attr.Values, fv) } } if !field.Creatable { attr.Disable = append(attr.Disable, types.ScenarioCreate) } if !field.Updatable { attr.Disable = append(attr.Disable, types.ScenarioUpdate) } attr.Tooltip = field.Comment return attr } // autoMigrate 自动合并字段 func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) { var ( pos int columnName string columnIsExists bool columnLabel string schemas []*types.Schema models []*types.Schema stmt *gorm.Statement ) stmt = cloneStmt(db) if err = stmt.Parse(model); err != nil { return } if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil { return } if len(schemas) > 0 { pos = len(schemas) } models = make([]*types.Schema, 0) for index, field := range stmt.Schema.Fields { columnName = field.DBName if columnName == "-" { continue } if columnName == "" { columnName = field.Name } columnIsExists = false for _, sm := range schemas { if sm.Column == columnName { columnIsExists = true break } } if columnIsExists { continue } columnLabel = field.Tag.Get("comment") if columnLabel == "" { columnLabel = fieldName(field.DBName) } isPrimaryKey := uint8(0) if field.PrimaryKey { isPrimaryKey = 1 } schemaModel := &types.Schema{ Domain: defaultDomain, ModuleName: module, TableName: stmt.Table, Enable: 1, Column: columnName, Label: columnLabel, Type: strings.ToLower(dataTypeOf(field)), Format: strings.ToLower(dataFormatOf(field)), Native: fieldNative(field), IsPrimaryKey: isPrimaryKey, Rule: fieldRule(field), Scenarios: fieldScenario(index, field), Attribute: fieldAttribute(field), Position: fieldPosition(field, pos), } //如果启用了在线调取接口功能,那么设置一下字段的format格式 if schemaModel.Attribute.Live.Enable { if schemaModel.Attribute.Live.Type != "" { schemaModel.Format = schemaModel.Attribute.Live.Type } } models = append(models, schemaModel) pos++ } if len(models) > 0 { err = db.Create(models).Error } naming = stmt.Table return } // SetHttpRouter 设置HTTP路由 func SetHttpRouter(router types.HttpRouter) { httpRouter = router } // AutoMigrate 自动合并表的schema func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) { var ( opts *Options table string router types.HttpRouter ) opts = &Options{} for _, cb := range cbs { cb(opts) } if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil { return } //路由模块处理 modelValue := newModel(model, db, types.Naming{ Pluralize: inflector.Pluralize(table), Singular: inflector.Singularize(table), ModuleName: opts.moduleName, TableName: table, }) modelValue.hookMgr = hookMgr modelValue.schemaLookup = VisibleSchemas if opts.router != nil { router = opts.router } if router == nil && httpRouter != nil { router = httpRouter } if opts.urlPrefix != "" { modelValue.urlPrefix = opts.urlPrefix } //路由绑定操作 if router != nil { if modelValue.hasScenario(types.ScenarioList) { router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search) } if modelValue.hasScenario(types.ScenarioCreate) { router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create) } if modelValue.hasScenario(types.ScenarioUpdate) { router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update) } if modelValue.hasScenario(types.ScenarioDelete) { router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete) } if modelValue.hasScenario(types.ScenarioView) { router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View) } if modelValue.hasScenario(types.ScenarioExport) { router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export) } if modelValue.hasScenario(types.ScenarioImport) { router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import) router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import) } } if opts.writer != nil { modelValue.response = opts.writer } if opts.formatter != nil { modelValue.formatter = opts.formatter } modelValue.disableDomain = opts.disableDomain modelEntities = append(modelEntities, modelValue) return } // CloneSchemas 克隆schemas func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) { var ( values []*types.Schema schemas []*types.Schema models []*types.Schema ) tx := db.WithContext(ctx) if domain == "" { domain = defaultDomain } if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil { return fmt.Errorf("schema not found") } tx.Where("domain=?", domain).Find(&schemas) hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool { for _, row := range values { if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column { return true } } return false } models = make([]*types.Schema, 0) for _, row := range values { if !hasSchemaFunc(schemas, row) { row.Id = 0 row.CreatedAt = 0 row.UpdatedAt = 0 row.Domain = domain models = append(models, row) } } if len(models) > 0 { err = tx.Save(models).Error } return } // GetSchemas 获取schemas func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) { var ( err error values []*types.Schema tx *gorm.DB ) values = make([]*types.Schema, 0) if domain == "" { domain = defaultDomain } if moduleName == "" || tableName == "" { return nil, gorm.ErrInvalidField } if ctx != nil { tx = db.WithContext(ctx) } else { tx = db } err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error if errors.Is(err, gorm.ErrRecordNotFound) { err = nil } return values, err } // VisibleSchemas 获取某个场景下面的schema func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) { if domain == "" { domain = defaultDomain } schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName) if err != nil { return nil, err } result := make([]*types.Schema, 0, len(schemas)) for _, row := range schemas { if row.IsPrimaryKey == 1 { result = append(result, row) continue } if row.Scenarios.Has(scenario) { result = append(result, row) } } return result, nil } // ModelTypes 查询指定模型的类型 func ModelTypes(ctx context.Context, db *gorm.DB, model any, domainName, labelColumn, valueColumn string) (values []*types.TypeValue) { tx := db.WithContext(ctx) result := make([]map[string]any, 0, 10) if domainName == "" { domainName = defaultDomain } tx.Model(model).Select(labelColumn, valueColumn).Where("domain=?", domainName).Scan(&result) values = make([]*types.TypeValue, 0, len(result)) for _, pairs := range result { feed := &types.TypeValue{} for k, v := range pairs { if k == labelColumn { feed.Label = v } if k == valueColumn { feed.Value = v } } values = append(values, feed) } return values } // GetFieldValue 获取模型某个字段的值 func GetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string) interface{} { var ( idx = -1 ) refVal := reflect.Indirect(refValue) for i, field := range stmt.Schema.Fields { if field.DBName == column || field.Name == column { idx = i break } } if idx > -1 { return refVal.Field(idx).Interface() } return nil } // SetFieldValue 设置模型某个字段的值 func SetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) { var ( idx = -1 ) refVal := reflect.Indirect(refValue) for i, field := range stmt.Schema.Fields { if field.DBName == column || field.Name == column { idx = i break } } if idx > -1 { refVal.Field(idx).Set(reflect.ValueOf(value)) } } func SafeSetFileValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) { var ( idx = -1 ) refVal := reflect.Indirect(refValue) for i, field := range stmt.Schema.Fields { if field.DBName == column || field.Name == column { idx = i break } } if idx > -1 { _ = reflection.Assign(refVal.Field(idx), value) } } // GetModels 获取所有注册了的模块 func GetModels() []*Model { return modelEntities } // OnBeforeCreate 创建前的回调 func OnBeforeCreate(cb BeforeCreate) { hookMgr.BeforeCreate(cb) } // OnAfterCreate 创建后的回调 func OnAfterCreate(cb AfterCreate) { hookMgr.AfterCreate(cb) } // OnBeforeUpdate 更新前的回调 func OnBeforeUpdate(cb BeforeUpdate) { hookMgr.BeforeUpdate(cb) } // OnAfterUpdate 更新后的回调 func OnAfterUpdate(cb AfterUpdate) { hookMgr.AfterUpdate(cb) } // OnBeforeSave 保存前的回调 func OnBeforeSave(cb BeforeSave) { hookMgr.BeforeSave(cb) } // OnAfterSave 保存后的回调 func OnAfterSave(cb AfterSave) { hookMgr.AfterSave(cb) } // OnBeforeDelete 删除前的回调 func OnBeforeDelete(cb BeforeDelete) { hookMgr.BeforeDelete(cb) } // OnAfterDelete 删除后的回调 func OnAfterDelete(cb AfterDelete) { hookMgr.AfterDelete(cb) } // OnAfterExport 导出后的回调 func OnAfterExport(cb AfterExport) { hookMgr.AfterExport(cb) } // OnAfterImport 导入后的回调 func OnAfterImport(cb AfterImport) { hookMgr.AfterImport(cb) }