rest/rest.go

754 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package rest
import (
"context"
"errors"
"fmt"
"git.nobla.cn/golang/kos/util/arrays"
"git.nobla.cn/golang/kos/util/reflection"
"git.nobla.cn/golang/rest/inflector"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"net/http"
"reflect"
"strconv"
"strings"
"time"
)
var (
modelEntities []*Model
httpRouter types.HttpRouter
hookMgr *hookManager
timeKind = reflect.TypeOf(time.Time{}).Kind()
timePtrKind = reflect.TypeOf(&time.Time{}).Kind()
matchEnums = []string{types.MatchExactly, types.MatchFuzzy}
)
var (
allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
)
func init() {
hookMgr = &hookManager{}
modelEntities = make([]*Model, 0)
}
// cloneStmt 从指定的db克隆一个 Statement 对象
func cloneStmt(db *gorm.DB) *gorm.Statement {
return &gorm.Statement{
DB: db,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
}
}
// dataTypeOf 推断数据的类型
func dataTypeOf(field *schema.Field) string {
var dataType string
reflectType := field.FieldType
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
if dataType = field.Tag.Get("type"); dataType != "" {
return dataType
}
dataValue := reflect.Indirect(reflect.New(reflectType))
switch dataValue.Kind() {
case reflect.Bool:
dataType = types.TypeBoolean
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
dataType = types.TypeInteger
case reflect.Float32, reflect.Float64:
dataType = types.TypeFloat
default:
dataType = types.TypeString
}
return dataType
}
// dataFormatOf 推断数据的格式
func dataFormatOf(field *schema.Field) string {
var format string
format = field.Tag.Get("format")
if format != "" {
return format
}
//如果有枚举值,直接设置为下拉类型
enum := field.Tag.Get("enum")
if enum != "" {
return types.FormatDropdown
}
reflectType := field.FieldType
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
//时间处理
dataValue := reflect.Indirect(reflect.New(reflectType))
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" {
switch dataValue.Kind() {
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return types.FormatTimestamp
default:
return types.FormatDatetime
}
}
if strings.Contains(strings.ToLower(field.Name), "pass") {
return types.FormatPassword
}
switch dataValue.Kind() {
case timeKind, timePtrKind:
format = types.FormatDatetime
case reflect.Bool:
format = types.FormatBoolean
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
format = types.FormatInteger
case reflect.Float32, reflect.Float64:
format = types.FormatFloat
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
format = types.FormatDatetime
}
default:
if field.Size >= 1024 {
format = types.FormatText
} else {
format = types.FormatString
}
}
return format
}
// fieldName 生成字段名称
func fieldName(name string) string {
tokens := strings.Split(name, "_")
for i, s := range tokens {
tokens[i] = strings.Title(s)
}
return strings.Join(tokens, " ")
}
// fieldNative 判断是否为原始字段
func fieldNative(field *schema.Field) uint8 {
if _, ok := field.Tag.Lookup("virtual"); ok {
return 0
}
return 1
}
// fieldRule 返回字段规则
func fieldRule(field *schema.Field) types.Rule {
r := types.Rule{
Required: []string{},
}
if field.GORMDataType == schema.String {
r.Max = field.Size
}
if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint {
r.Max = field.Scale
}
rs := field.Tag.Get("rule")
if rs != "" {
ss := strings.Split(rs, ";")
for _, s := range ss {
vs := strings.SplitN(s, ":", 2)
ls := len(vs)
if ls == 0 {
continue
}
switch vs[0] {
case "required", "require":
if ls > 1 {
bs := strings.Split(vs[1], ",")
for _, i := range bs {
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
r.Required = append(r.Required, i)
}
}
} else {
r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate}
}
case "unique":
r.Unique = true
case "regexp":
if ls > 1 {
r.Regular = vs[1]
}
}
}
}
if field.PrimaryKey {
r.Unique = true
}
return r
}
// fieldScenario 字段Scenarios
func fieldScenario(index int, field *schema.Field) types.Scenarios {
var ss types.Scenarios
if v, ok := field.Tag.Lookup("scenarios"); ok {
v = strings.TrimSpace(v)
if v != "" {
ss = strings.Split(v, ";")
}
} else {
if field.PrimaryKey {
ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport}
} else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
ss = []string{types.ScenarioList}
} else if field.Name == "DeletedAt" || field.Name == "Namespace" {
//不添加任何显示场景
ss = []string{}
} else {
if index < 10 {
//高级字段只配置一些简单的场景
ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
} else {
//高级字段只配置一些简单的场景
ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
}
}
}
return ss
}
// fieldPosition 字段的排序位置
func fieldPosition(field *schema.Field, i int) int {
s := field.Tag.Get("position")
n, _ := strconv.Atoi(s)
if n > 0 {
return n
}
return i + 100
}
// fieldAttribute 字段属性
func fieldAttribute(field *schema.Field) types.Attribute {
attr := types.Attribute{
Match: types.MatchFuzzy,
PrimaryKey: field.PrimaryKey,
DefaultValue: field.DefaultValue,
Readonly: []string{},
Disable: []string{},
Visible: make([]types.VisibleCondition, 0),
Values: make([]types.EnumValue, 0),
Live: types.LiveValue{},
}
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
attr.EndOfNow = true
}
//赋值属性
props := field.Tag.Get("props")
if props != "" {
vs := strings.Split(props, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
sv := strings.TrimSpace(kv[1])
switch strings.ToLower(strings.TrimSpace(kv[0])) {
case "icon":
attr.Icon = sv
case "match":
if arrays.Exists(sv, matchEnums) {
attr.Match = sv
}
case "endofnow", "end_of_now":
if ok, _ := strconv.ParseBool(sv); ok {
attr.EndOfNow = true
}
case "invisible":
if ok, _ := strconv.ParseBool(sv); ok {
attr.Invisible = true
}
case "suffix":
attr.Suffix = sv
case "tag":
attr.Tag = sv
case "tooltip":
attr.Tooltip = sv
case "uploadurl", "uploaduri", "upload_url", "upload_uri":
attr.UploadUrl = sv
case "description":
attr.Description = sv
case "readonly":
bs := strings.Split(sv, ",")
for _, i := range bs {
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
attr.Readonly = append(attr.Readonly, i)
}
}
}
}
}
//live的赋值
live := field.Tag.Get("live")
if live != "" {
attr.Live.Enable = true
vs := strings.Split(live, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
switch kv[0] {
case "method":
attr.Live.Method = kv[1]
case "type":
if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader {
attr.Live.Type = kv[1]
} else {
attr.Live.Type = types.LiveTypeDropdown
}
case "url", "uri":
attr.Live.Url = kv[1]
case "columns":
attr.Live.Columns = strings.Split(kv[1], ",")
}
}
}
dropdown := field.Tag.Get("dropdown")
if dropdown != "" {
attr.DropdownOption = &types.DropdownOption{}
vs := strings.Split(dropdown, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) == 0 {
continue
}
switch kv[0] {
case "created":
attr.DropdownOption.Created = true
case "filterable":
attr.DropdownOption.Filterable = true
case "autocomplete":
attr.DropdownOption.Autocomplete = true
case "default_first":
attr.DropdownOption.DefaultFirst = true
}
}
}
//显示条件
conditions := field.Tag.Get("condition")
if conditions != "" {
vs := strings.Split(conditions, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
cond := types.VisibleCondition{
Column: kv[0],
Values: make([]any, 0),
}
vv := strings.Split(kv[1], ",")
for _, x := range vv {
x = strings.TrimSpace(x)
if x == "" {
continue
}
cond.Values = append(cond.Values, x)
}
attr.Visible = append(attr.Visible, cond)
}
}
//赋值枚举值
enumns := field.Tag.Get("enum")
if enumns != "" {
vs := strings.Split(enumns, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
fv := types.EnumValue{Value: kv[0]}
//颜色分隔符
if pos := strings.IndexByte(kv[1], '#'); pos > -1 {
fv.Label = kv[1][:pos]
fv.Color = kv[1][pos:]
} else {
fv.Label = kv[1]
}
attr.Values = append(attr.Values, fv)
}
}
if !field.Creatable {
attr.Disable = append(attr.Disable, types.ScenarioCreate)
}
if !field.Updatable {
attr.Disable = append(attr.Disable, types.ScenarioUpdate)
}
attr.Tooltip = field.Comment
return attr
}
// autoMigrate 自动合并字段
func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) {
var (
pos int
columnName string
columnIsExists bool
columnLabel string
schemas []*types.Schema
models []*types.Schema
stmt *gorm.Statement
)
stmt = cloneStmt(db)
if err = stmt.Parse(model); err != nil {
return
}
if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil {
return
}
if len(schemas) > 0 {
pos = len(schemas)
}
models = make([]*types.Schema, 0)
for index, field := range stmt.Schema.Fields {
columnName = field.DBName
if columnName == "-" {
continue
}
if columnName == "" {
columnName = field.Name
}
columnIsExists = false
for _, sm := range schemas {
if sm.Column == columnName {
columnIsExists = true
break
}
}
if columnIsExists {
continue
}
columnLabel = field.Tag.Get("comment")
if columnLabel == "" {
columnLabel = fieldName(field.DBName)
}
isPrimaryKey := uint8(0)
if field.PrimaryKey {
isPrimaryKey = 1
}
schemaModel := &types.Schema{
Domain: defaultDomain,
ModuleName: module,
TableName: stmt.Table,
Enable: 1,
Column: columnName,
Label: columnLabel,
Type: strings.ToLower(dataTypeOf(field)),
Format: strings.ToLower(dataFormatOf(field)),
Native: fieldNative(field),
IsPrimaryKey: isPrimaryKey,
Rule: fieldRule(field),
Scenarios: fieldScenario(index, field),
Attribute: fieldAttribute(field),
Position: fieldPosition(field, pos),
}
//如果启用了在线调取接口功能那么设置一下字段的format格式
if schemaModel.Attribute.Live.Enable {
if schemaModel.Attribute.Live.Type != "" {
schemaModel.Format = schemaModel.Attribute.Live.Type
}
}
models = append(models, schemaModel)
pos++
}
if len(models) > 0 {
err = db.Create(models).Error
}
naming = stmt.Table
return
}
// SetHttpRouter 设置HTTP路由
func SetHttpRouter(router types.HttpRouter) {
httpRouter = router
}
// AutoMigrate 自动合并表的schema
func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) {
var (
opts *Options
table string
router types.HttpRouter
)
opts = &Options{}
for _, cb := range cbs {
cb(opts)
}
if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil {
return
}
//路由模块处理
modelValue := newModel(model, db, types.Naming{
Pluralize: inflector.Pluralize(table),
Singular: inflector.Singularize(table),
ModuleName: opts.moduleName,
TableName: table,
})
modelValue.hookMgr = hookMgr
modelValue.schemaLookup = VisibleSchemas
if opts.router != nil {
router = opts.router
}
if router == nil && httpRouter != nil {
router = httpRouter
}
if opts.urlPrefix != "" {
modelValue.urlPrefix = opts.urlPrefix
}
//路由绑定操作
if router != nil {
if modelValue.hasScenario(types.ScenarioList) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search)
}
if modelValue.hasScenario(types.ScenarioCreate) {
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create)
}
if modelValue.hasScenario(types.ScenarioUpdate) {
router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update)
}
if modelValue.hasScenario(types.ScenarioDelete) {
router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete)
}
if modelValue.hasScenario(types.ScenarioView) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View)
}
if modelValue.hasScenario(types.ScenarioExport) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export)
}
if modelValue.hasScenario(types.ScenarioImport) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import)
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import)
}
}
if opts.writer != nil {
modelValue.response = opts.writer
}
if opts.formatter != nil {
modelValue.formatter = opts.formatter
}
modelValue.disableDomain = opts.disableDomain
modelEntities = append(modelEntities, modelValue)
return
}
// CloneSchemas 克隆schemas
func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) {
var (
values []*types.Schema
schemas []*types.Schema
models []*types.Schema
)
tx := db.WithContext(ctx)
if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil {
return fmt.Errorf("schema not found")
}
tx.Where("domain=?", domain).Find(&schemas)
hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool {
for _, row := range values {
if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column {
return true
}
}
return false
}
models = make([]*types.Schema, 0)
for _, row := range values {
if !hasSchemaFunc(schemas, row) {
row.Id = 0
row.CreatedAt = 0
row.UpdatedAt = 0
row.Domain = domain
models = append(models, row)
}
}
if len(models) > 0 {
err = tx.Save(models).Error
}
return
}
// GetSchemas 获取schemas
func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) {
var (
err error
values []*types.Schema
tx *gorm.DB
)
values = make([]*types.Schema, 0)
if domain == "" {
domain = defaultDomain
}
if moduleName == "" || tableName == "" {
return nil, gorm.ErrInvalidField
}
if ctx != nil {
tx = db.WithContext(ctx)
} else {
tx = db
}
err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
err = nil
}
return values, err
}
// VisibleSchemas 获取某个场景下面的schema
func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) {
schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName)
if err != nil {
return nil, err
}
result := make([]*types.Schema, 0, len(schemas))
for _, row := range schemas {
if row.IsPrimaryKey == 1 {
result = append(result, row)
continue
}
if row.Scenarios.Has(scenario) {
result = append(result, row)
}
}
return result, nil
}
// ModelTypes 查询指定模型的类型
func ModelTypes(ctx context.Context, db *gorm.DB, model any, domainName, labelColumn, valueColumn string) (values []*types.TypeValue) {
tx := db.WithContext(ctx)
result := make([]map[string]any, 0, 10)
tx.Model(model).Select(labelColumn, valueColumn).Where("domain=?", domainName).Scan(&result)
values = make([]*types.TypeValue, 0, len(result))
for _, pairs := range result {
feed := &types.TypeValue{}
for k, v := range pairs {
if k == labelColumn {
feed.Label = v
}
if k == valueColumn {
feed.Value = v
}
}
values = append(values, feed)
}
return values
}
// GetFieldValue 获取模型某个字段的值
func GetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string) interface{} {
var (
idx = -1
)
refVal := reflect.Indirect(refValue)
for i, field := range stmt.Schema.Fields {
if field.DBName == column || field.Name == column {
idx = i
break
}
}
if idx > -1 {
return refVal.Field(idx).Interface()
}
return nil
}
// SetFieldValue 设置模型某个字段的值
func SetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) {
var (
idx = -1
)
refVal := reflect.Indirect(refValue)
for i, field := range stmt.Schema.Fields {
if field.DBName == column || field.Name == column {
idx = i
break
}
}
if idx > -1 {
refVal.Field(idx).Set(reflect.ValueOf(value))
}
}
func SafeSetFileValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) {
var (
idx = -1
)
refVal := reflect.Indirect(refValue)
for i, field := range stmt.Schema.Fields {
if field.DBName == column || field.Name == column {
idx = i
break
}
}
if idx > -1 {
_ = reflection.Assign(refVal.Field(idx), value)
}
}
// GetModels 获取所有注册了的模块
func GetModels() []*Model {
return modelEntities
}
// OnBeforeCreate 创建前的回调
func OnBeforeCreate(cb BeforeCreate) {
hookMgr.BeforeCreate(cb)
}
// OnAfterCreate 创建后的回调
func OnAfterCreate(cb AfterCreate) {
hookMgr.AfterCreate(cb)
}
// OnBeforeUpdate 更新前的回调
func OnBeforeUpdate(cb BeforeUpdate) {
hookMgr.BeforeUpdate(cb)
}
// OnAfterUpdate 更新后的回调
func OnAfterUpdate(cb AfterUpdate) {
hookMgr.AfterUpdate(cb)
}
// OnBeforeSave 保存前的回调
func OnBeforeSave(cb BeforeSave) {
hookMgr.BeforeSave(cb)
}
// OnAfterSave 保存后的回调
func OnAfterSave(cb AfterSave) {
hookMgr.AfterSave(cb)
}
// OnBeforeDelete 删除前的回调
func OnBeforeDelete(cb BeforeDelete) {
hookMgr.BeforeDelete(cb)
}
// OnAfterDelete 删除后的回调
func OnAfterDelete(cb AfterDelete) {
hookMgr.AfterDelete(cb)
}
// OnAfterExport 导出后的回调
func OnAfterExport(cb AfterExport) {
hookMgr.AfterExport(cb)
}
// OnAfterImport 导入后的回调
func OnAfterImport(cb AfterImport) {
hookMgr.AfterImport(cb)
}