From 61ffcec858a0eaf50128ed2c05e75af3f0590710 Mon Sep 17 00:00:00 2001 From: fancl Date: Wed, 11 Dec 2024 17:29:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E4=BB=93=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 60 ++ README.md | 1 + condition.go | 144 +++++ formatter.go | 240 ++++++++ go.mod | 26 + go.sum | 46 ++ hook.go | 251 ++++++++ inflector/inflector.go | 408 ++++++++++++ model.go | 987 ++++++++++++++++++++++++++++++ options.go | 52 ++ plugins/cache/cache.go | 108 ++++ plugins/identity/identified.go | 57 ++ plugins/sharding/README.md | 33 + plugins/sharding/condition.go | 120 ++++ plugins/sharding/scope.go | 287 +++++++++ plugins/sharding/sharding.go | 476 ++++++++++++++ plugins/sharding/sharding_test.go | 61 ++ plugins/sharding/types.go | 45 ++ plugins/validate/types.go | 107 ++++ plugins/validate/validation.go | 275 +++++++++ query.go | 405 ++++++++++++ rest.go | 753 +++++++++++++++++++++++ types.go | 82 +++ types/attribute.go | 48 ++ types/error.go | 7 + types/rule.go | 39 ++ types/schema.go | 81 +++ types/types.go | 231 +++++++ utils.go | 45 ++ 29 files changed, 5475 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 condition.go create mode 100644 formatter.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 hook.go create mode 100644 inflector/inflector.go create mode 100644 model.go create mode 100644 options.go create mode 100644 plugins/cache/cache.go create mode 100644 plugins/identity/identified.go create mode 100644 plugins/sharding/README.md create mode 100644 plugins/sharding/condition.go create mode 100644 plugins/sharding/scope.go create mode 100644 plugins/sharding/sharding.go create mode 100644 plugins/sharding/sharding_test.go create mode 100644 plugins/sharding/types.go create mode 100644 plugins/validate/types.go create mode 100644 plugins/validate/validation.go create mode 100644 query.go create mode 100644 rest.go create mode 100644 types.go create mode 100644 types/attribute.go create mode 100644 types/error.go create mode 100644 types/rule.go create mode 100644 types/schema.go create mode 100644 types/types.go create mode 100644 utils.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..556957c --- /dev/null +++ b/.gitignore @@ -0,0 +1,60 @@ +bin/ + +.svn/ +.godeps +./build +.cover/ +dist +_site +_posts +*.dat +.vscode +vendor + +# Go.gitignore + +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test +storage +.idea + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.local +.DS_Store + +profile + +# vim stuff +*.sw[op] + + +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +.vscode/* +!.vscode/extensions.json + +node_modules \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..69155bd --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# 数据库组件 \ No newline at end of file diff --git a/condition.go b/condition.go new file mode 100644 index 0000000..597c367 --- /dev/null +++ b/condition.go @@ -0,0 +1,144 @@ +package rest + +import ( + "context" + "encoding/json" + "git.nobla.cn/golang/kos/util/arrays" + "git.nobla.cn/golang/rest/types" + "net/http" + "strings" +) + +func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition { + for _, cond := range conditions { + if cond.Column == schema.Column { + return cond + } + } + return nil +} + +func BuildConditions(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) { + var ( + ok bool + skip bool + formValue string + activeQuery ActiveQuery + ) + if activeQuery, ok = query.Model().(ActiveQuery); ok { + if err = activeQuery.BeforeQuery(ctx, query); err != nil { + return + } + } + if arrays.Exists(r.Method, []string{http.MethodPut, http.MethodPost}) { + conditions := make([]*types.Condition, 0) + if err = json.NewDecoder(r.Body).Decode(&conditions); err != nil { + return + } + for _, row := range schemas { + if row.Native == 0 { + continue + } + cond := findCondition(row, conditions) + if cond == nil { + continue + } + switch row.Format { + case types.FormatInteger, types.FormatFloat, types.FormatTimestamp, types.FormatDatetime, types.FormatDate, types.FormatTime: + switch cond.Expr { + case types.OperatorBetween: + if len(cond.Values) == 2 { + query.AndFilterWhere(newCondition(row.Column, cond.Values[0]).WithExpr(">=")) + query.AndFilterWhere(newCondition(row.Column, cond.Values[1]).WithExpr("<=")) + } + case types.OperatorGreaterThan: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">")) + case types.OperatorGreaterEqual: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">=")) + case types.OperatorLessThan: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<")) + case types.OperatorLessEqual: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<=")) + default: + query.AndFilterWhere(newCondition(row.Column, cond.Value)) + } + default: + switch cond.Expr { + case types.OperatorLike: + query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("LIKE")) + default: + query.AndFilterWhere(newCondition(row.Column, cond.Value)) + } + } + } + } else { + qs := r.URL.Query() + for _, row := range schemas { + skip = false + if skip { + continue + } + if row.Native == 0 { + continue + } + formValue = qs.Get(row.Column) + switch row.Format { + case types.FormatString, types.FormatText: + if row.Attribute.Match == types.MatchExactly { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } else { + query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE")) + } + case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp: + var sep string + seps := []byte{',', '/'} + for _, s := range seps { + if strings.IndexByte(formValue, s) > -1 { + sep = string(s) + } + } + if ss := strings.Split(formValue, sep); len(ss) == 2 { + query.AndFilterWhere( + newCondition(row.Column, strings.TrimSpace(ss[0])).WithExpr(">="), + newCondition(row.Column, strings.TrimSpace(ss[1])).WithExpr("<="), + ) + } else { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } + case types.FormatInteger, types.FormatFloat: + query.AndFilterWhere(newCondition(row.Column, formValue)) + default: + if row.Type == types.TypeString { + if row.Attribute.Match == types.MatchExactly { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } else { + query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE")) + } + } else { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } + } + } + } + sortPar := r.FormValue("sort") + if sortPar != "" { + sorts := strings.Split(sortPar, ",") + for _, s := range sorts { + if s[0] == '-' { + query.OrderBy(s[1:], "DESC") + } else { + if s[0] == '+' { + query.OrderBy(s[1:], "ASC") + } else { + query.OrderBy(s, "ASC") + } + } + } + } + if activeQuery, ok = query.Model().(ActiveQuery); ok { + if err = activeQuery.AfterQuery(ctx, query); err != nil { + return + } + } + return +} diff --git a/formatter.go b/formatter.go new file mode 100644 index 0000000..8d40654 --- /dev/null +++ b/formatter.go @@ -0,0 +1,240 @@ +package rest + +import ( + "context" + "database/sql" + "fmt" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "reflect" + "strconv" + "sync" + "time" +) + +var ( + DefaultFormatter = NewFormatter() + + DefaultNullDisplay = "" +) + +func init() { + DefaultFormatter.Register("string", stringFormat) + DefaultFormatter.Register("integer", integerFormat) + DefaultFormatter.Register("decimal", decimalFormat) + DefaultFormatter.Register("date", dateFormat) + DefaultFormatter.Register("time", timeFormat) + DefaultFormatter.Register("datetime", datetimeFormat) + DefaultFormatter.Register("duration", durationFormat) + DefaultFormatter.Register("dropdown", dropdownFormat) + DefaultFormatter.Register("timestamp", datetimeFormat) + DefaultFormatter.Register("percentage", percentageFormat) +} + +type FormatFunc func(ctx context.Context, value any, model any, scm *types.Schema) any + +type Formatter struct { + callbacks sync.Map +} + +func (formatter *Formatter) Register(f string, fun FormatFunc) { + formatter.callbacks.Store(f, fun) +} + +func (formatter *Formatter) Format(ctx context.Context, format string, value any, model any, scm *types.Schema) any { + v, ok := formatter.callbacks.Load(format) + if ok { + return v.(FormatFunc)(ctx, value, model, scm) + } + return value +} + +func (formatter *Formatter) getModelValue(refValue reflect.Value, schema *types.Schema, stmt *gorm.Statement) any { + if stmt.Schema == nil { + return nil + } + field := stmt.Schema.LookUpField(schema.Column) + if field == nil { + return nil + } + return refValue.FieldByName(field.Name).Interface() +} + +func (formatter *Formatter) formatModel(ctx context.Context, refValue reflect.Value, schemas []*types.Schema, stmt *gorm.Statement, format string) any { + values := make(map[string]any) + multiValues := make(map[string]multiValue) + modelValue := refValue.Interface() + refValue = reflect.Indirect(refValue) + for _, scm := range schemas { + switch format { + case types.FormatRaw: + values[scm.Column] = formatter.getModelValue(refValue, scm, stmt) + case types.FormatBoth: + v := multiValue{ + Value: formatter.getModelValue(refValue, scm, stmt), + } + v.Text = formatter.Format(ctx, scm.Format, v.Value, modelValue, scm) + multiValues[scm.Column] = v + default: + values[scm.Column] = formatter.Format(ctx, scm.Format, formatter.getModelValue(refValue, scm, stmt), modelValue, scm) + } + } + if format == types.FormatBoth { + return multiValues + } else { + return values + } +} + +func (formatter *Formatter) formatModels(ctx context.Context, val any, schemas []*types.Schema, stmt *gorm.Statement, format string) any { + refValue := reflect.Indirect(reflect.ValueOf(val)) + if refValue.Kind() != reflect.Slice { + return []any{} + } + length := refValue.Len() + values := make([]any, length) + for i := 0; i < length; i++ { + rowValue := refValue.Index(i) + modelValue := rowValue.Interface() + if formatModel, ok := modelValue.(types.FormatModel); ok { + formatModel.Format(ctx, format, schemas) + } + values[i] = formatter.formatModel(ctx, rowValue, schemas, stmt, format) + } + return values +} + +func stringFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + return fmt.Sprint(value) +} + +func integerFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + var ( + n int + ) + switch value.(type) { + case float32, float64: + n = int(reflect.ValueOf(value).Float()) + case int, int8, int16, int32, int64: + n = int(reflect.ValueOf(value).Int()) + case uint, uint8, uint16, uint32, uint64: + n = int(reflect.ValueOf(value).Uint()) + case string: + n, _ = strconv.Atoi(reflect.ValueOf(value).String()) + case []byte: + n, _ = strconv.Atoi(string(reflect.ValueOf(value).Bytes())) + } + return n +} + +func decimalFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + var ( + n float64 + ) + switch value.(type) { + case float32, float64: + n = reflect.ValueOf(value).Float() + case int, int8, int16, int32, int64: + n = float64(reflect.ValueOf(value).Int()) + case uint, uint8, uint16, uint32, uint64: + n = float64(reflect.ValueOf(value).Uint()) + case string: + n, _ = strconv.ParseFloat(reflect.ValueOf(value).String(), 64) + case []byte: + n, _ = strconv.ParseFloat(string(reflect.ValueOf(value).Bytes()), 64) + } + return n +} + +func dateFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + if t, ok := value.(time.Time); ok { + return t.Format("2006-01-02") + } + if t, ok := value.(*sql.NullTime); ok { + if t != nil && t.Valid { + return t.Time.Format("2006-01-02") + } + } + if t, ok := value.(int64); ok { + tm := time.Unix(t, 0) + return tm.Format("2006-01-02") + } + return DefaultNullDisplay +} + +func timeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + if t, ok := value.(time.Time); ok { + return t.Format("15:04:05") + } + if t, ok := value.(*sql.NullTime); ok { + if t != nil && t.Valid { + return t.Time.Format("15:04:05") + } + } + if t, ok := value.(int64); ok { + tm := time.Unix(t, 0) + return tm.Format("15:04:05") + } + return value +} + +func datetimeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + if t, ok := value.(time.Time); ok { + return t.Format("2006-01-02 15:04:05") + } + if t, ok := value.(*sql.NullTime); ok { + if t != nil && t.Valid { + return t.Time.Format("2006-01-02 15:04:05") + } + } + if t, ok := value.(int64); ok { + if t > 0 { + tm := time.Unix(t, 0) + return tm.Format("2006-01-02 15:04:05") + } + } + return DefaultNullDisplay +} + +func percentageFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + n := decimalFormat(ctx, value, model, schema).(float64) + if n <= 1 { + return fmt.Sprintf("%.2f%%", n*100) + } else { + return fmt.Sprintf("%.2f%%", n) + } +} + +func durationFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + var ( + hour int + minVal int + sec int + ) + n := integerFormat(ctx, value, model, schema).(int) + hour = n / 3600 + minVal = (n - hour*3600) / 60 + sec = n - hour*3600 - minVal*60 + return fmt.Sprintf("%02d:%02d:%02d", hour, minVal, sec) +} + +func dropdownFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} { + attributes := schema.Attribute + if attributes.Values != nil { + for _, v := range attributes.Values { + if v.Value == value { + return v.Label + } + } + } + return value +} + +func NewFormatter() *Formatter { + formatter := &Formatter{} + return formatter +} + +func RegisterFormat(f string, cb FormatFunc) { + DefaultFormatter.Register(f, cb) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f2bb69f --- /dev/null +++ b/go.mod @@ -0,0 +1,26 @@ +module git.nobla.cn/golang/rest + +go 1.22.9 + +require ( + git.nobla.cn/golang/kos v0.1.32 + github.com/cespare/xxhash/v2 v2.3.0 + github.com/go-playground/validator/v10 v10.23.0 + github.com/longbridgeapp/sqlparser v0.3.2 + github.com/rs/xid v1.6.0 + github.com/uole/sqlparser v0.0.1 + gorm.io/gorm v1.25.12 +) + +require ( + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4446986 --- /dev/null +++ b/go.sum @@ -0,0 +1,46 @@ +git.nobla.cn/golang/kos v0.1.32 h1:sFVCA7vKc8dPUd0cxzwExOSPX2mmMh2IuwL6cYS1pBc= +git.nobla.cn/golang/kos v0.1.32/go.mod h1:35Z070+5oB39WcVrh5DDlnVeftL/Ccmscw2MZFe9fUg= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL8sThn8IHr/sO+o= +github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M= +github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y= +github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/uole/sqlparser v0.0.1 h1:LLUklg6Ne5MypXQuo53QcJv/xKdxtEKM9iUuEBN/lt8= +github.com/uole/sqlparser v0.0.1/go.mod h1:CRYFz2PTm9oHM0j9GFKi1VzPy70r6GsF0b9vpnwJ4yI= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/hook.go b/hook.go new file mode 100644 index 0000000..91c81ad --- /dev/null +++ b/hook.go @@ -0,0 +1,251 @@ +package rest + +import ( + "context" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" +) + +const ( + beforeCreate = "beforeCreate" + afterCreate = "afterCreate" + beforeUpdate = "beforeUpdate" + afterUpdate = "afterUpdate" + beforeSave = "beforeSave" + afterSave = "afterSave" + beforeDelete = "beforeDelete" + afterDelete = "afterDelete" + afterExport = "afterExport" + afterImport = "afterImport" +) + +type ( + BeforeCreate func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterCreate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) + BeforeUpdate func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterUpdate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) + BeforeSave func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterSave func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) + BeforeDelete func(ctx context.Context, tx *gorm.DB, model any) (err error) + AfterDelete func(ctx context.Context, tx *gorm.DB, model any) + AfterExport func(ctx context.Context, filename string) //导出回调 + AfterImport func(ctx context.Context, result *types.ImportResult) //导入回调 + hookManager struct { + callbacks map[string][]any + } +) + +type ( + ActiveQuery interface { + BeforeQuery(ctx context.Context, query *Query) (err error) + AfterQuery(ctx context.Context, query *Query) (err error) + } +) + +func (hook *hookManager) register(spec string, cb any) { + if hook.callbacks == nil { + hook.callbacks = make(map[string][]any) + } + if _, ok := hook.callbacks[spec]; !ok { + hook.callbacks[spec] = make([]any, 0) + } + hook.callbacks[spec] = append(hook.callbacks[spec], cb) +} + +func (hook *hookManager) beforeCreate(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeCreate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeCreate); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterCreate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) { + callbacks, ok := hook.callbacks[afterCreate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterCreate); ok { + cb(ctx, tx, model, diff) + } + } + return +} + +func (hook *hookManager) beforeUpdate(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeUpdate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeUpdate); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterUpdate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) { + callbacks, ok := hook.callbacks[afterUpdate] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterUpdate); ok { + cb(ctx, tx, model, diff) + } + } + return +} + +func (hook *hookManager) beforeSave(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeSave] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeSave); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterSave(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) { + callbacks, ok := hook.callbacks[afterSave] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterSave); ok { + cb(ctx, tx, model, diff) + } + } + return +} + +func (hook *hookManager) beforeDelete(ctx context.Context, tx *gorm.DB, model any) (err error) { + callbacks, ok := hook.callbacks[beforeDelete] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(BeforeDelete); ok { + if err = cb(ctx, tx, model); err != nil { + return err + } + } + } + return +} + +func (hook *hookManager) afterDelete(ctx context.Context, tx *gorm.DB, model any) { + callbacks, ok := hook.callbacks[afterDelete] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterDelete); ok { + cb(ctx, tx, model) + } + } + return +} + +func (hook *hookManager) afterExport(ctx context.Context, filename string) { + callbacks, ok := hook.callbacks[afterExport] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterExport); ok { + cb(ctx, filename) + } + } + return +} + +func (hook *hookManager) afterImport(ctx context.Context, ret *types.ImportResult) { + callbacks, ok := hook.callbacks[afterImport] + if !ok { + return + } + for _, callback := range callbacks { + if cb, ok := callback.(AfterImport); ok { + cb(ctx, ret) + } + } + return +} + +func (hook *hookManager) BeforeCreate(cb BeforeCreate) { + if cb != nil { + hook.register(beforeCreate, cb) + } +} + +func (hook *hookManager) AfterCreate(cb AfterCreate) { + if cb != nil { + hook.register(afterCreate, cb) + } +} + +func (hook *hookManager) BeforeUpdate(cb BeforeUpdate) { + if cb != nil { + hook.register(beforeUpdate, cb) + } +} + +func (hook *hookManager) AfterUpdate(cb AfterUpdate) { + if cb != nil { + hook.register(afterUpdate, cb) + } +} + +func (hook *hookManager) BeforeSave(cb BeforeSave) { + if cb != nil { + hook.register(beforeSave, cb) + } +} + +func (hook *hookManager) AfterSave(cb AfterSave) { + if cb != nil { + hook.register(afterSave, cb) + } +} + +func (hook *hookManager) BeforeDelete(cb BeforeDelete) { + if cb != nil { + hook.register(beforeDelete, cb) + } +} + +func (hook *hookManager) AfterDelete(cb AfterDelete) { + if cb != nil { + hook.register(afterDelete, cb) + } +} + +func (hook *hookManager) AfterExport(cb AfterExport) { + if cb != nil { + hook.register(afterExport, cb) + } +} + +func (hook *hookManager) AfterImport(cb AfterImport) { + if cb != nil { + hook.register(afterImport, cb) + } +} diff --git a/inflector/inflector.go b/inflector/inflector.go new file mode 100644 index 0000000..7a03031 --- /dev/null +++ b/inflector/inflector.go @@ -0,0 +1,408 @@ +package inflector + +import ( + "bytes" + "fmt" + "regexp" + "strings" + "sync" +) + +// Rule represents name of the inflector rule, can be +// Plural or Singular +type Rule int + +const ( + Plural = iota + Singular +) + +// InflectorRule represents inflector rule +type InflectorRule struct { + Rules []*ruleItem + Irregular []*irregularItem + Uninflected []string + compiledIrregular *regexp.Regexp + compiledUninflected *regexp.Regexp + compiledRules []*compiledRule +} + +type ruleItem struct { + pattern string + replacement string +} + +type irregularItem struct { + word string + replacement string +} + +// compiledRule represents compiled version of Inflector.Rules. +type compiledRule struct { + replacement string + *regexp.Regexp +} + +// threadsafe access to rules and caches +var mutex sync.Mutex +var rules = make(map[Rule]*InflectorRule) + +// Words that should not be inflected +var uninflected = []string{ + `Amoyese`, `bison`, `Borghese`, `bream`, `breeches`, `britches`, `buffalo`, + `cantus`, `carp`, `chassis`, `clippers`, `cod`, `coitus`, `Congoese`, + `contretemps`, `corps`, `debris`, `diabetes`, `djinn`, `eland`, `elk`, + `equipment`, `Faroese`, `flounder`, `Foochowese`, `gallows`, `Genevese`, + `Genoese`, `Gilbertese`, `graffiti`, `headquarters`, `herpes`, `hijinks`, + `Hottentotese`, `information`, `innings`, `jackanapes`, `Kiplingese`, + `Kongoese`, `Lucchese`, `mackerel`, `Maltese`, `.*?media`, `mews`, `moose`, + `mumps`, `Nankingese`, `news`, `nexus`, `Niasese`, `Pekingese`, + `Piedmontese`, `pincers`, `Pistoiese`, `pliers`, `Portuguese`, `proceedings`, + `rabies`, `rice`, `rhinoceros`, `salmon`, `Sarawakese`, `scissors`, + `sea[- ]bass`, `series`, `Shavese`, `shears`, `siemens`, `species`, `swine`, + `testes`, `trousers`, `trout`, `tuna`, `Vermontese`, `Wenchowese`, `whiting`, + `wildebeest`, `Yengeese`, +} + +// Plural words that should not be inflected +var uninflectedPlurals = []string{ + `.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`, + `people`, +} + +// Singular words that should not be inflected +var uninflectedSingulars = []string{ + `.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`, + `.*ss`, +} + +type cache map[string]string + +// Inflected words that already cached for immediate retrieval from a given Rule +var caches = make(map[Rule]cache) + +// map of irregular words where its key is a word and its value is the replacement +var irregularMaps = make(map[Rule]cache) + +var ( + // https://github.com/golang/lint/blob/master/lint.go#L770 + commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} + commonInitialismsReplacer *strings.Replacer +) + +func init() { + rules[Plural] = &InflectorRule{ + Rules: []*ruleItem{ + {`(?i)(s)tatus$`, `${1}${2}tatuses`}, + {`(?i)(quiz)$`, `${1}zes`}, + {`(?i)^(ox)$`, `${1}${2}en`}, + {`(?i)([m|l])ouse$`, `${1}ice`}, + {`(?i)(matr|vert|ind)(ix|ex)$`, `${1}ices`}, + {`(?i)(x|ch|ss|sh)$`, `${1}es`}, + {`(?i)([^aeiouy]|qu)y$`, `${1}ies`}, + {`(?i)(hive)$`, `$1s`}, + {`(?i)(?:([^f])fe|([lre])f)$`, `${1}${2}ves`}, + {`(?i)sis$`, `ses`}, + {`(?i)([ti])um$`, `${1}a`}, + {`(?i)(p)erson$`, `${1}eople`}, + {`(?i)(m)an$`, `${1}en`}, + {`(?i)(c)hild$`, `${1}hildren`}, + {`(?i)(buffal|tomat)o$`, `${1}${2}oes`}, + {`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|vir)us$`, `${1}i`}, + {`(?i)us$`, `uses`}, + {`(?i)(alias)$`, `${1}es`}, + {`(?i)(ax|cris|test)is$`, `${1}es`}, + {`s$`, `s`}, + {`^$`, ``}, + {`$`, `s`}, + }, + Irregular: []*irregularItem{ + {`atlas`, `atlases`}, + {`beef`, `beefs`}, + {`brother`, `brothers`}, + {`cafe`, `cafes`}, + {`child`, `children`}, + {`cookie`, `cookies`}, + {`corpus`, `corpuses`}, + {`cow`, `cows`}, + {`ganglion`, `ganglions`}, + {`genie`, `genies`}, + {`genus`, `genera`}, + {`graffito`, `graffiti`}, + {`hoof`, `hoofs`}, + {`loaf`, `loaves`}, + {`man`, `men`}, + {`money`, `monies`}, + {`mongoose`, `mongooses`}, + {`move`, `moves`}, + {`mythos`, `mythoi`}, + {`niche`, `niches`}, + {`numen`, `numina`}, + {`occiput`, `occiputs`}, + {`octopus`, `octopuses`}, + {`opus`, `opuses`}, + {`ox`, `oxen`}, + {`penis`, `penises`}, + {`person`, `people`}, + {`sex`, `sexes`}, + {`soliloquy`, `soliloquies`}, + {`testis`, `testes`}, + {`trilby`, `trilbys`}, + {`turf`, `turfs`}, + {`potato`, `potatoes`}, + {`hero`, `heroes`}, + {`tooth`, `teeth`}, + {`goose`, `geese`}, + {`foot`, `feet`}, + }, + } + prepare(Plural) + + rules[Singular] = &InflectorRule{ + Rules: []*ruleItem{ + {`(?i)(s)tatuses$`, `${1}${2}tatus`}, + {`(?i)^(.*)(menu)s$`, `${1}${2}`}, + {`(?i)(quiz)zes$`, `$1`}, + {`(?i)(matr)ices$`, `${1}ix`}, + {`(?i)(vert|ind)ices$`, `${1}ex`}, + {`(?i)^(ox)en`, `$1`}, + {`(?i)(alias)(es)*$`, `$1`}, + {`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|viri?)i$`, `${1}us`}, + {`(?i)([ftw]ax)es`, `$1`}, + {`(?i)(cris|ax|test)es$`, `${1}is`}, + {`(?i)(shoe|slave)s$`, `$1`}, + {`(?i)(o)es$`, `$1`}, + {`ouses$`, `ouse`}, + {`([^a])uses$`, `${1}us`}, + {`(?i)([m|l])ice$`, `${1}ouse`}, + {`(?i)(x|ch|ss|sh)es$`, `$1`}, + {`(?i)(m)ovies$`, `${1}${2}ovie`}, + {`(?i)(s)eries$`, `${1}${2}eries`}, + {`(?i)([^aeiouy]|qu)ies$`, `${1}y`}, + {`(?i)(tive)s$`, `$1`}, + {`(?i)([lre])ves$`, `${1}f`}, + {`(?i)([^fo])ves$`, `${1}fe`}, + {`(?i)(hive)s$`, `$1`}, + {`(?i)(drive)s$`, `$1`}, + {`(?i)(^analy)ses$`, `${1}sis`}, + {`(?i)(analy|diagno|^ba|(p)arenthe|(p)rogno|(s)ynop|(t)he)ses$`, `${1}${2}sis`}, + {`(?i)([ti])a$`, `${1}um`}, + {`(?i)(p)eople$`, `${1}${2}erson`}, + {`(?i)(m)en$`, `${1}an`}, + {`(?i)(c)hildren$`, `${1}${2}hild`}, + {`(?i)(n)ews$`, `${1}${2}ews`}, + {`eaus$`, `eau`}, + {`^(.*us)$`, `$1`}, + {`(?i)s$`, ``}, + }, + Irregular: []*irregularItem{ + {`foes`, `foe`}, + {`waves`, `wave`}, + {`curves`, `curve`}, + {`atlases`, `atlas`}, + {`beefs`, `beef`}, + {`brothers`, `brother`}, + {`cafes`, `cafe`}, + {`children`, `child`}, + {`cookies`, `cookie`}, + {`corpuses`, `corpus`}, + {`cows`, `cow`}, + {`ganglions`, `ganglion`}, + {`genies`, `genie`}, + {`genera`, `genus`}, + {`graffiti`, `graffito`}, + {`hoofs`, `hoof`}, + {`loaves`, `loaf`}, + {`men`, `man`}, + {`monies`, `money`}, + {`mongooses`, `mongoose`}, + {`moves`, `move`}, + {`mythoi`, `mythos`}, + {`niches`, `niche`}, + {`numina`, `numen`}, + {`occiputs`, `occiput`}, + {`octopuses`, `octopus`}, + {`opuses`, `opus`}, + {`oxen`, `ox`}, + {`penises`, `penis`}, + {`people`, `person`}, + {`sexes`, `sex`}, + {`soliloquies`, `soliloquy`}, + {`testes`, `testis`}, + {`trilbys`, `trilby`}, + {`turfs`, `turf`}, + {`potatoes`, `potato`}, + {`heroes`, `hero`}, + {`teeth`, `tooth`}, + {`geese`, `goose`}, + {`feet`, `foot`}, + }, + } + prepare(Singular) + + commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +// prepare rule, e.g., compile the pattern. +func prepare(r Rule) error { + var reString string + + switch r { + case Plural: + // Merge global uninflected with singularsUninflected + rules[r].Uninflected = merge(uninflected, uninflectedPlurals) + case Singular: + // Merge global uninflected with singularsUninflected + rules[r].Uninflected = merge(uninflected, uninflectedSingulars) + } + + // Set InflectorRule.compiledUninflected by joining InflectorRule.Uninflected into + // a single string then compile it. + reString = fmt.Sprintf(`(?i)(^(?:%s))$`, strings.Join(rules[r].Uninflected, `|`)) + rules[r].compiledUninflected = regexp.MustCompile(reString) + + // Prepare irregularMaps + irregularMaps[r] = make(cache, len(rules[r].Irregular)) + + // Set InflectorRule.compiledIrregular by joining the irregularItem.word of Inflector.Irregular + // into a single string then compile it. + vIrregulars := make([]string, len(rules[r].Irregular)) + for i, item := range rules[r].Irregular { + vIrregulars[i] = item.word + irregularMaps[r][item.word] = item.replacement + } + reString = fmt.Sprintf(`(?i)(.*)\b((?:%s))$`, strings.Join(vIrregulars, `|`)) + rules[r].compiledIrregular = regexp.MustCompile(reString) + + // Compile all patterns in InflectorRule.Rules + rules[r].compiledRules = make([]*compiledRule, len(rules[r].Rules)) + for i, item := range rules[r].Rules { + rules[r].compiledRules[i] = &compiledRule{item.replacement, regexp.MustCompile(item.pattern)} + } + + // Prepare caches + caches[r] = make(cache) + + return nil +} + +// merge slice a and slice b +func merge(a []string, b []string) []string { + result := make([]string, len(a)+len(b)) + copy(result, a) + copy(result[len(a):], b) + + return result +} + +func getInflected(r Rule, s string) string { + mutex.Lock() + defer mutex.Unlock() + if v, ok := caches[r][s]; ok { + return v + } + + // Check for irregular words + if res := rules[r].compiledIrregular.FindStringSubmatch(s); len(res) >= 3 { + var buf bytes.Buffer + + buf.WriteString(res[1]) + buf.WriteString(s[0:1]) + buf.WriteString(irregularMaps[r][strings.ToLower(res[2])][1:]) + + // Cache it then returns + caches[r][s] = buf.String() + return caches[r][s] + } + + // Check for uninflected words + if rules[r].compiledUninflected.MatchString(s) { + caches[r][s] = s + return caches[r][s] + } + + // Check each rule + for _, re := range rules[r].compiledRules { + if re.MatchString(s) { + caches[r][s] = re.ReplaceAllString(s, re.replacement) + return caches[r][s] + } + } + + // Returns unaltered + caches[r][s] = s + return caches[r][s] +} + +// Pluralize returns string s in plural form. +func Pluralize(s string) string { + return getInflected(Plural, s) +} + +// Singularize returns string s in singular form. +func Singularize(s string) string { + return getInflected(Singular, s) +} + +var ( + camelizeReg = regexp.MustCompile(`[^A-Za-z0-9]+`) +) + +// Camelize Converts a word like "send_email" to "SendEmail" +func Camelize(s string) string { + s = camelizeReg.ReplaceAllString(s, " ") + return strings.Replace(strings.Title(s), " ", "", -1) +} + +// Camel2id Converts a word like "SendEmail" to "send_email" +func Camel2id(name string) string { + var ( + value = commonInitialismsReplacer.Replace(name) + buf strings.Builder + lastCase, nextCase, nextNumber bool // upper case == true + curCase = value[0] <= 'Z' && value[0] >= 'A' + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' + nextNumber = value[i+1] >= '0' && value[i+1] <= '9' + + if curCase { + if lastCase && (nextCase || nextNumber) { + buf.WriteRune(v + 32) + } else { + if i > 0 && value[i-1] != '_' && value[i+1] != '_' { + buf.WriteByte('_') + } + buf.WriteRune(v + 32) + } + } else { + buf.WriteRune(v) + } + + lastCase = curCase + curCase = nextCase + } + + if curCase { + if !lastCase && len(value) > 1 { + buf.WriteByte('_') + } + buf.WriteByte(value[len(value)-1] + 32) + } else { + buf.WriteByte(value[len(value)-1]) + } + ret := buf.String() + return ret +} + +// Camel2words Converts a CamelCase name into space-separated words. +// For example, 'send_email' will be converted to 'Send Email'. +func Camel2words(s string) string { + s = camelizeReg.ReplaceAllString(s, " ") + return strings.Title(s) +} diff --git a/model.go b/model.go new file mode 100644 index 0000000..82f92cf --- /dev/null +++ b/model.go @@ -0,0 +1,987 @@ +package rest + +import ( + "context" + "encoding/csv" + "encoding/json" + "fmt" + "git.nobla.cn/golang/kos/util/arrays" + "git.nobla.cn/golang/kos/util/pool" + "git.nobla.cn/golang/rest/inflector" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "path" + "reflect" + "strconv" + "strings" + "time" +) + +type Model struct { + naming types.Naming //命名规则 + value reflect.Value //模块值 + db *gorm.DB //数据库 + primaryKey string //主键 + urlPrefix string //url前缀 + disableDomain bool //禁用域 + schemaLookup types.SchemaLookupFunc //获取schema的函数 + valueLookup types.ValueLookupFunc //查看域 + statement *gorm.Statement //字段声明 + formatter *Formatter //格式化 + response types.HttpWriter //HTTP响应 + hookMgr *hookManager //钩子管理器 + dirname string //存放文件目录 +} + +var ( + RuntimeScopeKey = &types.RuntimeScope{} +) + +// getDB 获取数据库连接对象 +func (m *Model) getDB() *gorm.DB { + return m.db +} + +// getFormatter 获取格式化组件 +func (m *Model) getFormatter() *Formatter { + if m.formatter != nil { + return m.formatter + } + return DefaultFormatter +} + +// getHook 获取钩子 +func (m *Model) getHook() *hookManager { + return m.hookMgr +} + +// hasScenario 判断是否有该场景 +func (m *Model) hasScenario(s string) bool { + return true +} + +// setValue 设置字段的值 +func (m *Model) setValue(refValue reflect.Value, column string, value any) { + SetFieldValue(m.statement, refValue, column, value) +} + +func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) { + SafeSetFileValue(m.statement, refValue, column, value) +} + +// getValue 获取字段的值 +func (m *Model) getValue(refValue reflect.Value, column string) interface{} { + return GetFieldValue(m.statement, refValue, column) +} + +// hasColumn 判断指定的列是否存在 +func (m *Model) hasColumn(column string) bool { + for _, field := range m.statement.Schema.Fields { + if field.DBName == column || field.Name == column { + return true + } + } + return false +} + +// getFilename 获取文件存放目录 +func (m *Model) getFilename(domain string, spec string, name string) string { + if m.dirname == "" { + m.dirname = os.TempDir() + } + filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name) + if _, err := os.Stat(path.Dir(filename)); err != nil { + _ = os.MkdirAll(path.Dir(filename), 0755) + } + return filename +} + +// findPrimaryKey 查找主键的值 +func (m *Model) findPrimaryKey(uri string, r *http.Request) string { + var ( + pos int + ) + urlPath := r.URL.Path + pos = strings.IndexByte(uri, ':') + if pos > 0 { + return urlPath[pos:] + } + return "" +} + +// parseReportColumn 解析报表的列 +func (m *Model) parseReportColumn(name, props string) *types.SelectColumn { + var ( + key string + value string + ) + column := &types.SelectColumn{ + Name: inflector.Camel2id(name), + Native: false, + } + tokens := strings.Split(props, ";") + for _, token := range tokens { + pair := strings.SplitN(token, ":", 2) + if len(pair) == 0 { + continue + } + if len(pair) == 1 { + key = strings.TrimSpace(pair[0]) + value = "" + } else { + key = strings.TrimSpace(pair[0]) + value = strings.TrimSpace(pair[1]) + } + switch key { + case "native": + column.Native = true + case "name": + column.Name = value + case "expr": + column.Expr = value + } + } + return column +} + +func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + columns := make([]string, 0) + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + scenarios := field.Tag.Get("scenarios") + if !hasToken(types.ScenarioList, scenarios) { + continue + } + isPrimary := field.Tag.Get("is_primary") + if isPrimary != "true" { + continue + } + column := m.parseReportColumn(field.Name, field.Tag.Get("report")) + if !column.Native { + continue + } + if column.Expr == "" { + columns = append(columns, dest.QuoteColumn(ctx, column.Name)) + } else { + columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name))) + } + } + columns = append(columns, "COUNT(*) AS count") + query.Select(columns...) +} + +func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + columns := make([]string, 0) + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + scenarios := field.Tag.Get("scenarios") + if !hasToken(types.ScenarioList, scenarios) { + continue + } + column := m.parseReportColumn(field.Name, field.Tag.Get("report")) + if !column.Native { + continue + } + if column.Expr == "" { + columns = append(columns, dest.QuoteColumn(ctx, column.Name)) + } else { + columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name))) + } + } + query.Select(columns...) +} + +// buildCondition 构建sql条件 +func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) { + return BuildConditions(ctx, r, query, schemas) +} + +// ModuleName 模块名称 +func (m *Model) ModuleName() string { + return m.naming.ModuleName +} + +// TableName 表的名称 +func (m *Model) TableName() string { + return m.naming.ModuleName +} + +// Fields 返回搜索的模型的字段 +func (m *Model) Fields() []*schema.Field { + return m.statement.Schema.Fields +} + +// Uri 获取请求的uri +func (m *Model) Uri(scenario string) string { + ss := make([]string, 4) + if m.urlPrefix != "" { + ss = append(ss, m.urlPrefix) + } + switch scenario { + case types.ScenarioList: + ss = append(ss, m.naming.ModuleName, m.naming.Pluralize) + case types.ScenarioView: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") + case types.ScenarioCreate: + ss = append(ss, m.naming.ModuleName, m.naming.Singular) + case types.ScenarioUpdate: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") + case types.ScenarioDelete: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") + case types.ScenarioExport: + ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export") + case types.ScenarioImport: + ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import") + } + uri := path.Join(ss...) + if !strings.HasPrefix(uri, "/") { + uri = "/" + uri + } + return uri +} + +// Method 获取HTTP请求的方法 +func (m *Model) Method(scenario string) string { + var ( + method = http.MethodGet + ) + switch scenario { + case types.ScenarioCreate: + method = http.MethodPost + case types.ScenarioUpdate: + method = http.MethodPut + case types.ScenarioDelete: + method = http.MethodDelete + } + return method +} + +// Search 实现通过HTTP方法查找数据 +func (m *Model) Search(w http.ResponseWriter, r *http.Request) { + var ( + ok bool + err error + qs url.Values + page int + pageSize int + pageIndex int + query *Query + domainName string + modelSlices reflect.Value + modelValues reflect.Value + searchSchemas []*types.Schema + listSchemas []*types.Schema + modelValue reflect.Value + scenario string + reporter types.Reporter + namerTable tableNamer + ) + qs = r.URL.Query() + page, _ = strconv.Atoi(qs.Get("page")) + pageSize, _ = strconv.Atoi(qs.Get("pagesize")) + if pageSize <= 0 { + pageSize = defaultPageSize + } + pageIndex = page + if pageIndex > 0 { + pageIndex-- + } + modelValue = reflect.New(m.value.Type()) + //这里创建指针类型,这样的话就能在format里面调用函数 + if m.value.Kind() != reflect.Ptr { + modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0) + } else { + modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0) + } + modelValues = reflect.New(modelSlices.Type()) + modelValues.Elem().Set(modelSlices) + query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface()) + domainName = m.valueLookup(types.FieldDomain, w, r) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + Request: r, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioList, + }) + if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + scenario = types.ScenarioList + if arrays.Exists(r.FormValue("scenario"), allowScenario) { + scenario = r.FormValue("scenario") + } + if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if !m.disableDomain { + if m.hasColumn(types.FieldDomain) { + query.AndWhere(newCondition(types.FieldDomain, domainName)) + } + } + if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) + return + } + // 处理表名逻辑 + if namerTable, ok = query.Model().(tableNamer); ok { + query.From(namerTable.HttpTableName(r)) + } + //处理报表逻辑 + if reporter, ok = modelValue.Interface().(types.Reporter); ok { + query.From(reporter.RealTable()) + } + res := &types.ListResponse{ + Page: page, + PageSize: pageSize, + } + if reporter == nil { + res.TotalCount = query.Limit(0).Offset(0).Count(query.Model()) + } else { + //如果是报表的情况,需要手动指定COUNT的雨具逻辑才能生效 + m.buildReporterCountColumns(childCtx, reporter, query) + res.TotalCount = query.Limit(0).Offset(0).Count(query.Model()) + + //这里需要重置一下选项,不然会出问题 + query.ResetSelect() + query.GroupBy(reporter.GroupBy(childCtx)...) + } + query.Offset(pageIndex * pageSize).Limit(pageSize) + if res.TotalCount > 0 { + if reporter != nil { + m.buildReporterQueryColumns(childCtx, reporter, query) + } + if err = query.All(modelValues.Interface()); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + // 不进行格式化输出 + res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format")) + } else { + res.Data = make([]string, 0) + } + m.response.Success(w, res) +} + +// Create 实现通过HTTP方法创建模型 +func (m *Model) Create(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + schemas []*types.Schema + diffAttrs []*types.DiffAttr + domainName string + modelValue reflect.Value + ) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil) + return + } + domainName = m.valueLookup(types.FieldDomain, w, r) + if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if !m.disableDomain { + if m.hasColumn(types.FieldDomain) { + m.setValue(modelValue, types.FieldDomain, domainName) + } + } + diffAttrs = make([]*types.DiffAttr, 0, 10) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + User: m.valueLookup("user", w, r), + Request: r, + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioCreate, + Schemas: schemas, + }) + dbSess := m.getDB().WithContext(childCtx) + if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil { + return + } + if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil { + return + } + if tabler, ok := model.(types.Tabler); ok { + errTx = tx.Table(tabler.TableName()).Save(model).Error + } else { + errTx = tx.Save(model).Error + } + if errTx != nil { + return + } + for _, row := range schemas { + diffAttrs = append(diffAttrs, &types.DiffAttr{ + Column: row.Column, + Label: row.Label, + OldValue: nil, + NewValue: m.getValue(modelValue, row.Column), + }) + } + return + }); err == nil { + res := &types.CreateResponse{ + ID: m.getValue(modelValue, m.primaryKey), + Status: "created", + } + if creator, ok := model.(afterCreated); ok { + creator.AfterCreated(childCtx, dbSess) + } + if preserver, ok := model.(afterSaved); ok { + preserver.AfterSaved(childCtx, dbSess) + } + m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs) + m.getHook().afterSave(childCtx, dbSess, model, diffAttrs) + m.response.Success(w, res) + } else { + m.response.Failure(w, types.RequestCreateFailure, err.Error(), err) + } +} + +// Update 实现通过HTTP方法更新模型 +func (m *Model) Update(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + schemas []*types.Schema + diffAttrs []*types.DiffAttr + domainName string + modelValue reflect.Value + oldValues map[string]any + ) + idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + domainName = m.valueLookup(types.FieldDomain, w, r) + if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + conditions := map[string]any{ + m.primaryKey: idStr, + } + if err = m.getDB().Where(conditions).First(model).Error; err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + oldValues = make(map[string]any) + for _, row := range schemas { + oldValues[row.Column] = m.getValue(modelValue, row.Column) + } + if err = json.NewDecoder(r.Body).Decode(model); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) + return + } + diffAttrs = make([]*types.DiffAttr, 0, 10) + updates := make(map[string]any) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + Request: r, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioUpdate, + Schemas: schemas, + PrimaryKeyValue: idStr, + }) + dbSess := m.getDB().WithContext(childCtx) + if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil { + return + } + if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil { + return + } + for _, row := range schemas { + v := m.getValue(modelValue, row.Column) + if oldValues[row.Column] != v { + updates[row.Column] = v + diffAttrs = append(diffAttrs, &types.DiffAttr{ + Column: row.Column, + Label: row.Label, + OldValue: oldValues[row.Column], + NewValue: v, + }) + } + } + if len(updates) > 0 { + if tabler, ok := model.(types.Tabler); ok { + errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error + } else { + errTx = tx.Model(model).Updates(updates).Error + } + if errTx != nil { + return + } + } + return + }); err == nil { + if updater, ok := model.(afterUpdated); ok { + updater.AfterUpdated(childCtx, dbSess) + } + if preserver, ok := model.(afterSaved); ok { + preserver.AfterSaved(childCtx, dbSess) + } + m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs) + m.getHook().afterSave(childCtx, dbSess, model, diffAttrs) + m.response.Success(w, types.UpdateResponse{ + ID: idStr, + Status: "updated", + }) + } else { + m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil) + } +} + +// Delete 实现通过HTTP方法删除模型 +func (m *Model) Delete(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + modelValue reflect.Value + ) + idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + conditions := map[string]any{ + m.primaryKey: idStr, + } + if err = m.getDB().Where(conditions).First(model).Error; err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: m.valueLookup(types.FieldDomain, w, r), + User: m.valueLookup("user", w, r), + Request: r, + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioDelete, + PrimaryKeyValue: idStr, + }) + dbSess := m.getDB().WithContext(childCtx) + if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil { + return + } + if tabler, ok := model.(types.Tabler); ok { + errTx = tx.Table(tabler.TableName()).Delete(model).Error + } else { + errTx = tx.Delete(model).Error + } + if errTx != nil { + return + } + m.getHook().afterDelete(childCtx, tx, model) + return + }); err == nil { + m.response.Success(w, types.DeleteResponse{ + ID: idStr, + Status: "deleted", + }) + } else { + m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil) + } +} + +// View 查看数据详情 +func (m *Model) View(w http.ResponseWriter, r *http.Request) { + var ( + err error + model any + modelValue reflect.Value + qs url.Values + schemas []*types.Schema + scenario string + domainName string + ) + qs = r.URL.Query() + idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) + modelValue = reflect.New(m.value.Type()) + model = modelValue.Interface() + conditions := map[string]any{ + m.primaryKey: idStr, + } + domainName = m.valueLookup(types.FieldDomain, w, r) + if err = m.getDB().Where(conditions).First(model).Error; err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + scenario = qs.Get("scenario") + if scenario == "" { + schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView) + } else { + schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario) + } + if err == nil { + m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format"))) + } else { + m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil) + } +} + +// Export 实现通过HTTP方法导出模型 +func (m *Model) Export(w http.ResponseWriter, r *http.Request) { + var ( + err error + query *Query + modelSlices reflect.Value + modelValues reflect.Value + searchSchemas []*types.Schema + exportSchemas []*types.Schema + domainName string + fp *os.File + modelValue reflect.Value + ) + if !m.hasScenario(types.ScenarioList) { + m.response.Failure(w, types.RequestDenied, "request denied", nil) + return + } + domainName = m.valueLookup(types.FieldDomain, w, r) + filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix())) + if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil) + return + } + defer func() { + _ = fp.Close() + }() + modelValue = reflect.New(m.value.Type()) + //这里创建指针类型,这样的话就能在format里面调用函数 + if m.value.Kind() != reflect.Ptr { + modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0) + } else { + modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0) + } + modelValues = reflect.New(modelSlices.Type()) + modelValues.Elem().Set(modelSlices) + query = NewQuery(m.getDB(), modelValue.Interface()) + childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + Request: r, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioExport, + }) + if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) + return + } + if !m.disableDomain { + if m.hasColumn(types.FieldDomain) { + query.AndWhere(newCondition(types.FieldDomain, domainName)) + } + } + // 处理表名逻辑 + if namerTable, ok := query.Model().(tableNamer); ok { + query.From(namerTable.HttpTableName(r)) + } + //处理报表逻辑 + if reporter, ok := modelValue.Interface().(types.Reporter); ok { + query.From(reporter.RealTable()) + query.GroupBy(reporter.GroupBy(childCtx)...) + m.buildReporterQueryColumns(childCtx, reporter, query) + } + if err = query.All(modelValues.Interface()); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular)) + value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "") + writer := csv.NewWriter(fp) + rows := make([]string, len(exportSchemas)) + for i, field := range exportSchemas { + rows[i] = field.Label + } + _ = writer.Write(rows) + if values, ok := value.([]any); ok { + for _, val := range values { + row, ok2 := val.(map[string]any) + if !ok2 { + continue + } + for i, field := range exportSchemas { + if v, ok := row[field.Column]; ok { + rows[i] = fmt.Sprint(v) + } else { + rows[i] = "" + } + } + _ = writer.Write(rows) + } + } + writer.Flush() + m.getHook().afterExport(childCtx, filename) + http.ServeContent(w, r, path.Base(filename), time.Now(), fp) +} + +// findSchema 查找指定的schema +func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema { + for _, row := range schemas { + if row.Label == label { + return row + } + } + return nil +} + +// importInternal 文件上传方法 +func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) { + var ( + err error + rows []string + fp *os.File + tm time.Time + fields []string + sess *gorm.DB + csvReader *csv.Reader + csvWriter *csv.Writer + modelValue reflect.Value + modelEntity any + diffAttrs []*types.DiffAttr + result *types.ImportResult + failureFp *os.File + failureFile string + ) + tm = time.Now() + result = &types.ImportResult{} + if fp, err = os.Open(filename); err != nil { + result.Code = types.ErrImportFileNotExists + goto __end + } + defer func() { + _ = fp.Close() + }() + csvReader = csv.NewReader(fp) + if rows, err = csvReader.Read(); err != nil { + result.Code = types.ErrImportFileUnavailable + goto __end + } + fields = make([]string, 0, len(rows)) + for _, s := range rows { + v := m.findSchema(s, schemas) + if v == nil { + result.Code = types.ErrImportColumnNotMatch + goto __end + } + fields = append(fields, v.Column) + } + sess = m.getDB().WithContext(ctx) + //失败文件指针 + failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix())) + if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil { + return + } + defer func() { + _ = failureFp.Close() + }() + csvWriter = csv.NewWriter(failureFp) + rows = append(rows, "Error") + _ = csvWriter.Write(rows) + diffAttrs = make([]*types.DiffAttr, len(schemas)) + for { + if rows, err = csvReader.Read(); err != nil { + break + } + result.TotalCount++ + if len(rows) != len(fields) { + continue + } + modelValue = reflect.New(m.value.Type()) + for idx, field := range fields { + m.safeSetValue(modelValue, field, rows[idx]) + } + if len(extraFields) > 0 { + for k, v := range extraFields { + m.safeSetValue(modelValue, k, v) + } + } + modelEntity = modelValue.Interface() + //写入数据 + if fast { + //如果是快速模式,直接存储数据 + if err = sess.Save(modelEntity).Error; err == nil { + result.SuccessCount++ + } else { + rows = append(rows, err.Error()) + _ = csvWriter.Write(rows) + } + } else { + if err = sess.Transaction(func(tx *gorm.DB) (errTx error) { + if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil { + return + } + if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil { + return + } + if tabler, ok := modelEntity.(types.Tabler); ok { + errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error + } else { + errTx = tx.Save(modelEntity).Error + } + if errTx != nil { + return + } + for idx, row := range schemas { + diffAttrs[idx] = &types.DiffAttr{ + Column: row.Column, + Label: row.Label, + NewValue: m.getValue(modelValue, row.Column), + } + } + m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs) + m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs) + return + }); err == nil { + result.SuccessCount++ + } else { + rows = append(rows, err.Error()) + _ = csvWriter.Write(rows) + } + } + } + csvWriter.Flush() +__end: + result.UploadFile = filename + if result.TotalCount > result.SuccessCount { + result.FailureFile = failureFile + } + result.Duration = time.Now().Sub(tm) + m.getHook().afterImport(ctx, result) +} + +// Import 实现通过HTTP方法导入 +func (m *Model) Import(w http.ResponseWriter, r *http.Request) { + var ( + err error + fast bool + schemas []*types.Schema + rows []string + domainName string + dst *os.File + fp multipart.File + csvWriter *csv.Writer + qs url.Values + extraFields map[string]string + ) + domainName = m.valueLookup(types.FieldDomain, w, r) + if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil { + m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) + return + } + //这里用background的context + childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{ + Domain: domainName, + User: m.valueLookup("user", w, r), + ModuleName: m.naming.ModuleName, + TableName: m.naming.TableName, + Scenario: types.ScenarioImport, + Schemas: schemas, + }) + if r.Method == http.MethodGet { + //下载导入模板 + csvWriter = csv.NewWriter(w) + rows = make([]string, 0, len(schemas)) + for _, row := range schemas { + //主键不需要导入 + if row.IsPrimaryKey == 1 { + continue + } + rows = append(rows, row.Label) + } + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular)) + err = csvWriter.Write(rows) + csvWriter.Flush() + return + } + filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix())) + if fp, _, err = r.FormFile("file"); err != nil { + m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil) + return + } + defer func() { + _ = fp.Close() + }() + if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil { + buf := pool.GetBytes(32 * 1024) + _, err = io.CopyBuffer(dst, fp, buf) + pool.PutBytes(buf) + _ = dst.Close() + } else { + m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil) + return + } + qs = r.URL.Query() + if qs != nil { + extraFields = make(map[string]string) + for k, _ := range qs { + if strings.HasPrefix(k, "_attr_") { + extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k) + } + } + } + fast, _ = strconv.ParseBool(qs.Get("__fast")) + go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields) + m.response.Success(w, types.ImportResponse{ + UID: m.valueLookup("user", w, r), + Status: "committed", + }) +} + +// newModel 创建一个模型 +func newModel(v any, db *gorm.DB, naming types.Naming) *Model { + model := &Model{ + db: db, + naming: naming, + response: &httpWriter{}, + value: reflect.Indirect(reflect.ValueOf(v)), + valueLookup: defaultValueLookup, + } + model.statement = &gorm.Statement{ + DB: model.getDB(), + ConnPool: model.getDB().ConnPool, + Clauses: map[string]clause.Clause{}, + } + if err := model.statement.Parse(v); err == nil { + if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 { + model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0] + } + } + return model +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..7746e20 --- /dev/null +++ b/options.go @@ -0,0 +1,52 @@ +package rest + +import "git.nobla.cn/golang/rest/types" + +type Options struct { + urlPrefix string + moduleName string + disableDomain bool + router types.HttpRouter + writer types.HttpWriter + formatter *Formatter + dirname string //文件目录 +} + +type Option func(o *Options) + +func WithUriPrefix(s string) Option { + return func(o *Options) { + o.urlPrefix = s + } +} + +func WithModuleName(s string) Option { + return func(o *Options) { + o.moduleName = s + } +} + +// WithoutDomain 禁用域 +func WithoutDomain() Option { + return func(o *Options) { + o.disableDomain = true + } +} + +func WithHttpRouter(s types.HttpRouter) Option { + return func(o *Options) { + o.router = s + } +} + +func WithHttpWriter(s types.HttpWriter) Option { + return func(o *Options) { + o.writer = s + } +} + +func WithFormatter(s *Formatter) Option { + return func(o *Options) { + o.formatter = s + } +} diff --git a/plugins/cache/cache.go b/plugins/cache/cache.go new file mode 100644 index 0000000..ad2a8c5 --- /dev/null +++ b/plugins/cache/cache.go @@ -0,0 +1,108 @@ +package cache + +import ( + "encoding/json" + "fmt" + "git.nobla.cn/golang/kos/pkg/cache" + xxhash "github.com/cespare/xxhash/v2" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "os" + "strconv" + "time" +) + +const ( + DisableCache = "DISABLE_CACHE" + DurationKey = "gorm:cache_duration" +) + +type Cacher struct { + rawQuery func(db *gorm.DB) +} + +func (c *Cacher) Name() string { + return "gorm:cache" +} + +func (c *Cacher) Initialize(db *gorm.DB) (err error) { + c.rawQuery = db.Callback().Query().Get("gorm:query") + err = db.Callback().Query().Replace("gorm:query", c.Query) + return +} + +// buildCacheKey 构建一个缓存的KEY +func (c *Cacher) buildCacheKey(db *gorm.DB) string { + s := strconv.FormatUint(xxhash.Sum64String(db.Statement.SQL.String()+fmt.Sprintf("%v", db.Statement.Vars)), 10) + return s +} + +// getDuration 获取缓存时长 +func (c *Cacher) getDuration(db *gorm.DB) time.Duration { + var ( + ok bool + v any + duration time.Duration + ) + if v, ok = db.InstanceGet(DurationKey); !ok { + return 0 + } + if duration, ok = v.(time.Duration); !ok { + return 0 + } + return duration +} + +// tryLoad 尝试从缓存读取数据 +func (c *Cacher) tryLoad(key string, db *gorm.DB) (err error) { + var ( + ok bool + buf []byte + ) + if buf, ok = cache.Get(db.Statement.Context, key); ok { + err = json.Unmarshal(buf, db.Statement.Dest) + } else { + err = os.ErrNotExist + } + return +} + +// storeCache 存储缓存数据 +func (c *Cacher) storeCache(key string, db *gorm.DB, duration time.Duration) (err error) { + var ( + buf []byte + ) + if buf, err = json.Marshal(db.Statement.Dest); err == nil { + cache.SetEx(db.Statement.Context, key, buf, duration) + } + return +} + +func (c *Cacher) Query(db *gorm.DB) { + var ( + err error + cacheKey string + duration time.Duration + ) + duration = c.getDuration(db) + if duration <= 0 { + c.rawQuery(db) + return + } + callbacks.BuildQuerySQL(db) + cacheKey = c.buildCacheKey(db) + if err = c.tryLoad(cacheKey, db); err == nil { + return + } + c.rawQuery(db) + if db.Error == nil { + //store cache + if err = c.storeCache(cacheKey, db, duration); err != nil { + _ = db.AddError(err) + } + } +} + +func New() *Cacher { + return &Cacher{} +} diff --git a/plugins/identity/identified.go b/plugins/identity/identified.go new file mode 100644 index 0000000..c26b806 --- /dev/null +++ b/plugins/identity/identified.go @@ -0,0 +1,57 @@ +package identity + +import ( + "github.com/rs/xid" + "gorm.io/gorm" + "gorm.io/gorm/schema" + "reflect" +) + +type Identify struct { +} + +func (identity *Identify) Name() string { + return "gorm:identity" +} + +func (identity *Identify) Initialize(db *gorm.DB) (err error) { + err = db.Callback().Create().Before("gorm:create").Register("auto_identified", identity.Grant) + return +} + +func (identity *Identify) NextID() string { + return xid.New().String() +} + +func (identity *Identify) Grant(db *gorm.DB) { + var ( + err error + field *schema.Field + ) + if db.Statement.Schema == nil { + return + } + if field = db.Statement.Schema.LookUpField("ID"); field == nil { + return + } + if field.DataType != schema.String { + return + } + if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue.Index(i)); zero { + if err = field.Set(db.Statement.Context, db.Statement.ReflectValue.Index(i), identity.NextID()); err != nil { + _ = db.AddError(err) + } + } + } + } else { + if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); zero { + db.Statement.SetColumn("ID", identity.NextID()) + } + } +} + +func New() *Identify { + return &Identify{} +} diff --git a/plugins/sharding/README.md b/plugins/sharding/README.md new file mode 100644 index 0000000..409de7a --- /dev/null +++ b/plugins/sharding/README.md @@ -0,0 +1,33 @@ +# 分表实现 + +首先定义一个`gorm`的模型,然后实现`shadring.Model`接口,比如如下示例 + +```go +// ShardingTable 返回增删改时候操作的数据表 +func (model *CdrLog) ShardingTable(scene string) string { + return model.TableName() +} + +// ShardingTables 返回查询时候一个范围内的表 +func (model *CdrLog) ShardingTables(ctx *sharding.Context) []string { + var ( + timestamp int64 + ) + timeRange := make([]int64, 0) + values := ctx.FindColumnValues("start_stamp") + if len(values) == 0 { + values = ctx.FindColumnValues("create_stamp") + } + if len(values) > 0 { + for _, v := range values { + timestamp, _ = strconv.ParseInt(fmt.Sprint(v), 10, 64) + timeRange = append(timeRange, timestamp) + } + } + return shard.DateTableNames(ctx.Context(), "cdr_logs", shard.ShardTypeDateMonth, timeRange) +} +``` + +`ShardingTable`方法是操作增删改的是回调具体表名的方法 + +`ShardingTables`方法是操作查询的时候,通过查询条件返回的表名的方法 \ No newline at end of file diff --git a/plugins/sharding/condition.go b/plugins/sharding/condition.go new file mode 100644 index 0000000..4b6c00f --- /dev/null +++ b/plugins/sharding/condition.go @@ -0,0 +1,120 @@ +package sharding + +const ( + ValueOperaEqual = iota + 0x10 + ValueOperaGreater + ValueOperaLess + ValueOperaRange + + ValueTypeString = iota + 0x30 + ValueTypeNumber + ValueTypeBoolean + ValueTypeNull + ValueTypeAny +) + +type ( + ColumnCondition struct { + Name string `json:"name"` + Value CondValue `json:"value"` + } + + CondValue interface { + Type() int + Opera() int + Value() any + } + + equalValue struct { + vType int + vData any + } + + rangeValue struct { + vType int + vData any + } + + greaterValue struct { + vType int + vData any + } + + lessValue struct { + vType int + vData any + } +) + +func (v *lessValue) Type() int { + return v.vType +} + +func (v *lessValue) Opera() int { + return ValueOperaLess +} + +func (v *lessValue) Value() any { + return v.vData +} + +func (v *greaterValue) Type() int { + return v.vType +} + +func (v *greaterValue) Opera() int { + return ValueOperaGreater +} + +func (v *greaterValue) Value() any { + return v.vData +} + +func (v *rangeValue) Type() int { + return v.vType +} + +func (v *rangeValue) Opera() int { + return ValueOperaRange +} + +func (v *rangeValue) Value() any { + return v.vData +} + +func (v *equalValue) Type() int { + return v.vType +} + +func (v *equalValue) Opera() int { + return ValueOperaEqual +} + +func (v *equalValue) Value() any { + return v.vData +} + +func newCondValue(vType int, op int, value any) CondValue { + switch op { + case ValueOperaGreater: + return &greaterValue{ + vType: vType, + vData: value, + } + case ValueOperaLess: + return &lessValue{ + vType: vType, + vData: value, + } + case ValueOperaRange: + return &rangeValue{ + vType: vType, + vData: value, + } + default: + return &equalValue{ + vType: vType, + vData: value, + } + } +} diff --git a/plugins/sharding/scope.go b/plugins/sharding/scope.go new file mode 100644 index 0000000..fe1bc53 --- /dev/null +++ b/plugins/sharding/scope.go @@ -0,0 +1,287 @@ +package sharding + +import ( + "context" + "github.com/longbridgeapp/sqlparser" + sqlparserX "github.com/uole/sqlparser" + "gorm.io/gorm" + "strconv" +) + +type Scope struct { + db *gorm.DB + stmt *sqlparser.SelectStatement + stmtX *sqlparserX.Select +} + +func (scope *Scope) findValue(express sqlparser.Expr) (int, any) { + var ( + vType int + vData any + ) + switch expr := express.(type) { + case *sqlparser.BindExpr: + vType = ValueTypeAny + if len(scope.db.Statement.Vars) > expr.Pos { + vData = scope.db.Statement.Vars[expr.Pos] + } else { + vType = ValueTypeNull + } + case *sqlparser.NumberLit: + vType = ValueTypeNumber + vData = expr.Value + case *sqlparser.StringLit: + vType = ValueTypeString + vData = expr.Value + case *sqlparser.BoolLit: + vType = ValueTypeBoolean + vData = expr.Value + case *sqlparser.BlobLit: + vType = ValueTypeString + vData = expr.Value + case *sqlparser.NullLit: + vType = ValueTypeNull + case *sqlparser.Range: + arr := make([]any, 2) + vType, arr[0] = scope.findValue(expr.X) + vType, arr[1] = scope.findValue(expr.Y) + vData = arr + } + return vType, vData +} + +func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) { + var ( + vType int + vData any + ) + switch express.Type { + case sqlparserX.IntVal: + vType = ValueTypeNumber + vData, _ = strconv.Atoi(string(express.Val)) + case sqlparserX.FloatVal: + vType = ValueTypeNumber + vData, _ = strconv.ParseFloat(string(express.Val), 64) + case sqlparserX.ValArg: + vType = ValueTypeAny + pos, _ := strconv.Atoi(string(express.Val[2:])) + if pos > 0 { + pos = pos - 1 + if len(scope.db.Statement.Vars) > pos { + vData = scope.db.Statement.Vars[pos] + } else { + vType = ValueTypeNull + } + } else { + vType = ValueTypeNull + } + default: + vType = ValueTypeString + vData = string(express.Val) + } + return vType, vData +} + +func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) { + var ( + ok bool + andExpr *sqlparserX.AndExpr + orExpr *sqlparserX.OrExpr + parentExpr *sqlparserX.ParenExpr + comparisonExpr *sqlparserX.ComparisonExpr + rangeExpr *sqlparserX.RangeCond + coumnExpr *sqlparserX.ColName + valueExpr *sqlparserX.SQLVal + ) + conditions = make([]*ColumnCondition, 0, 2) + if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok { + if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok { + return + } + if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok { + return + } + if coumnExpr.Name.EqualString(column) { + cond := &ColumnCondition{ + Name: coumnExpr.Name.String(), + } + vType, vData := scope.findValueX(valueExpr) + switch comparisonExpr.Operator { + case sqlparserX.LessThanStr, sqlparserX.LessEqualStr: + cond.Value = newCondValue(vType, ValueOperaLess, vData) + case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr: + cond.Value = newCondValue(vType, ValueOperaGreater, vData) + case sqlparserX.EqualStr: + cond.Value = newCondValue(vType, ValueOperaEqual, vData) + } + if cond.Value != nil { + conditions = append(conditions, cond) + } + } + } + + if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok { + if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok { + return + } + if coumnExpr.Name.EqualString(column) { + vType := 0 + arr := make([]any, 2) + if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok { + vType, arr[0] = scope.findValueX(valueExpr) + } + if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok { + vType, arr[1] = scope.findValueX(valueExpr) + } + conditions = append(conditions, &ColumnCondition{ + Name: coumnExpr.Name.String(), + Value: newCondValue(vType, ValueOperaRange, arr), + }) + } + } + + if andExpr, ok = expr.(*sqlparserX.AndExpr); ok { + if andExpr.Left != nil { + conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...) + } + if andExpr.Right != nil { + conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...) + } + } + if orExpr, ok = expr.(*sqlparserX.OrExpr); ok { + if orExpr.Left != nil { + conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...) + } + if orExpr.Right != nil { + conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...) + } + } + if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok { + if parentExpr.Expr != nil { + conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...) + } + } + return conditions +} + +func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition { + var ( + ok bool + identExpr *sqlparser.Ident + binaryExpr *sqlparser.BinaryExpr + parentExpr *sqlparser.ParenExpr + conditions []*ColumnCondition + ) + conditions = make([]*ColumnCondition, 0, 2) + if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok { + if parentExpr.X != nil { + if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok { + conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...) + } + if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok { + conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...) + } + } + } + + if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok { + if binaryExpr.X != nil { + if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok { + if identExpr.Name == column { + cond := &ColumnCondition{ + Name: identExpr.Name, + } + vType, vData := scope.findValue(binaryExpr.Y) + switch binaryExpr.Op { + case sqlparser.LT, sqlparser.LE: + cond.Value = newCondValue(vType, ValueOperaLess, vData) + case sqlparser.GT, sqlparser.GE: + cond.Value = newCondValue(vType, ValueOperaGreater, vData) + case sqlparser.RANGE, sqlparser.BETWEEN: + cond.Value = newCondValue(vType, ValueOperaRange, vData) + case sqlparser.EQ: + cond.Value = newCondValue(vType, ValueOperaEqual, vData) + } + if cond.Value != nil { + conditions = append(conditions, cond) + } + } + } else { + if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok { + conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...) + } + if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok { + conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...) + } + } + } + + if binaryExpr.Y != nil { + if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok { + conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...) + } + } + } + + return conditions +} + +func (scope *Scope) DB() *gorm.DB { + return scope.db +} + +func (scope *Scope) Context() context.Context { + return scope.db.Statement.Context +} + +func (scope *Scope) FindCondition(column string) []*ColumnCondition { + if scope.stmtX != nil { + if scope.stmtX.Where == nil { + return []*ColumnCondition{} + } + return scope.recursiveFindX(scope.stmtX.Where.Expr, column) + } + return scope.recursiveFind(scope.stmt.Condition, column) +} + +func (scope *Scope) FindColumnValues(column string) []any { + result := make([]any, 0) + conditions := scope.FindCondition(column) + if len(conditions) == 0 { + return result + } + for _, cond := range conditions { + if cond.Value.Opera() == ValueOperaGreater { + if len(result) == 0 { + result = make([]any, 2) + } + result[0] = cond.Value.Value() + } + + if cond.Value.Opera() == ValueOperaLess { + if len(result) == 0 { + result = make([]any, 2) + } + result[1] = cond.Value.Value() + } + + if cond.Value.Opera() == ValueOperaEqual { + result = append(result, cond.Value.Value()) + break + } + + if cond.Value.Opera() == ValueOperaRange { + if vs, ok := cond.Value.Value().([]any); ok { + if len(vs) == 2 { + if len(result) == 0 { + result = make([]any, 2) + } + result[0] = vs[0] + result[1] = vs[1] + } + } + break + } + } + return result +} diff --git a/plugins/sharding/sharding.go b/plugins/sharding/sharding.go new file mode 100644 index 0000000..3d3ac70 --- /dev/null +++ b/plugins/sharding/sharding.go @@ -0,0 +1,476 @@ +package sharding + +import ( + "github.com/longbridgeapp/sqlparser" + sqlparserX "github.com/uole/sqlparser" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "reflect" + "strings" +) + +type Sharding struct { + UnionAll bool + QuoteChar byte +} + +func (plugin *Sharding) Name() string { + return "gorm:sharding" +} + +func (plugin *Sharding) Initialize(db *gorm.DB) (err error) { + if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil { + return + } + if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil { + return + } + if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil { + return + } + if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil { + return err + } + return +} + +func (plugin *Sharding) Create(db *gorm.DB) { + var ( + ok bool + scopeModel Model + refValue reflect.Value + modelValue any + ) + if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { + if db.Statement.ReflectValue.Len() > 0 { + refValue = db.Statement.ReflectValue.Index(0) + modelValue = refValue.Interface() + } + } else { + if db.Statement.Model != nil { + modelValue = db.Statement.Model + } else { + modelValue = db.Statement.ReflectValue.Interface() + } + } + if modelValue != nil { + if scopeModel, ok = modelValue.(Model); ok { + db.Table(scopeModel.ShardingTable(sceneCreate)) + } + } +} + +func (plugin *Sharding) Update(db *gorm.DB) { + var ( + ok bool + scopeModel Model + refValue reflect.Value + modelValue any + ) + if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { + if db.Statement.ReflectValue.Len() > 0 { + refValue = db.Statement.ReflectValue.Index(0) + modelValue = refValue.Interface() + } + } else { + if db.Statement.Model != nil { + modelValue = db.Statement.Model + } else { + modelValue = db.Statement.ReflectValue.Interface() + } + } + if modelValue != nil { + if scopeModel, ok = modelValue.(Model); ok { + db.Table(scopeModel.ShardingTable(sceneUpdate)) + } + } +} + +func (plugin *Sharding) Delete(db *gorm.DB) { + var ( + ok bool + scopeModel Model + refValue reflect.Value + modelValue any + ) + if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { + if db.Statement.ReflectValue.Len() > 0 { + refValue = db.Statement.ReflectValue.Index(0) + modelValue = refValue.Interface() + } + } else { + if db.Statement.Model != nil { + modelValue = db.Statement.Model + } else { + modelValue = db.Statement.ReflectValue.Interface() + } + } + if modelValue != nil { + if scopeModel, ok = modelValue.(Model); ok { + db.Table(scopeModel.ShardingTable(sceneDelete)) + } + } +} + +func (plugin *Sharding) Query(db *gorm.DB) { + var ( + err error + ok bool + shardingModel Model + modelValue any + tables []string + rawVars []any + refValue reflect.Value + tableName *sqlparser.TableName + selectStmt *sqlparser.SelectStatement + stmt sqlparser.Statement + parser *sqlparser.Parser + numOfTable int + orderByExpr []*sqlparser.OrderingTerm + limitExpr sqlparser.Expr + offsetExpr sqlparser.Expr + groupingExpr []sqlparser.Expr + havingExpr sqlparser.Expr + isCountStatement bool + countField string + ) + if db.Statement.Model != nil { + refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type()) + } else { + refValue = reflect.New(db.Statement.ReflectValue.Type()) + } + if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct { + refValue = reflect.Indirect(refValue) + } + if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice { + elemType := refValue.Type().Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + modelValue = reflect.New(elemType).Interface() + } else { + modelValue = refValue.Interface() + } + if shardingModel, ok = modelValue.(Model); !ok { + return + } + if db.Statement.SQL.Len() == 0 { + callbacks.BuildQuerySQL(db) + } + parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String())) + if stmt, err = parser.ParseStatement(); err != nil { + return + } + if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok { + return + } + tables = shardingModel.ShardingTables(&Scope{ + db: db, + stmt: selectStmt, + }) + numOfTable = len(tables) + if numOfTable <= 1 { + return + } + rawVars = make([]any, 0, len(db.Statement.Vars)) + for _, v := range db.Statement.Vars { + rawVars = append(rawVars, v) + } + //是否是查询count语句 + //如果不是count的语句,添加order和group的支持 + if v := db.Statement.Context.Value("@sql_count_statement"); v != nil { + if v == true { + isCountStatement = true + } + } + if !isCountStatement && len(*selectStmt.Columns) == 1 { + for _, column := range *selectStmt.Columns { + if expr, ok := column.Expr.(*sqlparser.Call); ok { + if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword { + isCountStatement = true + countField = expr.String() + break + } + } + } + } + + if len(selectStmt.OrderBy) > 0 { + orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy)) + for _, row := range selectStmt.OrderBy { + orderByExpr = append(orderByExpr, row) + } + selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0) + } + if len(selectStmt.GroupingElements) > 0 { + groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements)) + for _, row := range selectStmt.GroupingElements { + groupingExpr = append(groupingExpr, row) + } + if selectStmt.HavingCondition != nil { + havingExpr = selectStmt.HavingCondition + } + } + if selectStmt.Limit != nil { + limitExpr = selectStmt.Limit + selectStmt.Limit = nil + } + if selectStmt.Offset != nil { + offsetExpr = selectStmt.Offset + selectStmt.Offset = nil + } + db.Statement.SQL.Reset() + if isCountStatement { + db.Statement.SQL.WriteString("SELECT SUM(") + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(strings.Trim(countField, "`")) + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(") FROM (") + } else { + db.Statement.SQL.WriteString("SELECT * FROM (") + } + for i, name := range tables { + db.Statement.SQL.WriteByte('(') + if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok { + tableName.Name.Name = name + } + db.Statement.SQL.WriteString(selectStmt.String()) + db.Statement.SQL.WriteByte(')') + if i < numOfTable-1 { + if plugin.UnionAll { + db.Statement.SQL.WriteString(" UNION ALL ") + } else { + db.Statement.SQL.WriteString(" UNION ") + } + } + if i > 0 { + //copy vars + db.Statement.Vars = append(db.Statement.Vars, rawVars...) + } + } + db.Statement.SQL.WriteString(") tbl ") + if !isCountStatement { + if len(groupingExpr) > 0 { + db.Statement.SQL.WriteString(" GROUP BY ") + for i, expr := range groupingExpr { + if i != 0 { + db.Statement.SQL.WriteString(", ") + } + db.Statement.SQL.WriteString(expr.String()) + } + + if havingExpr != nil { + db.Statement.SQL.WriteString(" HAVING ") + db.Statement.SQL.WriteString(havingExpr.String()) + } + } + + if orderByExpr != nil && len(orderByExpr) > 0 { + db.Statement.SQL.WriteString(" ORDER BY ") + for i, term := range orderByExpr { + if i != 0 { + db.Statement.SQL.WriteString(", ") + } + db.Statement.SQL.WriteString(term.String()) + } + } + if limitExpr != nil { + db.Statement.SQL.WriteString(" LIMIT ") + db.Statement.SQL.WriteString(limitExpr.String()) + if offsetExpr != nil { + db.Statement.SQL.WriteString(" OFFSET ") + db.Statement.SQL.WriteString(offsetExpr.String()) + } + } + } + return +} + +func (plugin *Sharding) QueryX(db *gorm.DB) { + var ( + err error + ok bool + shardingModel Model + modelValue any + tables []string + rawVars []any + refValue reflect.Value + selectStmt *sqlparserX.Select + stmt sqlparserX.Statement + numOfTable int + isCountStatement bool + isPureCountStatement bool + trackedBuffer *sqlparserX.TrackedBuffer + funcExpr *sqlparserX.FuncExpr + aliasedExpr *sqlparserX.AliasedExpr + orderByExpr sqlparserX.OrderBy + groupByExpr sqlparserX.GroupBy + havingExpr *sqlparserX.Where + limitExpr *sqlparserX.Limit + ) + if db.Statement.Model != nil { + refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type()) + } else { + if !db.Statement.ReflectValue.IsValid() { + return + } + refValue = reflect.New(db.Statement.ReflectValue.Type()) + } + if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct { + refValue = reflect.Indirect(refValue) + } + if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice { + elemType := refValue.Type().Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + modelValue = reflect.New(elemType).Interface() + } else { + modelValue = refValue.Interface() + } + if shardingModel, ok = modelValue.(Model); !ok { + return + } + if db.Statement.SQL.Len() == 0 { + callbacks.BuildQuerySQL(db) + } + if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil { + return + } + if selectStmt, ok = stmt.(*sqlparserX.Select); !ok { + return + } + tables = shardingModel.ShardingTables(&Scope{ + db: db, + stmtX: selectStmt, + }) + numOfTable = len(tables) + if numOfTable <= 1 { + return + } + // 保存值 + rawVars = make([]any, 0, len(db.Statement.Vars)) + for _, v := range db.Statement.Vars { + rawVars = append(rawVars, v) + } + // 替换语句 + if selectStmt.OrderBy != nil { + orderByExpr = selectStmt.OrderBy + selectStmt.OrderBy = nil + } + if selectStmt.GroupBy != nil { + groupByExpr = selectStmt.GroupBy + //selectStmt.GroupBy = nil + } + if selectStmt.Having != nil { + havingExpr = selectStmt.Having + //selectStmt.Having = nil + } + if selectStmt.Limit != nil { + limitExpr = selectStmt.Limit + selectStmt.Limit = nil + } + // 检查是否为COUNT语句 + //如果不是count的语句,添加order和group的支持 + if v := db.Statement.Context.Value("@sql_count_statement"); v != nil { + //这里处理的是报表的情况,再报表里面需要重写COUNT语句才能获取到正确的数量 + if v == true { + isCountStatement = true + isPureCountStatement = true + } + } + //常规的COUNT逻辑 + if !isCountStatement && len(selectStmt.SelectExprs) == 1 { + for _, expr := range selectStmt.SelectExprs { + if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok { + if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok { + if funcExpr.Name.EqualString("count") { + isCountStatement = true + break + } + } + } + } + } + // 重写SQL + db.Statement.SQL.Reset() + trackedBuffer = sqlparserX.NewTrackedBuffer(nil) + + if isCountStatement { + if isPureCountStatement { + db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (") + } else { + db.Statement.SQL.WriteString("SELECT SUM(") + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(strings.Trim("count(*)", "`")) + db.Statement.SQL.WriteByte(plugin.QuoteChar) + db.Statement.SQL.WriteString(") FROM (") + } + } else { + if bs, ok := modelValue.(SelectBuilder); ok { + columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs) + if len(columns) > 0 { + db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (") + } else { + db.Statement.SQL.WriteString("SELECT * FROM (") + } + } else { + db.Statement.SQL.WriteString("SELECT * FROM (") + } + } + for i, name := range tables { + trackedBuffer.Reset() + db.Statement.SQL.WriteByte('(') + //赋值新的表名称 + selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{ + Expr: sqlparserX.TableName{ + Name: sqlparserX.NewTableIdent(name), + }, + }} + selectStmt.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + db.Statement.SQL.WriteByte(')') + if i < numOfTable-1 { + if plugin.UnionAll { + db.Statement.SQL.WriteString(" UNION ALL ") + } else { + db.Statement.SQL.WriteString(" UNION ") + } + } + if i > 0 { + //copy vars + db.Statement.Vars = append(db.Statement.Vars, rawVars...) + } + } + db.Statement.SQL.WriteString(") tbl ") + if !isCountStatement { + //node.GroupBy, node.Having, node.OrderBy, node.Limit + if groupByExpr != nil { + trackedBuffer.Reset() + groupByExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + if havingExpr != nil { + trackedBuffer.Reset() + havingExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + if orderByExpr != nil { + trackedBuffer.Reset() + orderByExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + if limitExpr != nil { + trackedBuffer.Reset() + limitExpr.Format(trackedBuffer) + db.Statement.SQL.WriteString(trackedBuffer.String()) + } + } +} + +func New() *Sharding { + return &Sharding{ + UnionAll: true, + QuoteChar: '`', + } +} diff --git a/plugins/sharding/sharding_test.go b/plugins/sharding/sharding_test.go new file mode 100644 index 0000000..a26bbc6 --- /dev/null +++ b/plugins/sharding/sharding_test.go @@ -0,0 +1,61 @@ +package sharding + +import ( + "fmt" + "github.com/uole/sqlparser" + "testing" +) + +func TestFormat(t *testing.T) { + t.Log(fmt.Sprintf("%.2f%%", 0.5851888*100)) +} + +func TestSharding_Query(t *testing.T) { + //sql := "SELECT COUNT(*) FROM aaa" + sql := "SELECT uid,SUM(IF(direction='inbound',1,0)) AS inbound_times,SUM(IF(direction='inbound',IF(answer_duration>0,1,0),0)) AS inbound_answer_times,SUM(IF(direction='outbound',1,0)) AS outbound_times,SUM(IF(direction='outbound',IF(answer_duration>0,1,0),0)) AS outbound_answer_times FROM `cdr_logs` WHERE ((`domain` = 'test.cc.echo.me' OR `domain` = 'default') AND `name` <> '') AND (`create_stamp` BETWEEN 1712505600 AND 1712505608) AND `name` IN ('a','b','c') GROUP BY `uid` having uid != '' ORDER BY create_stamp DESC LIMIT 15" + stmt, err := sqlparser.Parse(sql) + if err != nil { + t.Error(err) + return + } + selectStmt, ok := stmt.(*sqlparser.Select) + if !ok { + t.Error("not select stmt") + return + } + buf := sqlparser.NewTrackedBuffer(nil) + sqlparser.NewTableIdent("test").Format(buf) + buf.Reset() + selectStmt.Format(buf) + t.Log("SQL") + t.Log(buf.String()) + buf.Reset() + t.Log("SELECT") + selectStmt.SelectExprs.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("FROM") + selectStmt.From.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("WHERE") + selectStmt.Where.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("ORDER BY") + selectStmt.OrderBy.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("GROUP BY") + selectStmt.GroupBy.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("LIMIT") + selectStmt.Limit.Format(buf) + t.Log(buf.String()) + buf.Reset() + t.Log("Having") + selectStmt.Having.Format(buf) + t.Log(buf.String()) + buf.Reset() +} diff --git a/plugins/sharding/types.go b/plugins/sharding/types.go new file mode 100644 index 0000000..1374c29 --- /dev/null +++ b/plugins/sharding/types.go @@ -0,0 +1,45 @@ +package sharding + +import ( + "context" + sqlparserX "github.com/uole/sqlparser" +) + +const ( + TypeDatetime = "datetime" + TypeHash = "hash" +) + +const ( + DateTypeYear = iota + 5 + DateTypeMonth + DateTypeWeek + DateTypeDay +) + +const ( + sceneCreate = "create" + sceneUpdate = "update" + sceneDelete = "delete" +) + +const ( + stmtCountKeyword = "count" +) + +type ( + Rule struct { + Type string `json:"type"` + Args int `json:"args"` + } + + Model interface { + ShardingRule() Rule + ShardingTable(scene string) string + ShardingTables(scope *Scope) []string + } + + SelectBuilder interface { + BuildSelect(ctx context.Context, expr sqlparserX.SelectExprs) []string + } +) diff --git a/plugins/validate/types.go b/plugins/validate/types.go new file mode 100644 index 0000000..c357a09 --- /dev/null +++ b/plugins/validate/types.go @@ -0,0 +1,107 @@ +package validate + +import ( + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "reflect" + "regexp" + "strconv" +) + +const ( + SkipValidations = "validations:skip_validations" +) + +var ( + scopeCtxKey = &validateScope{} + telephoneRegex = regexp.MustCompile("^\\d{5,20}$") +) + +type ( + validateScope struct { + DB *gorm.DB + Column string + Domain string + MultiDomain bool + Model interface{} + } + + StructError struct { + Tag string `json:"tag"` + Column string `json:"column"` + Message string `json:"message"` + } + + validateRule struct { + Rule string + Value string + Valid bool + } +) + +func (err *StructError) Error() string { + return err.Message +} + +func newRule(ss ...string) *validateRule { + v := &validateRule{ + Valid: true, + } + if len(ss) == 1 { + v.Rule = ss[0] + } else if len(ss) >= 2 { + v.Rule = ss[0] + v.Value = ss[1] + } + return v +} + +// formatError 格式化错误消息 +func formatError(rule types.Rule, scm *types.Schema, tag string) string { + var s string + switch tag { + case "db_unique": + s = scm.Label + "值已经存在." + break + case "required": + s = scm.Label + "值不能为空." + case "max": + if scm.Type == "string" { + s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max) + } else { + s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max) + } + case "min": + if scm.Type == "string" { + s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max) + } else { + s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max) + } + } + return s +} + +// isEmpty 判断值是否为空 +func isEmpty(val any) bool { + if val == nil { + return true + } + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.String, reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.IsNil() || v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface()) +} diff --git a/plugins/validate/validation.go b/plugins/validate/validation.go new file mode 100644 index 0000000..59d83db --- /dev/null +++ b/plugins/validate/validation.go @@ -0,0 +1,275 @@ +package validate + +import ( + "context" + "fmt" + "git.nobla.cn/golang/rest" + "git.nobla.cn/golang/rest/types" + validator "github.com/go-playground/validator/v10" + "gorm.io/gorm" + "gorm.io/gorm/schema" + "reflect" + "strconv" + "strings" +) + +type Validate struct { + validator *validator.Validate +} + +func (validate *Validate) Name() string { + return "gorm:validate" +} + +func (validate *Validate) telephoneValidate(ctx context.Context, fl validator.FieldLevel) bool { + val := fmt.Sprint(fl.Field().Interface()) + return telephoneRegex.MatchString(val) +} + +func (validate *Validate) uniqueValidate(ctx context.Context, fl validator.FieldLevel) bool { + var ( + scope *validateScope + ok bool + count int64 + field *schema.Field + primaryKeyValue reflect.Value + ) + val := fl.Field().Interface() + if scope, ok = ctx.Value(scopeCtxKey).(*validateScope); !ok { + return true + } + if len(scope.DB.Statement.Schema.PrimaryFields) > 0 { + field = scope.DB.Statement.Schema.PrimaryFields[0] + primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model)) + for _, n := range field.BindNames { + primaryKeyValue = primaryKeyValue.FieldByName(n) + } + } + sess := scope.DB.Session(&gorm.Session{NewDB: true}) + if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil { + //多域校验 + if scope.MultiDomain && scope.Domain != "" { + sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ? AND domain = ?", val, primaryKeyValue.Interface(), scope.Domain).Count(&count) + } else { + sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count) + } + } else { + if scope.MultiDomain && scope.Domain != "" { + sess.Model(scope.Model).Where(scope.Column+"=? AND domain = ?", val, scope.Domain).Count(&count) + } else { + sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count) + } + } + if count > 0 { + return false + } + return true +} + +func (validate *Validate) grantRules(scm *types.Schema, scenario string, rule types.Rule) []*validateRule { + rules := make([]*validateRule, 0, 5) + if rule.Min != 0 { + rules = append(rules, newRule("min", strconv.Itoa(rule.Min))) + } + if rule.Max != 0 { + rules = append(rules, newRule("max", strconv.Itoa(rule.Max))) + } + //主键不做唯一判断 + if rule.Unique && !scm.Attribute.PrimaryKey { + rules = append(rules, newRule("db_unique")) + } + if rule.Type != "" { + rules = append(rules, newRule(rule.Type)) + } + if rule.Required != nil && len(rule.Required) > 0 { + for _, v := range rule.Required { + if v == scenario { + rules = append(rules, newRule("required")) + break + } + } + } + return rules +} + +func (validate *Validate) buildRules(rs []*validateRule) string { + var sb strings.Builder + for _, r := range rs { + if !r.Valid { + continue + } + if sb.Len() > 0 { + sb.WriteString(",") + } + if r.Value == "" { + sb.WriteString(r.Rule) + } else { + sb.WriteString(r.Rule + "=" + r.Value) + } + } + return sb.String() +} + +func (validate *Validate) findRule(name string, rules []*validateRule) *validateRule { + for _, r := range rules { + if r.Rule == name { + return r + } + } + return nil +} + +func (validate *Validate) Initialize(db *gorm.DB) (err error) { + validate.validator = validator.New() + if err = db.Callback().Create().Before("gorm:before_create").Register("model_validate", validate.Validate); err != nil { + return + } + if err = db.Callback().Create().Before("gorm:before_update").Register("model_validate", validate.Validate); err != nil { + return + } + if err = validate.validator.RegisterValidationCtx("telephone", validate.telephoneValidate); err != nil { + return + } + if err = validate.validator.RegisterValidationCtx("db_unique", validate.uniqueValidate); err != nil { + return + } + return +} + +func (validate *Validate) inArray(v any, vs []any) bool { + sv := fmt.Sprint(v) + for _, s := range vs { + if fmt.Sprint(s) == sv { + return true + } + } + return false +} + +// isVisible 判断字段是否需要显示 +func (validate *Validate) isVisible(stmt *gorm.Statement, scm *types.Schema) bool { + if len(scm.Attribute.Visible) <= 0 { + return true + } + for _, row := range scm.Attribute.Visible { + if len(row.Values) == 0 { + continue + } + targetField := stmt.Schema.LookUpField(row.Column) + if targetField == nil { + continue + } + targetValue := stmt.ReflectValue.FieldByName(targetField.Name) + if !targetValue.IsValid() { + return false + } + if !validate.inArray(targetValue.Interface(), row.Values) { + return false + } + } + return true +} + +// Validate 校验字段 +func (validate *Validate) Validate(db *gorm.DB) { + var ( + ok bool + err error + rules []*validateRule + stmt *gorm.Statement + skipValidate bool + multiDomain bool + value reflect.Value + runtimeScope *types.RuntimeScope + ) + if result, ok := db.Get(SkipValidations); ok && result.(bool) { + return + } + stmt = db.Statement + if stmt.Model == nil { + return + } + if db.Statement.Context == nil { + return + } + if val := db.Statement.Context.Value(rest.RuntimeScopeKey); val != nil { + if runtimeScope, ok = val.(*types.RuntimeScope); !ok { + return + } + } else { + return + } + if runtimeScope.Schemas == nil { + return + } + if stmt.Schema.LookUpField("domain") != nil { + multiDomain = true + } + for _, row := range runtimeScope.Schemas { + //如果字段隐藏,那么就不进行校验 + if !validate.isVisible(stmt, row) { + continue + } + if rules = validate.grantRules(row, runtimeScope.Scenario, row.Rule); len(rules) <= 0 { + continue + } + field := stmt.Schema.LookUpField(row.Column) + if field == nil { + continue + } + value = stmt.ReflectValue.FieldByName(field.Name) + if !value.IsValid() { + continue + } + skipValidate = false + if r := validate.findRule("required", rules); r != nil { + if value.Interface() != nil { + vType := reflect.ValueOf(value.Interface()) + switch vType.Kind() { + case reflect.Bool: + skipValidate = true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + skipValidate = true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + skipValidate = true + case reflect.Float32, reflect.Float64: + skipValidate = true + default: + skipValidate = false + } + } + if skipValidate { + r.Valid = false + } + } else { + if isEmpty(value.Interface()) { + continue + } + } + ctx := context.WithValue(db.Statement.Context, scopeCtxKey, &validateScope{ + DB: db, + Column: row.Column, + Model: stmt.Model, + Domain: runtimeScope.Domain, + MultiDomain: multiDomain, + }) + if err = validate.validator.VarCtx(ctx, value.Interface(), validate.buildRules(rules)); err != nil { + if errs, ok := err.(validator.ValidationErrors); ok { + for _, e := range errs { + _ = db.AddError(&StructError{ + Tag: e.Tag(), + Column: row.Column, + Message: formatError(row.Rule, row, e.Tag()), + }) + } + } else { + _ = db.AddError(err) + } + break + } + } +} + +func New() *Validate { + return &Validate{} +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..b82267d --- /dev/null +++ b/query.go @@ -0,0 +1,405 @@ +package rest + +import ( + "context" + "fmt" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "reflect" + "strconv" + "strings" +) + +type ( + Query struct { + db *gorm.DB + condition string + fields []string + params []interface{} + table string + joins []join + orderBy []string + groupBy []string + modelValue any + limit int + offset int + } + + condition struct { + Field string `json:"field"` + Value interface{} `json:"value"` + Expr string `json:"expr"` + } + + join struct { + Table string + Direction string + Conditions []*condition + } +) + +func (cond *condition) WithExpr(v string) *condition { + cond.Expr = v + return cond +} + +func (query *Query) Model() any { + return query.modelValue +} + +func (query *Query) compile() (*gorm.DB, error) { + db := query.db + if query.condition != "" { + db = db.Where(query.condition, query.params...) + } + if query.fields != nil { + db = db.Select(strings.Join(query.fields, ",")) + } + if query.table != "" { + db = db.Table(query.table) + } + if query.joins != nil && len(query.joins) > 0 { + for _, joinEntity := range query.joins { + cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...) + db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...) + } + } + if query.orderBy != nil && len(query.orderBy) > 0 { + db = db.Order(strings.Join(query.orderBy, ",")) + } + if query.groupBy != nil && len(query.groupBy) > 0 { + db = db.Group(strings.Join(query.groupBy, ",")) + } + if query.offset > 0 { + db = db.Offset(query.offset) + } + if query.limit > 0 { + db = db.Limit(query.limit) + } + return db, nil +} + +func (query *Query) decodeValue(v any) string { + refVal := reflect.Indirect(reflect.ValueOf(v)) + switch refVal.Kind() { + case reflect.Bool: + if refVal.Bool() { + return "1" + } else { + return "0" + } + case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64: + return strconv.FormatInt(refVal.Int(), 10) + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(refVal.Float(), 'f', -1, 64) + case reflect.String: + return "'" + refVal.String() + "'" + default: + return fmt.Sprint(v) + } +} + +func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) { + var ( + sb strings.Builder + ) + params = make([]interface{}, 0) + for _, cond := range conditions { + if filter { + if isEmpty(cond.Value) { + continue + } + } + if cond.Expr == "" { + cond.Expr = "=" + } + switch strings.ToUpper(cond.Expr) { + case "=", "<>", ">", "<", ">=", "<=", "!=": + if sb.Len() > 0 { + sb.WriteString(" " + operator + " ") + } + if cond.Expr == "=" && cond.Value == nil { + sb.WriteString("`" + cond.Field + "` IS NULL") + } else { + sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?") + params = append(params, cond.Value) + } + case "LIKE": + if sb.Len() > 0 { + sb.WriteString(" " + operator + " ") + } + cond.Value = fmt.Sprintf("%%%s%%", cond.Value) + sb.WriteString("`" + cond.Field + "` LIKE ?") + params = append(params, cond.Value) + case "IN": + if sb.Len() > 0 { + sb.WriteString(" " + operator + " ") + } + refVal := reflect.Indirect(reflect.ValueOf(cond.Value)) + switch refVal.Kind() { + case reflect.Slice, reflect.Array: + ss := make([]string, refVal.Len()) + for i := 0; i < refVal.Len(); i++ { + ss[i] = query.decodeValue(refVal.Index(i)) + } + sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")") + case reflect.String: + sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")") + default: + } + case "BETWEEN": + refVal := reflect.ValueOf(cond.Value) + if refVal.Kind() == reflect.Slice && refVal.Len() == 2 { + sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?") + params = append(params, refVal.Index(0), refVal.Index(1)) + } + } + } + str = sb.String() + return +} + +func (query *Query) Select(fields ...string) *Query { + if query.fields == nil { + query.fields = fields + } else { + query.fields = append(query.fields, fields...) + } + return query +} + +func (query *Query) From(table string) *Query { + query.table = table + return query +} + +func (query *Query) LeftJoin(table string, conditions ...*condition) *Query { + query.joins = append(query.joins, join{ + Table: table, + Direction: "LEFT", + Conditions: conditions, + }) + return query +} + +func (query *Query) RightJoin(table string, conditions ...*condition) *Query { + query.joins = append(query.joins, join{ + Table: table, + Direction: "RIGHT", + Conditions: conditions, + }) + return query +} + +func (query *Query) InnerJoin(table string, conditions ...*condition) *Query { + query.joins = append(query.joins, join{ + Table: table, + Direction: "INNER", + Conditions: conditions, + }) + return query +} + +func (query *Query) AddCondition(expr, column string, val any) { + if expr == "" { + expr = "=" + } + query.AndWhere(newConditionWithOperator(expr, column, val)) +} + +func (query *Query) AndWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("AND", false, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND (" + cs + ")" + } + return query +} + +func (query *Query) AndFilterWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("AND", true, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND " + cs + } + return query +} + +func (query *Query) OrWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("OR", false, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND (" + cs + ")" + } + return query +} + +func (query *Query) OrFilterWhere(conditions ...*condition) *Query { + length := len(conditions) + if length == 0 { + return query + } + cs, ps := query.buildConditions("OR", true, conditions...) + if cs == "" { + return query + } + query.params = append(query.params, ps...) + if query.condition == "" { + query.condition = cs + } else { + query.condition += " AND (" + cs + ")" + } + return query +} + +func (query *Query) GroupBy(cols ...string) *Query { + query.groupBy = append(query.groupBy, cols...) + return query +} + +func (query *Query) OrderBy(col, direction string) *Query { + direction = strings.ToUpper(direction) + if direction == "" || !(direction == "ASC" || direction == "DESC") { + direction = "ASC" + } + query.orderBy = append(query.orderBy, col+" "+direction) + return query +} + +func (query *Query) Offset(i int) *Query { + query.offset = i + return query +} + +func (query *Query) Limit(i int) *Query { + query.limit = i + return query +} + +func (query *Query) ResetSelect() *Query { + query.fields = make([]string, 0) + return query +} + +func (query *Query) Count(v interface{}) (i int64) { + var ( + db *gorm.DB + err error + ) + if db, err = query.compile(); err != nil { + return + } else { + if v != nil { + refVal := reflect.ValueOf(v) + switch refVal.Kind() { + case reflect.String: + if query.table == "" { + err = db.Table(refVal.String()).Count(&i).Error + } else { + err = db.Table(query.table).Count(&i).Error + } + default: + //如果是报表的模型,这的话手动构建一条SQL语句 + if reporter, ok := v.(types.Reporter); ok { + sqlRes := &sqlCountResponse{} + childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true) + db.WithContext(childCtx).Model(reporter).First(sqlRes) + i = sqlRes.Count + } else { + err = db.Model(v).Count(&i).Error + } + } + } + } + return +} + +func (query *Query) One(v interface{}) (err error) { + var ( + db *gorm.DB + ) + if db, err = query.compile(); err != nil { + return + } else { + err = db.First(v).Error + } + return +} + +func (query *Query) All(v interface{}) (err error) { + var ( + db *gorm.DB + ) + if db, err = query.compile(); err != nil { + return + } else { + err = db.Find(v).Error + } + return +} + +func NewCondition(column, opera string, value any) *condition { + if opera == "" { + opera = "=" + } + return &condition{ + Field: column, + Value: value, + Expr: opera, + } +} + +func newCondition(field string, value interface{}) *condition { + return &condition{ + Field: field, + Value: value, + Expr: "=", + } +} + +func newConditionWithOperator(operator, field string, value interface{}) *condition { + cond := &condition{ + Field: field, + Value: value, + Expr: operator, + } + return cond +} + +func NewQuery(db *gorm.DB, model any) *Query { + return &Query{ + db: db, + modelValue: model, + params: make([]interface{}, 0), + orderBy: make([]string, 0), + groupBy: make([]string, 0), + joins: make([]join, 0), + } +} diff --git a/rest.go b/rest.go new file mode 100644 index 0000000..7ef9e92 --- /dev/null +++ b/rest.go @@ -0,0 +1,753 @@ +package rest + +import ( + "context" + "errors" + "fmt" + "git.nobla.cn/golang/kos/util/arrays" + "git.nobla.cn/golang/kos/util/reflection" + "git.nobla.cn/golang/rest/inflector" + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "net/http" + "reflect" + "strconv" + "strings" + "time" +) + +var ( + modelEntities []*Model + httpRouter types.HttpRouter + hookMgr *hookManager + timeKind = reflect.TypeOf(time.Time{}).Kind() + timePtrKind = reflect.TypeOf(&time.Time{}).Kind() + + matchEnums = []string{types.MatchExactly, types.MatchFuzzy} +) + +var ( + allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} +) + +func init() { + hookMgr = &hookManager{} + modelEntities = make([]*Model, 0) +} + +// cloneStmt 从指定的db克隆一个 Statement 对象 +func cloneStmt(db *gorm.DB) *gorm.Statement { + return &gorm.Statement{ + DB: db, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + } +} + +// dataTypeOf 推断数据的类型 +func dataTypeOf(field *schema.Field) string { + var dataType string + reflectType := field.FieldType + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + if dataType = field.Tag.Get("type"); dataType != "" { + return dataType + } + dataValue := reflect.Indirect(reflect.New(reflectType)) + switch dataValue.Kind() { + case reflect.Bool: + dataType = types.TypeBoolean + case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + dataType = types.TypeInteger + case reflect.Float32, reflect.Float64: + dataType = types.TypeFloat + default: + dataType = types.TypeString + } + return dataType +} + +// dataFormatOf 推断数据的格式 +func dataFormatOf(field *schema.Field) string { + var format string + format = field.Tag.Get("format") + if format != "" { + return format + } + //如果有枚举值,直接设置为下拉类型 + enum := field.Tag.Get("enum") + if enum != "" { + return types.FormatDropdown + } + reflectType := field.FieldType + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + //时间处理 + dataValue := reflect.Indirect(reflect.New(reflectType)) + if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" { + switch dataValue.Kind() { + case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return types.FormatTimestamp + default: + return types.FormatDatetime + } + } + if strings.Contains(strings.ToLower(field.Name), "pass") { + return types.FormatPassword + } + switch dataValue.Kind() { + case timeKind, timePtrKind: + format = types.FormatDatetime + case reflect.Bool: + format = types.FormatBoolean + case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + format = types.FormatInteger + case reflect.Float32, reflect.Float64: + format = types.FormatFloat + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + format = types.FormatDatetime + } + default: + if field.Size >= 1024 { + format = types.FormatText + } else { + format = types.FormatString + } + } + return format +} + +// fieldName 生成字段名称 +func fieldName(name string) string { + tokens := strings.Split(name, "_") + for i, s := range tokens { + tokens[i] = strings.Title(s) + } + return strings.Join(tokens, " ") +} + +// fieldNative 判断是否为原始字段 +func fieldNative(field *schema.Field) uint8 { + if _, ok := field.Tag.Lookup("virtual"); ok { + return 0 + } + return 1 +} + +// fieldRule 返回字段规则 +func fieldRule(field *schema.Field) types.Rule { + r := types.Rule{ + Required: []string{}, + } + if field.GORMDataType == schema.String { + r.Max = field.Size + } + if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint { + r.Max = field.Scale + } + + rs := field.Tag.Get("rule") + if rs != "" { + ss := strings.Split(rs, ";") + for _, s := range ss { + vs := strings.SplitN(s, ":", 2) + ls := len(vs) + if ls == 0 { + continue + } + switch vs[0] { + case "required", "require": + if ls > 1 { + bs := strings.Split(vs[1], ",") + for _, i := range bs { + if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) { + r.Required = append(r.Required, i) + } + } + } else { + r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate} + } + case "unique": + r.Unique = true + case "regexp": + if ls > 1 { + r.Regular = vs[1] + } + } + } + } + if field.PrimaryKey { + r.Unique = true + } + return r +} + +// fieldScenario 字段Scenarios +func fieldScenario(index int, field *schema.Field) types.Scenarios { + var ss types.Scenarios + if v, ok := field.Tag.Lookup("scenarios"); ok { + v = strings.TrimSpace(v) + if v != "" { + ss = strings.Split(v, ";") + } + } else { + if field.PrimaryKey { + ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport} + } else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" { + ss = []string{types.ScenarioList} + } else if field.Name == "DeletedAt" || field.Name == "Namespace" { + //不添加任何显示场景 + ss = []string{} + } else { + if index < 10 { + //高级字段只配置一些简单的场景 + ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} + } else { + //高级字段只配置一些简单的场景 + ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport} + } + } + } + return ss +} + +// fieldPosition 字段的排序位置 +func fieldPosition(field *schema.Field, i int) int { + s := field.Tag.Get("position") + n, _ := strconv.Atoi(s) + if n > 0 { + return n + } + return i + 100 +} + +// fieldAttribute 字段属性 +func fieldAttribute(field *schema.Field) types.Attribute { + attr := types.Attribute{ + Match: types.MatchFuzzy, + PrimaryKey: field.PrimaryKey, + DefaultValue: field.DefaultValue, + Readonly: []string{}, + Disable: []string{}, + Visible: make([]types.VisibleCondition, 0), + Values: make([]types.EnumValue, 0), + Live: types.LiveValue{}, + } + if field.Name == "CreatedAt" || field.Name == "UpdatedAt" { + attr.EndOfNow = true + } + //赋值属性 + props := field.Tag.Get("props") + if props != "" { + vs := strings.Split(props, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + sv := strings.TrimSpace(kv[1]) + switch strings.ToLower(strings.TrimSpace(kv[0])) { + case "icon": + attr.Icon = sv + case "match": + if arrays.Exists(sv, matchEnums) { + attr.Match = sv + } + case "endofnow", "end_of_now": + if ok, _ := strconv.ParseBool(sv); ok { + attr.EndOfNow = true + } + case "invisible": + if ok, _ := strconv.ParseBool(sv); ok { + attr.Invisible = true + } + case "suffix": + attr.Suffix = sv + case "tag": + attr.Tag = sv + case "tooltip": + attr.Tooltip = sv + case "uploadurl", "uploaduri", "upload_url", "upload_uri": + attr.UploadUrl = sv + case "description": + attr.Description = sv + case "readonly": + bs := strings.Split(sv, ",") + for _, i := range bs { + if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) { + attr.Readonly = append(attr.Readonly, i) + } + } + } + } + } + //live的赋值 + live := field.Tag.Get("live") + if live != "" { + attr.Live.Enable = true + vs := strings.Split(live, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + switch kv[0] { + case "method": + attr.Live.Method = kv[1] + case "type": + if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader { + attr.Live.Type = kv[1] + } else { + attr.Live.Type = types.LiveTypeDropdown + } + case "url", "uri": + attr.Live.Url = kv[1] + case "columns": + attr.Live.Columns = strings.Split(kv[1], ",") + } + } + } + + dropdown := field.Tag.Get("dropdown") + if dropdown != "" { + attr.DropdownOption = &types.DropdownOption{} + vs := strings.Split(dropdown, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) == 0 { + continue + } + switch kv[0] { + case "created": + attr.DropdownOption.Created = true + case "filterable": + attr.DropdownOption.Filterable = true + case "autocomplete": + attr.DropdownOption.Autocomplete = true + case "default_first": + attr.DropdownOption.DefaultFirst = true + } + } + } + + //显示条件 + conditions := field.Tag.Get("condition") + if conditions != "" { + vs := strings.Split(conditions, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + cond := types.VisibleCondition{ + Column: kv[0], + Values: make([]any, 0), + } + vv := strings.Split(kv[1], ",") + for _, x := range vv { + x = strings.TrimSpace(x) + if x == "" { + continue + } + cond.Values = append(cond.Values, x) + } + attr.Visible = append(attr.Visible, cond) + } + } + //赋值枚举值 + enumns := field.Tag.Get("enum") + if enumns != "" { + vs := strings.Split(enumns, ";") + for _, str := range vs { + kv := strings.SplitN(str, ":", 2) + if len(kv) != 2 { + continue + } + fv := types.EnumValue{Value: kv[0]} + //颜色分隔符 + if pos := strings.IndexByte(kv[1], '#'); pos > -1 { + fv.Label = kv[1][:pos] + fv.Color = kv[1][pos:] + } else { + fv.Label = kv[1] + } + attr.Values = append(attr.Values, fv) + } + } + if !field.Creatable { + attr.Disable = append(attr.Disable, types.ScenarioCreate) + } + if !field.Updatable { + attr.Disable = append(attr.Disable, types.ScenarioUpdate) + } + attr.Tooltip = field.Comment + return attr +} + +// autoMigrate 自动合并字段 +func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) { + var ( + pos int + columnName string + columnIsExists bool + columnLabel string + schemas []*types.Schema + models []*types.Schema + stmt *gorm.Statement + ) + + stmt = cloneStmt(db) + if err = stmt.Parse(model); err != nil { + return + } + if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil { + return + } + if len(schemas) > 0 { + pos = len(schemas) + } + models = make([]*types.Schema, 0) + for index, field := range stmt.Schema.Fields { + columnName = field.DBName + if columnName == "-" { + continue + } + if columnName == "" { + columnName = field.Name + } + columnIsExists = false + for _, sm := range schemas { + if sm.Column == columnName { + columnIsExists = true + break + } + } + if columnIsExists { + continue + } + columnLabel = field.Tag.Get("comment") + if columnLabel == "" { + columnLabel = fieldName(field.DBName) + } + isPrimaryKey := uint8(0) + if field.PrimaryKey { + isPrimaryKey = 1 + } + schemaModel := &types.Schema{ + Domain: defaultDomain, + ModuleName: module, + TableName: stmt.Table, + Enable: 1, + Column: columnName, + Label: columnLabel, + Type: strings.ToLower(dataTypeOf(field)), + Format: strings.ToLower(dataFormatOf(field)), + Native: fieldNative(field), + IsPrimaryKey: isPrimaryKey, + Rule: fieldRule(field), + Scenarios: fieldScenario(index, field), + Attribute: fieldAttribute(field), + Position: fieldPosition(field, pos), + } + //如果启用了在线调取接口功能,那么设置一下字段的format格式 + if schemaModel.Attribute.Live.Enable { + if schemaModel.Attribute.Live.Type != "" { + schemaModel.Format = schemaModel.Attribute.Live.Type + } + } + models = append(models, schemaModel) + pos++ + } + if len(models) > 0 { + err = db.Create(models).Error + } + naming = stmt.Table + return +} + +// SetHttpRouter 设置HTTP路由 +func SetHttpRouter(router types.HttpRouter) { + httpRouter = router +} + +// AutoMigrate 自动合并表的schema +func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) { + var ( + opts *Options + table string + router types.HttpRouter + ) + opts = &Options{} + for _, cb := range cbs { + cb(opts) + } + if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil { + return + } + //路由模块处理 + modelValue := newModel(model, db, types.Naming{ + Pluralize: inflector.Pluralize(table), + Singular: inflector.Singularize(table), + ModuleName: opts.moduleName, + TableName: table, + }) + modelValue.hookMgr = hookMgr + modelValue.schemaLookup = VisibleSchemas + if opts.router != nil { + router = opts.router + } + if router == nil && httpRouter != nil { + router = httpRouter + } + if opts.urlPrefix != "" { + modelValue.urlPrefix = opts.urlPrefix + } + //路由绑定操作 + if router != nil { + if modelValue.hasScenario(types.ScenarioList) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search) + } + if modelValue.hasScenario(types.ScenarioCreate) { + router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create) + } + if modelValue.hasScenario(types.ScenarioUpdate) { + router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update) + } + if modelValue.hasScenario(types.ScenarioDelete) { + router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete) + } + if modelValue.hasScenario(types.ScenarioView) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View) + } + if modelValue.hasScenario(types.ScenarioExport) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export) + } + if modelValue.hasScenario(types.ScenarioImport) { + router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import) + router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import) + } + } + if opts.writer != nil { + modelValue.response = opts.writer + } + if opts.formatter != nil { + modelValue.formatter = opts.formatter + } + modelValue.disableDomain = opts.disableDomain + modelEntities = append(modelEntities, modelValue) + return +} + +// CloneSchemas 克隆schemas +func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) { + var ( + values []*types.Schema + schemas []*types.Schema + models []*types.Schema + ) + tx := db.WithContext(ctx) + if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil { + return fmt.Errorf("schema not found") + } + tx.Where("domain=?", domain).Find(&schemas) + hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool { + for _, row := range values { + if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column { + return true + } + } + return false + } + models = make([]*types.Schema, 0) + for _, row := range values { + if !hasSchemaFunc(schemas, row) { + row.Id = 0 + row.CreatedAt = 0 + row.UpdatedAt = 0 + row.Domain = domain + models = append(models, row) + } + } + if len(models) > 0 { + err = tx.Save(models).Error + } + return +} + +// GetSchemas 获取schemas +func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) { + var ( + err error + values []*types.Schema + tx *gorm.DB + ) + values = make([]*types.Schema, 0) + if domain == "" { + domain = defaultDomain + } + if moduleName == "" || tableName == "" { + return nil, gorm.ErrInvalidField + } + if ctx != nil { + tx = db.WithContext(ctx) + } else { + tx = db + } + err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = nil + } + return values, err +} + +// VisibleSchemas 获取某个场景下面的schema +func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) { + schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName) + if err != nil { + return nil, err + } + result := make([]*types.Schema, 0, len(schemas)) + for _, row := range schemas { + if row.IsPrimaryKey == 1 { + result = append(result, row) + continue + } + if row.Scenarios.Has(scenario) { + result = append(result, row) + } + } + return result, nil +} + +// ModelTypes 查询指定模型的类型 +func ModelTypes(ctx context.Context, db *gorm.DB, model any, domainName, labelColumn, valueColumn string) (values []*types.TypeValue) { + tx := db.WithContext(ctx) + result := make([]map[string]any, 0, 10) + tx.Model(model).Select(labelColumn, valueColumn).Where("domain=?", domainName).Scan(&result) + values = make([]*types.TypeValue, 0, len(result)) + for _, pairs := range result { + feed := &types.TypeValue{} + for k, v := range pairs { + if k == labelColumn { + feed.Label = v + } + if k == valueColumn { + feed.Value = v + } + } + values = append(values, feed) + } + return values +} + +// GetFieldValue 获取模型某个字段的值 +func GetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string) interface{} { + var ( + idx = -1 + ) + refVal := reflect.Indirect(refValue) + for i, field := range stmt.Schema.Fields { + if field.DBName == column || field.Name == column { + idx = i + break + } + } + if idx > -1 { + return refVal.Field(idx).Interface() + } + return nil +} + +// SetFieldValue 设置模型某个字段的值 +func SetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) { + var ( + idx = -1 + ) + refVal := reflect.Indirect(refValue) + for i, field := range stmt.Schema.Fields { + if field.DBName == column || field.Name == column { + idx = i + break + } + } + if idx > -1 { + refVal.Field(idx).Set(reflect.ValueOf(value)) + } +} + +func SafeSetFileValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) { + var ( + idx = -1 + ) + refVal := reflect.Indirect(refValue) + for i, field := range stmt.Schema.Fields { + if field.DBName == column || field.Name == column { + idx = i + break + } + } + if idx > -1 { + _ = reflection.Assign(refVal.Field(idx), value) + } +} + +// GetModels 获取所有注册了的模块 +func GetModels() []*Model { + return modelEntities +} + +// OnBeforeCreate 创建前的回调 +func OnBeforeCreate(cb BeforeCreate) { + hookMgr.BeforeCreate(cb) +} + +// OnAfterCreate 创建后的回调 +func OnAfterCreate(cb AfterCreate) { + hookMgr.AfterCreate(cb) +} + +// OnBeforeUpdate 更新前的回调 +func OnBeforeUpdate(cb BeforeUpdate) { + hookMgr.BeforeUpdate(cb) +} + +// OnAfterUpdate 更新后的回调 +func OnAfterUpdate(cb AfterUpdate) { + hookMgr.AfterUpdate(cb) +} + +// OnBeforeSave 保存前的回调 +func OnBeforeSave(cb BeforeSave) { + hookMgr.BeforeSave(cb) +} + +// OnAfterSave 保存后的回调 +func OnAfterSave(cb AfterSave) { + hookMgr.AfterSave(cb) +} + +// OnBeforeDelete 删除前的回调 +func OnBeforeDelete(cb BeforeDelete) { + hookMgr.BeforeDelete(cb) +} + +// OnAfterDelete 删除后的回调 +func OnAfterDelete(cb AfterDelete) { + hookMgr.AfterDelete(cb) +} + +// OnAfterExport 导出后的回调 +func OnAfterExport(cb AfterExport) { + hookMgr.AfterExport(cb) +} + +// OnAfterImport 导入后的回调 +func OnAfterImport(cb AfterImport) { + hookMgr.AfterImport(cb) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..903f188 --- /dev/null +++ b/types.go @@ -0,0 +1,82 @@ +package rest + +import ( + "context" + "encoding/json" + "gorm.io/gorm" + "net/http" +) + +const ( + defaultPageSize = 15 + defaultDomain = "localhost" +) + +type ( + httpWriter struct { + } + + multiValue struct { + Text any `json:"label"` + Value any `json:"value"` + } + + stdResponse struct { + Code int `json:"code"` + Reason string `json:"reason,omitempty"` + Data any `json:"data,omitempty"` + } + + tableNamer interface { + HttpTableName(req *http.Request) string + } + + //创建后的回调,这个回调不在事物内 + afterCreated interface { + AfterCreated(ctx context.Context, tx *gorm.DB) + } + //更新后的回调,这个回调不在事物内 + afterUpdated interface { + AfterUpdated(ctx context.Context, tx *gorm.DB) + } + //保存后的回调,这个回调不在事物内 + afterSaved interface { + AfterSaved(ctx context.Context, tx *gorm.DB) + } + + sqlCountResponse struct { + Count int64 `json:"count"` + } +) + +func (h *httpWriter) Success(w http.ResponseWriter, data any) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(stdResponse{ + Code: 0, + Data: data, + }); err != nil { + w.Write([]byte(err.Error())) + } +} + +func (h *httpWriter) Failure(w http.ResponseWriter, reason int, message string, data any) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(stdResponse{ + Code: reason, + Reason: message, + Data: data, + }); err != nil { + w.Write([]byte(err.Error())) + } +} + +func defaultValueLookup(field string, w http.ResponseWriter, r *http.Request) string { + var ( + domainName string + ) + domainName = r.Header.Get(field) + if domainName == "" { + return defaultDomain + } + return domainName +} diff --git a/types/attribute.go b/types/attribute.go new file mode 100644 index 0000000..c7809b5 --- /dev/null +++ b/types/attribute.go @@ -0,0 +1,48 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" +) + +type ( + Attribute struct { + Match string `json:"match"` //匹配模式 + Tag string `json:"tag,omitempty"` //字段标签 + PrimaryKey bool `json:"primary_key"` //是否为主键 + DefaultValue string `json:"default_value"` //默认值 + Readonly []string `json:"readonly"` //只读场景 + Disable []string `json:"disable"` //禁用场景 + Visible []VisibleCondition `json:"visible"` //可见条 + Invisible bool `json:"invisible"` //不可见的字段,表示在UI界面看不到这个字段,但是这个字段需要 + EndOfNow bool `json:"end_of_now"` //最大时间为当前时间 + Values []EnumValue `json:"values"` //值 + Live LiveValue `json:"live"` //延时加载配置 + UploadUrl string `json:"upload_url,omitempty"` //上传地址 + Icon string `json:"icon,omitempty"` //显示图标 + Sort bool `json:"sort"` //是否允许排序 + Suffix string `json:"suffix,omitempty"` //追加内容 + Tooltip string `json:"tooltip,omitempty"` //字段提示信息 + Description string `json:"description,omitempty"` //字段说明信息 + DropdownOption *DropdownOption `json:"dropdown,omitempty"` //下拉选项 + } +) + +// Scan implements the Scanner interface. +func (n *Attribute) Scan(value interface{}) error { + if value == nil { + return nil + } + switch s := value.(type) { + case string: + return json.Unmarshal([]byte(s), n) + case []byte: + return json.Unmarshal(s, n) + } + return ErrDbTypeUnsupported +} + +// Value implements the driver Valuer interface. +func (n Attribute) Value() (driver.Value, error) { + return json.Marshal(n) +} diff --git a/types/error.go b/types/error.go new file mode 100644 index 0000000..94a9f8a --- /dev/null +++ b/types/error.go @@ -0,0 +1,7 @@ +package types + +import "errors" + +var ( + ErrDbTypeUnsupported = errors.New("database type unsupported") +) diff --git a/types/rule.go b/types/rule.go new file mode 100644 index 0000000..db253b8 --- /dev/null +++ b/types/rule.go @@ -0,0 +1,39 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" +) + +type Rule struct { + Min int `json:"min"` + Max int `json:"max"` + Type string `json:"type"` + Unique bool `json:"unique"` + Required []string `json:"required"` + Regular string `json:"regular"` +} + +// Value implements the driver Valuer interface. +func (n Scenarios) Value() (driver.Value, error) { + return json.Marshal(n) +} + +// Scan implements the Scanner interface. +func (n *Rule) Scan(value interface{}) error { + if value == nil { + return nil + } + switch s := value.(type) { + case string: + return json.Unmarshal([]byte(s), n) + case []byte: + return json.Unmarshal(s, n) + } + return ErrDbTypeUnsupported +} + +// Value implements the driver Valuer interface. +func (n Rule) Value() (driver.Value, error) { + return json.Marshal(n) +} diff --git a/types/schema.go b/types/schema.go new file mode 100644 index 0000000..2866135 --- /dev/null +++ b/types/schema.go @@ -0,0 +1,81 @@ +package types + +import ( + "encoding/json" +) + +type ( + LiveValue struct { + Enable bool `json:"enable"` + Type string `json:"type"` + Url string `json:"url"` + Method string `json:"method"` + Body string `json:"body"` + ContentType string `json:"content_type"` + Columns []string `json:"columns"` + } + + EnumValue struct { + Label string `json:"label"` + Value string `json:"value"` + Color string `json:"color"` + } + + VisibleCondition struct { + Column string `json:"column"` + Values []interface{} `json:"values"` + } + + DropdownOption struct { + Created bool `json:"created,omitempty"` + Filterable bool `json:"filterable,omitempty"` + Autocomplete bool `json:"autocomplete,omitempty"` + DefaultFirst bool `json:"default_first,omitempty"` + } + + Scenarios []string + + Schema struct { + Id uint64 `json:"id" gorm:"primary_key"` + CreatedAt int64 `json:"created_at" gorm:"autoCreateTime"` //创建时间 + UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime"` //更新时间 + Domain string `json:"domain" gorm:"column:domain;type:char(60);index"` //域 + ModuleName string `json:"module_name" gorm:"column:module_name;type:varchar(60);index"` //模块名称 + TableName string `json:"table_name" gorm:"column:table_name;type:varchar(120);index"` //表名称 + Enable uint8 `json:"enable" gorm:"column:enable;type:int(1)"` //是否启用 + Column string `json:"column" gorm:"type:varchar(120)"` //字段名称 + Label string `json:"label" gorm:"type:varchar(120)"` //显示名称 + Type string `json:"type" gorm:"type:varchar(120)"` //字段类型 + Format string `json:"format" gorm:"type:varchar(120)"` //字段格式 + Native uint8 `json:"native" gorm:"type:int(1)"` //是否为原生字段 + IsPrimaryKey uint8 `json:"is_primary_key" gorm:"type:int(1)"` //是否为主键 + Expression string `json:"expression" gorm:"type:varchar(526)"` //计算规则 + Scenarios Scenarios `json:"scenarios" gorm:"type:varchar(120)"` //场景 + Rule Rule `json:"rule" gorm:"type:varchar(2048)"` //字段规则 + Attribute Attribute `json:"attribute" gorm:"type:varchar(4096)"` //字段属性 + Position int `json:"position"` //字段排序位置 + } +) + +func (n Scenarios) Has(str string) bool { + for _, v := range n { + if v == str { + return true + } + } + return false +} + +// Scan implements the Scanner interface. +func (n *Scenarios) Scan(value interface{}) error { + if value == nil { + return nil + } + switch s := value.(type) { + case string: + return json.Unmarshal([]byte(s), n) + case []byte: + return json.Unmarshal(s, n) + } + return ErrDbTypeUnsupported +} diff --git a/types/types.go b/types/types.go new file mode 100644 index 0000000..ff2d9d2 --- /dev/null +++ b/types/types.go @@ -0,0 +1,231 @@ +package types + +import ( + "context" + "gorm.io/gorm" + "net/http" + "time" +) + +const ( + SceneCreate = "create" + SceneUpdate = "update" + SceneDelete = "delete" +) + +const ( + ScenarioCreate = "create" + ScenarioUpdate = "update" + ScenarioDelete = "delete" + ScenarioSearch = "search" + ScenarioExport = "export" + ScenarioImport = "import" + ScenarioList = "list" + ScenarioView = "view" + ScenarioMapping = "mapping" +) + +const ( + MatchExactly = "exactly" //精确匹配 + MatchFuzzy = "fuzzy" //模糊匹配 +) + +const ( + LiveTypeDropdown = "dropdown" + LiveTypeCascader = "cascader" +) + +const ( + TypeInteger = "integer" + TypeFloat = "float" + TypeBoolean = "boolean" + TypeString = "string" +) + +const ( + FormatInteger = "integer" + FormatFloat = "float" + FormatBoolean = "boolean" + FormatString = "string" + FormatText = "text" + FormatDropdown = "dropdown" + FormatDatetime = "datetime" + FormatDate = "date" + FormatTime = "time" + FormatTimestamp = "timestamp" + FormatPassword = "password" +) + +const ( + OperatorEqual = "eq" + OperatorGreaterThan = "gt" + OperatorGreaterEqual = "ge" + OperatorLessThan = "lt" + OperatorLessEqual = "le" + OperatorLike = "like" + OperatorBetween = "between" +) + +const ( + ErrImportFileNotExists = 1004 + ErrImportFileUnavailable = 1005 + ErrImportColumnNotMatch = 1006 +) + +const ( + defaultPageSize = 15 + defaultDomain = "localhost" +) + +const ( + RequestDenied = 8005 + RequestRecordNotFound = 8004 + RequestPayloadInvalid = 8006 + RequestCreateFailure = 8007 + RequestUpdateFailure = 8008 + RequestDeleteFailure = 8009 +) + +const ( + FormatRaw = "raw" + FormatBoth = "both" +) + +const ( + FieldDomain = "domain" +) + +type ( + // HttpWriter http 响应接口 + HttpWriter interface { + Success(w http.ResponseWriter, data any) + Failure(w http.ResponseWriter, reason int, message string, data any) + } + + // HttpRouter http 路由管理工具 + HttpRouter interface { + Handle(method string, uri string, handler http.HandlerFunc) + } + + // TypeValue 键值对数据 + TypeValue struct { + Label any `json:"label"` + Value any `json:"value"` + } + + // NestedValue 层级数据 + NestedValue struct { + Label any `json:"label"` + Value any `json:"value"` + Children []*NestedValue `json:"children,omitempty"` + } + + //ValueLookupFunc 查找域的函数 + ValueLookupFunc func(column string, w http.ResponseWriter, r *http.Request) string + + //SchemaLookupFunc 查找schema的回调函数 + SchemaLookupFunc func(ctx context.Context, db *gorm.DB, domain, module, table, scenario string) ([]*Schema, error) +) + +type ( + multiValue struct { + Text any `json:"label"` + Value any `json:"value"` + } + + Tabler interface { + TableName() string + } + + Reporter interface { + TableName() string + RealTable() string + QuoteColumn(ctx context.Context, column string) string + GroupBy(ctx context.Context) []string + } + + FormatModel interface { + Format(ctx context.Context, scene string, schemas []*Schema) + } + + SelectColumn struct { + Name string `json:"name"` + Expr string `json:"expr"` + Native bool `json:"native"` + Callback string `json:"callback"` + } + + DiffAttr struct { + Column string `json:"column"` + Label string `json:"label"` + OldValue any `json:"old_value"` + NewValue any `json:"new_value"` + } + + Condition struct { + Column string `json:"column"` + Expr string `json:"expr"` + Value any `json:"value,omitempty"` + Values []any `json:"values,omitempty"` + } + + Naming struct { + Pluralize string + Singular string + ModuleName string + TableName string + } + + RuntimeScope struct { + Domain string //域 + User string //用户 + ModuleName string //模块名称 + TableName string //表名称 + Scenario string //场景名称 + Schemas []*Schema //字段schema + Request *http.Request //HTTP请求结构 + PrimaryKeyValue any //主键 + } + + ImportResult struct { + Code int `json:"code"` + TotalCount int `json:"total_count"` + SuccessCount int `json:"success_count"` + UploadFile string `json:"upload_file"` + FailureFile string `json:"failure_file"` + Duration time.Duration `json:"duration"` + } + + sqlCountResponse struct { + Count int64 `json:"count"` + } +) + +type ( + ListResponse struct { + Page int `json:"page"` + PageSize int `json:"pagesize"` + TotalCount int64 `json:"totalCount"` + Data any `json:"data"` + } + + CreateResponse struct { + ID any `json:"id"` + Status string `json:"status"` + } + + UpdateResponse struct { + ID any `json:"id"` + Status string `json:"status"` + } + + DeleteResponse struct { + ID any `json:"id"` + Status string `json:"status"` + } + + ImportResponse struct { + UID string `json:"uid"` + Status string `json:"status"` + } +) diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..4a5fc30 --- /dev/null +++ b/utils.go @@ -0,0 +1,45 @@ +package rest + +import ( + "git.nobla.cn/golang/kos/util/arrays" + "reflect" + "strings" +) + +func hasToken(hack string, need string) bool { + if len(need) == 0 || len(hack) == 0 { + return false + } + char := []byte{',', ';'} + for _, c := range char { + if strings.IndexByte(need, c) > -1 { + return arrays.Exists(hack, strings.Split(need, string(c))) + } + } + return false +} + +func isEmpty(val any) bool { + if val == nil { + return true + } + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.String, reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.IsNil() || v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + default: + return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface()) + } +}