rest/rest.go

754 lines
19 KiB
Go
Raw Normal View History

2024-12-11 17:29:01 +08:00
package rest
import (
"context"
"errors"
"fmt"
"git.nobla.cn/golang/kos/util/arrays"
"git.nobla.cn/golang/kos/util/reflection"
"git.nobla.cn/golang/rest/inflector"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"net/http"
"reflect"
"strconv"
"strings"
"time"
)
var (
modelEntities []*Model
httpRouter types.HttpRouter
hookMgr *hookManager
timeKind = reflect.TypeOf(time.Time{}).Kind()
timePtrKind = reflect.TypeOf(&time.Time{}).Kind()
matchEnums = []string{types.MatchExactly, types.MatchFuzzy}
)
var (
allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
)
func init() {
hookMgr = &hookManager{}
modelEntities = make([]*Model, 0)
}
// cloneStmt 从指定的db克隆一个 Statement 对象
func cloneStmt(db *gorm.DB) *gorm.Statement {
return &gorm.Statement{
DB: db,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
}
}
// dataTypeOf 推断数据的类型
func dataTypeOf(field *schema.Field) string {
var dataType string
reflectType := field.FieldType
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
if dataType = field.Tag.Get("type"); dataType != "" {
return dataType
}
dataValue := reflect.Indirect(reflect.New(reflectType))
switch dataValue.Kind() {
case reflect.Bool:
dataType = types.TypeBoolean
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
dataType = types.TypeInteger
case reflect.Float32, reflect.Float64:
dataType = types.TypeFloat
default:
dataType = types.TypeString
}
return dataType
}
// dataFormatOf 推断数据的格式
func dataFormatOf(field *schema.Field) string {
var format string
format = field.Tag.Get("format")
if format != "" {
return format
}
//如果有枚举值,直接设置为下拉类型
enum := field.Tag.Get("enum")
if enum != "" {
return types.FormatDropdown
}
reflectType := field.FieldType
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
//时间处理
dataValue := reflect.Indirect(reflect.New(reflectType))
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" {
switch dataValue.Kind() {
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return types.FormatTimestamp
default:
return types.FormatDatetime
}
}
if strings.Contains(strings.ToLower(field.Name), "pass") {
return types.FormatPassword
}
switch dataValue.Kind() {
case timeKind, timePtrKind:
format = types.FormatDatetime
case reflect.Bool:
format = types.FormatBoolean
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
format = types.FormatInteger
case reflect.Float32, reflect.Float64:
format = types.FormatFloat
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
format = types.FormatDatetime
}
default:
if field.Size >= 1024 {
format = types.FormatText
} else {
format = types.FormatString
}
}
return format
}
// fieldName 生成字段名称
func fieldName(name string) string {
tokens := strings.Split(name, "_")
for i, s := range tokens {
tokens[i] = strings.Title(s)
}
return strings.Join(tokens, " ")
}
// fieldNative 判断是否为原始字段
func fieldNative(field *schema.Field) uint8 {
if _, ok := field.Tag.Lookup("virtual"); ok {
return 0
}
return 1
}
// fieldRule 返回字段规则
func fieldRule(field *schema.Field) types.Rule {
r := types.Rule{
Required: []string{},
}
if field.GORMDataType == schema.String {
r.Max = field.Size
}
if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint {
r.Max = field.Scale
}
rs := field.Tag.Get("rule")
if rs != "" {
ss := strings.Split(rs, ";")
for _, s := range ss {
vs := strings.SplitN(s, ":", 2)
ls := len(vs)
if ls == 0 {
continue
}
switch vs[0] {
case "required", "require":
if ls > 1 {
bs := strings.Split(vs[1], ",")
for _, i := range bs {
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
r.Required = append(r.Required, i)
}
}
} else {
r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate}
}
case "unique":
r.Unique = true
case "regexp":
if ls > 1 {
r.Regular = vs[1]
}
}
}
}
if field.PrimaryKey {
r.Unique = true
}
return r
}
// fieldScenario 字段Scenarios
func fieldScenario(index int, field *schema.Field) types.Scenarios {
var ss types.Scenarios
if v, ok := field.Tag.Lookup("scenarios"); ok {
v = strings.TrimSpace(v)
if v != "" {
ss = strings.Split(v, ";")
}
} else {
if field.PrimaryKey {
ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport}
} else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
ss = []string{types.ScenarioList}
} else if field.Name == "DeletedAt" || field.Name == "Namespace" {
//不添加任何显示场景
ss = []string{}
} else {
if index < 10 {
//高级字段只配置一些简单的场景
ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
} else {
//高级字段只配置一些简单的场景
ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
}
}
}
return ss
}
// fieldPosition 字段的排序位置
func fieldPosition(field *schema.Field, i int) int {
s := field.Tag.Get("position")
n, _ := strconv.Atoi(s)
if n > 0 {
return n
}
return i + 100
}
// fieldAttribute 字段属性
func fieldAttribute(field *schema.Field) types.Attribute {
attr := types.Attribute{
Match: types.MatchFuzzy,
PrimaryKey: field.PrimaryKey,
DefaultValue: field.DefaultValue,
Readonly: []string{},
Disable: []string{},
Visible: make([]types.VisibleCondition, 0),
Values: make([]types.EnumValue, 0),
Live: types.LiveValue{},
}
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
attr.EndOfNow = true
}
//赋值属性
props := field.Tag.Get("props")
if props != "" {
vs := strings.Split(props, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
sv := strings.TrimSpace(kv[1])
switch strings.ToLower(strings.TrimSpace(kv[0])) {
case "icon":
attr.Icon = sv
case "match":
if arrays.Exists(sv, matchEnums) {
attr.Match = sv
}
case "endofnow", "end_of_now":
if ok, _ := strconv.ParseBool(sv); ok {
attr.EndOfNow = true
}
case "invisible":
if ok, _ := strconv.ParseBool(sv); ok {
attr.Invisible = true
}
case "suffix":
attr.Suffix = sv
case "tag":
attr.Tag = sv
case "tooltip":
attr.Tooltip = sv
case "uploadurl", "uploaduri", "upload_url", "upload_uri":
attr.UploadUrl = sv
case "description":
attr.Description = sv
case "readonly":
bs := strings.Split(sv, ",")
for _, i := range bs {
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
attr.Readonly = append(attr.Readonly, i)
}
}
}
}
}
//live的赋值
live := field.Tag.Get("live")
if live != "" {
attr.Live.Enable = true
vs := strings.Split(live, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
switch kv[0] {
case "method":
attr.Live.Method = kv[1]
case "type":
if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader {
attr.Live.Type = kv[1]
} else {
attr.Live.Type = types.LiveTypeDropdown
}
case "url", "uri":
attr.Live.Url = kv[1]
case "columns":
attr.Live.Columns = strings.Split(kv[1], ",")
}
}
}
dropdown := field.Tag.Get("dropdown")
if dropdown != "" {
attr.DropdownOption = &types.DropdownOption{}
vs := strings.Split(dropdown, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) == 0 {
continue
}
switch kv[0] {
case "created":
attr.DropdownOption.Created = true
case "filterable":
attr.DropdownOption.Filterable = true
case "autocomplete":
attr.DropdownOption.Autocomplete = true
case "default_first":
attr.DropdownOption.DefaultFirst = true
}
}
}
//显示条件
conditions := field.Tag.Get("condition")
if conditions != "" {
vs := strings.Split(conditions, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
cond := types.VisibleCondition{
Column: kv[0],
Values: make([]any, 0),
}
vv := strings.Split(kv[1], ",")
for _, x := range vv {
x = strings.TrimSpace(x)
if x == "" {
continue
}
cond.Values = append(cond.Values, x)
}
attr.Visible = append(attr.Visible, cond)
}
}
//赋值枚举值
enumns := field.Tag.Get("enum")
if enumns != "" {
vs := strings.Split(enumns, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
fv := types.EnumValue{Value: kv[0]}
//颜色分隔符
if pos := strings.IndexByte(kv[1], '#'); pos > -1 {
fv.Label = kv[1][:pos]
fv.Color = kv[1][pos:]
} else {
fv.Label = kv[1]
}
attr.Values = append(attr.Values, fv)
}
}
if !field.Creatable {
attr.Disable = append(attr.Disable, types.ScenarioCreate)
}
if !field.Updatable {
attr.Disable = append(attr.Disable, types.ScenarioUpdate)
}
attr.Tooltip = field.Comment
return attr
}
// autoMigrate 自动合并字段
func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) {
var (
pos int
columnName string
columnIsExists bool
columnLabel string
schemas []*types.Schema
models []*types.Schema
stmt *gorm.Statement
)
stmt = cloneStmt(db)
if err = stmt.Parse(model); err != nil {
return
}
if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil {
return
}
if len(schemas) > 0 {
pos = len(schemas)
}
models = make([]*types.Schema, 0)
for index, field := range stmt.Schema.Fields {
columnName = field.DBName
if columnName == "-" {
continue
}
if columnName == "" {
columnName = field.Name
}
columnIsExists = false
for _, sm := range schemas {
if sm.Column == columnName {
columnIsExists = true
break
}
}
if columnIsExists {
continue
}
columnLabel = field.Tag.Get("comment")
if columnLabel == "" {
columnLabel = fieldName(field.DBName)
}
isPrimaryKey := uint8(0)
if field.PrimaryKey {
isPrimaryKey = 1
}
schemaModel := &types.Schema{
Domain: defaultDomain,
ModuleName: module,
TableName: stmt.Table,
Enable: 1,
Column: columnName,
Label: columnLabel,
Type: strings.ToLower(dataTypeOf(field)),
Format: strings.ToLower(dataFormatOf(field)),
Native: fieldNative(field),
IsPrimaryKey: isPrimaryKey,
Rule: fieldRule(field),
Scenarios: fieldScenario(index, field),
Attribute: fieldAttribute(field),
Position: fieldPosition(field, pos),
}
//如果启用了在线调取接口功能那么设置一下字段的format格式
if schemaModel.Attribute.Live.Enable {
if schemaModel.Attribute.Live.Type != "" {
schemaModel.Format = schemaModel.Attribute.Live.Type
}
}
models = append(models, schemaModel)
pos++
}
if len(models) > 0 {
err = db.Create(models).Error
}
naming = stmt.Table
return
}
// SetHttpRouter 设置HTTP路由
func SetHttpRouter(router types.HttpRouter) {
httpRouter = router
}
// AutoMigrate 自动合并表的schema
func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) {
var (
opts *Options
table string
router types.HttpRouter
)
opts = &Options{}
for _, cb := range cbs {
cb(opts)
}
if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil {
return
}
//路由模块处理
modelValue := newModel(model, db, types.Naming{
Pluralize: inflector.Pluralize(table),
Singular: inflector.Singularize(table),
ModuleName: opts.moduleName,
TableName: table,
})
modelValue.hookMgr = hookMgr
modelValue.schemaLookup = VisibleSchemas
if opts.router != nil {
router = opts.router
}
if router == nil && httpRouter != nil {
router = httpRouter
}
if opts.urlPrefix != "" {
modelValue.urlPrefix = opts.urlPrefix
}
//路由绑定操作
if router != nil {
if modelValue.hasScenario(types.ScenarioList) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search)
}
if modelValue.hasScenario(types.ScenarioCreate) {
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create)
}
if modelValue.hasScenario(types.ScenarioUpdate) {
router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update)
}
if modelValue.hasScenario(types.ScenarioDelete) {
router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete)
}
if modelValue.hasScenario(types.ScenarioView) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View)
}
if modelValue.hasScenario(types.ScenarioExport) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export)
}
if modelValue.hasScenario(types.ScenarioImport) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import)
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import)
}
}
if opts.writer != nil {
modelValue.response = opts.writer
}
if opts.formatter != nil {
modelValue.formatter = opts.formatter
}
modelValue.disableDomain = opts.disableDomain
modelEntities = append(modelEntities, modelValue)
return
}
// CloneSchemas 克隆schemas
func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) {
var (
values []*types.Schema
schemas []*types.Schema
models []*types.Schema
)
tx := db.WithContext(ctx)
if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil {
return fmt.Errorf("schema not found")
}
tx.Where("domain=?", domain).Find(&schemas)
hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool {
for _, row := range values {
if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column {
return true
}
}
return false
}
models = make([]*types.Schema, 0)
for _, row := range values {
if !hasSchemaFunc(schemas, row) {
row.Id = 0
row.CreatedAt = 0
row.UpdatedAt = 0
row.Domain = domain
models = append(models, row)
}
}
if len(models) > 0 {
err = tx.Save(models).Error
}
return
}
// GetSchemas 获取schemas
func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) {
var (
err error
values []*types.Schema
tx *gorm.DB
)
values = make([]*types.Schema, 0)
if domain == "" {
domain = defaultDomain
}
if moduleName == "" || tableName == "" {
return nil, gorm.ErrInvalidField
}
if ctx != nil {
tx = db.WithContext(ctx)
} else {
tx = db
}
err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
err = nil
}
return values, err
}
// VisibleSchemas 获取某个场景下面的schema
func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) {
schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName)
if err != nil {
return nil, err
}
result := make([]*types.Schema, 0, len(schemas))
for _, row := range schemas {
if row.IsPrimaryKey == 1 {
result = append(result, row)
continue
}
if row.Scenarios.Has(scenario) {
result = append(result, row)
}
}
return result, nil
}
// ModelTypes 查询指定模型的类型
func ModelTypes(ctx context.Context, db *gorm.DB, model any, domainName, labelColumn, valueColumn string) (values []*types.TypeValue) {
tx := db.WithContext(ctx)
result := make([]map[string]any, 0, 10)
tx.Model(model).Select(labelColumn, valueColumn).Where("domain=?", domainName).Scan(&result)
values = make([]*types.TypeValue, 0, len(result))
for _, pairs := range result {
feed := &types.TypeValue{}
for k, v := range pairs {
if k == labelColumn {
feed.Label = v
}
if k == valueColumn {
feed.Value = v
}
}
values = append(values, feed)
}
return values
}
// GetFieldValue 获取模型某个字段的值
func GetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string) interface{} {
var (
idx = -1
)
refVal := reflect.Indirect(refValue)
for i, field := range stmt.Schema.Fields {
if field.DBName == column || field.Name == column {
idx = i
break
}
}
if idx > -1 {
return refVal.Field(idx).Interface()
}
return nil
}
// SetFieldValue 设置模型某个字段的值
func SetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) {
var (
idx = -1
)
refVal := reflect.Indirect(refValue)
for i, field := range stmt.Schema.Fields {
if field.DBName == column || field.Name == column {
idx = i
break
}
}
if idx > -1 {
refVal.Field(idx).Set(reflect.ValueOf(value))
}
}
func SafeSetFileValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) {
var (
idx = -1
)
refVal := reflect.Indirect(refValue)
for i, field := range stmt.Schema.Fields {
if field.DBName == column || field.Name == column {
idx = i
break
}
}
if idx > -1 {
_ = reflection.Assign(refVal.Field(idx), value)
}
}
// GetModels 获取所有注册了的模块
func GetModels() []*Model {
return modelEntities
}
// OnBeforeCreate 创建前的回调
func OnBeforeCreate(cb BeforeCreate) {
hookMgr.BeforeCreate(cb)
}
// OnAfterCreate 创建后的回调
func OnAfterCreate(cb AfterCreate) {
hookMgr.AfterCreate(cb)
}
// OnBeforeUpdate 更新前的回调
func OnBeforeUpdate(cb BeforeUpdate) {
hookMgr.BeforeUpdate(cb)
}
// OnAfterUpdate 更新后的回调
func OnAfterUpdate(cb AfterUpdate) {
hookMgr.AfterUpdate(cb)
}
// OnBeforeSave 保存前的回调
func OnBeforeSave(cb BeforeSave) {
hookMgr.BeforeSave(cb)
}
// OnAfterSave 保存后的回调
func OnAfterSave(cb AfterSave) {
hookMgr.AfterSave(cb)
}
// OnBeforeDelete 删除前的回调
func OnBeforeDelete(cb BeforeDelete) {
hookMgr.BeforeDelete(cb)
}
// OnAfterDelete 删除后的回调
func OnAfterDelete(cb AfterDelete) {
hookMgr.AfterDelete(cb)
}
// OnAfterExport 导出后的回调
func OnAfterExport(cb AfterExport) {
hookMgr.AfterExport(cb)
}
// OnAfterImport 导入后的回调
func OnAfterImport(cb AfterImport) {
hookMgr.AfterImport(cb)
}