commit 61ffcec858a0eaf50128ed2c05e75af3f0590710 Author: fancl Date: Wed Dec 11 17:29:01 2024 +0800 初始化仓库 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..556957c --- /dev/null +++ b/.gitignore @@ -0,0 +1,60 @@ +bin/ + +.svn/ +.godeps +./build +.cover/ +dist +_site +_posts +*.dat +.vscode +vendor + +# Go.gitignore + +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test +storage +.idea + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.local +.DS_Store + +profile + +# vim stuff +*.sw[op] + + +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +.vscode/* +!.vscode/extensions.json + +node_modules \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..69155bd --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# 数据库组件 \ No newline at end of file diff --git a/condition.go b/condition.go new file mode 100644 index 0000000..597c367 --- /dev/null +++ b/condition.go @@ -0,0 +1,144 @@ +package rest + +import ( + "context" + "encoding/json" + "git.nobla.cn/golang/kos/util/arrays" + "git.nobla.cn/golang/rest/types" + "net/http" + "strings" +) + +func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition { + for _, cond := range conditions { + if cond.Column == schema.Column { + return cond + } + } + return nil +} + +func BuildConditions(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) { + var ( + ok bool + skip bool + formValue string + activeQuery ActiveQuery + ) + if activeQuery, ok = query.Model().(ActiveQuery); ok { + if err = activeQuery.BeforeQuery(ctx, query); err != nil { + return + } + } + if arrays.Exists(r.Method, []string{http.MethodPut, http.MethodPost}) { + conditions := make([]*types.Condition, 0) + if err = json.NewDecoder(r.Body).Decode(&conditions); err != nil { + return + } + for _, row := range schemas { + if row.Native == 0 { + continue + } + cond := findCondition(row, conditions) + if cond == nil { + continue + } + switch row.Format { + case types.FormatInteger, types.FormatFloat, types.FormatTimestamp, types.FormatDatetime, types.FormatDate, types.FormatTime: + switch cond.Expr { + case types.OperatorBetween: + if len(cond.Values) == 2 { + query.AndFilterWhere(newCondition(row.Column, cond.Values[0]).WithExpr(">=")) + query.AndFilterWhere(newCondition(row.Column, cond.Values[1]).WithExpr("<=")) + } + case types.OperatorGreaterThan: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">")) + case types.OperatorGreaterEqual: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">=")) + case types.OperatorLessThan: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<")) + case types.OperatorLessEqual: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<=")) + default: + query.AndFilterWhere(newCondition(row.Column, cond.Value)) + } + default: + switch cond.Expr { + case types.OperatorLike: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("LIKE")) + default: + query.AndFilterWhere(newCondition(row.Column, cond.Value)) + } + } + } + } else { + qs := r.URL.Query() + for _, row := range schemas { + skip = false + if skip { + continue + } + if row.Native == 0 { + continue + } + formValue = qs.Get(row.Column) + switch row.Format { + case types.FormatString, types.FormatText: + if row.Attribute.Match == types.MatchExactly { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } else { + query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE")) + } + case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp: + var sep string + seps := []byte{',', '/'} + for _, s := range seps { + if strings.IndexByte(formValue, s) > -1 { + sep = string(s) + } + } + if ss := strings.Split(formValue, sep); len(ss) == 2 { + query.AndFilterWhere( + newCondition(row.Column, strings.TrimSpace(ss[0])).WithExpr(">="), + newCondition(row.Column, strings.TrimSpace(ss[1])).WithExpr("<="), + ) + } else { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } + case types.FormatInteger, types.FormatFloat: + query.AndFilterWhere(newCondition(row.Column, formValue)) + default: + if row.Type == types.TypeString { + if row.Attribute.Match == types.MatchExactly { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } else { + query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE")) + } + } else { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } + } + } + } + sortPar := r.FormValue("sort") + if sortPar != "" { + sorts := strings.Split(sortPar, ",") + for _, s := range sorts { + if s[0] == '-' { + query.OrderBy(s[1:], "DESC") + } else { + if s[0] == '+' { + query.OrderBy(s[1:], "ASC") + } else { + query.OrderBy(s, "ASC") + } + } + } + } + if activeQuery, ok = query.Model().(ActiveQuery); ok { + if err = activeQuery.AfterQuery(ctx, query); err != nil { + return + } + } + return +} diff --git a/formatter.go b/formatter.go new file mode 100644 index 0000000..8d40654 --- /dev/null +++ b/formatter.go @@ -0,0 +1,240 @@ +package rest + +import ( + "context" + "database/sql" + "fmt" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "reflect" + "strconv" + "sync" + "time" +) + +var ( + DefaultFormatter = NewFormatter() + + DefaultNullDisplay = "" +) + +func init() { + DefaultFormatter.Register("string", stringFormat) + DefaultFormatter.Register("integer", integerFormat) + DefaultFormatter.Register("decimal", decimalFormat) + DefaultFormatter.Register("date", dateFormat) + DefaultFormatter.Register("time", timeFormat) + DefaultFormatter.Register("datetime", datetimeFormat) + DefaultFormatter.Register("duration", durationFormat) + DefaultFormatter.Register("dropdown", dropdownFormat) + DefaultFormatter.Register("timestamp", datetimeFormat) + DefaultFormatter.Register("percentage", percentageFormat) +} + +type FormatFunc func(ctx context.Context, value any, model any, scm *types.Schema) any + +type Formatter struct { + callbacks sync.Map +} + +func (formatter *Formatter) Register(f string, fun FormatFunc) { + formatter.callbacks.Store(f, fun) +} + +func (formatter *Formatter) Format(ctx context.Context, format string, value any, model any, scm *types.Schema) any { + v, ok := formatter.callbacks.Load(format) + if ok { + return v.(FormatFunc)(ctx, value, model, scm) + } + return value +} + +func (formatter *Formatter) getModelValue(refValue reflect.Value, schema *types.Schema, stmt *gorm.Statement) any { + if stmt.Schema == nil { + return nil + } + field := stmt.Schema.LookUpField(schema.Column) + if field == nil { + return nil + } + return refValue.FieldByName(field.Name).Interface() +} + +func (formatter *Formatter) formatModel(ctx context.Context, refValue reflect.Value, schemas []*types.Schema, stmt *gorm.Statement, format string) any { + values := make(map[string]any) + multiValues := make(map[string]multiValue) + modelValue := refValue.Interface() + refValue = reflect.Indirect(refValue) + for _, scm := range schemas { + switch format { + case types.FormatRaw: + values[scm.Column] = formatter.getModelValue(refValue, scm, stmt) + case types.FormatBoth: + v := multiValue{ + Value: formatter.getModelValue(refValue, scm, stmt), + } + v.Text = formatter.Format(ctx, scm.Format, v.Value, modelValue, scm) + multiValues[scm.Column] = v + default: + values[scm.Column] = formatter.Format(ctx, scm.Format, formatter.getModelValue(refValue, scm, stmt), modelValue, scm) + } + } + if format == types.FormatBoth { + return multiValues + } else { + return values + } +} + +func (formatter *Formatter) formatModels(ctx context.Context, val any, schemas []*types.Schema, stmt *gorm.Statement, format string) any { + refValue := reflect.Indirect(reflect.ValueOf(val)) + if refValue.Kind() != reflect.Slice { + return []any{} + } + length := refValue.Len() + values := make([]any, length) + for i := 0; i < length; i++ { + rowValue := refValue.Index(i) + modelValue := rowValue.Interface() + if formatModel, ok := modelValue.(types.FormatModel); ok { + formatModel.Format(ctx, format, schemas) + } + values[i] = formatter.formatModel(ctx, rowValue, schemas, stmt, format) + } + return values +} + +func stringFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + return fmt.Sprint(value) +} + +func integerFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + var ( + n int + ) + switch value.(type) { + case float32, float64: + n = int(reflect.ValueOf(value).Float()) + case int, int8, int16, int32, int64: + n = int(reflect.ValueOf(value).Int()) + case uint, uint8, uint16, uint32, uint64: + n = int(reflect.ValueOf(value).Uint()) + case string: + n, _ = strconv.Atoi(reflect.ValueOf(value).String()) + case []byte: + n, _ = strconv.Atoi(string(reflect.ValueOf(value).Bytes())) + } + return n +} + +func decimalFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + var ( + n float64 + ) + switch value.(type) { + case float32, float64: + n = reflect.ValueOf(value).Float() + case int, int8, int16, int32, int64: + n = float64(reflect.ValueOf(value).Int()) + case uint, uint8, uint16, uint32, uint64: + n = float64(reflect.ValueOf(value).Uint()) + case string: + n, _ = strconv.ParseFloat(reflect.ValueOf(value).String(), 64) + case []byte: + n, _ = strconv.ParseFloat(string(reflect.ValueOf(value).Bytes()), 64) + } + return n +} + +func dateFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + if t, ok := value.(time.Time); ok { + return t.Format("2006-01-02") + } + if t, ok := value.(*sql.NullTime); ok { + if t != nil && t.Valid { + return t.Time.Format("2006-01-02") + } + } + if t, ok := value.(int64); ok { + tm := time.Unix(t, 0) + return tm.Format("2006-01-02") + } + return DefaultNullDisplay +} + +func timeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + if t, ok := value.(time.Time); ok { + return t.Format("15:04:05") + } + if t, ok := value.(*sql.NullTime); ok { + if t != nil && t.Valid { + return t.Time.Format("15:04:05") + } + } + if t, ok := value.(int64); ok { + tm := time.Unix(t, 0) + return tm.Format("15:04:05") + } + return value +} + +func datetimeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + if t, ok := value.(time.Time); ok { + return t.Format("2006-01-02 15:04:05") + } + if t, ok := value.(*sql.NullTime); ok { + if t != nil && t.Valid { + return t.Time.Format("2006-01-02 15:04:05") + } + } + if t, ok := value.(int64); ok { + if t > 0 { + tm := time.Unix(t, 0) + return tm.Format("2006-01-02 15:04:05") + } + } + return DefaultNullDisplay +} + +func percentageFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + n := decimalFormat(ctx, value, model, schema).(float64) + if n <= 1 { + return fmt.Sprintf("%.2f%%", n*100) + } else { + return fmt.Sprintf("%.2f%%", n) + } +} + +func durationFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + var ( + hour int + minVal int + sec int + ) + n := integerFormat(ctx, value, model, schema).(int) + hour = n / 3600 + minVal = (n - hour*3600) / 60 + sec = n - hour*3600 - minVal*60 + return fmt.Sprintf("%02d:%02d:%02d", hour, minVal, sec) +} + +func dropdownFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + attributes := schema.Attribute + if attributes.Values != nil { + for _, v := range attributes.Values { + if v.Value == value { + return v.Label + } + } + } + return value +} + +func NewFormatter() *Formatter { + formatter := &Formatter{} + return formatter +} + +func RegisterFormat(f string, cb FormatFunc) { + DefaultFormatter.Register(f, cb) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f2bb69f --- /dev/null +++ b/go.mod @@ -0,0 +1,26 @@ +module git.nobla.cn/golang/rest + +go 1.22.9 + +require ( + git.nobla.cn/golang/kos v0.1.32 + github.com/cespare/xxhash/v2 v2.3.0 + github.com/go-playground/validator/v10 v10.23.0 + github.com/longbridgeapp/sqlparser v0.3.2 + github.com/rs/xid v1.6.0 + github.com/uole/sqlparser v0.0.1 + gorm.io/gorm v1.25.12 +) + +require ( + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4446986 --- /dev/null +++ b/go.sum @@ -0,0 +1,46 @@ +git.nobla.cn/golang/kos v0.1.32 h1:sFVCA7vKc8dPUd0cxzwExOSPX2mmMh2IuwL6cYS1pBc= +git.nobla.cn/golang/kos v0.1.32/go.mod h1:35Z070+5oB39WcVrh5DDlnVeftL/Ccmscw2MZFe9fUg= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL8sThn8IHr/sO+o= +github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M= +github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y= +github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/uole/sqlparser v0.0.1 h1:LLUklg6Ne5MypXQuo53QcJv/xKdxtEKM9iUuEBN/lt8= +github.com/uole/sqlparser v0.0.1/go.mod h1:CRYFz2PTm9oHM0j9GFKi1VzPy70r6GsF0b9vpnwJ4yI= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/hook.go b/hook.go new file mode 100644 index 0000000..91c81ad --- /dev/null +++ b/hook.go @@ -0,0 +1,251 @@ +package rest + +import ( + "context" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" +) + +const ( + beforeCreate = "beforeCreate" + afterCreate = "afterCreate" + beforeUpdate = "beforeUpdate" + afterUpdate = "afterUpdate" + beforeSave = "beforeSave" + afterSave = "afterSave" + beforeDelete = "beforeDelete" + afterDelete = "afterDelete" + afterExport = "afterExport" + afterImport = "afterImport" +) + +type ( + BeforeCreate func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterCreate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) + BeforeUpdate func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterUpdate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) + BeforeSave func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterSave func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) + BeforeDelete func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterDelete func(ctx context.Context, tx *gorm.DB, model any) + AfterExport func(ctx context.Context, filename string) //导出回调 + AfterImport func(ctx context.Context, result *types.ImportResult) //导入回调 + hookManager struct { + callbacks map[string][]any + } +) + +type ( + ActiveQuery interface { + BeforeQuery(ctx context.Context, query *Query) (err error) + AfterQuery(ctx context.Context, query *Query) (err error) + } +) + +func (hook *hookManager) register(spec string, cb any) { + if hook.callbacks == nil { + hook.callbacks = make(map[string][]any) + } + if _, ok := hook.callbacks[spec]; !ok { + hook.callbacks[spec] = make([]any, 0) + } + hook.callbacks[spec] = append(hook.callbacks[spec], cb) +} + +func (hook *hookManager) beforeCreate(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeCreate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeCreate); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterCreate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) { + callbacks, ok := hook.callbacks[afterCreate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterCreate); ok { + cb(ctx, tx, model, diff) + } + } + return +} + +func (hook *hookManager) beforeUpdate(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeUpdate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeUpdate); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterUpdate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) { + callbacks, ok := hook.callbacks[afterUpdate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterUpdate); ok { + cb(ctx, tx, model, diff) + } + } + return +} + +func (hook *hookManager) beforeSave(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeSave] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeSave); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterSave(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) { + callbacks, ok := hook.callbacks[afterSave] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterSave); ok { + cb(ctx, tx, model, diff) + } + } + return +} + +func (hook *hookManager) beforeDelete(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeDelete] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeDelete); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterDelete(ctx context.Context, tx *gorm.DB, model any) { + callbacks, ok := hook.callbacks[afterDelete] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterDelete); ok { + cb(ctx, tx, model) + } + } + return +} + +func (hook *hookManager) afterExport(ctx context.Context, filename string) { + callbacks, ok := hook.callbacks[afterExport] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterExport); ok { + cb(ctx, filename) + } + } + return +} + +func (hook *hookManager) afterImport(ctx context.Context, ret *types.ImportResult) { + callbacks, ok := hook.callbacks[afterImport] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterImport); ok { + cb(ctx, ret) + } + } + return +} + +func (hook *hookManager) BeforeCreate(cb BeforeCreate) { + if cb != nil { + hook.register(beforeCreate, cb) + } +} + +func (hook *hookManager) AfterCreate(cb AfterCreate) { + if cb != nil { + hook.register(afterCreate, cb) + } +} + +func (hook *hookManager) BeforeUpdate(cb BeforeUpdate) { + if cb != nil { + hook.register(beforeUpdate, cb) + } +} + +func (hook *hookManager) AfterUpdate(cb AfterUpdate) { + if cb != nil { + hook.register(afterUpdate, cb) + } +} + +func (hook *hookManager) BeforeSave(cb BeforeSave) { + if cb != nil { + hook.register(beforeSave, cb) + } +} + +func (hook *hookManager) AfterSave(cb AfterSave) { + if cb != nil { + hook.register(afterSave, cb) + } +} + +func (hook *hookManager) BeforeDelete(cb BeforeDelete) { + if cb != nil { + hook.register(beforeDelete, cb) + } +} + +func (hook *hookManager) AfterDelete(cb AfterDelete) { + if cb != nil { + hook.register(afterDelete, cb) + } +} + +func (hook *hookManager) AfterExport(cb AfterExport) { + if cb != nil { + hook.register(afterExport, cb) + } +} + +func (hook *hookManager) AfterImport(cb AfterImport) { + if cb != nil { + hook.register(afterImport, cb) + } +} diff --git a/inflector/inflector.go b/inflector/inflector.go new file mode 100644 index 0000000..7a03031 --- /dev/null +++ b/inflector/inflector.go @@ -0,0 +1,408 @@ +package inflector + +import ( + "bytes" + "fmt" + "regexp" + "strings" + "sync" +) + +// Rule represents name of the inflector rule, can be +// Plural or Singular +type Rule int + +const ( + Plural = iota + Singular +) + +// InflectorRule represents inflector rule +type InflectorRule struct { + Rules []*ruleItem + Irregular []*irregularItem + Uninflected []string + compiledIrregular *regexp.Regexp + compiledUninflected *regexp.Regexp + compiledRules []*compiledRule +} + +type ruleItem struct { + pattern string + replacement string +} + +type irregularItem struct { + word string + replacement string +} + +// compiledRule represents compiled version of Inflector.Rules. +type compiledRule struct { + replacement string + *regexp.Regexp +} + +// threadsafe access to rules and caches +var mutex sync.Mutex +var rules = make(map[Rule]*InflectorRule) + +// Words that should not be inflected +var uninflected = []string{ + `Amoyese`, `bison`, `Borghese`, `bream`, `breeches`, `britches`, `buffalo`, + `cantus`, `carp`, `chassis`, `clippers`, `cod`, `coitus`, `Congoese`, + `contretemps`, `corps`, `debris`, `diabetes`, `djinn`, `eland`, `elk`, + `equipment`, `Faroese`, `flounder`, `Foochowese`, `gallows`, `Genevese`, + `Genoese`, `Gilbertese`, `graffiti`, `headquarters`, `herpes`, `hijinks`, + `Hottentotese`, `information`, `innings`, `jackanapes`, `Kiplingese`, + `Kongoese`, `Lucchese`, `mackerel`, `Maltese`, `.*?media`, `mews`, `moose`, + `mumps`, `Nankingese`, `news`, `nexus`, `Niasese`, `Pekingese`, + `Piedmontese`, `pincers`, `Pistoiese`, `pliers`, `Portuguese`, `proceedings`, + `rabies`, `rice`, `rhinoceros`, `salmon`, `Sarawakese`, `scissors`, + `sea[- ]bass`, `series`, `Shavese`, `shears`, `siemens`, `species`, `swine`, + `testes`, `trousers`, `trout`, `tuna`, `Vermontese`, `Wenchowese`, `whiting`, + `wildebeest`, `Yengeese`, +} + +// Plural words that should not be inflected +var uninflectedPlurals = []string{ + `.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`, + `people`, +} + +// Singular words that should not be inflected +var uninflectedSingulars = []string{ + `.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`, + `.*ss`, +} + +type cache map[string]string + +// Inflected words that already cached for immediate retrieval from a given Rule +var caches = make(map[Rule]cache) + +// map of irregular words where its key is a word and its value is the replacement +var irregularMaps = make(map[Rule]cache) + +var ( + // https://github.com/golang/lint/blob/master/lint.go#L770 + commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} + commonInitialismsReplacer *strings.Replacer +) + +func init() { + rules[Plural] = &InflectorRule{ + Rules: []*ruleItem{ + {`(?i)(s)tatus$`, `${1}${2}tatuses`}, + {`(?i)(quiz)$`, `${1}zes`}, + {`(?i)^(ox)$`, `${1}${2}en`}, + {`(?i)([m|l])ouse$`, `${1}ice`}, + {`(?i)(matr|vert|ind)(ix|ex)$`, `${1}ices`}, + {`(?i)(x|ch|ss|sh)$`, `${1}es`}, + {`(?i)([^aeiouy]|qu)y$`, `${1}ies`}, + {`(?i)(hive)$`, `$1s`}, + {`(?i)(?:([^f])fe|([lre])f)$`, `${1}${2}ves`}, + {`(?i)sis$`, `ses`}, + {`(?i)([ti])um$`, `${1}a`}, + {`(?i)(p)erson$`, `${1}eople`}, + {`(?i)(m)an$`, `${1}en`}, + {`(?i)(c)hild$`, `${1}hildren`}, + {`(?i)(buffal|tomat)o$`, `${1}${2}oes`}, + {`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|vir)us$`, `${1}i`}, + {`(?i)us$`, `uses`}, + {`(?i)(alias)$`, `${1}es`}, + {`(?i)(ax|cris|test)is$`, `${1}es`}, + {`s$`, `s`}, + {`^$`, ``}, + {`$`, `s`}, + }, + Irregular: []*irregularItem{ + {`atlas`, `atlases`}, + {`beef`, `beefs`}, + {`brother`, `brothers`}, + {`cafe`, `cafes`}, + {`child`, `children`}, + {`cookie`, `cookies`}, + {`corpus`, `corpuses`}, + {`cow`, `cows`}, + {`ganglion`, `ganglions`}, + {`genie`, `genies`}, + {`genus`, `genera`}, + {`graffito`, `graffiti`}, + {`hoof`, `hoofs`}, + {`loaf`, `loaves`}, + {`man`, `men`}, + {`money`, `monies`}, + {`mongoose`, `mongooses`}, + {`move`, `moves`}, + {`mythos`, `mythoi`}, + {`niche`, `niches`}, + {`numen`, `numina`}, + {`occiput`, `occiputs`}, + {`octopus`, `octopuses`}, + {`opus`, `opuses`}, + {`ox`, `oxen`}, + {`penis`, `penises`}, + {`person`, `people`}, + {`sex`, `sexes`}, + {`soliloquy`, `soliloquies`}, + {`testis`, `testes`}, + {`trilby`, `trilbys`}, + {`turf`, `turfs`}, + {`potato`, `potatoes`}, + {`hero`, `heroes`}, + {`tooth`, `teeth`}, + {`goose`, `geese`}, + {`foot`, `feet`}, + }, + } + prepare(Plural) + + rules[Singular] = &InflectorRule{ + Rules: []*ruleItem{ + {`(?i)(s)tatuses$`, `${1}${2}tatus`}, + {`(?i)^(.*)(menu)s$`, `${1}${2}`}, + {`(?i)(quiz)zes$`, `$1`}, + {`(?i)(matr)ices$`, `${1}ix`}, + {`(?i)(vert|ind)ices$`, `${1}ex`}, + {`(?i)^(ox)en`, `$1`}, + {`(?i)(alias)(es)*$`, `$1`}, + {`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|viri?)i$`, `${1}us`}, + {`(?i)([ftw]ax)es`, `$1`}, + {`(?i)(cris|ax|test)es$`, `${1}is`}, + {`(?i)(shoe|slave)s$`, `$1`}, + {`(?i)(o)es$`, `$1`}, + {`ouses$`, `ouse`}, + {`([^a])uses$`, `${1}us`}, + {`(?i)([m|l])ice$`, `${1}ouse`}, + {`(?i)(x|ch|ss|sh)es$`, `$1`}, + {`(?i)(m)ovies$`, `${1}${2}ovie`}, + {`(?i)(s)eries$`, `${1}${2}eries`}, + {`(?i)([^aeiouy]|qu)ies$`, `${1}y`}, + {`(?i)(tive)s$`, `$1`}, + {`(?i)([lre])ves$`, `${1}f`}, + {`(?i)([^fo])ves$`, `${1}fe`}, + {`(?i)(hive)s$`, `$1`}, + {`(?i)(drive)s$`, `$1`}, + {`(?i)(^analy)ses$`, `${1}sis`}, + {`(?i)(analy|diagno|^ba|(p)arenthe|(p)rogno|(s)ynop|(t)he)ses$`, `${1}${2}sis`}, + {`(?i)([ti])a$`, `${1}um`}, + {`(?i)(p)eople$`, `${1}${2}erson`}, + {`(?i)(m)en$`, `${1}an`}, + {`(?i)(c)hildren$`, `${1}${2}hild`}, + {`(?i)(n)ews$`, `${1}${2}ews`}, + {`eaus$`, `eau`}, + {`^(.*us)$`, `$1`}, + {`(?i)s$`, ``}, + }, + Irregular: []*irregularItem{ + {`foes`, `foe`}, + {`waves`, `wave`}, + {`curves`, `curve`}, + {`atlases`, `atlas`}, + {`beefs`, `beef`}, + {`brothers`, `brother`}, + {`cafes`, `cafe`}, + {`children`, `child`}, + {`cookies`, `cookie`}, + {`corpuses`, `corpus`}, + {`cows`, `cow`}, + {`ganglions`, `ganglion`}, + {`genies`, `genie`}, + {`genera`, `genus`}, + {`graffiti`, `graffito`}, + {`hoofs`, `hoof`}, + {`loaves`, `loaf`}, + {`men`, `man`}, + {`monies`, `money`}, + {`mongooses`, `mongoose`}, + {`moves`, `move`}, + {`mythoi`, `mythos`}, + {`niches`, `niche`}, + {`numina`, `numen`}, + {`occiputs`, `occiput`}, + {`octopuses`, `octopus`}, + {`opuses`, `opus`}, + {`oxen`, `ox`}, + {`penises`, `penis`}, + {`people`, `person`}, + {`sexes`, `sex`}, + {`soliloquies`, `soliloquy`}, + {`testes`, `testis`}, + {`trilbys`, `trilby`}, + {`turfs`, `turf`}, + {`potatoes`, `potato`}, + {`heroes`, `hero`}, + {`teeth`, `tooth`}, + {`geese`, `goose`}, + {`feet`, `foot`}, + }, + } + prepare(Singular) + + commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +// prepare rule, e.g., compile the pattern. +func prepare(r Rule) error { + var reString string + + switch r { + case Plural: + // Merge global uninflected with singularsUninflected + rules[r].Uninflected = merge(uninflected, uninflectedPlurals) + case Singular: + // Merge global uninflected with singularsUninflected + rules[r].Uninflected = merge(uninflected, uninflectedSingulars) + } + + // Set InflectorRule.compiledUninflected by joining InflectorRule.Uninflected into + // a single string then compile it. + reString = fmt.Sprintf(`(?i)(^(?:%s))$`, strings.Join(rules[r].Uninflected, `|`)) + rules[r].compiledUninflected = regexp.MustCompile(reString) + + // Prepare irregularMaps + irregularMaps[r] = make(cache, len(rules[r].Irregular)) + + // Set InflectorRule.compiledIrregular by joining the irregularItem.word of Inflector.Irregular + // into a single string then compile it. + vIrregulars := make([]string, len(rules[r].Irregular)) + for i, item := range rules[r].Irregular { + vIrregulars[i] = item.word + irregularMaps[r][item.word] = item.replacement + } + reString = fmt.Sprintf(`(?i)(.*)\b((?:%s))$`, strings.Join(vIrregulars, `|`)) + rules[r].compiledIrregular = regexp.MustCompile(reString) + + // Compile all patterns in InflectorRule.Rules + rules[r].compiledRules = make([]*compiledRule, len(rules[r].Rules)) + for i, item := range rules[r].Rules { + rules[r].compiledRules[i] = &compiledRule{item.replacement, regexp.MustCompile(item.pattern)} + } + + // Prepare caches + caches[r] = make(cache) + + return nil +} + +// merge slice a and slice b +func merge(a []string, b []string) []string { + result := make([]string, len(a)+len(b)) + copy(result, a) + copy(result[len(a):], b) + + return result +} + +func getInflected(r Rule, s string) string { + mutex.Lock() + defer mutex.Unlock() + if v, ok := caches[r][s]; ok { + return v + } + + // Check for irregular words + if res := rules[r].compiledIrregular.FindStringSubmatch(s); len(res) >= 3 { + var buf bytes.Buffer + + buf.WriteString(res[1]) + buf.WriteString(s[0:1]) + buf.WriteString(irregularMaps[r][strings.ToLower(res[2])][1:]) + + // Cache it then returns + caches[r][s] = buf.String() + return caches[r][s] + } + + // Check for uninflected words + if rules[r].compiledUninflected.MatchString(s) { + caches[r][s] = s + return caches[r][s] + } + + // Check each rule + for _, re := range rules[r].compiledRules { + if re.MatchString(s) { + caches[r][s] = re.ReplaceAllString(s, re.replacement) + return caches[r][s] + } + } + + // Returns unaltered + caches[r][s] = s + return caches[r][s] +} + +// Pluralize returns string s in plural form. +func Pluralize(s string) string { + return getInflected(Plural, s) +} + +// Singularize returns string s in singular form. +func Singularize(s string) string { + return getInflected(Singular, s) +} + +var ( + camelizeReg = regexp.MustCompile(`[^A-Za-z0-9]+`) +) + +// Camelize Converts a word like "send_email" to "SendEmail" +func Camelize(s string) string { + s = camelizeReg.ReplaceAllString(s, " ") + return strings.Replace(strings.Title(s), " ", "", -1) +} + +// Camel2id Converts a word like "SendEmail" to "send_email" +func Camel2id(name string) string { + var ( + value = commonInitialismsReplacer.Replace(name) + buf strings.Builder + lastCase, nextCase, nextNumber bool // upper case == true + curCase = value[0] <= 'Z' && value[0] >= 'A' + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' + nextNumber = value[i+1] >= '0' && value[i+1] <= '9' + + if curCase { + if lastCase && (nextCase || nextNumber) { + buf.WriteRune(v + 32) + } else { + if i > 0 && value[i-1] != '_' && value[i+1] != '_' { + buf.WriteByte('_') + } + buf.WriteRune(v + 32) + } + } else { + buf.WriteRune(v) + } + + lastCase = curCase + curCase = nextCase + } + + if curCase { + if !lastCase && len(value) > 1 { + buf.WriteByte('_') + } + buf.WriteByte(value[len(value)-1] + 32) + } else { + buf.WriteByte(value[len(value)-1]) + } + ret := buf.String() + return ret +} + +// Camel2words Converts a CamelCase name into space-separated words. +// For example, 'send_email' will be converted to 'Send Email'. +func Camel2words(s string) string { + s = camelizeReg.ReplaceAllString(s, " ") + return strings.Title(s) +} diff --git a/model.go b/model.go new file mode 100644 index 0000000..82f92cf --- /dev/null +++ b/model.go @@ -0,0 +1,987 @@ +package rest + +import ( + "context" + "encoding/csv" + "encoding/json" + "fmt" + "git.nobla.cn/golang/kos/util/arrays" + "git.nobla.cn/golang/kos/util/pool" + "git.nobla.cn/golang/rest/inflector" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "path" + "reflect" + "strconv" + "strings" + "time" +) + +type Model struct { + naming types.Naming //命名规则 + value reflect.Value //模块值 + db *gorm.DB //数据库 + primaryKey string //主键 + urlPrefix string //url前缀 + disableDomain bool //禁用域 + schemaLookup types.SchemaLookupFunc //获取schema的函数 + valueLookup types.ValueLookupFunc //查看域 + statement *gorm.Statement //字段声明 + formatter *Formatter //格式化 + response types.HttpWriter //HTTP响应 + hookMgr *hookManager //钩子管理器 + dirname string //存放文件目录 +} + +var ( + RuntimeScopeKey = &types.RuntimeScope{} +) + +// getDB 获取数据库连接对象 +func (m *Model) getDB() *gorm.DB { + return m.db +} + +// getFormatter 获取格式化组件 +func (m *Model) getFormatter() *Formatter { + if m.formatter != nil { + return m.formatter + } + return DefaultFormatter +} + +// getHook 获取钩子 +func (m *Model) getHook() *hookManager { + return m.hookMgr +} + +// hasScenario 判断是否有该场景 +func (m *Model) hasScenario(s string) bool { + return true +} + +// setValue 设置字段的值 +func (m *Model) setValue(refValue reflect.Value, column string, value any) { + SetFieldValue(m.statement, refValue, column, value) +} + +func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) { + SafeSetFileValue(m.statement, refValue, column, value) +} + +// getValue 获取字段的值 +func (m *Model) getValue(refValue reflect.Value, column string) interface{} { + return GetFieldValue(m.statement, refValue, column) +} + +// hasColumn 判断指定的列是否存在 +func (m *Model) hasColumn(column string) bool { + for _, field := range m.statement.Schema.Fields { + if field.DBName == column || field.Name == column { + return true + } + } + return false +} + +// getFilename 获取文件存放目录 +func (m *Model) getFilename(domain string, spec string, name string) string { + if m.dirname == "" { + m.dirname = os.TempDir() + } + filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name) + if _, err := os.Stat(path.Dir(filename)); err != nil { + _ = os.MkdirAll(path.Dir(filename), 0755) + } + return filename +} + +// findPrimaryKey 查找主键的值 +func (m *Model) findPrimaryKey(uri string, r *http.Request) string { + var ( + pos int + ) + urlPath := r.URL.Path + pos = strings.IndexByte(uri, ':') + if pos > 0 { + return urlPath[pos:] + } + return "" +} + +// parseReportColumn 解析报表的列 +func (m *Model) parseReportColumn(name, props string) *types.SelectColumn { + var ( + key string + value string + ) + column := &types.SelectColumn{ + Name: inflector.Camel2id(name), + Native: false, + } + tokens := strings.Split(props, ";") + for _, token := range tokens { + pair := strings.SplitN(token, ":", 2) + if len(pair) == 0 { + continue + } + if len(pair) == 1 { + key = strings.TrimSpace(pair[0]) + value = "" + } else { + key = strings.TrimSpace(pair[0]) + value = strings.TrimSpace(pair[1]) + } + switch key { + case "native": + column.Native = true + case "name": + column.Name = value + case "expr": + column.Expr = value + } + } + return column +} + +func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + columns := make([]string, 0) + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + scenarios := field.Tag.Get("scenarios") + if !hasToken(types.ScenarioList, scenarios) { + continue + } + isPrimary := field.Tag.Get("is_primary") + if isPrimary != "true" { + continue + } + column := m.parseReportColumn(field.Name, field.Tag.Get("report")) + if !column.Native { + continue + } + if column.Expr == "" { + columns = append(columns, dest.QuoteColumn(ctx, column.Name)) + } else { + columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name))) + } + } + columns = append(columns, "COUNT(*) AS count") + query.Select(columns...) +} + +func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + columns := make([]string, 0) + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + scenarios := field.Tag.Get("scenarios") + if !hasToken(types.ScenarioList, scenarios) { + continue + } + column := m.parseReportColumn(field.Name, field.Tag.Get("report")) + if !column.Native { + continue + } + if column.Expr == "" { + columns = append(columns, dest.QuoteColumn(ctx, column.Name)) + } else { + columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name))) + } + } + query.Select(columns...) +} + +// buildCondition 构建sql条件 +func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) { + return BuildConditions(ctx, r, query, schemas) +} + +// ModuleName 模块名称 +func (m *Model) ModuleName() string { + return m.naming.ModuleName +} + +// TableName 表的名称 +func (m *Model) TableName() string { + return m.naming.ModuleName +} + +// Fields 返回搜索的模型的字段 +func (m *Model) Fields() []*schema.Field { + return m.statement.Schema.Fields +} + +// Uri 获取请求的uri +func (m *Model) Uri(scenario string) string { + ss := make([]string, 4) + if m.urlPrefix != "" { + ss = append(ss, m.urlPrefix) + } + switch scenario { + case types.ScenarioList: + ss = append(ss, m.naming.ModuleName, m.naming.Pluralize) + case types.ScenarioView: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") + case types.ScenarioCreate: + ss = append(ss, m.naming.ModuleName, m.naming.Singular) + case types.ScenarioUpdate: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") + case types.ScenarioDelete: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") + case types.ScenarioExport: + ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export") + case types.ScenarioImport: + ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import") + } + uri := path.Join(ss...) + if !strings.HasPrefix(uri, "/") { + uri = "/" + uri + } + return uri +} + +// Method 获取HTTP请求的方法 +func (m *Model) Method(scenario string) string { + var ( + method = http.MethodGet + ) + switch scenario { + case types.ScenarioCreate: + method = http.MethodPost + case types.ScenarioUpdate: + method = http.MethodPut + case types.ScenarioDelete: + method = http.MethodDelete + } + return method +} + +// Search 实现通过HTTP方法查找数据 +func (m *Model) Search(w http.ResponseWriter, r *http.Request) { + var ( + ok bool + err error + qs url.Values + page int + pageSize int + pageIndex int + query *Query + domainName string + modelSlices reflect.Value + modelValues reflect.Value + searchSchemas []*types.Schema + listSchemas []*types.Schema + modelValue reflect.Value + scenario string + reporter types.Reporter + namerTable tableNamer + ) + qs = r.URL.Query() + page, _ = strconv.Atoi(qs.Get("page")) + pageSize, _ = strconv.Atoi(qs.Get("pagesize")) + if pageSize <= 0 { + pageSize = defaultPageSize + } + pageIndex = page + if pageIndex > 0 { + pageIndex-- + } + modelValue = reflect.New(m.value.Type()) + //这里创建指针类型,这样的话就能在format里面调用函数 + if m.value.Kind() != reflect.Ptr { + modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0) + } else { + modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0) + } + modelValues = reflect.New(modelSlices.Type()) + modelValues.Elem().Set(modelSlices) + query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface()) + domainName = m.valueLookup(types.FieldDomain, w, r) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + Request: r, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioList, + }) + if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + scenario = types.ScenarioList + if arrays.Exists(r.FormValue("scenario"), allowScenario) { + scenario = r.FormValue("scenario") + } + if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if !m.disableDomain { + if m.hasColumn(types.FieldDomain) { + query.AndWhere(newCondition(types.FieldDomain, domainName)) + } + } + if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) + return + } + // 处理表名逻辑 + if namerTable, ok = query.Model().(tableNamer); ok { + query.From(namerTable.HttpTableName(r)) + } + //处理报表逻辑 + if reporter, ok = modelValue.Interface().(types.Reporter); ok { + query.From(reporter.RealTable()) + } + res := &types.ListResponse{ + Page: page, + PageSize: pageSize, + } + if reporter == nil { + res.TotalCount = query.Limit(0).Offset(0).Count(query.Model()) + } else { + //如果是报表的情况,需要手动指定COUNT的雨具逻辑才能生效 + m.buildReporterCountColumns(childCtx, reporter, query) + res.TotalCount = query.Limit(0).Offset(0).Count(query.Model()) + + //这里需要重置一下选项,不然会出问题 + query.ResetSelect() + query.GroupBy(reporter.GroupBy(childCtx)...) + } + query.Offset(pageIndex * pageSize).Limit(pageSize) + if res.TotalCount > 0 { + if reporter != nil { + m.buildReporterQueryColumns(childCtx, reporter, query) + } + if err = query.All(modelValues.Interface()); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + // 不进行格式化输出 + res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format")) + } else { + res.Data = make([]string, 0) + } + m.response.Success(w, res) +} + +// Create 实现通过HTTP方法创建模型 +func (m *Model) Create(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + schemas []*types.Schema + diffAttrs []*types.DiffAttr + domainName string + modelValue reflect.Value + ) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil) + return + } + domainName = m.valueLookup(types.FieldDomain, w, r) + if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if !m.disableDomain { + if m.hasColumn(types.FieldDomain) { + m.setValue(modelValue, types.FieldDomain, domainName) + } + } + diffAttrs = make([]*types.DiffAttr, 0, 10) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + User: m.valueLookup("user", w, r), + Request: r, + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioCreate, + Schemas: schemas, + }) + dbSess := m.getDB().WithContext(childCtx) + if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil { + return + } + if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil { + return + } + if tabler, ok := model.(types.Tabler); ok { + errTx = tx.Table(tabler.TableName()).Save(model).Error + } else { + errTx = tx.Save(model).Error + } + if errTx != nil { + return + } + for _, row := range schemas { + diffAttrs = append(diffAttrs, &types.DiffAttr{ + Column: row.Column, + Label: row.Label, + OldValue: nil, + NewValue: m.getValue(modelValue, row.Column), + }) + } + return + }); err == nil { + res := &types.CreateResponse{ + ID: m.getValue(modelValue, m.primaryKey), + Status: "created", + } + if creator, ok := model.(afterCreated); ok { + creator.AfterCreated(childCtx, dbSess) + } + if preserver, ok := model.(afterSaved); ok { + preserver.AfterSaved(childCtx, dbSess) + } + m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs) + m.getHook().afterSave(childCtx, dbSess, model, diffAttrs) + m.response.Success(w, res) + } else { + m.response.Failure(w, types.RequestCreateFailure, err.Error(), err) + } +} + +// Update 实现通过HTTP方法更新模型 +func (m *Model) Update(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + schemas []*types.Schema + diffAttrs []*types.DiffAttr + domainName string + modelValue reflect.Value + oldValues map[string]any + ) + idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + domainName = m.valueLookup(types.FieldDomain, w, r) + if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + conditions := map[string]any{ + m.primaryKey: idStr, + } + if err = m.getDB().Where(conditions).First(model).Error; err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + oldValues = make(map[string]any) + for _, row := range schemas { + oldValues[row.Column] = m.getValue(modelValue, row.Column) + } + if err = json.NewDecoder(r.Body).Decode(model); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) + return + } + diffAttrs = make([]*types.DiffAttr, 0, 10) + updates := make(map[string]any) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + Request: r, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioUpdate, + Schemas: schemas, + PrimaryKeyValue: idStr, + }) + dbSess := m.getDB().WithContext(childCtx) + if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil { + return + } + if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil { + return + } + for _, row := range schemas { + v := m.getValue(modelValue, row.Column) + if oldValues[row.Column] != v { + updates[row.Column] = v + diffAttrs = append(diffAttrs, &types.DiffAttr{ + Column: row.Column, + Label: row.Label, + OldValue: oldValues[row.Column], + NewValue: v, + }) + } + } + if len(updates) > 0 { + if tabler, ok := model.(types.Tabler); ok { + errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error + } else { + errTx = tx.Model(model).Updates(updates).Error + } + if errTx != nil { + return + } + } + return + }); err == nil { + if updater, ok := model.(afterUpdated); ok { + updater.AfterUpdated(childCtx, dbSess) + } + if preserver, ok := model.(afterSaved); ok { + preserver.AfterSaved(childCtx, dbSess) + } + m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs) + m.getHook().afterSave(childCtx, dbSess, model, diffAttrs) + m.response.Success(w, types.UpdateResponse{ + ID: idStr, + Status: "updated", + }) + } else { + m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil) + } +} + +// Delete 实现通过HTTP方法删除模型 +func (m *Model) Delete(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + modelValue reflect.Value + ) + idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + conditions := map[string]any{ + m.primaryKey: idStr, + } + if err = m.getDB().Where(conditions).First(model).Error; err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: m.valueLookup(types.FieldDomain, w, r), + User: m.valueLookup("user", w, r), + Request: r, + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioDelete, + PrimaryKeyValue: idStr, + }) + dbSess := m.getDB().WithContext(childCtx) + if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil { + return + } + if tabler, ok := model.(types.Tabler); ok { + errTx = tx.Table(tabler.TableName()).Delete(model).Error + } else { + errTx = tx.Delete(model).Error + } + if errTx != nil { + return + } + m.getHook().afterDelete(childCtx, tx, model) + return + }); err == nil { + m.response.Success(w, types.DeleteResponse{ + ID: idStr, + Status: "deleted", + }) + } else { + m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil) + } +} + +// View 查看数据详情 +func (m *Model) View(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + modelValue reflect.Value + qs url.Values + schemas []*types.Schema + scenario string + domainName string + ) + qs = r.URL.Query() + idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + conditions := map[string]any{ + m.primaryKey: idStr, + } + domainName = m.valueLookup(types.FieldDomain, w, r) + if err = m.getDB().Where(conditions).First(model).Error; err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + scenario = qs.Get("scenario") + if scenario == "" { + schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView) + } else { + schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario) + } + if err == nil { + m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format"))) + } else { + m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil) + } +} + +// Export 实现通过HTTP方法导出模型 +func (m *Model) Export(w http.ResponseWriter, r *http.Request) { + var ( + err error + query *Query + modelSlices reflect.Value + modelValues reflect.Value + searchSchemas []*types.Schema + exportSchemas []*types.Schema + domainName string + fp *os.File + modelValue reflect.Value + ) + if !m.hasScenario(types.ScenarioList) { + m.response.Failure(w, types.RequestDenied, "request denied", nil) + return + } + domainName = m.valueLookup(types.FieldDomain, w, r) + filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix())) + if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil) + return + } + defer func() { + _ = fp.Close() + }() + modelValue = reflect.New(m.value.Type()) + //这里创建指针类型,这样的话就能在format里面调用函数 + if m.value.Kind() != reflect.Ptr { + modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0) + } else { + modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0) + } + modelValues = reflect.New(modelSlices.Type()) + modelValues.Elem().Set(modelSlices) + query = NewQuery(m.getDB(), modelValue.Interface()) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + Request: r, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioExport, + }) + if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) + return + } + if !m.disableDomain { + if m.hasColumn(types.FieldDomain) { + query.AndWhere(newCondition(types.FieldDomain, domainName)) + } + } + // 处理表名逻辑 + if namerTable, ok := query.Model().(tableNamer); ok { + query.From(namerTable.HttpTableName(r)) + } + //处理报表逻辑 + if reporter, ok := modelValue.Interface().(types.Reporter); ok { + query.From(reporter.RealTable()) + query.GroupBy(reporter.GroupBy(childCtx)...) + m.buildReporterQueryColumns(childCtx, reporter, query) + } + if err = query.All(modelValues.Interface()); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular)) + value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "") + writer := csv.NewWriter(fp) + rows := make([]string, len(exportSchemas)) + for i, field := range exportSchemas { + rows[i] = field.Label + } + _ = writer.Write(rows) + if values, ok := value.([]any); ok { + for _, val := range values { + row, ok2 := val.(map[string]any) + if !ok2 { + continue + } + for i, field := range exportSchemas { + if v, ok := row[field.Column]; ok { + rows[i] = fmt.Sprint(v) + } else { + rows[i] = "" + } + } + _ = writer.Write(rows) + } + } + writer.Flush() + m.getHook().afterExport(childCtx, filename) + http.ServeContent(w, r, path.Base(filename), time.Now(), fp) +} + +// findSchema 查找指定的schema +func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema { + for _, row := range schemas { + if row.Label == label { + return row + } + } + return nil +} + +// importInternal 文件上传方法 +func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) { + var ( + err error + rows []string + fp *os.File + tm time.Time + fields []string + sess *gorm.DB + csvReader *csv.Reader + csvWriter *csv.Writer + modelValue reflect.Value + modelEntity any + diffAttrs []*types.DiffAttr + result *types.ImportResult + failureFp *os.File + failureFile string + ) + tm = time.Now() + result = &types.ImportResult{} + if fp, err = os.Open(filename); err != nil { + result.Code = types.ErrImportFileNotExists + goto __end + } + defer func() { + _ = fp.Close() + }() + csvReader = csv.NewReader(fp) + if rows, err = csvReader.Read(); err != nil { + result.Code = types.ErrImportFileUnavailable + goto __end + } + fields = make([]string, 0, len(rows)) + for _, s := range rows { + v := m.findSchema(s, schemas) + if v == nil { + result.Code = types.ErrImportColumnNotMatch + goto __end + } + fields = append(fields, v.Column) + } + sess = m.getDB().WithContext(ctx) + //失败文件指针 + failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix())) + if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil { + return + } + defer func() { + _ = failureFp.Close() + }() + csvWriter = csv.NewWriter(failureFp) + rows = append(rows, "Error") + _ = csvWriter.Write(rows) + diffAttrs = make([]*types.DiffAttr, len(schemas)) + for { + if rows, err = csvReader.Read(); err != nil { + break + } + result.TotalCount++ + if len(rows) != len(fields) { + continue + } + modelValue = reflect.New(m.value.Type()) + for idx, field := range fields { + m.safeSetValue(modelValue, field, rows[idx]) + } + if len(extraFields) > 0 { + for k, v := range extraFields { + m.safeSetValue(modelValue, k, v) + } + } + modelEntity = modelValue.Interface() + //写入数据 + if fast { + //如果是快速模式,直接存储数据 + if err = sess.Save(modelEntity).Error; err == nil { + result.SuccessCount++ + } else { + rows = append(rows, err.Error()) + _ = csvWriter.Write(rows) + } + } else { + if err = sess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil { + return + } + if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil { + return + } + if tabler, ok := modelEntity.(types.Tabler); ok { + errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error + } else { + errTx = tx.Save(modelEntity).Error + } + if errTx != nil { + return + } + for idx, row := range schemas { + diffAttrs[idx] = &types.DiffAttr{ + Column: row.Column, + Label: row.Label, + NewValue: m.getValue(modelValue, row.Column), + } + } + m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs) + m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs) + return + }); err == nil { + result.SuccessCount++ + } else { + rows = append(rows, err.Error()) + _ = csvWriter.Write(rows) + } + } + } + csvWriter.Flush() +__end: + result.UploadFile = filename + if result.TotalCount > result.SuccessCount { + result.FailureFile = failureFile + } + result.Duration = time.Now().Sub(tm) + m.getHook().afterImport(ctx, result) +} + +// Import 实现通过HTTP方法导入 +func (m *Model) Import(w http.ResponseWriter, r *http.Request) { + var ( + err error + fast bool + schemas []*types.Schema + rows []string + domainName string + dst *os.File + fp multipart.File + csvWriter *csv.Writer + qs url.Values + extraFields map[string]string + ) + domainName = m.valueLookup(types.FieldDomain, w, r) + if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + //这里用background的context + childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioImport, + Schemas: schemas, + }) + if r.Method == http.MethodGet { + //下载导入模板 + csvWriter = csv.NewWriter(w) + rows = make([]string, 0, len(schemas)) + for _, row := range schemas { + //主键不需要导入 + if row.IsPrimaryKey == 1 { + continue + } + rows = append(rows, row.Label) + } + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular)) + err = csvWriter.Write(rows) + csvWriter.Flush() + return + } + filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix())) + if fp, _, err = r.FormFile("file"); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil) + return + } + defer func() { + _ = fp.Close() + }() + if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil { + buf := pool.GetBytes(32 * 1024) + _, err = io.CopyBuffer(dst, fp, buf) + pool.PutBytes(buf) + _ = dst.Close() + } else { + m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil) + return + } + qs = r.URL.Query() + if qs != nil { + extraFields = make(map[string]string) + for k, _ := range qs { + if strings.HasPrefix(k, "_attr_") { + extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k) + } + } + } + fast, _ = strconv.ParseBool(qs.Get("__fast")) + go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields) + m.response.Success(w, types.ImportResponse{ + UID: m.valueLookup("user", w, r), + Status: "committed", + }) +} + +// newModel 创建一个模型 +func newModel(v any, db *gorm.DB, naming types.Naming) *Model { + model := &Model{ + db: db, + naming: naming, + response: &httpWriter{}, + value: reflect.Indirect(reflect.ValueOf(v)), + valueLookup: defaultValueLookup, + } + model.statement = &gorm.Statement{ + DB: model.getDB(), + ConnPool: model.getDB().ConnPool, + Clauses: map[string]clause.Clause{}, + } + if err := model.statement.Parse(v); err == nil { + if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 { + model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0] + } + } + return model +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..7746e20 --- /dev/null +++ b/options.go @@ -0,0 +1,52 @@ +package rest + +import "git.nobla.cn/golang/rest/types" + +type Options struct { + urlPrefix string + moduleName string + disableDomain bool + router types.HttpRouter + writer types.HttpWriter + formatter *Formatter + dirname string //文件目录 +} + +type Option func(o *Options) + +func WithUriPrefix(s string) Option { + return func(o *Options) { + o.urlPrefix = s + } +} + +func WithModuleName(s string) Option { + return func(o *Options) { + o.moduleName = s + } +} + +// WithoutDomain 禁用域 +func WithoutDomain() Option { + return func(o *Options) { + o.disableDomain = true + } +} + +func WithHttpRouter(s types.HttpRouter) Option { + return func(o *Options) { + o.router = s + } +} + +func WithHttpWriter(s types.HttpWriter) Option { + return func(o *Options) { + o.writer = s + } +} + +func WithFormatter(s *Formatter) Option { + return func(o *Options) { + o.formatter = s + } +} diff --git a/plugins/cache/cache.go b/plugins/cache/cache.go new file mode 100644 index 0000000..ad2a8c5 --- /dev/null +++ b/plugins/cache/cache.go @@ -0,0 +1,108 @@ +package cache + +import ( + "encoding/json" + "fmt" + "git.nobla.cn/golang/kos/pkg/cache" + xxhash "github.com/cespare/xxhash/v2" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "os" + "strconv" + "time" +) + +const ( + DisableCache = "DISABLE_CACHE" + DurationKey = "gorm:cache_duration" +) + +type Cacher struct { + rawQuery func(db *gorm.DB) +} + +func (c *Cacher) Name() string { + return "gorm:cache" +} + +func (c *Cacher) Initialize(db *gorm.DB) (err error) { + c.rawQuery = db.Callback().Query().Get("gorm:query") + err = db.Callback().Query().Replace("gorm:query", c.Query) + return +} + +// buildCacheKey 构建一个缓存的KEY +func (c *Cacher) buildCacheKey(db *gorm.DB) string { + s := strconv.FormatUint(xxhash.Sum64String(db.Statement.SQL.String()+fmt.Sprintf("%v", db.Statement.Vars)), 10) + return s +} + +// getDuration 获取缓存时长 +func (c *Cacher) getDuration(db *gorm.DB) time.Duration { + var ( + ok bool + v any + duration time.Duration + ) + if v, ok = db.InstanceGet(DurationKey); !ok { + return 0 + } + if duration, ok = v.(time.Duration); !ok { + return 0 + } + return duration +} + +// tryLoad 尝试从缓存读取数据 +func (c *Cacher) tryLoad(key string, db *gorm.DB) (err error) { + var ( + ok bool + buf []byte + ) + if buf, ok = cache.Get(db.Statement.Context, key); ok { + err = json.Unmarshal(buf, db.Statement.Dest) + } else { + err = os.ErrNotExist + } + return +} + +// storeCache 存储缓存数据 +func (c *Cacher) storeCache(key string, db *gorm.DB, duration time.Duration) (err error) { + var ( + buf []byte + ) + if buf, err = json.Marshal(db.Statement.Dest); err == nil { + cache.SetEx(db.Statement.Context, key, buf, duration) + } + return +} + +func (c *Cacher) Query(db *gorm.DB) { + var ( + err error + cacheKey string + duration time.Duration + ) + duration = c.getDuration(db) + if duration <= 0 { + c.rawQuery(db) + return + } + callbacks.BuildQuerySQL(db) + cacheKey = c.buildCacheKey(db) + if err = c.tryLoad(cacheKey, db); err == nil { + return + } + c.rawQuery(db) + if db.Error == nil { + //store cache + if err = c.storeCache(cacheKey, db, duration); err != nil { + _ = db.AddError(err) + } + } +} + +func New() *Cacher { + return &Cacher{} +} diff --git a/plugins/identity/identified.go b/plugins/identity/identified.go new file mode 100644 index 0000000..c26b806 --- /dev/null +++ b/plugins/identity/identified.go @@ -0,0 +1,57 @@ +package identity + +import ( + "github.com/rs/xid" + "gorm.io/gorm" + "gorm.io/gorm/schema" + "reflect" +) + +type Identify struct { +} + +func (identity *Identify) Name() string { + return "gorm:identity" +} + +func (identity *Identify) Initialize(db *gorm.DB) (err error) { + err = db.Callback().Create().Before("gorm:create").Register("auto_identified", identity.Grant) + return +} + +func (identity *Identify) NextID() string { + return xid.New().String() +} + +func (identity *Identify) Grant(db *gorm.DB) { + var ( + err error + field *schema.Field + ) + if db.Statement.Schema == nil { + return + } + if field = db.Statement.Schema.LookUpField("ID"); field == nil { + return + } + if field.DataType != schema.String { + return + } + if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue.Index(i)); zero { + if err = field.Set(db.Statement.Context, db.Statement.ReflectValue.Index(i), identity.NextID()); err != nil { + _ = db.AddError(err) + } + } + } + } else { + if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); zero { + db.Statement.SetColumn("ID", identity.NextID()) + } + } +} + +func New() *Identify { + return &Identify{} +} diff --git a/plugins/sharding/README.md b/plugins/sharding/README.md new file mode 100644 index 0000000..409de7a --- /dev/null +++ b/plugins/sharding/README.md @@ -0,0 +1,33 @@ +# 分表实现 + +首先定义一个`gorm`的模型,然后实现`shadring.Model`接口,比如如下示例 + +```go +// ShardingTable 返回增删改时候操作的数据表 +func (model *CdrLog) ShardingTable(scene string) string { + return model.TableName() +} + +// ShardingTables 返回查询时候一个范围内的表 +func (model *CdrLog) ShardingTables(ctx *sharding.Context) []string { + var ( + timestamp int64 + ) + timeRange := make([]int64, 0) + values := ctx.FindColumnValues("start_stamp") + if len(values) == 0 { + values = ctx.FindColumnValues("create_stamp") + } + if len(values) > 0 { + for _, v := range values { + timestamp, _ = strconv.ParseInt(fmt.Sprint(v), 10, 64) + timeRange = append(timeRange, timestamp) + } + } + return shard.DateTableNames(ctx.Context(), "cdr_logs", shard.ShardTypeDateMonth, timeRange) +} +``` + +`ShardingTable`方法是操作增删改的是回调具体表名的方法 + +`ShardingTables`方法是操作查询的时候,通过查询条件返回的表名的方法 \ No newline at end of file diff --git a/plugins/sharding/condition.go b/plugins/sharding/condition.go new file mode 100644 index 0000000..4b6c00f --- /dev/null +++ b/plugins/sharding/condition.go @@ -0,0 +1,120 @@ +package sharding + +const ( + ValueOperaEqual = iota + 0x10 + ValueOperaGreater + ValueOperaLess + ValueOperaRange + + ValueTypeString = iota + 0x30 + ValueTypeNumber + ValueTypeBoolean + ValueTypeNull + ValueTypeAny +) + +type ( + ColumnCondition struct { + Name string `json:"name"` + Value CondValue `json:"value"` + } + + CondValue interface { + Type() int + Opera() int + Value() any + } + + equalValue struct { + vType int + vData any + } + + rangeValue struct { + vType int + vData any + } + + greaterValue struct { + vType int + vData any + } + + lessValue struct { + vType int + vData any + } +) + +func (v *lessValue) Type() int { + return v.vType +} + +func (v *lessValue) Opera() int { + return ValueOperaLess +} + +func (v *lessValue) Value() any { + return v.vData +} + +func (v *greaterValue) Type() int { + return v.vType +} + +func (v *greaterValue) Opera() int { + return ValueOperaGreater +} + +func (v *greaterValue) Value() any { + return v.vData +} + +func (v *rangeValue) Type() int { + return v.vType +} + +func (v *rangeValue) Opera() int { + return ValueOperaRange +} + +func (v *rangeValue) Value() any { + return v.vData +} + +func (v *equalValue) Type() int { + return v.vType +} + +func (v *equalValue) Opera() int { + return ValueOperaEqual +} + +func (v *equalValue) Value() any { + return v.vData +} + +func newCondValue(vType int, op int, value any) CondValue { + switch op { + case ValueOperaGreater: + return &greaterValue{ + vType: vType, + vData: value, + } + case ValueOperaLess: + return &lessValue{ + vType: vType, + vData: value, + } + case ValueOperaRange: + return &rangeValue{ + vType: vType, + vData: value, + } + default: + return &equalValue{ + vType: vType, + vData: value, + } + } +} diff --git a/plugins/sharding/scope.go b/plugins/sharding/scope.go new file mode 100644 index 0000000..fe1bc53 --- /dev/null +++ b/plugins/sharding/scope.go @@ -0,0 +1,287 @@ +package sharding + +import ( + "context" + "github.com/longbridgeapp/sqlparser" + sqlparserX "github.com/uole/sqlparser" + "gorm.io/gorm" + "strconv" +) + +type Scope struct { + db *gorm.DB + stmt *sqlparser.SelectStatement + stmtX *sqlparserX.Select +} + +func (scope *Scope) findValue(express sqlparser.Expr) (int, any) { + var ( + vType int + vData any + ) + switch expr := express.(type) { + case *sqlparser.BindExpr: + vType = ValueTypeAny + if len(scope.db.Statement.Vars) > expr.Pos { + vData = scope.db.Statement.Vars[expr.Pos] + } else { + vType = ValueTypeNull + } + case *sqlparser.NumberLit: + vType = ValueTypeNumber + vData = expr.Value + case *sqlparser.StringLit: + vType = ValueTypeString + vData = expr.Value + case *sqlparser.BoolLit: + vType = ValueTypeBoolean + vData = expr.Value + case *sqlparser.BlobLit: + vType = ValueTypeString + vData = expr.Value + case *sqlparser.NullLit: + vType = ValueTypeNull + case *sqlparser.Range: + arr := make([]any, 2) + vType, arr[0] = scope.findValue(expr.X) + vType, arr[1] = scope.findValue(expr.Y) + vData = arr + } + return vType, vData +} + +func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) { + var ( + vType int + vData any + ) + switch express.Type { + case sqlparserX.IntVal: + vType = ValueTypeNumber + vData, _ = strconv.Atoi(string(express.Val)) + case sqlparserX.FloatVal: + vType = ValueTypeNumber + vData, _ = strconv.ParseFloat(string(express.Val), 64) + case sqlparserX.ValArg: + vType = ValueTypeAny + pos, _ := strconv.Atoi(string(express.Val[2:])) + if pos > 0 { + pos = pos - 1 + if len(scope.db.Statement.Vars) > pos { + vData = scope.db.Statement.Vars[pos] + } else { + vType = ValueTypeNull + } + } else { + vType = ValueTypeNull + } + default: + vType = ValueTypeString + vData = string(express.Val) + } + return vType, vData +} + +func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) { + var ( + ok bool + andExpr *sqlparserX.AndExpr + orExpr *sqlparserX.OrExpr + parentExpr *sqlparserX.ParenExpr + comparisonExpr *sqlparserX.ComparisonExpr + rangeExpr *sqlparserX.RangeCond + coumnExpr *sqlparserX.ColName + valueExpr *sqlparserX.SQLVal + ) + conditions = make([]*ColumnCondition, 0, 2) + if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok { + if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok { + return + } + if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok { + return + } + if coumnExpr.Name.EqualString(column) { + cond := &ColumnCondition{ + Name: coumnExpr.Name.String(), + } + vType, vData := scope.findValueX(valueExpr) + switch comparisonExpr.Operator { + case sqlparserX.LessThanStr, sqlparserX.LessEqualStr: + cond.Value = newCondValue(vType, ValueOperaLess, vData) + case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr: + cond.Value = newCondValue(vType, ValueOperaGreater, vData) + case sqlparserX.EqualStr: + cond.Value = newCondValue(vType, ValueOperaEqual, vData) + } + if cond.Value != nil { + conditions = append(conditions, cond) + } + } + } + + if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok { + if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok { + return + } + if coumnExpr.Name.EqualString(column) { + vType := 0 + arr := make([]any, 2) + if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok { + vType, arr[0] = scope.findValueX(valueExpr) + } + if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok { + vType, arr[1] = scope.findValueX(valueExpr) + } + conditions = append(conditions, &ColumnCondition{ + Name: coumnExpr.Name.String(), + Value: newCondValue(vType, ValueOperaRange, arr), + }) + } + } + + if andExpr, ok = expr.(*sqlparserX.AndExpr); ok { + if andExpr.Left != nil { + conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...) + } + if andExpr.Right != nil { + conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...) + } + } + if orExpr, ok = expr.(*sqlparserX.OrExpr); ok { + if orExpr.Left != nil { + conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...) + } + if orExpr.Right != nil { + conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...) + } + } + if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok { + if parentExpr.Expr != nil { + conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...) + } + } + return conditions +} + +func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition { + var ( + ok bool + identExpr *sqlparser.Ident + binaryExpr *sqlparser.BinaryExpr + parentExpr *sqlparser.ParenExpr + conditions []*ColumnCondition + ) + conditions = make([]*ColumnCondition, 0, 2) + if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok { + if parentExpr.X != nil { + if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok { + conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...) + } + if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok { + conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...) + } + } + } + + if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok { + if binaryExpr.X != nil { + if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok { + if identExpr.Name == column { + cond := &ColumnCondition{ + Name: identExpr.Name, + } + vType, vData := scope.findValue(binaryExpr.Y) + switch binaryExpr.Op { + case sqlparser.LT, sqlparser.LE: + cond.Value = newCondValue(vType, ValueOperaLess, vData) + case sqlparser.GT, sqlparser.GE: + cond.Value = newCondValue(vType, ValueOperaGreater, vData) + case sqlparser.RANGE, sqlparser.BETWEEN: + cond.Value = newCondValue(vType, ValueOperaRange, vData) + case sqlparser.EQ: + cond.Value = newCondValue(vType, ValueOperaEqual, vData) + } + if cond.Value != nil { + conditions = append(conditions, cond) + } + } + } else { + if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok { + conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...) + } + if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok { + conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...) + } + } + } + + if binaryExpr.Y != nil { + if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok { + conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...) + } + } + } + + return conditions +} + +func (scope *Scope) DB() *gorm.DB { + return scope.db +} + +func (scope *Scope) Context() context.Context { + return scope.db.Statement.Context +} + +func (scope *Scope) FindCondition(column string) []*ColumnCondition { + if scope.stmtX != nil { + if scope.stmtX.Where == nil { + return []*ColumnCondition{} + } + return scope.recursiveFindX(scope.stmtX.Where.Expr, column) + } + return scope.recursiveFind(scope.stmt.Condition, column) +} + +func (scope *Scope) FindColumnValues(column string) []any { + result := make([]any, 0) + conditions := scope.FindCondition(column) + if len(conditions) == 0 { + return result + } + for _, cond := range conditions { + if cond.Value.Opera() == ValueOperaGreater { + if len(result) == 0 { + result = make([]any, 2) + } + result[0] = cond.Value.Value() + } + + if cond.Value.Opera() == ValueOperaLess { + if len(result) == 0 { + result = make([]any, 2) + } + result[1] = cond.Value.Value() + } + + if cond.Value.Opera() == ValueOperaEqual { + result = append(result, cond.Value.Value()) + break + } + + if cond.Value.Opera() == ValueOperaRange { + if vs, ok := cond.Value.Value().([]any); ok { + if len(vs) == 2 { + if len(result) == 0 { + result = make([]any, 2) + } + result[0] = vs[0] + result[1] = vs[1] + } + } + break + } + } + return result +} diff --git a/plugins/sharding/sharding.go b/plugins/sharding/sharding.go new file mode 100644 index 0000000..3d3ac70 --- /dev/null +++ b/plugins/sharding/sharding.go @@ -0,0 +1,476 @@ +package sharding + +import ( + "github.com/longbridgeapp/sqlparser" + sqlparserX "github.com/uole/sqlparser" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "reflect" + "strings" +) + +type Sharding struct { + UnionAll bool + QuoteChar byte +} + +func (plugin *Sharding) Name() string { + return "gorm:sharding" +} + +func (plugin *Sharding) Initialize(db *gorm.DB) (err error) { + if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil { + return + } + if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil { + return + } + if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil { + return + } + if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil { + return err + } + return +} + +func (plugin *Sharding) Create(db *gorm.DB) { + var ( + ok bool + scopeModel Model + refValue reflect.Value + modelValue any + ) + if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { + if db.Statement.ReflectValue.Len() > 0 { + refValue = db.Statement.ReflectValue.Index(0) + modelValue = refValue.Interface() + } + } else { + if db.Statement.Model != nil { + modelValue = db.Statement.Model + } else { + modelValue = db.Statement.ReflectValue.Interface() + } + } + if modelValue != nil { + if scopeModel, ok = modelValue.(Model); ok { + db.Table(scopeModel.ShardingTable(sceneCreate)) + } + } +} + +func (plugin *Sharding) Update(db *gorm.DB) { + var ( + ok bool + scopeModel Model + refValue reflect.Value + modelValue any + ) + if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { + if db.Statement.ReflectValue.Len() > 0 { + refValue = db.Statement.ReflectValue.Index(0) + modelValue = refValue.Interface() + } + } else { + if db.Statement.Model != nil { + modelValue = db.Statement.Model + } else { + modelValue = db.Statement.ReflectValue.Interface() + } + } + if modelValue != nil { + if scopeModel, ok = modelValue.(Model); ok { + db.Table(scopeModel.ShardingTable(sceneUpdate)) + } + } +} + +func (plugin *Sharding) Delete(db *gorm.DB) { + var ( + ok bool + scopeModel Model + refValue reflect.Value + modelValue any + ) + if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { + if db.Statement.ReflectValue.Len() > 0 { + refValue = db.Statement.ReflectValue.Index(0) + modelValue = refValue.Interface() + } + } else { + if db.Statement.Model != nil { + modelValue = db.Statement.Model + } else { + modelValue = db.Statement.ReflectValue.Interface() + } + } + if modelValue != nil { + if scopeModel, ok = modelValue.(Model); ok { + db.Table(scopeModel.ShardingTable(sceneDelete)) + } + } +} + +func (plugin *Sharding) Query(db *gorm.DB) { + var ( + err error + ok bool + shardingModel Model + modelValue any + tables []string + rawVars []any + refValue reflect.Value + tableName *sqlparser.TableName + selectStmt *sqlparser.SelectStatement + stmt sqlparser.Statement + parser *sqlparser.Parser + numOfTable int + orderByExpr []*sqlparser.OrderingTerm + limitExpr sqlparser.Expr + offsetExpr sqlparser.Expr + groupingExpr []sqlparser.Expr + havingExpr sqlparser.Expr + isCountStatement bool + countField string + ) + if db.Statement.Model != nil { + refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type()) + } else { + refValue = reflect.New(db.Statement.ReflectValue.Type()) + } + if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct { + refValue = reflect.Indirect(refValue) + } + if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice { + elemType := refValue.Type().Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + modelValue = reflect.New(elemType).Interface() + } else { + modelValue = refValue.Interface() + } + if shardingModel, ok = modelValue.(Model); !ok { + return + } + if db.Statement.SQL.Len() == 0 { + callbacks.BuildQuerySQL(db) + } + parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String())) + if stmt, err = parser.ParseStatement(); err != nil { + return + } + if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok { + return + } + tables = shardingModel.ShardingTables(&Scope{ + db: db, + stmt: selectStmt, + }) + numOfTable = len(tables) + if numOfTable <= 1 { + return + } + rawVars = make([]any, 0, len(db.Statement.Vars)) + for _, v := range db.Statement.Vars { + rawVars = append(rawVars, v) + } + //是否是查询count语句 + //如果不是count的语句,添加order和group的支持 + if v := db.Statement.Context.Value("@sql_count_statement"); v != nil { + if v == true { + isCountStatement = true + } + } + if !isCountStatement && len(*selectStmt.Columns) == 1 { + for _, column := range *selectStmt.Columns { + if expr, ok := column.Expr.(*sqlparser.Call); ok { + if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword { + isCountStatement = true + countField = expr.String() + break + } + } + } + } + + if len(selectStmt.OrderBy) > 0 { + orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy)) + for _, row := range selectStmt.OrderBy { + orderByExpr = append(orderByExpr, row) + } + selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0) + } + if len(selectStmt.GroupingElements) > 0 { + groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements)) + for _, row := range selectStmt.GroupingElements { + groupingExpr = append(groupingExpr, row) + } + if selectStmt.HavingCondition != nil { + havingExpr = selectStmt.HavingCondition + } + } + if selectStmt.Limit != nil { + limitExpr = selectStmt.Limit + selectStmt.Limit = nil + } + if selectStmt.Offset != nil { + offsetExpr = selectStmt.Offset + selectStmt.Offset = nil + } + db.Statement.SQL.Reset() + if isCountStatement { + db.Statement.SQL.WriteString("SELECT SUM(") + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(strings.Trim(countField, "`")) + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(") FROM (") + } else { + db.Statement.SQL.WriteString("SELECT * FROM (") + } + for i, name := range tables { + db.Statement.SQL.WriteByte('(') + if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok { + tableName.Name.Name = name + } + db.Statement.SQL.WriteString(selectStmt.String()) + db.Statement.SQL.WriteByte(')') + if i < numOfTable-1 { + if plugin.UnionAll { + db.Statement.SQL.WriteString(" UNION ALL ") + } else { + db.Statement.SQL.WriteString(" UNION ") + } + } + if i > 0 { + //copy vars + db.Statement.Vars = append(db.Statement.Vars, rawVars...) + } + } + db.Statement.SQL.WriteString(") tbl ") + if !isCountStatement { + if len(groupingExpr) > 0 { + db.Statement.SQL.WriteString(" GROUP BY ") + for i, expr := range groupingExpr { + if i != 0 { + db.Statement.SQL.WriteString(", ") + } + db.Statement.SQL.WriteString(expr.String()) + } + + if havingExpr != nil { + db.Statement.SQL.WriteString(" HAVING ") + db.Statement.SQL.WriteString(havingExpr.String()) + } + } + + if orderByExpr != nil && len(orderByExpr) > 0 { + db.Statement.SQL.WriteString(" ORDER BY ") + for i, term := range orderByExpr { + if i != 0 { + db.Statement.SQL.WriteString(", ") + } + db.Statement.SQL.WriteString(term.String()) + } + } + if limitExpr != nil { + db.Statement.SQL.WriteString(" LIMIT ") + db.Statement.SQL.WriteString(limitExpr.String()) + if offsetExpr != nil { + db.Statement.SQL.WriteString(" OFFSET ") + db.Statement.SQL.WriteString(offsetExpr.String()) + } + } + } + return +} + +func (plugin *Sharding) QueryX(db *gorm.DB) { + var ( + err error + ok bool + shardingModel Model + modelValue any + tables []string + rawVars []any + refValue reflect.Value + selectStmt *sqlparserX.Select + stmt sqlparserX.Statement + numOfTable int + isCountStatement bool + isPureCountStatement bool + trackedBuffer *sqlparserX.TrackedBuffer + funcExpr *sqlparserX.FuncExpr + aliasedExpr *sqlparserX.AliasedExpr + orderByExpr sqlparserX.OrderBy + groupByExpr sqlparserX.GroupBy + havingExpr *sqlparserX.Where + limitExpr *sqlparserX.Limit + ) + if db.Statement.Model != nil { + refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type()) + } else { + if !db.Statement.ReflectValue.IsValid() { + return + } + refValue = reflect.New(db.Statement.ReflectValue.Type()) + } + if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct { + refValue = reflect.Indirect(refValue) + } + if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice { + elemType := refValue.Type().Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + modelValue = reflect.New(elemType).Interface() + } else { + modelValue = refValue.Interface() + } + if shardingModel, ok = modelValue.(Model); !ok { + return + } + if db.Statement.SQL.Len() == 0 { + callbacks.BuildQuerySQL(db) + } + if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil { + return + } + if selectStmt, ok = stmt.(*sqlparserX.Select); !ok { + return + } + tables = shardingModel.ShardingTables(&Scope{ + db: db, + stmtX: selectStmt, + }) + numOfTable = len(tables) + if numOfTable <= 1 { + return + } + // 保存值 + rawVars = make([]any, 0, len(db.Statement.Vars)) + for _, v := range db.Statement.Vars { + rawVars = append(rawVars, v) + } + // 替换语句 + if selectStmt.OrderBy != nil { + orderByExpr = selectStmt.OrderBy + selectStmt.OrderBy = nil + } + if selectStmt.GroupBy != nil { + groupByExpr = selectStmt.GroupBy + //selectStmt.GroupBy = nil + } + if selectStmt.Having != nil { + havingExpr = selectStmt.Having + //selectStmt.Having = nil + } + if selectStmt.Limit != nil { + limitExpr = selectStmt.Limit + selectStmt.Limit = nil + } + // 检查是否为COUNT语句 + //如果不是count的语句,添加order和group的支持 + if v := db.Statement.Context.Value("@sql_count_statement"); v != nil { + //这里处理的是报表的情况,再报表里面需要重写COUNT语句才能获取到正确的数量 + if v == true { + isCountStatement = true + isPureCountStatement = true + } + } + //常规的COUNT逻辑 + if !isCountStatement && len(selectStmt.SelectExprs) == 1 { + for _, expr := range selectStmt.SelectExprs { + if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok { + if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok { + if funcExpr.Name.EqualString("count") { + isCountStatement = true + break + } + } + } + } + } + // 重写SQL + db.Statement.SQL.Reset() + trackedBuffer = sqlparserX.NewTrackedBuffer(nil) + + if isCountStatement { + if isPureCountStatement { + db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (") + } else { + db.Statement.SQL.WriteString("SELECT SUM(") + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(strings.Trim("count(*)", "`")) + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(") FROM (") + } + } else { + if bs, ok := modelValue.(SelectBuilder); ok { + columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs) + if len(columns) > 0 { + db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (") + } else { + db.Statement.SQL.WriteString("SELECT * FROM (") + } + } else { + db.Statement.SQL.WriteString("SELECT * FROM (") + } + } + for i, name := range tables { + trackedBuffer.Reset() + db.Statement.SQL.WriteByte('(') + //赋值新的表名称 + selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{ + Expr: sqlparserX.TableName{ + Name: sqlparserX.NewTableIdent(name), + }, + }} + selectStmt.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + db.Statement.SQL.WriteByte(')') + if i < numOfTable-1 { + if plugin.UnionAll { + db.Statement.SQL.WriteString(" UNION ALL ") + } else { + db.Statement.SQL.WriteString(" UNION ") + } + } + if i > 0 { + //copy vars + db.Statement.Vars = append(db.Statement.Vars, rawVars...) + } + } + db.Statement.SQL.WriteString(") tbl ") + if !isCountStatement { + //node.GroupBy, node.Having, node.OrderBy, node.Limit + if groupByExpr != nil { + trackedBuffer.Reset() + groupByExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + if havingExpr != nil { + trackedBuffer.Reset() + havingExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + if orderByExpr != nil { + trackedBuffer.Reset() + orderByExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + if limitExpr != nil { + trackedBuffer.Reset() + limitExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + } +} + +func New() *Sharding { + return &Sharding{ + UnionAll: true, + QuoteChar: '`', + } +} diff --git a/plugins/sharding/sharding_test.go b/plugins/sharding/sharding_test.go new file mode 100644 index 0000000..a26bbc6 --- /dev/null +++ b/plugins/sharding/sharding_test.go @@ -0,0 +1,61 @@ +package sharding + +import ( + "fmt" + "github.com/uole/sqlparser" + "testing" +) + +func TestFormat(t *testing.T) { + t.Log(fmt.Sprintf("%.2f%%", 0.5851888*100)) +} + +func TestSharding_Query(t *testing.T) { + //sql := "SELECT COUNT(*) FROM aaa" + sql := "SELECT uid,SUM(IF(direction='inbound',1,0)) AS inbound_times,SUM(IF(direction='inbound',IF(answer_duration>0,1,0),0)) AS inbound_answer_times,SUM(IF(direction='outbound',1,0)) AS outbound_times,SUM(IF(direction='outbound',IF(answer_duration>0,1,0),0)) AS outbound_answer_times FROM `cdr_logs` WHERE ((`domain` = 'test.cc.echo.me' OR `domain` = 'default') AND `name` <> '') AND (`create_stamp` BETWEEN 1712505600 AND 1712505608) AND `name` IN ('a','b','c') GROUP BY `uid` having uid != '' ORDER BY create_stamp DESC LIMIT 15" + stmt, err := sqlparser.Parse(sql) + if err != nil { + t.Error(err) + return + } + selectStmt, ok := stmt.(*sqlparser.Select) + if !ok { + t.Error("not select stmt") + return + } + buf := sqlparser.NewTrackedBuffer(nil) + sqlparser.NewTableIdent("test").Format(buf) + buf.Reset() + selectStmt.Format(buf) + t.Log("SQL") + t.Log(buf.String()) + buf.Reset() + t.Log("SELECT") + selectStmt.SelectExprs.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("FROM") + selectStmt.From.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("WHERE") + selectStmt.Where.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("ORDER BY") + selectStmt.OrderBy.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("GROUP BY") + selectStmt.GroupBy.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("LIMIT") + selectStmt.Limit.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("Having") + selectStmt.Having.Format(buf) + t.Log(buf.String()) + buf.Reset() +} diff --git a/plugins/sharding/types.go b/plugins/sharding/types.go new file mode 100644 index 0000000..1374c29 --- /dev/null +++ b/plugins/sharding/types.go @@ -0,0 +1,45 @@ +package sharding + +import ( + "context" + sqlparserX "github.com/uole/sqlparser" +) + +const ( + TypeDatetime = "datetime" + TypeHash = "hash" +) + +const ( + DateTypeYear = iota + 5 + DateTypeMonth + DateTypeWeek + DateTypeDay +) + +const ( + sceneCreate = "create" + sceneUpdate = "update" + sceneDelete = "delete" +) + +const ( + stmtCountKeyword = "count" +) + +type ( + Rule struct { + Type string `json:"type"` + Args int `json:"args"` + } + + Model interface { + ShardingRule() Rule + ShardingTable(scene string) string + ShardingTables(scope *Scope) []string + } + + SelectBuilder interface { + BuildSelect(ctx context.Context, expr sqlparserX.SelectExprs) []string + } +) diff --git a/plugins/validate/types.go b/plugins/validate/types.go new file mode 100644 index 0000000..c357a09 --- /dev/null +++ b/plugins/validate/types.go @@ -0,0 +1,107 @@ +package validate + +import ( + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "reflect" + "regexp" + "strconv" +) + +const ( + SkipValidations = "validations:skip_validations" +) + +var ( + scopeCtxKey = &validateScope{} + telephoneRegex = regexp.MustCompile("^\\d{5,20}$") +) + +type ( + validateScope struct { + DB *gorm.DB + Column string + Domain string + MultiDomain bool + Model interface{} + } + + StructError struct { + Tag string `json:"tag"` + Column string `json:"column"` + Message string `json:"message"` + } + + validateRule struct { + Rule string + Value string + Valid bool + } +) + +func (err *StructError) Error() string { + return err.Message +} + +func newRule(ss ...string) *validateRule { + v := &validateRule{ + Valid: true, + } + if len(ss) == 1 { + v.Rule = ss[0] + } else if len(ss) >= 2 { + v.Rule = ss[0] + v.Value = ss[1] + } + return v +} + +// formatError 格式化错误消息 +func formatError(rule types.Rule, scm *types.Schema, tag string) string { + var s string + switch tag { + case "db_unique": + s = scm.Label + "值已经存在." + break + case "required": + s = scm.Label + "值不能为空." + case "max": + if scm.Type == "string" { + s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max) + } else { + s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max) + } + case "min": + if scm.Type == "string" { + s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max) + } else { + s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max) + } + } + return s +} + +// isEmpty 判断值是否为空 +func isEmpty(val any) bool { + if val == nil { + return true + } + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.String, reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.IsNil() || v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface()) +} diff --git a/plugins/validate/validation.go b/plugins/validate/validation.go new file mode 100644 index 0000000..59d83db --- /dev/null +++ b/plugins/validate/validation.go @@ -0,0 +1,275 @@ +package validate + +import ( + "context" + "fmt" + "git.nobla.cn/golang/rest" + "git.nobla.cn/golang/rest/types" + validator "github.com/go-playground/validator/v10" + "gorm.io/gorm" + "gorm.io/gorm/schema" + "reflect" + "strconv" + "strings" +) + +type Validate struct { + validator *validator.Validate +} + +func (validate *Validate) Name() string { + return "gorm:validate" +} + +func (validate *Validate) telephoneValidate(ctx context.Context, fl validator.FieldLevel) bool { + val := fmt.Sprint(fl.Field().Interface()) + return telephoneRegex.MatchString(val) +} + +func (validate *Validate) uniqueValidate(ctx context.Context, fl validator.FieldLevel) bool { + var ( + scope *validateScope + ok bool + count int64 + field *schema.Field + primaryKeyValue reflect.Value + ) + val := fl.Field().Interface() + if scope, ok = ctx.Value(scopeCtxKey).(*validateScope); !ok { + return true + } + if len(scope.DB.Statement.Schema.PrimaryFields) > 0 { + field = scope.DB.Statement.Schema.PrimaryFields[0] + primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model)) + for _, n := range field.BindNames { + primaryKeyValue = primaryKeyValue.FieldByName(n) + } + } + sess := scope.DB.Session(&gorm.Session{NewDB: true}) + if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil { + //多域校验 + if scope.MultiDomain && scope.Domain != "" { + sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ? AND domain = ?", val, primaryKeyValue.Interface(), scope.Domain).Count(&count) + } else { + sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count) + } + } else { + if scope.MultiDomain && scope.Domain != "" { + sess.Model(scope.Model).Where(scope.Column+"=? AND domain = ?", val, scope.Domain).Count(&count) + } else { + sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count) + } + } + if count > 0 { + return false + } + return true +} + +func (validate *Validate) grantRules(scm *types.Schema, scenario string, rule types.Rule) []*validateRule { + rules := make([]*validateRule, 0, 5) + if rule.Min != 0 { + rules = append(rules, newRule("min", strconv.Itoa(rule.Min))) + } + if rule.Max != 0 { + rules = append(rules, newRule("max", strconv.Itoa(rule.Max))) + } + //主键不做唯一判断 + if rule.Unique && !scm.Attribute.PrimaryKey { + rules = append(rules, newRule("db_unique")) + } + if rule.Type != "" { + rules = append(rules, newRule(rule.Type)) + } + if rule.Required != nil && len(rule.Required) > 0 { + for _, v := range rule.Required { + if v == scenario { + rules = append(rules, newRule("required")) + break + } + } + } + return rules +} + +func (validate *Validate) buildRules(rs []*validateRule) string { + var sb strings.Builder + for _, r := range rs { + if !r.Valid { + continue + } + if sb.Len() > 0 { + sb.WriteString(",") + } + if r.Value == "" { + sb.WriteString(r.Rule) + } else { + sb.WriteString(r.Rule + "=" + r.Value) + } + } + return sb.String() +} + +func (validate *Validate) findRule(name string, rules []*validateRule) *validateRule { + for _, r := range rules { + if r.Rule == name { + return r + } + } + return nil +} + +func (validate *Validate) Initialize(db *gorm.DB) (err error) { + validate.validator = validator.New() + if err = db.Callback().Create().Before("gorm:before_create").Register("model_validate", validate.Validate); err != nil { + return + } + if err = db.Callback().Create().Before("gorm:before_update").Register("model_validate", validate.Validate); err != nil { + return + } + if err = validate.validator.RegisterValidationCtx("telephone", validate.telephoneValidate); err != nil { + return + } + if err = validate.validator.RegisterValidationCtx("db_unique", validate.uniqueValidate); err != nil { + return + } + return +} + +func (validate *Validate) inArray(v any, vs []any) bool { + sv := fmt.Sprint(v) + for _, s := range vs { + if fmt.Sprint(s) == sv { + return true + } + } + return false +} + +// isVisible 判断字段是否需要显示 +func (validate *Validate) isVisible(stmt *gorm.Statement, scm *types.Schema) bool { + if len(scm.Attribute.Visible) <= 0 { + return true + } + for _, row := range scm.Attribute.Visible { + if len(row.Values) == 0 { + continue + } + targetField := stmt.Schema.LookUpField(row.Column) + if targetField == nil { + continue + } + targetValue := stmt.ReflectValue.FieldByName(targetField.Name) + if !targetValue.IsValid() { + return false + } + if !validate.inArray(targetValue.Interface(), row.Values) { + return false + } + } + return true +} + +// Validate 校验字段 +func (validate *Validate) Validate(db *gorm.DB) { + var ( + ok bool + err error + rules []*validateRule + stmt *gorm.Statement + skipValidate bool + multiDomain bool + value reflect.Value + runtimeScope *types.RuntimeScope + ) + if result, ok := db.Get(SkipValidations); ok && result.(bool) { + return + } + stmt = db.Statement + if stmt.Model == nil { + return + } + if db.Statement.Context == nil { + return + } + if val := db.Statement.Context.Value(rest.RuntimeScopeKey); val != nil { + if runtimeScope, ok = val.(*types.RuntimeScope); !ok { + return + } + } else { + return + } + if runtimeScope.Schemas == nil { + return + } + if stmt.Schema.LookUpField("domain") != nil { + multiDomain = true + } + for _, row := range runtimeScope.Schemas { + //如果字段隐藏,那么就不进行校验 + if !validate.isVisible(stmt, row) { + continue + } + if rules = validate.grantRules(row, runtimeScope.Scenario, row.Rule); len(rules) <= 0 { + continue + } + field := stmt.Schema.LookUpField(row.Column) + if field == nil { + continue + } + value = stmt.ReflectValue.FieldByName(field.Name) + if !value.IsValid() { + continue + } + skipValidate = false + if r := validate.findRule("required", rules); r != nil { + if value.Interface() != nil { + vType := reflect.ValueOf(value.Interface()) + switch vType.Kind() { + case reflect.Bool: + skipValidate = true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + skipValidate = true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + skipValidate = true + case reflect.Float32, reflect.Float64: + skipValidate = true + default: + skipValidate = false + } + } + if skipValidate { + r.Valid = false + } + } else { + if isEmpty(value.Interface()) { + continue + } + } + ctx := context.WithValue(db.Statement.Context, scopeCtxKey, &validateScope{ + DB: db, + Column: row.Column, + Model: stmt.Model, + Domain: runtimeScope.Domain, + MultiDomain: multiDomain, + }) + if err = validate.validator.VarCtx(ctx, value.Interface(), validate.buildRules(rules)); err != nil { + if errs, ok := err.(validator.ValidationErrors); ok { + for _, e := range errs { + _ = db.AddError(&StructError{ + Tag: e.Tag(), + Column: row.Column, + Message: formatError(row.Rule, row, e.Tag()), + }) + } + } else { + _ = db.AddError(err) + } + break + } + } +} + +func New() *Validate { + return &Validate{} +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..b82267d --- /dev/null +++ b/query.go @@ -0,0 +1,405 @@ +package rest + +import ( + "context" + "fmt" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "reflect" + "strconv" + "strings" +) + +type ( + Query struct { + db *gorm.DB + condition string + fields []string + params []interface{} + table string + joins []join + orderBy []string + groupBy []string + modelValue any + limit int + offset int + } + + condition struct { + Field string `json:"field"` + Value interface{} `json:"value"` + Expr string `json:"expr"` + } + + join struct { + Table string + Direction string + Conditions []*condition + } +) + +func (cond *condition) WithExpr(v string) *condition { + cond.Expr = v + return cond +} + +func (query *Query) Model() any { + return query.modelValue +} + +func (query *Query) compile() (*gorm.DB, error) { + db := query.db + if query.condition != "" { + db = db.Where(query.condition, query.params...) + } + if query.fields != nil { + db = db.Select(strings.Join(query.fields, ",")) + } + if query.table != "" { + db = db.Table(query.table) + } + if query.joins != nil && len(query.joins) > 0 { + for _, joinEntity := range query.joins { + cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...) + db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...) + } + } + if query.orderBy != nil && len(query.orderBy) > 0 { + db = db.Order(strings.Join(query.orderBy, ",")) + } + if query.groupBy != nil && len(query.groupBy) > 0 { + db = db.Group(strings.Join(query.groupBy, ",")) + } + if query.offset > 0 { + db = db.Offset(query.offset) + } + if query.limit > 0 { + db = db.Limit(query.limit) + } + return db, nil +} + +func (query *Query) decodeValue(v any) string { + refVal := reflect.Indirect(reflect.ValueOf(v)) + switch refVal.Kind() { + case reflect.Bool: + if refVal.Bool() { + return "1" + } else { + return "0" + } + case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64: + return strconv.FormatInt(refVal.Int(), 10) + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(refVal.Float(), 'f', -1, 64) + case reflect.String: + return "'" + refVal.String() + "'" + default: + return fmt.Sprint(v) + } +} + +func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) { + var ( + sb strings.Builder + ) + params = make([]interface{}, 0) + for _, cond := range conditions { + if filter { + if isEmpty(cond.Value) { + continue + } + } + if cond.Expr == "" { + cond.Expr = "=" + } + switch strings.ToUpper(cond.Expr) { + case "=", "<>", ">", "<", ">=", "<=", "!=": + if sb.Len() > 0 { + sb.WriteString(" " + operator + " ") + } + if cond.Expr == "=" && cond.Value == nil { + sb.WriteString("`" + cond.Field + "` IS NULL") + } else { + sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?") + params = append(params, cond.Value) + } + case "LIKE": + if sb.Len() > 0 { + sb.WriteString(" " + operator + " ") + } + cond.Value = fmt.Sprintf("%%%s%%", cond.Value) + sb.WriteString("`" + cond.Field + "` LIKE ?") + params = append(params, cond.Value) + case "IN": + if sb.Len() > 0 { + sb.WriteString(" " + operator + " ") + } + refVal := reflect.Indirect(reflect.ValueOf(cond.Value)) + switch refVal.Kind() { + case reflect.Slice, reflect.Array: + ss := make([]string, refVal.Len()) + for i := 0; i < refVal.Len(); i++ { + ss[i] = query.decodeValue(refVal.Index(i)) + } + sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")") + case reflect.String: + sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")") + default: + } + case "BETWEEN": + refVal := reflect.ValueOf(cond.Value) + if refVal.Kind() == reflect.Slice && refVal.Len() == 2 { + sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?") + params = append(params, refVal.Index(0), refVal.Index(1)) + } + } + } + str = sb.String() + return +} + +func (query *Query) Select(fields ...string) *Query { + if query.fields == nil { + query.fields = fields + } else { + query.fields = append(query.fields, fields...) + } + return query +} + +func (query *Query) From(table string) *Query { + query.table = table + return query +} + +func (query *Query) LeftJoin(table string, conditions ...*condition) *Query { + query.joins = append(query.joins, join{ + Table: table, + Direction: "LEFT", + Conditions: conditions, + }) + return query +} + +func (query *Query) RightJoin(table string, conditions ...*condition) *Query { + query.joins = append(query.joins, join{ + Table: table, + Direction: "RIGHT", + Conditions: conditions, + }) + return query +} + +func (query *Query) InnerJoin(table string, conditions ...*condition) *Query { + query.joins = append(query.joins, join{ + Table: table, + Direction: "INNER", + Conditions: conditions, + }) + return query +} + +func (query *Query) AddCondition(expr, column string, val any) { + if expr == "" { + expr = "=" + } + query.AndWhere(newConditionWithOperator(expr, column, val)) +} + +func (query *Query) AndWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("AND", false, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND (" + cs + ")" + } + return query +} + +func (query *Query) AndFilterWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("AND", true, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND " + cs + } + return query +} + +func (query *Query) OrWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("OR", false, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND (" + cs + ")" + } + return query +} + +func (query *Query) OrFilterWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("OR", true, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND (" + cs + ")" + } + return query +} + +func (query *Query) GroupBy(cols ...string) *Query { + query.groupBy = append(query.groupBy, cols...) + return query +} + +func (query *Query) OrderBy(col, direction string) *Query { + direction = strings.ToUpper(direction) + if direction == "" || !(direction == "ASC" || direction == "DESC") { + direction = "ASC" + } + query.orderBy = append(query.orderBy, col+" "+direction) + return query +} + +func (query *Query) Offset(i int) *Query { + query.offset = i + return query +} + +func (query *Query) Limit(i int) *Query { + query.limit = i + return query +} + +func (query *Query) ResetSelect() *Query { + query.fields = make([]string, 0) + return query +} + +func (query *Query) Count(v interface{}) (i int64) { + var ( + db *gorm.DB + err error + ) + if db, err = query.compile(); err != nil { + return + } else { + if v != nil { + refVal := reflect.ValueOf(v) + switch refVal.Kind() { + case reflect.String: + if query.table == "" { + err = db.Table(refVal.String()).Count(&i).Error + } else { + err = db.Table(query.table).Count(&i).Error + } + default: + //如果是报表的模型,这的话手动构建一条SQL语句 + if reporter, ok := v.(types.Reporter); ok { + sqlRes := &sqlCountResponse{} + childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true) + db.WithContext(childCtx).Model(reporter).First(sqlRes) + i = sqlRes.Count + } else { + err = db.Model(v).Count(&i).Error + } + } + } + } + return +} + +func (query *Query) One(v interface{}) (err error) { + var ( + db *gorm.DB + ) + if db, err = query.compile(); err != nil { + return + } else { + err = db.First(v).Error + } + return +} + +func (query *Query) All(v interface{}) (err error) { + var ( + db *gorm.DB + ) + if db, err = query.compile(); err != nil { + return + } else { + err = db.Find(v).Error + } + return +} + +func NewCondition(column, opera string, value any) *condition { + if opera == "" { + opera = "=" + } + return &condition{ + Field: column, + Value: value, + Expr: opera, + } +} + +func newCondition(field string, value interface{}) *condition { + return &condition{ + Field: field, + Value: value, + Expr: "=", + } +} + +func newConditionWithOperator(operator, field string, value interface{}) *condition { + cond := &condition{ + Field: field, + Value: value, + Expr: operator, + } + return cond +} + +func NewQuery(db *gorm.DB, model any) *Query { + return &Query{ + db: db, + modelValue: model, + params: make([]interface{}, 0), + orderBy: make([]string, 0), + groupBy: make([]string, 0), + joins: make([]join, 0), + } +} diff --git a/rest.go b/rest.go new file mode 100644 index 0000000..7ef9e92 --- /dev/null +++ b/rest.go @@ -0,0 +1,753 @@ +package rest + +import ( + "context" + "errors" + "fmt" + "git.nobla.cn/golang/kos/util/arrays" + "git.nobla.cn/golang/kos/util/reflection" + "git.nobla.cn/golang/rest/inflector" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "net/http" + "reflect" + "strconv" + "strings" + "time" +) + +var ( + modelEntities []*Model + httpRouter types.HttpRouter + hookMgr *hookManager + timeKind = reflect.TypeOf(time.Time{}).Kind() + timePtrKind = reflect.TypeOf(&time.Time{}).Kind() + + matchEnums = []string{types.MatchExactly, types.MatchFuzzy} +) + +var ( + allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} +) + +func init() { + hookMgr = &hookManager{} + modelEntities = make([]*Model, 0) +} + +// cloneStmt 从指定的db克隆一个 Statement 对象 +func cloneStmt(db *gorm.DB) *gorm.Statement { + return &gorm.Statement{ + DB: db, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + } +} + +// dataTypeOf 推断数据的类型 +func dataTypeOf(field *schema.Field) string { + var dataType string + reflectType := field.FieldType + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + if dataType = field.Tag.Get("type"); dataType != "" { + return dataType + } + dataValue := reflect.Indirect(reflect.New(reflectType)) + switch dataValue.Kind() { + case reflect.Bool: + dataType = types.TypeBoolean + case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + dataType = types.TypeInteger + case reflect.Float32, reflect.Float64: + dataType = types.TypeFloat + default: + dataType = types.TypeString + } + return dataType +} + +// dataFormatOf 推断数据的格式 +func dataFormatOf(field *schema.Field) string { + var format string + format = field.Tag.Get("format") + if format != "" { + return format + } + //如果有枚举值,直接设置为下拉类型 + enum := field.Tag.Get("enum") + if enum != "" { + return types.FormatDropdown + } + reflectType := field.FieldType + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + //时间处理 + dataValue := reflect.Indirect(reflect.New(reflectType)) + if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" { + switch dataValue.Kind() { + case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return types.FormatTimestamp + default: + return types.FormatDatetime + } + } + if strings.Contains(strings.ToLower(field.Name), "pass") { + return types.FormatPassword + } + switch dataValue.Kind() { + case timeKind, timePtrKind: + format = types.FormatDatetime + case reflect.Bool: + format = types.FormatBoolean + case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + format = types.FormatInteger + case reflect.Float32, reflect.Float64: + format = types.FormatFloat + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + format = types.FormatDatetime + } + default: + if field.Size >= 1024 { + format = types.FormatText + } else { + format = types.FormatString + } + } + return format +} + +// fieldName 生成字段名称 +func fieldName(name string) string { + tokens := strings.Split(name, "_") + for i, s := range tokens { + tokens[i] = strings.Title(s) + } + return strings.Join(tokens, " ") +} + +// fieldNative 判断是否为原始字段 +func fieldNative(field *schema.Field) uint8 { + if _, ok := field.Tag.Lookup("virtual"); ok { + return 0 + } + return 1 +} + +// fieldRule 返回字段规则 +func fieldRule(field *schema.Field) types.Rule { + r := types.Rule{ + Required: []string{}, + } + if field.GORMDataType == schema.String { + r.Max = field.Size + } + if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint { + r.Max = field.Scale + } + + rs := field.Tag.Get("rule") + if rs != "" { + ss := strings.Split(rs, ";") + for _, s := range ss { + vs := strings.SplitN(s, ":", 2) + ls := len(vs) + if ls == 0 { + continue + } + switch vs[0] { + case "required", "require": + if ls > 1 { + bs := strings.Split(vs[1], ",") + for _, i := range bs { + if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) { + r.Required = append(r.Required, i) + } + } + } else { + r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate} + } + case "unique": + r.Unique = true + case "regexp": + if ls > 1 { + r.Regular = vs[1] + } + } + } + } + if field.PrimaryKey { + r.Unique = true + } + return r +} + +// fieldScenario 字段Scenarios +func fieldScenario(index int, field *schema.Field) types.Scenarios { + var ss types.Scenarios + if v, ok := field.Tag.Lookup("scenarios"); ok { + v = strings.TrimSpace(v) + if v != "" { + ss = strings.Split(v, ";") + } + } else { + if field.PrimaryKey { + ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport} + } else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" { + ss = []string{types.ScenarioList} + } else if field.Name == "DeletedAt" || field.Name == "Namespace" { + //不添加任何显示场景 + ss = []string{} + } else { + if index < 10 { + //高级字段只配置一些简单的场景 + ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} + } else { + //高级字段只配置一些简单的场景 + ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} + } + } + } + return ss +} + +// fieldPosition 字段的排序位置 +func fieldPosition(field *schema.Field, i int) int { + s := field.Tag.Get("position") + n, _ := strconv.Atoi(s) + if n > 0 { + return n + } + return i + 100 +} + +// fieldAttribute 字段属性 +func fieldAttribute(field *schema.Field) types.Attribute { + attr := types.Attribute{ + Match: types.MatchFuzzy, + PrimaryKey: field.PrimaryKey, + DefaultValue: field.DefaultValue, + Readonly: []string{}, + Disable: []string{}, + Visible: make([]types.VisibleCondition, 0), + Values: make([]types.EnumValue, 0), + Live: types.LiveValue{}, + } + if field.Name == "CreatedAt" || field.Name == "UpdatedAt" { + attr.EndOfNow = true + } + //赋值属性 + props := field.Tag.Get("props") + if props != "" { + vs := strings.Split(props, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + sv := strings.TrimSpace(kv[1]) + switch strings.ToLower(strings.TrimSpace(kv[0])) { + case "icon": + attr.Icon = sv + case "match": + if arrays.Exists(sv, matchEnums) { + attr.Match = sv + } + case "endofnow", "end_of_now": + if ok, _ := strconv.ParseBool(sv); ok { + attr.EndOfNow = true + } + case "invisible": + if ok, _ := strconv.ParseBool(sv); ok { + attr.Invisible = true + } + case "suffix": + attr.Suffix = sv + case "tag": + attr.Tag = sv + case "tooltip": + attr.Tooltip = sv + case "uploadurl", "uploaduri", "upload_url", "upload_uri": + attr.UploadUrl = sv + case "description": + attr.Description = sv + case "readonly": + bs := strings.Split(sv, ",") + for _, i := range bs { + if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) { + attr.Readonly = append(attr.Readonly, i) + } + } + } + } + } + //live的赋值 + live := field.Tag.Get("live") + if live != "" { + attr.Live.Enable = true + vs := strings.Split(live, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + switch kv[0] { + case "method": + attr.Live.Method = kv[1] + case "type": + if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader { + attr.Live.Type = kv[1] + } else { + attr.Live.Type = types.LiveTypeDropdown + } + case "url", "uri": + attr.Live.Url = kv[1] + case "columns": + attr.Live.Columns = strings.Split(kv[1], ",") + } + } + } + + dropdown := field.Tag.Get("dropdown") + if dropdown != "" { + attr.DropdownOption = &types.DropdownOption{} + vs := strings.Split(dropdown, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) == 0 { + continue + } + switch kv[0] { + case "created": + attr.DropdownOption.Created = true + case "filterable": + attr.DropdownOption.Filterable = true + case "autocomplete": + attr.DropdownOption.Autocomplete = true + case "default_first": + attr.DropdownOption.DefaultFirst = true + } + } + } + + //显示条件 + conditions := field.Tag.Get("condition") + if conditions != "" { + vs := strings.Split(conditions, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + cond := types.VisibleCondition{ + Column: kv[0], + Values: make([]any, 0), + } + vv := strings.Split(kv[1], ",") + for _, x := range vv { + x = strings.TrimSpace(x) + if x == "" { + continue + } + cond.Values = append(cond.Values, x) + } + attr.Visible = append(attr.Visible, cond) + } + } + //赋值枚举值 + enumns := field.Tag.Get("enum") + if enumns != "" { + vs := strings.Split(enumns, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + fv := types.EnumValue{Value: kv[0]} + //颜色分隔符 + if pos := strings.IndexByte(kv[1], '#'); pos > -1 { + fv.Label = kv[1][:pos] + fv.Color = kv[1][pos:] + } else { + fv.Label = kv[1] + } + attr.Values = append(attr.Values, fv) + } + } + if !field.Creatable { + attr.Disable = append(attr.Disable, types.ScenarioCreate) + } + if !field.Updatable { + attr.Disable = append(attr.Disable, types.ScenarioUpdate) + } + attr.Tooltip = field.Comment + return attr +} + +// autoMigrate 自动合并字段 +func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) { + var ( + pos int + columnName string + columnIsExists bool + columnLabel string + schemas []*types.Schema + models []*types.Schema + stmt *gorm.Statement + ) + + stmt = cloneStmt(db) + if err = stmt.Parse(model); err != nil { + return + } + if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil { + return + } + if len(schemas) > 0 { + pos = len(schemas) + } + models = make([]*types.Schema, 0) + for index, field := range stmt.Schema.Fields { + columnName = field.DBName + if columnName == "-" { + continue + } + if columnName == "" { + columnName = field.Name + } + columnIsExists = false + for _, sm := range schemas { + if sm.Column == columnName { + columnIsExists = true + break + } + } + if columnIsExists { + continue + } + columnLabel = field.Tag.Get("comment") + if columnLabel == "" { + columnLabel = fieldName(field.DBName) + } + isPrimaryKey := uint8(0) + if field.PrimaryKey { + isPrimaryKey = 1 + } + schemaModel := &types.Schema{ + Domain: defaultDomain, + ModuleName: module, + TableName: stmt.Table, + Enable: 1, + Column: columnName, + Label: columnLabel, + Type: strings.ToLower(dataTypeOf(field)), + Format: strings.ToLower(dataFormatOf(field)), + Native: fieldNative(field), + IsPrimaryKey: isPrimaryKey, + Rule: fieldRule(field), + Scenarios: fieldScenario(index, field), + Attribute: fieldAttribute(field), + Position: fieldPosition(field, pos), + } + //如果启用了在线调取接口功能,那么设置一下字段的format格式 + if schemaModel.Attribute.Live.Enable { + if schemaModel.Attribute.Live.Type != "" { + schemaModel.Format = schemaModel.Attribute.Live.Type + } + } + models = append(models, schemaModel) + pos++ + } + if len(models) > 0 { + err = db.Create(models).Error + } + naming = stmt.Table + return +} + +// SetHttpRouter 设置HTTP路由 +func SetHttpRouter(router types.HttpRouter) { + httpRouter = router +} + +// AutoMigrate 自动合并表的schema +func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) { + var ( + opts *Options + table string + router types.HttpRouter + ) + opts = &Options{} + for _, cb := range cbs { + cb(opts) + } + if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil { + return + } + //路由模块处理 + modelValue := newModel(model, db, types.Naming{ + Pluralize: inflector.Pluralize(table), + Singular: inflector.Singularize(table), + ModuleName: opts.moduleName, + TableName: table, + }) + modelValue.hookMgr = hookMgr + modelValue.schemaLookup = VisibleSchemas + if opts.router != nil { + router = opts.router + } + if router == nil && httpRouter != nil { + router = httpRouter + } + if opts.urlPrefix != "" { + modelValue.urlPrefix = opts.urlPrefix + } + //路由绑定操作 + if router != nil { + if modelValue.hasScenario(types.ScenarioList) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search) + } + if modelValue.hasScenario(types.ScenarioCreate) { + router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create) + } + if modelValue.hasScenario(types.ScenarioUpdate) { + router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update) + } + if modelValue.hasScenario(types.ScenarioDelete) { + router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete) + } + if modelValue.hasScenario(types.ScenarioView) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View) + } + if modelValue.hasScenario(types.ScenarioExport) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export) + } + if modelValue.hasScenario(types.ScenarioImport) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import) + router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import) + } + } + if opts.writer != nil { + modelValue.response = opts.writer + } + if opts.formatter != nil { + modelValue.formatter = opts.formatter + } + modelValue.disableDomain = opts.disableDomain + modelEntities = append(modelEntities, modelValue) + return +} + +// CloneSchemas 克隆schemas +func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) { + var ( + values []*types.Schema + schemas []*types.Schema + models []*types.Schema + ) + tx := db.WithContext(ctx) + if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil { + return fmt.Errorf("schema not found") + } + tx.Where("domain=?", domain).Find(&schemas) + hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool { + for _, row := range values { + if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column { + return true + } + } + return false + } + models = make([]*types.Schema, 0) + for _, row := range values { + if !hasSchemaFunc(schemas, row) { + row.Id = 0 + row.CreatedAt = 0 + row.UpdatedAt = 0 + row.Domain = domain + models = append(models, row) + } + } + if len(models) > 0 { + err = tx.Save(models).Error + } + return +} + +// GetSchemas 获取schemas +func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) { + var ( + err error + values []*types.Schema + tx *gorm.DB + ) + values = make([]*types.Schema, 0) + if domain == "" { + domain = defaultDomain + } + if moduleName == "" || tableName == "" { + return nil, gorm.ErrInvalidField + } + if ctx != nil { + tx = db.WithContext(ctx) + } else { + tx = db + } + err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = nil + } + return values, err +} + +// VisibleSchemas 获取某个场景下面的schema +func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) { + schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName) + if err != nil { + return nil, err + } + result := make([]*types.Schema, 0, len(schemas)) + for _, row := range schemas { + if row.IsPrimaryKey == 1 { + result = append(result, row) + continue + } + if row.Scenarios.Has(scenario) { + result = append(result, row) + } + } + return result, nil +} + +// ModelTypes 查询指定模型的类型 +func ModelTypes(ctx context.Context, db *gorm.DB, model any, domainName, labelColumn, valueColumn string) (values []*types.TypeValue) { + tx := db.WithContext(ctx) + result := make([]map[string]any, 0, 10) + tx.Model(model).Select(labelColumn, valueColumn).Where("domain=?", domainName).Scan(&result) + values = make([]*types.TypeValue, 0, len(result)) + for _, pairs := range result { + feed := &types.TypeValue{} + for k, v := range pairs { + if k == labelColumn { + feed.Label = v + } + if k == valueColumn { + feed.Value = v + } + } + values = append(values, feed) + } + return values +} + +// GetFieldValue 获取模型某个字段的值 +func GetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string) interface{} { + var ( + idx = -1 + ) + refVal := reflect.Indirect(refValue) + for i, field := range stmt.Schema.Fields { + if field.DBName == column || field.Name == column { + idx = i + break + } + } + if idx > -1 { + return refVal.Field(idx).Interface() + } + return nil +} + +// SetFieldValue 设置模型某个字段的值 +func SetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) { + var ( + idx = -1 + ) + refVal := reflect.Indirect(refValue) + for i, field := range stmt.Schema.Fields { + if field.DBName == column || field.Name == column { + idx = i + break + } + } + if idx > -1 { + refVal.Field(idx).Set(reflect.ValueOf(value)) + } +} + +func SafeSetFileValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) { + var ( + idx = -1 + ) + refVal := reflect.Indirect(refValue) + for i, field := range stmt.Schema.Fields { + if field.DBName == column || field.Name == column { + idx = i + break + } + } + if idx > -1 { + _ = reflection.Assign(refVal.Field(idx), value) + } +} + +// GetModels 获取所有注册了的模块 +func GetModels() []*Model { + return modelEntities +} + +// OnBeforeCreate 创建前的回调 +func OnBeforeCreate(cb BeforeCreate) { + hookMgr.BeforeCreate(cb) +} + +// OnAfterCreate 创建后的回调 +func OnAfterCreate(cb AfterCreate) { + hookMgr.AfterCreate(cb) +} + +// OnBeforeUpdate 更新前的回调 +func OnBeforeUpdate(cb BeforeUpdate) { + hookMgr.BeforeUpdate(cb) +} + +// OnAfterUpdate 更新后的回调 +func OnAfterUpdate(cb AfterUpdate) { + hookMgr.AfterUpdate(cb) +} + +// OnBeforeSave 保存前的回调 +func OnBeforeSave(cb BeforeSave) { + hookMgr.BeforeSave(cb) +} + +// OnAfterSave 保存后的回调 +func OnAfterSave(cb AfterSave) { + hookMgr.AfterSave(cb) +} + +// OnBeforeDelete 删除前的回调 +func OnBeforeDelete(cb BeforeDelete) { + hookMgr.BeforeDelete(cb) +} + +// OnAfterDelete 删除后的回调 +func OnAfterDelete(cb AfterDelete) { + hookMgr.AfterDelete(cb) +} + +// OnAfterExport 导出后的回调 +func OnAfterExport(cb AfterExport) { + hookMgr.AfterExport(cb) +} + +// OnAfterImport 导入后的回调 +func OnAfterImport(cb AfterImport) { + hookMgr.AfterImport(cb) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..903f188 --- /dev/null +++ b/types.go @@ -0,0 +1,82 @@ +package rest + +import ( + "context" + "encoding/json" + "gorm.io/gorm" + "net/http" +) + +const ( + defaultPageSize = 15 + defaultDomain = "localhost" +) + +type ( + httpWriter struct { + } + + multiValue struct { + Text any `json:"label"` + Value any `json:"value"` + } + + stdResponse struct { + Code int `json:"code"` + Reason string `json:"reason,omitempty"` + Data any `json:"data,omitempty"` + } + + tableNamer interface { + HttpTableName(req *http.Request) string + } + + //创建后的回调,这个回调不在事物内 + afterCreated interface { + AfterCreated(ctx context.Context, tx *gorm.DB) + } + //更新后的回调,这个回调不在事物内 + afterUpdated interface { + AfterUpdated(ctx context.Context, tx *gorm.DB) + } + //保存后的回调,这个回调不在事物内 + afterSaved interface { + AfterSaved(ctx context.Context, tx *gorm.DB) + } + + sqlCountResponse struct { + Count int64 `json:"count"` + } +) + +func (h *httpWriter) Success(w http.ResponseWriter, data any) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(stdResponse{ + Code: 0, + Data: data, + }); err != nil { + w.Write([]byte(err.Error())) + } +} + +func (h *httpWriter) Failure(w http.ResponseWriter, reason int, message string, data any) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(stdResponse{ + Code: reason, + Reason: message, + Data: data, + }); err != nil { + w.Write([]byte(err.Error())) + } +} + +func defaultValueLookup(field string, w http.ResponseWriter, r *http.Request) string { + var ( + domainName string + ) + domainName = r.Header.Get(field) + if domainName == "" { + return defaultDomain + } + return domainName +} diff --git a/types/attribute.go b/types/attribute.go new file mode 100644 index 0000000..c7809b5 --- /dev/null +++ b/types/attribute.go @@ -0,0 +1,48 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" +) + +type ( + Attribute struct { + Match string `json:"match"` //匹配模式 + Tag string `json:"tag,omitempty"` //字段标签 + PrimaryKey bool `json:"primary_key"` //是否为主键 + DefaultValue string `json:"default_value"` //默认值 + Readonly []string `json:"readonly"` //只读场景 + Disable []string `json:"disable"` //禁用场景 + Visible []VisibleCondition `json:"visible"` //可见条 + Invisible bool `json:"invisible"` //不可见的字段,表示在UI界面看不到这个字段,但是这个字段需要 + EndOfNow bool `json:"end_of_now"` //最大时间为当前时间 + Values []EnumValue `json:"values"` //值 + Live LiveValue `json:"live"` //延时加载配置 + UploadUrl string `json:"upload_url,omitempty"` //上传地址 + Icon string `json:"icon,omitempty"` //显示图标 + Sort bool `json:"sort"` //是否允许排序 + Suffix string `json:"suffix,omitempty"` //追加内容 + Tooltip string `json:"tooltip,omitempty"` //字段提示信息 + Description string `json:"description,omitempty"` //字段说明信息 + DropdownOption *DropdownOption `json:"dropdown,omitempty"` //下拉选项 + } +) + +// Scan implements the Scanner interface. +func (n *Attribute) Scan(value interface{}) error { + if value == nil { + return nil + } + switch s := value.(type) { + case string: + return json.Unmarshal([]byte(s), n) + case []byte: + return json.Unmarshal(s, n) + } + return ErrDbTypeUnsupported +} + +// Value implements the driver Valuer interface. +func (n Attribute) Value() (driver.Value, error) { + return json.Marshal(n) +} diff --git a/types/error.go b/types/error.go new file mode 100644 index 0000000..94a9f8a --- /dev/null +++ b/types/error.go @@ -0,0 +1,7 @@ +package types + +import "errors" + +var ( + ErrDbTypeUnsupported = errors.New("database type unsupported") +) diff --git a/types/rule.go b/types/rule.go new file mode 100644 index 0000000..db253b8 --- /dev/null +++ b/types/rule.go @@ -0,0 +1,39 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" +) + +type Rule struct { + Min int `json:"min"` + Max int `json:"max"` + Type string `json:"type"` + Unique bool `json:"unique"` + Required []string `json:"required"` + Regular string `json:"regular"` +} + +// Value implements the driver Valuer interface. +func (n Scenarios) Value() (driver.Value, error) { + return json.Marshal(n) +} + +// Scan implements the Scanner interface. +func (n *Rule) Scan(value interface{}) error { + if value == nil { + return nil + } + switch s := value.(type) { + case string: + return json.Unmarshal([]byte(s), n) + case []byte: + return json.Unmarshal(s, n) + } + return ErrDbTypeUnsupported +} + +// Value implements the driver Valuer interface. +func (n Rule) Value() (driver.Value, error) { + return json.Marshal(n) +} diff --git a/types/schema.go b/types/schema.go new file mode 100644 index 0000000..2866135 --- /dev/null +++ b/types/schema.go @@ -0,0 +1,81 @@ +package types + +import ( + "encoding/json" +) + +type ( + LiveValue struct { + Enable bool `json:"enable"` + Type string `json:"type"` + Url string `json:"url"` + Method string `json:"method"` + Body string `json:"body"` + ContentType string `json:"content_type"` + Columns []string `json:"columns"` + } + + EnumValue struct { + Label string `json:"label"` + Value string `json:"value"` + Color string `json:"color"` + } + + VisibleCondition struct { + Column string `json:"column"` + Values []interface{} `json:"values"` + } + + DropdownOption struct { + Created bool `json:"created,omitempty"` + Filterable bool `json:"filterable,omitempty"` + Autocomplete bool `json:"autocomplete,omitempty"` + DefaultFirst bool `json:"default_first,omitempty"` + } + + Scenarios []string + + Schema struct { + Id uint64 `json:"id" gorm:"primary_key"` + CreatedAt int64 `json:"created_at" gorm:"autoCreateTime"` //创建时间 + UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime"` //更新时间 + Domain string `json:"domain" gorm:"column:domain;type:char(60);index"` //域 + ModuleName string `json:"module_name" gorm:"column:module_name;type:varchar(60);index"` //模块名称 + TableName string `json:"table_name" gorm:"column:table_name;type:varchar(120);index"` //表名称 + Enable uint8 `json:"enable" gorm:"column:enable;type:int(1)"` //是否启用 + Column string `json:"column" gorm:"type:varchar(120)"` //字段名称 + Label string `json:"label" gorm:"type:varchar(120)"` //显示名称 + Type string `json:"type" gorm:"type:varchar(120)"` //字段类型 + Format string `json:"format" gorm:"type:varchar(120)"` //字段格式 + Native uint8 `json:"native" gorm:"type:int(1)"` //是否为原生字段 + IsPrimaryKey uint8 `json:"is_primary_key" gorm:"type:int(1)"` //是否为主键 + Expression string `json:"expression" gorm:"type:varchar(526)"` //计算规则 + Scenarios Scenarios `json:"scenarios" gorm:"type:varchar(120)"` //场景 + Rule Rule `json:"rule" gorm:"type:varchar(2048)"` //字段规则 + Attribute Attribute `json:"attribute" gorm:"type:varchar(4096)"` //字段属性 + Position int `json:"position"` //字段排序位置 + } +) + +func (n Scenarios) Has(str string) bool { + for _, v := range n { + if v == str { + return true + } + } + return false +} + +// Scan implements the Scanner interface. +func (n *Scenarios) Scan(value interface{}) error { + if value == nil { + return nil + } + switch s := value.(type) { + case string: + return json.Unmarshal([]byte(s), n) + case []byte: + return json.Unmarshal(s, n) + } + return ErrDbTypeUnsupported +} diff --git a/types/types.go b/types/types.go new file mode 100644 index 0000000..ff2d9d2 --- /dev/null +++ b/types/types.go @@ -0,0 +1,231 @@ +package types + +import ( + "context" + "gorm.io/gorm" + "net/http" + "time" +) + +const ( + SceneCreate = "create" + SceneUpdate = "update" + SceneDelete = "delete" +) + +const ( + ScenarioCreate = "create" + ScenarioUpdate = "update" + ScenarioDelete = "delete" + ScenarioSearch = "search" + ScenarioExport = "export" + ScenarioImport = "import" + ScenarioList = "list" + ScenarioView = "view" + ScenarioMapping = "mapping" +) + +const ( + MatchExactly = "exactly" //精确匹配 + MatchFuzzy = "fuzzy" //模糊匹配 +) + +const ( + LiveTypeDropdown = "dropdown" + LiveTypeCascader = "cascader" +) + +const ( + TypeInteger = "integer" + TypeFloat = "float" + TypeBoolean = "boolean" + TypeString = "string" +) + +const ( + FormatInteger = "integer" + FormatFloat = "float" + FormatBoolean = "boolean" + FormatString = "string" + FormatText = "text" + FormatDropdown = "dropdown" + FormatDatetime = "datetime" + FormatDate = "date" + FormatTime = "time" + FormatTimestamp = "timestamp" + FormatPassword = "password" +) + +const ( + OperatorEqual = "eq" + OperatorGreaterThan = "gt" + OperatorGreaterEqual = "ge" + OperatorLessThan = "lt" + OperatorLessEqual = "le" + OperatorLike = "like" + OperatorBetween = "between" +) + +const ( + ErrImportFileNotExists = 1004 + ErrImportFileUnavailable = 1005 + ErrImportColumnNotMatch = 1006 +) + +const ( + defaultPageSize = 15 + defaultDomain = "localhost" +) + +const ( + RequestDenied = 8005 + RequestRecordNotFound = 8004 + RequestPayloadInvalid = 8006 + RequestCreateFailure = 8007 + RequestUpdateFailure = 8008 + RequestDeleteFailure = 8009 +) + +const ( + FormatRaw = "raw" + FormatBoth = "both" +) + +const ( + FieldDomain = "domain" +) + +type ( + // HttpWriter http 响应接口 + HttpWriter interface { + Success(w http.ResponseWriter, data any) + Failure(w http.ResponseWriter, reason int, message string, data any) + } + + // HttpRouter http 路由管理工具 + HttpRouter interface { + Handle(method string, uri string, handler http.HandlerFunc) + } + + // TypeValue 键值对数据 + TypeValue struct { + Label any `json:"label"` + Value any `json:"value"` + } + + // NestedValue 层级数据 + NestedValue struct { + Label any `json:"label"` + Value any `json:"value"` + Children []*NestedValue `json:"children,omitempty"` + } + + //ValueLookupFunc 查找域的函数 + ValueLookupFunc func(column string, w http.ResponseWriter, r *http.Request) string + + //SchemaLookupFunc 查找schema的回调函数 + SchemaLookupFunc func(ctx context.Context, db *gorm.DB, domain, module, table, scenario string) ([]*Schema, error) +) + +type ( + multiValue struct { + Text any `json:"label"` + Value any `json:"value"` + } + + Tabler interface { + TableName() string + } + + Reporter interface { + TableName() string + RealTable() string + QuoteColumn(ctx context.Context, column string) string + GroupBy(ctx context.Context) []string + } + + FormatModel interface { + Format(ctx context.Context, scene string, schemas []*Schema) + } + + SelectColumn struct { + Name string `json:"name"` + Expr string `json:"expr"` + Native bool `json:"native"` + Callback string `json:"callback"` + } + + DiffAttr struct { + Column string `json:"column"` + Label string `json:"label"` + OldValue any `json:"old_value"` + NewValue any `json:"new_value"` + } + + Condition struct { + Column string `json:"column"` + Expr string `json:"expr"` + Value any `json:"value,omitempty"` + Values []any `json:"values,omitempty"` + } + + Naming struct { + Pluralize string + Singular string + ModuleName string + TableName string + } + + RuntimeScope struct { + Domain string //域 + User string //用户 + ModuleName string //模块名称 + TableName string //表名称 + Scenario string //场景名称 + Schemas []*Schema //字段schema + Request *http.Request //HTTP请求结构 + PrimaryKeyValue any //主键 + } + + ImportResult struct { + Code int `json:"code"` + TotalCount int `json:"total_count"` + SuccessCount int `json:"success_count"` + UploadFile string `json:"upload_file"` + FailureFile string `json:"failure_file"` + Duration time.Duration `json:"duration"` + } + + sqlCountResponse struct { + Count int64 `json:"count"` + } +) + +type ( + ListResponse struct { + Page int `json:"page"` + PageSize int `json:"pagesize"` + TotalCount int64 `json:"totalCount"` + Data any `json:"data"` + } + + CreateResponse struct { + ID any `json:"id"` + Status string `json:"status"` + } + + UpdateResponse struct { + ID any `json:"id"` + Status string `json:"status"` + } + + DeleteResponse struct { + ID any `json:"id"` + Status string `json:"status"` + } + + ImportResponse struct { + UID string `json:"uid"` + Status string `json:"status"` + } +) diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..4a5fc30 --- /dev/null +++ b/utils.go @@ -0,0 +1,45 @@ +package rest + +import ( + "git.nobla.cn/golang/kos/util/arrays" + "reflect" + "strings" +) + +func hasToken(hack string, need string) bool { + if len(need) == 0 || len(hack) == 0 { + return false + } + char := []byte{',', ';'} + for _, c := range char { + if strings.IndexByte(need, c) > -1 { + return arrays.Exists(hack, strings.Split(need, string(c))) + } + } + return false +} + +func isEmpty(val any) bool { + if val == nil { + return true + } + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.String, reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.IsNil() || v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + default: + return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface()) + } +}