初始化仓库
This commit is contained in:
commit
61ffcec858
|
@ -0,0 +1,60 @@
|
|||
bin/
|
||||
|
||||
.svn/
|
||||
.godeps
|
||||
./build
|
||||
.cover/
|
||||
dist
|
||||
_site
|
||||
_posts
|
||||
*.dat
|
||||
.vscode
|
||||
vendor
|
||||
|
||||
# Go.gitignore
|
||||
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
storage
|
||||
.idea
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
*.local
|
||||
.DS_Store
|
||||
|
||||
profile
|
||||
|
||||
# vim stuff
|
||||
*.sw[op]
|
||||
|
||||
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
|
||||
node_modules
|
|
@ -0,0 +1,144 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"git.nobla.cn/golang/kos/util/arrays"
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition {
|
||||
for _, cond := range conditions {
|
||||
if cond.Column == schema.Column {
|
||||
return cond
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func BuildConditions(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
|
||||
var (
|
||||
ok bool
|
||||
skip bool
|
||||
formValue string
|
||||
activeQuery ActiveQuery
|
||||
)
|
||||
if activeQuery, ok = query.Model().(ActiveQuery); ok {
|
||||
if err = activeQuery.BeforeQuery(ctx, query); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if arrays.Exists(r.Method, []string{http.MethodPut, http.MethodPost}) {
|
||||
conditions := make([]*types.Condition, 0)
|
||||
if err = json.NewDecoder(r.Body).Decode(&conditions); err != nil {
|
||||
return
|
||||
}
|
||||
for _, row := range schemas {
|
||||
if row.Native == 0 {
|
||||
continue
|
||||
}
|
||||
cond := findCondition(row, conditions)
|
||||
if cond == nil {
|
||||
continue
|
||||
}
|
||||
switch row.Format {
|
||||
case types.FormatInteger, types.FormatFloat, types.FormatTimestamp, types.FormatDatetime, types.FormatDate, types.FormatTime:
|
||||
switch cond.Expr {
|
||||
case types.OperatorBetween:
|
||||
if len(cond.Values) == 2 {
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Values[0]).WithExpr(">="))
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Values[1]).WithExpr("<="))
|
||||
}
|
||||
case types.OperatorGreaterThan:
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">"))
|
||||
case types.OperatorGreaterEqual:
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">="))
|
||||
case types.OperatorLessThan:
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<"))
|
||||
case types.OperatorLessEqual:
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<="))
|
||||
default:
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Value))
|
||||
}
|
||||
default:
|
||||
switch cond.Expr {
|
||||
case types.OperatorLike:
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("LIKE"))
|
||||
default:
|
||||
query.AndFilterWhere(newCondition(row.Column, cond.Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
qs := r.URL.Query()
|
||||
for _, row := range schemas {
|
||||
skip = false
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
if row.Native == 0 {
|
||||
continue
|
||||
}
|
||||
formValue = qs.Get(row.Column)
|
||||
switch row.Format {
|
||||
case types.FormatString, types.FormatText:
|
||||
if row.Attribute.Match == types.MatchExactly {
|
||||
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||
} else {
|
||||
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
|
||||
}
|
||||
case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp:
|
||||
var sep string
|
||||
seps := []byte{',', '/'}
|
||||
for _, s := range seps {
|
||||
if strings.IndexByte(formValue, s) > -1 {
|
||||
sep = string(s)
|
||||
}
|
||||
}
|
||||
if ss := strings.Split(formValue, sep); len(ss) == 2 {
|
||||
query.AndFilterWhere(
|
||||
newCondition(row.Column, strings.TrimSpace(ss[0])).WithExpr(">="),
|
||||
newCondition(row.Column, strings.TrimSpace(ss[1])).WithExpr("<="),
|
||||
)
|
||||
} else {
|
||||
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||
}
|
||||
case types.FormatInteger, types.FormatFloat:
|
||||
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||
default:
|
||||
if row.Type == types.TypeString {
|
||||
if row.Attribute.Match == types.MatchExactly {
|
||||
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||
} else {
|
||||
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
|
||||
}
|
||||
} else {
|
||||
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
sortPar := r.FormValue("sort")
|
||||
if sortPar != "" {
|
||||
sorts := strings.Split(sortPar, ",")
|
||||
for _, s := range sorts {
|
||||
if s[0] == '-' {
|
||||
query.OrderBy(s[1:], "DESC")
|
||||
} else {
|
||||
if s[0] == '+' {
|
||||
query.OrderBy(s[1:], "ASC")
|
||||
} else {
|
||||
query.OrderBy(s, "ASC")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if activeQuery, ok = query.Model().(ActiveQuery); ok {
|
||||
if err = activeQuery.AfterQuery(ctx, query); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,240 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
"gorm.io/gorm"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultFormatter = NewFormatter()
|
||||
|
||||
DefaultNullDisplay = ""
|
||||
)
|
||||
|
||||
func init() {
|
||||
DefaultFormatter.Register("string", stringFormat)
|
||||
DefaultFormatter.Register("integer", integerFormat)
|
||||
DefaultFormatter.Register("decimal", decimalFormat)
|
||||
DefaultFormatter.Register("date", dateFormat)
|
||||
DefaultFormatter.Register("time", timeFormat)
|
||||
DefaultFormatter.Register("datetime", datetimeFormat)
|
||||
DefaultFormatter.Register("duration", durationFormat)
|
||||
DefaultFormatter.Register("dropdown", dropdownFormat)
|
||||
DefaultFormatter.Register("timestamp", datetimeFormat)
|
||||
DefaultFormatter.Register("percentage", percentageFormat)
|
||||
}
|
||||
|
||||
type FormatFunc func(ctx context.Context, value any, model any, scm *types.Schema) any
|
||||
|
||||
type Formatter struct {
|
||||
callbacks sync.Map
|
||||
}
|
||||
|
||||
func (formatter *Formatter) Register(f string, fun FormatFunc) {
|
||||
formatter.callbacks.Store(f, fun)
|
||||
}
|
||||
|
||||
func (formatter *Formatter) Format(ctx context.Context, format string, value any, model any, scm *types.Schema) any {
|
||||
v, ok := formatter.callbacks.Load(format)
|
||||
if ok {
|
||||
return v.(FormatFunc)(ctx, value, model, scm)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (formatter *Formatter) getModelValue(refValue reflect.Value, schema *types.Schema, stmt *gorm.Statement) any {
|
||||
if stmt.Schema == nil {
|
||||
return nil
|
||||
}
|
||||
field := stmt.Schema.LookUpField(schema.Column)
|
||||
if field == nil {
|
||||
return nil
|
||||
}
|
||||
return refValue.FieldByName(field.Name).Interface()
|
||||
}
|
||||
|
||||
func (formatter *Formatter) formatModel(ctx context.Context, refValue reflect.Value, schemas []*types.Schema, stmt *gorm.Statement, format string) any {
|
||||
values := make(map[string]any)
|
||||
multiValues := make(map[string]multiValue)
|
||||
modelValue := refValue.Interface()
|
||||
refValue = reflect.Indirect(refValue)
|
||||
for _, scm := range schemas {
|
||||
switch format {
|
||||
case types.FormatRaw:
|
||||
values[scm.Column] = formatter.getModelValue(refValue, scm, stmt)
|
||||
case types.FormatBoth:
|
||||
v := multiValue{
|
||||
Value: formatter.getModelValue(refValue, scm, stmt),
|
||||
}
|
||||
v.Text = formatter.Format(ctx, scm.Format, v.Value, modelValue, scm)
|
||||
multiValues[scm.Column] = v
|
||||
default:
|
||||
values[scm.Column] = formatter.Format(ctx, scm.Format, formatter.getModelValue(refValue, scm, stmt), modelValue, scm)
|
||||
}
|
||||
}
|
||||
if format == types.FormatBoth {
|
||||
return multiValues
|
||||
} else {
|
||||
return values
|
||||
}
|
||||
}
|
||||
|
||||
func (formatter *Formatter) formatModels(ctx context.Context, val any, schemas []*types.Schema, stmt *gorm.Statement, format string) any {
|
||||
refValue := reflect.Indirect(reflect.ValueOf(val))
|
||||
if refValue.Kind() != reflect.Slice {
|
||||
return []any{}
|
||||
}
|
||||
length := refValue.Len()
|
||||
values := make([]any, length)
|
||||
for i := 0; i < length; i++ {
|
||||
rowValue := refValue.Index(i)
|
||||
modelValue := rowValue.Interface()
|
||||
if formatModel, ok := modelValue.(types.FormatModel); ok {
|
||||
formatModel.Format(ctx, format, schemas)
|
||||
}
|
||||
values[i] = formatter.formatModel(ctx, rowValue, schemas, stmt, format)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func stringFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
return fmt.Sprint(value)
|
||||
}
|
||||
|
||||
func integerFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
var (
|
||||
n int
|
||||
)
|
||||
switch value.(type) {
|
||||
case float32, float64:
|
||||
n = int(reflect.ValueOf(value).Float())
|
||||
case int, int8, int16, int32, int64:
|
||||
n = int(reflect.ValueOf(value).Int())
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
n = int(reflect.ValueOf(value).Uint())
|
||||
case string:
|
||||
n, _ = strconv.Atoi(reflect.ValueOf(value).String())
|
||||
case []byte:
|
||||
n, _ = strconv.Atoi(string(reflect.ValueOf(value).Bytes()))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func decimalFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
var (
|
||||
n float64
|
||||
)
|
||||
switch value.(type) {
|
||||
case float32, float64:
|
||||
n = reflect.ValueOf(value).Float()
|
||||
case int, int8, int16, int32, int64:
|
||||
n = float64(reflect.ValueOf(value).Int())
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
n = float64(reflect.ValueOf(value).Uint())
|
||||
case string:
|
||||
n, _ = strconv.ParseFloat(reflect.ValueOf(value).String(), 64)
|
||||
case []byte:
|
||||
n, _ = strconv.ParseFloat(string(reflect.ValueOf(value).Bytes()), 64)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func dateFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
if t, ok := value.(time.Time); ok {
|
||||
return t.Format("2006-01-02")
|
||||
}
|
||||
if t, ok := value.(*sql.NullTime); ok {
|
||||
if t != nil && t.Valid {
|
||||
return t.Time.Format("2006-01-02")
|
||||
}
|
||||
}
|
||||
if t, ok := value.(int64); ok {
|
||||
tm := time.Unix(t, 0)
|
||||
return tm.Format("2006-01-02")
|
||||
}
|
||||
return DefaultNullDisplay
|
||||
}
|
||||
|
||||
func timeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
if t, ok := value.(time.Time); ok {
|
||||
return t.Format("15:04:05")
|
||||
}
|
||||
if t, ok := value.(*sql.NullTime); ok {
|
||||
if t != nil && t.Valid {
|
||||
return t.Time.Format("15:04:05")
|
||||
}
|
||||
}
|
||||
if t, ok := value.(int64); ok {
|
||||
tm := time.Unix(t, 0)
|
||||
return tm.Format("15:04:05")
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func datetimeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
if t, ok := value.(time.Time); ok {
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if t, ok := value.(*sql.NullTime); ok {
|
||||
if t != nil && t.Valid {
|
||||
return t.Time.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
}
|
||||
if t, ok := value.(int64); ok {
|
||||
if t > 0 {
|
||||
tm := time.Unix(t, 0)
|
||||
return tm.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
}
|
||||
return DefaultNullDisplay
|
||||
}
|
||||
|
||||
func percentageFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
n := decimalFormat(ctx, value, model, schema).(float64)
|
||||
if n <= 1 {
|
||||
return fmt.Sprintf("%.2f%%", n*100)
|
||||
} else {
|
||||
return fmt.Sprintf("%.2f%%", n)
|
||||
}
|
||||
}
|
||||
|
||||
func durationFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
var (
|
||||
hour int
|
||||
minVal int
|
||||
sec int
|
||||
)
|
||||
n := integerFormat(ctx, value, model, schema).(int)
|
||||
hour = n / 3600
|
||||
minVal = (n - hour*3600) / 60
|
||||
sec = n - hour*3600 - minVal*60
|
||||
return fmt.Sprintf("%02d:%02d:%02d", hour, minVal, sec)
|
||||
}
|
||||
|
||||
func dropdownFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||
attributes := schema.Attribute
|
||||
if attributes.Values != nil {
|
||||
for _, v := range attributes.Values {
|
||||
if v.Value == value {
|
||||
return v.Label
|
||||
}
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func NewFormatter() *Formatter {
|
||||
formatter := &Formatter{}
|
||||
return formatter
|
||||
}
|
||||
|
||||
func RegisterFormat(f string, cb FormatFunc) {
|
||||
DefaultFormatter.Register(f, cb)
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
module git.nobla.cn/golang/rest
|
||||
|
||||
go 1.22.9
|
||||
|
||||
require (
|
||||
git.nobla.cn/golang/kos v0.1.32
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/go-playground/validator/v10 v10.23.0
|
||||
github.com/longbridgeapp/sqlparser v0.3.2
|
||||
github.com/rs/xid v1.6.0
|
||||
github.com/uole/sqlparser v0.0.1
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
golang.org/x/crypto v0.19.0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sys v0.17.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
git.nobla.cn/golang/kos v0.1.32 h1:sFVCA7vKc8dPUd0cxzwExOSPX2mmMh2IuwL6cYS1pBc=
|
||||
git.nobla.cn/golang/kos v0.1.32/go.mod h1:35Z070+5oB39WcVrh5DDlnVeftL/Ccmscw2MZFe9fUg=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL8sThn8IHr/sO+o=
|
||||
github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M=
|
||||
github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y=
|
||||
github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/uole/sqlparser v0.0.1 h1:LLUklg6Ne5MypXQuo53QcJv/xKdxtEKM9iUuEBN/lt8=
|
||||
github.com/uole/sqlparser v0.0.1/go.mod h1:CRYFz2PTm9oHM0j9GFKi1VzPy70r6GsF0b9vpnwJ4yI=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
|
@ -0,0 +1,251 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
beforeCreate = "beforeCreate"
|
||||
afterCreate = "afterCreate"
|
||||
beforeUpdate = "beforeUpdate"
|
||||
afterUpdate = "afterUpdate"
|
||||
beforeSave = "beforeSave"
|
||||
afterSave = "afterSave"
|
||||
beforeDelete = "beforeDelete"
|
||||
afterDelete = "afterDelete"
|
||||
afterExport = "afterExport"
|
||||
afterImport = "afterImport"
|
||||
)
|
||||
|
||||
type (
|
||||
BeforeCreate func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||
AfterCreate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
|
||||
BeforeUpdate func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||
AfterUpdate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
|
||||
BeforeSave func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||
AfterSave func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
|
||||
BeforeDelete func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||
AfterDelete func(ctx context.Context, tx *gorm.DB, model any)
|
||||
AfterExport func(ctx context.Context, filename string) //导出回调
|
||||
AfterImport func(ctx context.Context, result *types.ImportResult) //导入回调
|
||||
hookManager struct {
|
||||
callbacks map[string][]any
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
ActiveQuery interface {
|
||||
BeforeQuery(ctx context.Context, query *Query) (err error)
|
||||
AfterQuery(ctx context.Context, query *Query) (err error)
|
||||
}
|
||||
)
|
||||
|
||||
func (hook *hookManager) register(spec string, cb any) {
|
||||
if hook.callbacks == nil {
|
||||
hook.callbacks = make(map[string][]any)
|
||||
}
|
||||
if _, ok := hook.callbacks[spec]; !ok {
|
||||
hook.callbacks[spec] = make([]any, 0)
|
||||
}
|
||||
hook.callbacks[spec] = append(hook.callbacks[spec], cb)
|
||||
}
|
||||
|
||||
func (hook *hookManager) beforeCreate(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||
callbacks, ok := hook.callbacks[beforeCreate]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(BeforeCreate); ok {
|
||||
if err = cb(ctx, tx, model); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) afterCreate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
|
||||
callbacks, ok := hook.callbacks[afterCreate]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(AfterCreate); ok {
|
||||
cb(ctx, tx, model, diff)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) beforeUpdate(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||
callbacks, ok := hook.callbacks[beforeUpdate]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(BeforeUpdate); ok {
|
||||
if err = cb(ctx, tx, model); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) afterUpdate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
|
||||
callbacks, ok := hook.callbacks[afterUpdate]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(AfterUpdate); ok {
|
||||
cb(ctx, tx, model, diff)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) beforeSave(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||
callbacks, ok := hook.callbacks[beforeSave]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(BeforeSave); ok {
|
||||
if err = cb(ctx, tx, model); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) afterSave(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
|
||||
callbacks, ok := hook.callbacks[afterSave]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(AfterSave); ok {
|
||||
cb(ctx, tx, model, diff)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) beforeDelete(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||
callbacks, ok := hook.callbacks[beforeDelete]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(BeforeDelete); ok {
|
||||
if err = cb(ctx, tx, model); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) afterDelete(ctx context.Context, tx *gorm.DB, model any) {
|
||||
callbacks, ok := hook.callbacks[afterDelete]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(AfterDelete); ok {
|
||||
cb(ctx, tx, model)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) afterExport(ctx context.Context, filename string) {
|
||||
callbacks, ok := hook.callbacks[afterExport]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(AfterExport); ok {
|
||||
cb(ctx, filename)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) afterImport(ctx context.Context, ret *types.ImportResult) {
|
||||
callbacks, ok := hook.callbacks[afterImport]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, callback := range callbacks {
|
||||
if cb, ok := callback.(AfterImport); ok {
|
||||
cb(ctx, ret)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hook *hookManager) BeforeCreate(cb BeforeCreate) {
|
||||
if cb != nil {
|
||||
hook.register(beforeCreate, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) AfterCreate(cb AfterCreate) {
|
||||
if cb != nil {
|
||||
hook.register(afterCreate, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) BeforeUpdate(cb BeforeUpdate) {
|
||||
if cb != nil {
|
||||
hook.register(beforeUpdate, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) AfterUpdate(cb AfterUpdate) {
|
||||
if cb != nil {
|
||||
hook.register(afterUpdate, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) BeforeSave(cb BeforeSave) {
|
||||
if cb != nil {
|
||||
hook.register(beforeSave, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) AfterSave(cb AfterSave) {
|
||||
if cb != nil {
|
||||
hook.register(afterSave, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) BeforeDelete(cb BeforeDelete) {
|
||||
if cb != nil {
|
||||
hook.register(beforeDelete, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) AfterDelete(cb AfterDelete) {
|
||||
if cb != nil {
|
||||
hook.register(afterDelete, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) AfterExport(cb AfterExport) {
|
||||
if cb != nil {
|
||||
hook.register(afterExport, cb)
|
||||
}
|
||||
}
|
||||
|
||||
func (hook *hookManager) AfterImport(cb AfterImport) {
|
||||
if cb != nil {
|
||||
hook.register(afterImport, cb)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,408 @@
|
|||
package inflector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Rule represents name of the inflector rule, can be
|
||||
// Plural or Singular
|
||||
type Rule int
|
||||
|
||||
const (
|
||||
Plural = iota
|
||||
Singular
|
||||
)
|
||||
|
||||
// InflectorRule represents inflector rule
|
||||
type InflectorRule struct {
|
||||
Rules []*ruleItem
|
||||
Irregular []*irregularItem
|
||||
Uninflected []string
|
||||
compiledIrregular *regexp.Regexp
|
||||
compiledUninflected *regexp.Regexp
|
||||
compiledRules []*compiledRule
|
||||
}
|
||||
|
||||
type ruleItem struct {
|
||||
pattern string
|
||||
replacement string
|
||||
}
|
||||
|
||||
type irregularItem struct {
|
||||
word string
|
||||
replacement string
|
||||
}
|
||||
|
||||
// compiledRule represents compiled version of Inflector.Rules.
|
||||
type compiledRule struct {
|
||||
replacement string
|
||||
*regexp.Regexp
|
||||
}
|
||||
|
||||
// threadsafe access to rules and caches
|
||||
var mutex sync.Mutex
|
||||
var rules = make(map[Rule]*InflectorRule)
|
||||
|
||||
// Words that should not be inflected
|
||||
var uninflected = []string{
|
||||
`Amoyese`, `bison`, `Borghese`, `bream`, `breeches`, `britches`, `buffalo`,
|
||||
`cantus`, `carp`, `chassis`, `clippers`, `cod`, `coitus`, `Congoese`,
|
||||
`contretemps`, `corps`, `debris`, `diabetes`, `djinn`, `eland`, `elk`,
|
||||
`equipment`, `Faroese`, `flounder`, `Foochowese`, `gallows`, `Genevese`,
|
||||
`Genoese`, `Gilbertese`, `graffiti`, `headquarters`, `herpes`, `hijinks`,
|
||||
`Hottentotese`, `information`, `innings`, `jackanapes`, `Kiplingese`,
|
||||
`Kongoese`, `Lucchese`, `mackerel`, `Maltese`, `.*?media`, `mews`, `moose`,
|
||||
`mumps`, `Nankingese`, `news`, `nexus`, `Niasese`, `Pekingese`,
|
||||
`Piedmontese`, `pincers`, `Pistoiese`, `pliers`, `Portuguese`, `proceedings`,
|
||||
`rabies`, `rice`, `rhinoceros`, `salmon`, `Sarawakese`, `scissors`,
|
||||
`sea[- ]bass`, `series`, `Shavese`, `shears`, `siemens`, `species`, `swine`,
|
||||
`testes`, `trousers`, `trout`, `tuna`, `Vermontese`, `Wenchowese`, `whiting`,
|
||||
`wildebeest`, `Yengeese`,
|
||||
}
|
||||
|
||||
// Plural words that should not be inflected
|
||||
var uninflectedPlurals = []string{
|
||||
`.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`,
|
||||
`people`,
|
||||
}
|
||||
|
||||
// Singular words that should not be inflected
|
||||
var uninflectedSingulars = []string{
|
||||
`.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`,
|
||||
`.*ss`,
|
||||
}
|
||||
|
||||
type cache map[string]string
|
||||
|
||||
// Inflected words that already cached for immediate retrieval from a given Rule
|
||||
var caches = make(map[Rule]cache)
|
||||
|
||||
// map of irregular words where its key is a word and its value is the replacement
|
||||
var irregularMaps = make(map[Rule]cache)
|
||||
|
||||
var (
|
||||
// https://github.com/golang/lint/blob/master/lint.go#L770
|
||||
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
commonInitialismsReplacer *strings.Replacer
|
||||
)
|
||||
|
||||
func init() {
|
||||
rules[Plural] = &InflectorRule{
|
||||
Rules: []*ruleItem{
|
||||
{`(?i)(s)tatus$`, `${1}${2}tatuses`},
|
||||
{`(?i)(quiz)$`, `${1}zes`},
|
||||
{`(?i)^(ox)$`, `${1}${2}en`},
|
||||
{`(?i)([m|l])ouse$`, `${1}ice`},
|
||||
{`(?i)(matr|vert|ind)(ix|ex)$`, `${1}ices`},
|
||||
{`(?i)(x|ch|ss|sh)$`, `${1}es`},
|
||||
{`(?i)([^aeiouy]|qu)y$`, `${1}ies`},
|
||||
{`(?i)(hive)$`, `$1s`},
|
||||
{`(?i)(?:([^f])fe|([lre])f)$`, `${1}${2}ves`},
|
||||
{`(?i)sis$`, `ses`},
|
||||
{`(?i)([ti])um$`, `${1}a`},
|
||||
{`(?i)(p)erson$`, `${1}eople`},
|
||||
{`(?i)(m)an$`, `${1}en`},
|
||||
{`(?i)(c)hild$`, `${1}hildren`},
|
||||
{`(?i)(buffal|tomat)o$`, `${1}${2}oes`},
|
||||
{`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|vir)us$`, `${1}i`},
|
||||
{`(?i)us$`, `uses`},
|
||||
{`(?i)(alias)$`, `${1}es`},
|
||||
{`(?i)(ax|cris|test)is$`, `${1}es`},
|
||||
{`s$`, `s`},
|
||||
{`^$`, ``},
|
||||
{`$`, `s`},
|
||||
},
|
||||
Irregular: []*irregularItem{
|
||||
{`atlas`, `atlases`},
|
||||
{`beef`, `beefs`},
|
||||
{`brother`, `brothers`},
|
||||
{`cafe`, `cafes`},
|
||||
{`child`, `children`},
|
||||
{`cookie`, `cookies`},
|
||||
{`corpus`, `corpuses`},
|
||||
{`cow`, `cows`},
|
||||
{`ganglion`, `ganglions`},
|
||||
{`genie`, `genies`},
|
||||
{`genus`, `genera`},
|
||||
{`graffito`, `graffiti`},
|
||||
{`hoof`, `hoofs`},
|
||||
{`loaf`, `loaves`},
|
||||
{`man`, `men`},
|
||||
{`money`, `monies`},
|
||||
{`mongoose`, `mongooses`},
|
||||
{`move`, `moves`},
|
||||
{`mythos`, `mythoi`},
|
||||
{`niche`, `niches`},
|
||||
{`numen`, `numina`},
|
||||
{`occiput`, `occiputs`},
|
||||
{`octopus`, `octopuses`},
|
||||
{`opus`, `opuses`},
|
||||
{`ox`, `oxen`},
|
||||
{`penis`, `penises`},
|
||||
{`person`, `people`},
|
||||
{`sex`, `sexes`},
|
||||
{`soliloquy`, `soliloquies`},
|
||||
{`testis`, `testes`},
|
||||
{`trilby`, `trilbys`},
|
||||
{`turf`, `turfs`},
|
||||
{`potato`, `potatoes`},
|
||||
{`hero`, `heroes`},
|
||||
{`tooth`, `teeth`},
|
||||
{`goose`, `geese`},
|
||||
{`foot`, `feet`},
|
||||
},
|
||||
}
|
||||
prepare(Plural)
|
||||
|
||||
rules[Singular] = &InflectorRule{
|
||||
Rules: []*ruleItem{
|
||||
{`(?i)(s)tatuses$`, `${1}${2}tatus`},
|
||||
{`(?i)^(.*)(menu)s$`, `${1}${2}`},
|
||||
{`(?i)(quiz)zes$`, `$1`},
|
||||
{`(?i)(matr)ices$`, `${1}ix`},
|
||||
{`(?i)(vert|ind)ices$`, `${1}ex`},
|
||||
{`(?i)^(ox)en`, `$1`},
|
||||
{`(?i)(alias)(es)*$`, `$1`},
|
||||
{`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|viri?)i$`, `${1}us`},
|
||||
{`(?i)([ftw]ax)es`, `$1`},
|
||||
{`(?i)(cris|ax|test)es$`, `${1}is`},
|
||||
{`(?i)(shoe|slave)s$`, `$1`},
|
||||
{`(?i)(o)es$`, `$1`},
|
||||
{`ouses$`, `ouse`},
|
||||
{`([^a])uses$`, `${1}us`},
|
||||
{`(?i)([m|l])ice$`, `${1}ouse`},
|
||||
{`(?i)(x|ch|ss|sh)es$`, `$1`},
|
||||
{`(?i)(m)ovies$`, `${1}${2}ovie`},
|
||||
{`(?i)(s)eries$`, `${1}${2}eries`},
|
||||
{`(?i)([^aeiouy]|qu)ies$`, `${1}y`},
|
||||
{`(?i)(tive)s$`, `$1`},
|
||||
{`(?i)([lre])ves$`, `${1}f`},
|
||||
{`(?i)([^fo])ves$`, `${1}fe`},
|
||||
{`(?i)(hive)s$`, `$1`},
|
||||
{`(?i)(drive)s$`, `$1`},
|
||||
{`(?i)(^analy)ses$`, `${1}sis`},
|
||||
{`(?i)(analy|diagno|^ba|(p)arenthe|(p)rogno|(s)ynop|(t)he)ses$`, `${1}${2}sis`},
|
||||
{`(?i)([ti])a$`, `${1}um`},
|
||||
{`(?i)(p)eople$`, `${1}${2}erson`},
|
||||
{`(?i)(m)en$`, `${1}an`},
|
||||
{`(?i)(c)hildren$`, `${1}${2}hild`},
|
||||
{`(?i)(n)ews$`, `${1}${2}ews`},
|
||||
{`eaus$`, `eau`},
|
||||
{`^(.*us)$`, `$1`},
|
||||
{`(?i)s$`, ``},
|
||||
},
|
||||
Irregular: []*irregularItem{
|
||||
{`foes`, `foe`},
|
||||
{`waves`, `wave`},
|
||||
{`curves`, `curve`},
|
||||
{`atlases`, `atlas`},
|
||||
{`beefs`, `beef`},
|
||||
{`brothers`, `brother`},
|
||||
{`cafes`, `cafe`},
|
||||
{`children`, `child`},
|
||||
{`cookies`, `cookie`},
|
||||
{`corpuses`, `corpus`},
|
||||
{`cows`, `cow`},
|
||||
{`ganglions`, `ganglion`},
|
||||
{`genies`, `genie`},
|
||||
{`genera`, `genus`},
|
||||
{`graffiti`, `graffito`},
|
||||
{`hoofs`, `hoof`},
|
||||
{`loaves`, `loaf`},
|
||||
{`men`, `man`},
|
||||
{`monies`, `money`},
|
||||
{`mongooses`, `mongoose`},
|
||||
{`moves`, `move`},
|
||||
{`mythoi`, `mythos`},
|
||||
{`niches`, `niche`},
|
||||
{`numina`, `numen`},
|
||||
{`occiputs`, `occiput`},
|
||||
{`octopuses`, `octopus`},
|
||||
{`opuses`, `opus`},
|
||||
{`oxen`, `ox`},
|
||||
{`penises`, `penis`},
|
||||
{`people`, `person`},
|
||||
{`sexes`, `sex`},
|
||||
{`soliloquies`, `soliloquy`},
|
||||
{`testes`, `testis`},
|
||||
{`trilbys`, `trilby`},
|
||||
{`turfs`, `turf`},
|
||||
{`potatoes`, `potato`},
|
||||
{`heroes`, `hero`},
|
||||
{`teeth`, `tooth`},
|
||||
{`geese`, `goose`},
|
||||
{`feet`, `foot`},
|
||||
},
|
||||
}
|
||||
prepare(Singular)
|
||||
|
||||
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
|
||||
for _, initialism := range commonInitialisms {
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||
}
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
}
|
||||
|
||||
// prepare rule, e.g., compile the pattern.
|
||||
func prepare(r Rule) error {
|
||||
var reString string
|
||||
|
||||
switch r {
|
||||
case Plural:
|
||||
// Merge global uninflected with singularsUninflected
|
||||
rules[r].Uninflected = merge(uninflected, uninflectedPlurals)
|
||||
case Singular:
|
||||
// Merge global uninflected with singularsUninflected
|
||||
rules[r].Uninflected = merge(uninflected, uninflectedSingulars)
|
||||
}
|
||||
|
||||
// Set InflectorRule.compiledUninflected by joining InflectorRule.Uninflected into
|
||||
// a single string then compile it.
|
||||
reString = fmt.Sprintf(`(?i)(^(?:%s))$`, strings.Join(rules[r].Uninflected, `|`))
|
||||
rules[r].compiledUninflected = regexp.MustCompile(reString)
|
||||
|
||||
// Prepare irregularMaps
|
||||
irregularMaps[r] = make(cache, len(rules[r].Irregular))
|
||||
|
||||
// Set InflectorRule.compiledIrregular by joining the irregularItem.word of Inflector.Irregular
|
||||
// into a single string then compile it.
|
||||
vIrregulars := make([]string, len(rules[r].Irregular))
|
||||
for i, item := range rules[r].Irregular {
|
||||
vIrregulars[i] = item.word
|
||||
irregularMaps[r][item.word] = item.replacement
|
||||
}
|
||||
reString = fmt.Sprintf(`(?i)(.*)\b((?:%s))$`, strings.Join(vIrregulars, `|`))
|
||||
rules[r].compiledIrregular = regexp.MustCompile(reString)
|
||||
|
||||
// Compile all patterns in InflectorRule.Rules
|
||||
rules[r].compiledRules = make([]*compiledRule, len(rules[r].Rules))
|
||||
for i, item := range rules[r].Rules {
|
||||
rules[r].compiledRules[i] = &compiledRule{item.replacement, regexp.MustCompile(item.pattern)}
|
||||
}
|
||||
|
||||
// Prepare caches
|
||||
caches[r] = make(cache)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// merge slice a and slice b
|
||||
func merge(a []string, b []string) []string {
|
||||
result := make([]string, len(a)+len(b))
|
||||
copy(result, a)
|
||||
copy(result[len(a):], b)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func getInflected(r Rule, s string) string {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
if v, ok := caches[r][s]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
// Check for irregular words
|
||||
if res := rules[r].compiledIrregular.FindStringSubmatch(s); len(res) >= 3 {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.WriteString(res[1])
|
||||
buf.WriteString(s[0:1])
|
||||
buf.WriteString(irregularMaps[r][strings.ToLower(res[2])][1:])
|
||||
|
||||
// Cache it then returns
|
||||
caches[r][s] = buf.String()
|
||||
return caches[r][s]
|
||||
}
|
||||
|
||||
// Check for uninflected words
|
||||
if rules[r].compiledUninflected.MatchString(s) {
|
||||
caches[r][s] = s
|
||||
return caches[r][s]
|
||||
}
|
||||
|
||||
// Check each rule
|
||||
for _, re := range rules[r].compiledRules {
|
||||
if re.MatchString(s) {
|
||||
caches[r][s] = re.ReplaceAllString(s, re.replacement)
|
||||
return caches[r][s]
|
||||
}
|
||||
}
|
||||
|
||||
// Returns unaltered
|
||||
caches[r][s] = s
|
||||
return caches[r][s]
|
||||
}
|
||||
|
||||
// Pluralize returns string s in plural form.
|
||||
func Pluralize(s string) string {
|
||||
return getInflected(Plural, s)
|
||||
}
|
||||
|
||||
// Singularize returns string s in singular form.
|
||||
func Singularize(s string) string {
|
||||
return getInflected(Singular, s)
|
||||
}
|
||||
|
||||
var (
|
||||
camelizeReg = regexp.MustCompile(`[^A-Za-z0-9]+`)
|
||||
)
|
||||
|
||||
// Camelize Converts a word like "send_email" to "SendEmail"
|
||||
func Camelize(s string) string {
|
||||
s = camelizeReg.ReplaceAllString(s, " ")
|
||||
return strings.Replace(strings.Title(s), " ", "", -1)
|
||||
}
|
||||
|
||||
// Camel2id Converts a word like "SendEmail" to "send_email"
|
||||
func Camel2id(name string) string {
|
||||
var (
|
||||
value = commonInitialismsReplacer.Replace(name)
|
||||
buf strings.Builder
|
||||
lastCase, nextCase, nextNumber bool // upper case == true
|
||||
curCase = value[0] <= 'Z' && value[0] >= 'A'
|
||||
)
|
||||
|
||||
for i, v := range value[:len(value)-1] {
|
||||
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
|
||||
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
|
||||
|
||||
if curCase {
|
||||
if lastCase && (nextCase || nextNumber) {
|
||||
buf.WriteRune(v + 32)
|
||||
} else {
|
||||
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
|
||||
buf.WriteByte('_')
|
||||
}
|
||||
buf.WriteRune(v + 32)
|
||||
}
|
||||
} else {
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
|
||||
lastCase = curCase
|
||||
curCase = nextCase
|
||||
}
|
||||
|
||||
if curCase {
|
||||
if !lastCase && len(value) > 1 {
|
||||
buf.WriteByte('_')
|
||||
}
|
||||
buf.WriteByte(value[len(value)-1] + 32)
|
||||
} else {
|
||||
buf.WriteByte(value[len(value)-1])
|
||||
}
|
||||
ret := buf.String()
|
||||
return ret
|
||||
}
|
||||
|
||||
// Camel2words Converts a CamelCase name into space-separated words.
|
||||
// For example, 'send_email' will be converted to 'Send Email'.
|
||||
func Camel2words(s string) string {
|
||||
s = camelizeReg.ReplaceAllString(s, " ")
|
||||
return strings.Title(s)
|
||||
}
|
|
@ -0,0 +1,987 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"git.nobla.cn/golang/kos/util/arrays"
|
||||
"git.nobla.cn/golang/kos/util/pool"
|
||||
"git.nobla.cn/golang/rest/inflector"
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
naming types.Naming //命名规则
|
||||
value reflect.Value //模块值
|
||||
db *gorm.DB //数据库
|
||||
primaryKey string //主键
|
||||
urlPrefix string //url前缀
|
||||
disableDomain bool //禁用域
|
||||
schemaLookup types.SchemaLookupFunc //获取schema的函数
|
||||
valueLookup types.ValueLookupFunc //查看域
|
||||
statement *gorm.Statement //字段声明
|
||||
formatter *Formatter //格式化
|
||||
response types.HttpWriter //HTTP响应
|
||||
hookMgr *hookManager //钩子管理器
|
||||
dirname string //存放文件目录
|
||||
}
|
||||
|
||||
var (
|
||||
RuntimeScopeKey = &types.RuntimeScope{}
|
||||
)
|
||||
|
||||
// getDB 获取数据库连接对象
|
||||
func (m *Model) getDB() *gorm.DB {
|
||||
return m.db
|
||||
}
|
||||
|
||||
// getFormatter 获取格式化组件
|
||||
func (m *Model) getFormatter() *Formatter {
|
||||
if m.formatter != nil {
|
||||
return m.formatter
|
||||
}
|
||||
return DefaultFormatter
|
||||
}
|
||||
|
||||
// getHook 获取钩子
|
||||
func (m *Model) getHook() *hookManager {
|
||||
return m.hookMgr
|
||||
}
|
||||
|
||||
// hasScenario 判断是否有该场景
|
||||
func (m *Model) hasScenario(s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// setValue 设置字段的值
|
||||
func (m *Model) setValue(refValue reflect.Value, column string, value any) {
|
||||
SetFieldValue(m.statement, refValue, column, value)
|
||||
}
|
||||
|
||||
func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) {
|
||||
SafeSetFileValue(m.statement, refValue, column, value)
|
||||
}
|
||||
|
||||
// getValue 获取字段的值
|
||||
func (m *Model) getValue(refValue reflect.Value, column string) interface{} {
|
||||
return GetFieldValue(m.statement, refValue, column)
|
||||
}
|
||||
|
||||
// hasColumn 判断指定的列是否存在
|
||||
func (m *Model) hasColumn(column string) bool {
|
||||
for _, field := range m.statement.Schema.Fields {
|
||||
if field.DBName == column || field.Name == column {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getFilename 获取文件存放目录
|
||||
func (m *Model) getFilename(domain string, spec string, name string) string {
|
||||
if m.dirname == "" {
|
||||
m.dirname = os.TempDir()
|
||||
}
|
||||
filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name)
|
||||
if _, err := os.Stat(path.Dir(filename)); err != nil {
|
||||
_ = os.MkdirAll(path.Dir(filename), 0755)
|
||||
}
|
||||
return filename
|
||||
}
|
||||
|
||||
// findPrimaryKey 查找主键的值
|
||||
func (m *Model) findPrimaryKey(uri string, r *http.Request) string {
|
||||
var (
|
||||
pos int
|
||||
)
|
||||
urlPath := r.URL.Path
|
||||
pos = strings.IndexByte(uri, ':')
|
||||
if pos > 0 {
|
||||
return urlPath[pos:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseReportColumn 解析报表的列
|
||||
func (m *Model) parseReportColumn(name, props string) *types.SelectColumn {
|
||||
var (
|
||||
key string
|
||||
value string
|
||||
)
|
||||
column := &types.SelectColumn{
|
||||
Name: inflector.Camel2id(name),
|
||||
Native: false,
|
||||
}
|
||||
tokens := strings.Split(props, ";")
|
||||
for _, token := range tokens {
|
||||
pair := strings.SplitN(token, ":", 2)
|
||||
if len(pair) == 0 {
|
||||
continue
|
||||
}
|
||||
if len(pair) == 1 {
|
||||
key = strings.TrimSpace(pair[0])
|
||||
value = ""
|
||||
} else {
|
||||
key = strings.TrimSpace(pair[0])
|
||||
value = strings.TrimSpace(pair[1])
|
||||
}
|
||||
switch key {
|
||||
case "native":
|
||||
column.Native = true
|
||||
case "name":
|
||||
column.Name = value
|
||||
case "expr":
|
||||
column.Expr = value
|
||||
}
|
||||
}
|
||||
return column
|
||||
}
|
||||
|
||||
func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
columns := make([]string, 0)
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
scenarios := field.Tag.Get("scenarios")
|
||||
if !hasToken(types.ScenarioList, scenarios) {
|
||||
continue
|
||||
}
|
||||
isPrimary := field.Tag.Get("is_primary")
|
||||
if isPrimary != "true" {
|
||||
continue
|
||||
}
|
||||
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
|
||||
if !column.Native {
|
||||
continue
|
||||
}
|
||||
if column.Expr == "" {
|
||||
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
|
||||
} else {
|
||||
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
|
||||
}
|
||||
}
|
||||
columns = append(columns, "COUNT(*) AS count")
|
||||
query.Select(columns...)
|
||||
}
|
||||
|
||||
func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
columns := make([]string, 0)
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
scenarios := field.Tag.Get("scenarios")
|
||||
if !hasToken(types.ScenarioList, scenarios) {
|
||||
continue
|
||||
}
|
||||
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
|
||||
if !column.Native {
|
||||
continue
|
||||
}
|
||||
if column.Expr == "" {
|
||||
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
|
||||
} else {
|
||||
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
|
||||
}
|
||||
}
|
||||
query.Select(columns...)
|
||||
}
|
||||
|
||||
// buildCondition 构建sql条件
|
||||
func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
|
||||
return BuildConditions(ctx, r, query, schemas)
|
||||
}
|
||||
|
||||
// ModuleName 模块名称
|
||||
func (m *Model) ModuleName() string {
|
||||
return m.naming.ModuleName
|
||||
}
|
||||
|
||||
// TableName 表的名称
|
||||
func (m *Model) TableName() string {
|
||||
return m.naming.ModuleName
|
||||
}
|
||||
|
||||
// Fields 返回搜索的模型的字段
|
||||
func (m *Model) Fields() []*schema.Field {
|
||||
return m.statement.Schema.Fields
|
||||
}
|
||||
|
||||
// Uri 获取请求的uri
|
||||
func (m *Model) Uri(scenario string) string {
|
||||
ss := make([]string, 4)
|
||||
if m.urlPrefix != "" {
|
||||
ss = append(ss, m.urlPrefix)
|
||||
}
|
||||
switch scenario {
|
||||
case types.ScenarioList:
|
||||
ss = append(ss, m.naming.ModuleName, m.naming.Pluralize)
|
||||
case types.ScenarioView:
|
||||
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||||
case types.ScenarioCreate:
|
||||
ss = append(ss, m.naming.ModuleName, m.naming.Singular)
|
||||
case types.ScenarioUpdate:
|
||||
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||||
case types.ScenarioDelete:
|
||||
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||||
case types.ScenarioExport:
|
||||
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export")
|
||||
case types.ScenarioImport:
|
||||
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import")
|
||||
}
|
||||
uri := path.Join(ss...)
|
||||
if !strings.HasPrefix(uri, "/") {
|
||||
uri = "/" + uri
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// Method 获取HTTP请求的方法
|
||||
func (m *Model) Method(scenario string) string {
|
||||
var (
|
||||
method = http.MethodGet
|
||||
)
|
||||
switch scenario {
|
||||
case types.ScenarioCreate:
|
||||
method = http.MethodPost
|
||||
case types.ScenarioUpdate:
|
||||
method = http.MethodPut
|
||||
case types.ScenarioDelete:
|
||||
method = http.MethodDelete
|
||||
}
|
||||
return method
|
||||
}
|
||||
|
||||
// Search 实现通过HTTP方法查找数据
|
||||
func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ok bool
|
||||
err error
|
||||
qs url.Values
|
||||
page int
|
||||
pageSize int
|
||||
pageIndex int
|
||||
query *Query
|
||||
domainName string
|
||||
modelSlices reflect.Value
|
||||
modelValues reflect.Value
|
||||
searchSchemas []*types.Schema
|
||||
listSchemas []*types.Schema
|
||||
modelValue reflect.Value
|
||||
scenario string
|
||||
reporter types.Reporter
|
||||
namerTable tableNamer
|
||||
)
|
||||
qs = r.URL.Query()
|
||||
page, _ = strconv.Atoi(qs.Get("page"))
|
||||
pageSize, _ = strconv.Atoi(qs.Get("pagesize"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = defaultPageSize
|
||||
}
|
||||
pageIndex = page
|
||||
if pageIndex > 0 {
|
||||
pageIndex--
|
||||
}
|
||||
modelValue = reflect.New(m.value.Type())
|
||||
//这里创建指针类型,这样的话就能在format里面调用函数
|
||||
if m.value.Kind() != reflect.Ptr {
|
||||
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
|
||||
} else {
|
||||
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
|
||||
}
|
||||
modelValues = reflect.New(modelSlices.Type())
|
||||
modelValues.Elem().Set(modelSlices)
|
||||
query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface())
|
||||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||
Domain: domainName,
|
||||
Request: r,
|
||||
User: m.valueLookup("user", w, r),
|
||||
ModuleName: m.naming.ModuleName,
|
||||
TableName: m.naming.TableName,
|
||||
Scenario: types.ScenarioList,
|
||||
})
|
||||
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
scenario = types.ScenarioList
|
||||
if arrays.Exists(r.FormValue("scenario"), allowScenario) {
|
||||
scenario = r.FormValue("scenario")
|
||||
}
|
||||
if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
if !m.disableDomain {
|
||||
if m.hasColumn(types.FieldDomain) {
|
||||
query.AndWhere(newCondition(types.FieldDomain, domainName))
|
||||
}
|
||||
}
|
||||
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
|
||||
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||||
return
|
||||
}
|
||||
// 处理表名逻辑
|
||||
if namerTable, ok = query.Model().(tableNamer); ok {
|
||||
query.From(namerTable.HttpTableName(r))
|
||||
}
|
||||
//处理报表逻辑
|
||||
if reporter, ok = modelValue.Interface().(types.Reporter); ok {
|
||||
query.From(reporter.RealTable())
|
||||
}
|
||||
res := &types.ListResponse{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
if reporter == nil {
|
||||
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
|
||||
} else {
|
||||
//如果是报表的情况,需要手动指定COUNT的雨具逻辑才能生效
|
||||
m.buildReporterCountColumns(childCtx, reporter, query)
|
||||
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
|
||||
|
||||
//这里需要重置一下选项,不然会出问题
|
||||
query.ResetSelect()
|
||||
query.GroupBy(reporter.GroupBy(childCtx)...)
|
||||
}
|
||||
query.Offset(pageIndex * pageSize).Limit(pageSize)
|
||||
if res.TotalCount > 0 {
|
||||
if reporter != nil {
|
||||
m.buildReporterQueryColumns(childCtx, reporter, query)
|
||||
}
|
||||
if err = query.All(modelValues.Interface()); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
// 不进行格式化输出
|
||||
res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format"))
|
||||
} else {
|
||||
res.Data = make([]string, 0)
|
||||
}
|
||||
m.response.Success(w, res)
|
||||
}
|
||||
|
||||
// Create 实现通过HTTP方法创建模型
|
||||
func (m *Model) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
err error
|
||||
model any
|
||||
schemas []*types.Schema
|
||||
diffAttrs []*types.DiffAttr
|
||||
domainName string
|
||||
modelValue reflect.Value
|
||||
)
|
||||
modelValue = reflect.New(m.value.Type())
|
||||
model = modelValue.Interface()
|
||||
if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil {
|
||||
m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
if !m.disableDomain {
|
||||
if m.hasColumn(types.FieldDomain) {
|
||||
m.setValue(modelValue, types.FieldDomain, domainName)
|
||||
}
|
||||
}
|
||||
diffAttrs = make([]*types.DiffAttr, 0, 10)
|
||||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||
Domain: domainName,
|
||||
User: m.valueLookup("user", w, r),
|
||||
Request: r,
|
||||
ModuleName: m.naming.ModuleName,
|
||||
TableName: m.naming.TableName,
|
||||
Scenario: types.ScenarioCreate,
|
||||
Schemas: schemas,
|
||||
})
|
||||
dbSess := m.getDB().WithContext(childCtx)
|
||||
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||
if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil {
|
||||
return
|
||||
}
|
||||
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
|
||||
return
|
||||
}
|
||||
if tabler, ok := model.(types.Tabler); ok {
|
||||
errTx = tx.Table(tabler.TableName()).Save(model).Error
|
||||
} else {
|
||||
errTx = tx.Save(model).Error
|
||||
}
|
||||
if errTx != nil {
|
||||
return
|
||||
}
|
||||
for _, row := range schemas {
|
||||
diffAttrs = append(diffAttrs, &types.DiffAttr{
|
||||
Column: row.Column,
|
||||
Label: row.Label,
|
||||
OldValue: nil,
|
||||
NewValue: m.getValue(modelValue, row.Column),
|
||||
})
|
||||
}
|
||||
return
|
||||
}); err == nil {
|
||||
res := &types.CreateResponse{
|
||||
ID: m.getValue(modelValue, m.primaryKey),
|
||||
Status: "created",
|
||||
}
|
||||
if creator, ok := model.(afterCreated); ok {
|
||||
creator.AfterCreated(childCtx, dbSess)
|
||||
}
|
||||
if preserver, ok := model.(afterSaved); ok {
|
||||
preserver.AfterSaved(childCtx, dbSess)
|
||||
}
|
||||
m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs)
|
||||
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
|
||||
m.response.Success(w, res)
|
||||
} else {
|
||||
m.response.Failure(w, types.RequestCreateFailure, err.Error(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update 实现通过HTTP方法更新模型
|
||||
func (m *Model) Update(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
err error
|
||||
model any
|
||||
schemas []*types.Schema
|
||||
diffAttrs []*types.DiffAttr
|
||||
domainName string
|
||||
modelValue reflect.Value
|
||||
oldValues map[string]any
|
||||
)
|
||||
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
|
||||
modelValue = reflect.New(m.value.Type())
|
||||
model = modelValue.Interface()
|
||||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
conditions := map[string]any{
|
||||
m.primaryKey: idStr,
|
||||
}
|
||||
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
oldValues = make(map[string]any)
|
||||
for _, row := range schemas {
|
||||
oldValues[row.Column] = m.getValue(modelValue, row.Column)
|
||||
}
|
||||
if err = json.NewDecoder(r.Body).Decode(model); err != nil {
|
||||
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||||
return
|
||||
}
|
||||
diffAttrs = make([]*types.DiffAttr, 0, 10)
|
||||
updates := make(map[string]any)
|
||||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||
Domain: domainName,
|
||||
Request: r,
|
||||
User: m.valueLookup("user", w, r),
|
||||
ModuleName: m.naming.ModuleName,
|
||||
TableName: m.naming.TableName,
|
||||
Scenario: types.ScenarioUpdate,
|
||||
Schemas: schemas,
|
||||
PrimaryKeyValue: idStr,
|
||||
})
|
||||
dbSess := m.getDB().WithContext(childCtx)
|
||||
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||
if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil {
|
||||
return
|
||||
}
|
||||
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
|
||||
return
|
||||
}
|
||||
for _, row := range schemas {
|
||||
v := m.getValue(modelValue, row.Column)
|
||||
if oldValues[row.Column] != v {
|
||||
updates[row.Column] = v
|
||||
diffAttrs = append(diffAttrs, &types.DiffAttr{
|
||||
Column: row.Column,
|
||||
Label: row.Label,
|
||||
OldValue: oldValues[row.Column],
|
||||
NewValue: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(updates) > 0 {
|
||||
if tabler, ok := model.(types.Tabler); ok {
|
||||
errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error
|
||||
} else {
|
||||
errTx = tx.Model(model).Updates(updates).Error
|
||||
}
|
||||
if errTx != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}); err == nil {
|
||||
if updater, ok := model.(afterUpdated); ok {
|
||||
updater.AfterUpdated(childCtx, dbSess)
|
||||
}
|
||||
if preserver, ok := model.(afterSaved); ok {
|
||||
preserver.AfterSaved(childCtx, dbSess)
|
||||
}
|
||||
m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs)
|
||||
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
|
||||
m.response.Success(w, types.UpdateResponse{
|
||||
ID: idStr,
|
||||
Status: "updated",
|
||||
})
|
||||
} else {
|
||||
m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete 实现通过HTTP方法删除模型
|
||||
func (m *Model) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
err error
|
||||
model any
|
||||
modelValue reflect.Value
|
||||
)
|
||||
idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r)
|
||||
modelValue = reflect.New(m.value.Type())
|
||||
model = modelValue.Interface()
|
||||
conditions := map[string]any{
|
||||
m.primaryKey: idStr,
|
||||
}
|
||||
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||
Domain: m.valueLookup(types.FieldDomain, w, r),
|
||||
User: m.valueLookup("user", w, r),
|
||||
Request: r,
|
||||
ModuleName: m.naming.ModuleName,
|
||||
TableName: m.naming.TableName,
|
||||
Scenario: types.ScenarioDelete,
|
||||
PrimaryKeyValue: idStr,
|
||||
})
|
||||
dbSess := m.getDB().WithContext(childCtx)
|
||||
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||
if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil {
|
||||
return
|
||||
}
|
||||
if tabler, ok := model.(types.Tabler); ok {
|
||||
errTx = tx.Table(tabler.TableName()).Delete(model).Error
|
||||
} else {
|
||||
errTx = tx.Delete(model).Error
|
||||
}
|
||||
if errTx != nil {
|
||||
return
|
||||
}
|
||||
m.getHook().afterDelete(childCtx, tx, model)
|
||||
return
|
||||
}); err == nil {
|
||||
m.response.Success(w, types.DeleteResponse{
|
||||
ID: idStr,
|
||||
Status: "deleted",
|
||||
})
|
||||
} else {
|
||||
m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// View 查看数据详情
|
||||
func (m *Model) View(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
err error
|
||||
model any
|
||||
modelValue reflect.Value
|
||||
qs url.Values
|
||||
schemas []*types.Schema
|
||||
scenario string
|
||||
domainName string
|
||||
)
|
||||
qs = r.URL.Query()
|
||||
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
|
||||
modelValue = reflect.New(m.value.Type())
|
||||
model = modelValue.Interface()
|
||||
conditions := map[string]any{
|
||||
m.primaryKey: idStr,
|
||||
}
|
||||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
scenario = qs.Get("scenario")
|
||||
if scenario == "" {
|
||||
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView)
|
||||
} else {
|
||||
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario)
|
||||
}
|
||||
if err == nil {
|
||||
m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format")))
|
||||
} else {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Export 实现通过HTTP方法导出模型
|
||||
func (m *Model) Export(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
err error
|
||||
query *Query
|
||||
modelSlices reflect.Value
|
||||
modelValues reflect.Value
|
||||
searchSchemas []*types.Schema
|
||||
exportSchemas []*types.Schema
|
||||
domainName string
|
||||
fp *os.File
|
||||
modelValue reflect.Value
|
||||
)
|
||||
if !m.hasScenario(types.ScenarioList) {
|
||||
m.response.Failure(w, types.RequestDenied, "request denied", nil)
|
||||
return
|
||||
}
|
||||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||
filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
|
||||
if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil {
|
||||
m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = fp.Close()
|
||||
}()
|
||||
modelValue = reflect.New(m.value.Type())
|
||||
//这里创建指针类型,这样的话就能在format里面调用函数
|
||||
if m.value.Kind() != reflect.Ptr {
|
||||
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
|
||||
} else {
|
||||
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
|
||||
}
|
||||
modelValues = reflect.New(modelSlices.Type())
|
||||
modelValues.Elem().Set(modelSlices)
|
||||
query = NewQuery(m.getDB(), modelValue.Interface())
|
||||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||
Domain: domainName,
|
||||
Request: r,
|
||||
User: m.valueLookup("user", w, r),
|
||||
ModuleName: m.naming.ModuleName,
|
||||
TableName: m.naming.TableName,
|
||||
Scenario: types.ScenarioExport,
|
||||
})
|
||||
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
|
||||
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||||
return
|
||||
}
|
||||
if !m.disableDomain {
|
||||
if m.hasColumn(types.FieldDomain) {
|
||||
query.AndWhere(newCondition(types.FieldDomain, domainName))
|
||||
}
|
||||
}
|
||||
// 处理表名逻辑
|
||||
if namerTable, ok := query.Model().(tableNamer); ok {
|
||||
query.From(namerTable.HttpTableName(r))
|
||||
}
|
||||
//处理报表逻辑
|
||||
if reporter, ok := modelValue.Interface().(types.Reporter); ok {
|
||||
query.From(reporter.RealTable())
|
||||
query.GroupBy(reporter.GroupBy(childCtx)...)
|
||||
m.buildReporterQueryColumns(childCtx, reporter, query)
|
||||
}
|
||||
if err = query.All(modelValues.Interface()); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
|
||||
value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "")
|
||||
writer := csv.NewWriter(fp)
|
||||
rows := make([]string, len(exportSchemas))
|
||||
for i, field := range exportSchemas {
|
||||
rows[i] = field.Label
|
||||
}
|
||||
_ = writer.Write(rows)
|
||||
if values, ok := value.([]any); ok {
|
||||
for _, val := range values {
|
||||
row, ok2 := val.(map[string]any)
|
||||
if !ok2 {
|
||||
continue
|
||||
}
|
||||
for i, field := range exportSchemas {
|
||||
if v, ok := row[field.Column]; ok {
|
||||
rows[i] = fmt.Sprint(v)
|
||||
} else {
|
||||
rows[i] = ""
|
||||
}
|
||||
}
|
||||
_ = writer.Write(rows)
|
||||
}
|
||||
}
|
||||
writer.Flush()
|
||||
m.getHook().afterExport(childCtx, filename)
|
||||
http.ServeContent(w, r, path.Base(filename), time.Now(), fp)
|
||||
}
|
||||
|
||||
// findSchema 查找指定的schema
|
||||
func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema {
|
||||
for _, row := range schemas {
|
||||
if row.Label == label {
|
||||
return row
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// importInternal 文件上传方法
|
||||
func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) {
|
||||
var (
|
||||
err error
|
||||
rows []string
|
||||
fp *os.File
|
||||
tm time.Time
|
||||
fields []string
|
||||
sess *gorm.DB
|
||||
csvReader *csv.Reader
|
||||
csvWriter *csv.Writer
|
||||
modelValue reflect.Value
|
||||
modelEntity any
|
||||
diffAttrs []*types.DiffAttr
|
||||
result *types.ImportResult
|
||||
failureFp *os.File
|
||||
failureFile string
|
||||
)
|
||||
tm = time.Now()
|
||||
result = &types.ImportResult{}
|
||||
if fp, err = os.Open(filename); err != nil {
|
||||
result.Code = types.ErrImportFileNotExists
|
||||
goto __end
|
||||
}
|
||||
defer func() {
|
||||
_ = fp.Close()
|
||||
}()
|
||||
csvReader = csv.NewReader(fp)
|
||||
if rows, err = csvReader.Read(); err != nil {
|
||||
result.Code = types.ErrImportFileUnavailable
|
||||
goto __end
|
||||
}
|
||||
fields = make([]string, 0, len(rows))
|
||||
for _, s := range rows {
|
||||
v := m.findSchema(s, schemas)
|
||||
if v == nil {
|
||||
result.Code = types.ErrImportColumnNotMatch
|
||||
goto __end
|
||||
}
|
||||
fields = append(fields, v.Column)
|
||||
}
|
||||
sess = m.getDB().WithContext(ctx)
|
||||
//失败文件指针
|
||||
failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix()))
|
||||
if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = failureFp.Close()
|
||||
}()
|
||||
csvWriter = csv.NewWriter(failureFp)
|
||||
rows = append(rows, "Error")
|
||||
_ = csvWriter.Write(rows)
|
||||
diffAttrs = make([]*types.DiffAttr, len(schemas))
|
||||
for {
|
||||
if rows, err = csvReader.Read(); err != nil {
|
||||
break
|
||||
}
|
||||
result.TotalCount++
|
||||
if len(rows) != len(fields) {
|
||||
continue
|
||||
}
|
||||
modelValue = reflect.New(m.value.Type())
|
||||
for idx, field := range fields {
|
||||
m.safeSetValue(modelValue, field, rows[idx])
|
||||
}
|
||||
if len(extraFields) > 0 {
|
||||
for k, v := range extraFields {
|
||||
m.safeSetValue(modelValue, k, v)
|
||||
}
|
||||
}
|
||||
modelEntity = modelValue.Interface()
|
||||
//写入数据
|
||||
if fast {
|
||||
//如果是快速模式,直接存储数据
|
||||
if err = sess.Save(modelEntity).Error; err == nil {
|
||||
result.SuccessCount++
|
||||
} else {
|
||||
rows = append(rows, err.Error())
|
||||
_ = csvWriter.Write(rows)
|
||||
}
|
||||
} else {
|
||||
if err = sess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||
if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil {
|
||||
return
|
||||
}
|
||||
if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil {
|
||||
return
|
||||
}
|
||||
if tabler, ok := modelEntity.(types.Tabler); ok {
|
||||
errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error
|
||||
} else {
|
||||
errTx = tx.Save(modelEntity).Error
|
||||
}
|
||||
if errTx != nil {
|
||||
return
|
||||
}
|
||||
for idx, row := range schemas {
|
||||
diffAttrs[idx] = &types.DiffAttr{
|
||||
Column: row.Column,
|
||||
Label: row.Label,
|
||||
NewValue: m.getValue(modelValue, row.Column),
|
||||
}
|
||||
}
|
||||
m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs)
|
||||
m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs)
|
||||
return
|
||||
}); err == nil {
|
||||
result.SuccessCount++
|
||||
} else {
|
||||
rows = append(rows, err.Error())
|
||||
_ = csvWriter.Write(rows)
|
||||
}
|
||||
}
|
||||
}
|
||||
csvWriter.Flush()
|
||||
__end:
|
||||
result.UploadFile = filename
|
||||
if result.TotalCount > result.SuccessCount {
|
||||
result.FailureFile = failureFile
|
||||
}
|
||||
result.Duration = time.Now().Sub(tm)
|
||||
m.getHook().afterImport(ctx, result)
|
||||
}
|
||||
|
||||
// Import 实现通过HTTP方法导入
|
||||
func (m *Model) Import(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
err error
|
||||
fast bool
|
||||
schemas []*types.Schema
|
||||
rows []string
|
||||
domainName string
|
||||
dst *os.File
|
||||
fp multipart.File
|
||||
csvWriter *csv.Writer
|
||||
qs url.Values
|
||||
extraFields map[string]string
|
||||
)
|
||||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
|
||||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||
return
|
||||
}
|
||||
//这里用background的context
|
||||
childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{
|
||||
Domain: domainName,
|
||||
User: m.valueLookup("user", w, r),
|
||||
ModuleName: m.naming.ModuleName,
|
||||
TableName: m.naming.TableName,
|
||||
Scenario: types.ScenarioImport,
|
||||
Schemas: schemas,
|
||||
})
|
||||
if r.Method == http.MethodGet {
|
||||
//下载导入模板
|
||||
csvWriter = csv.NewWriter(w)
|
||||
rows = make([]string, 0, len(schemas))
|
||||
for _, row := range schemas {
|
||||
//主键不需要导入
|
||||
if row.IsPrimaryKey == 1 {
|
||||
continue
|
||||
}
|
||||
rows = append(rows, row.Label)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
|
||||
err = csvWriter.Write(rows)
|
||||
csvWriter.Flush()
|
||||
return
|
||||
}
|
||||
filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
|
||||
if fp, _, err = r.FormFile("file"); err != nil {
|
||||
m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = fp.Close()
|
||||
}()
|
||||
if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil {
|
||||
buf := pool.GetBytes(32 * 1024)
|
||||
_, err = io.CopyBuffer(dst, fp, buf)
|
||||
pool.PutBytes(buf)
|
||||
_ = dst.Close()
|
||||
} else {
|
||||
m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil)
|
||||
return
|
||||
}
|
||||
qs = r.URL.Query()
|
||||
if qs != nil {
|
||||
extraFields = make(map[string]string)
|
||||
for k, _ := range qs {
|
||||
if strings.HasPrefix(k, "_attr_") {
|
||||
extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k)
|
||||
}
|
||||
}
|
||||
}
|
||||
fast, _ = strconv.ParseBool(qs.Get("__fast"))
|
||||
go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields)
|
||||
m.response.Success(w, types.ImportResponse{
|
||||
UID: m.valueLookup("user", w, r),
|
||||
Status: "committed",
|
||||
})
|
||||
}
|
||||
|
||||
// newModel 创建一个模型
|
||||
func newModel(v any, db *gorm.DB, naming types.Naming) *Model {
|
||||
model := &Model{
|
||||
db: db,
|
||||
naming: naming,
|
||||
response: &httpWriter{},
|
||||
value: reflect.Indirect(reflect.ValueOf(v)),
|
||||
valueLookup: defaultValueLookup,
|
||||
}
|
||||
model.statement = &gorm.Statement{
|
||||
DB: model.getDB(),
|
||||
ConnPool: model.getDB().ConnPool,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
if err := model.statement.Parse(v); err == nil {
|
||||
if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 {
|
||||
model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0]
|
||||
}
|
||||
}
|
||||
return model
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package rest
|
||||
|
||||
import "git.nobla.cn/golang/rest/types"
|
||||
|
||||
type Options struct {
|
||||
urlPrefix string
|
||||
moduleName string
|
||||
disableDomain bool
|
||||
router types.HttpRouter
|
||||
writer types.HttpWriter
|
||||
formatter *Formatter
|
||||
dirname string //文件目录
|
||||
}
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
func WithUriPrefix(s string) Option {
|
||||
return func(o *Options) {
|
||||
o.urlPrefix = s
|
||||
}
|
||||
}
|
||||
|
||||
func WithModuleName(s string) Option {
|
||||
return func(o *Options) {
|
||||
o.moduleName = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutDomain 禁用域
|
||||
func WithoutDomain() Option {
|
||||
return func(o *Options) {
|
||||
o.disableDomain = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithHttpRouter(s types.HttpRouter) Option {
|
||||
return func(o *Options) {
|
||||
o.router = s
|
||||
}
|
||||
}
|
||||
|
||||
func WithHttpWriter(s types.HttpWriter) Option {
|
||||
return func(o *Options) {
|
||||
o.writer = s
|
||||
}
|
||||
}
|
||||
|
||||
func WithFormatter(s *Formatter) Option {
|
||||
return func(o *Options) {
|
||||
o.formatter = s
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"git.nobla.cn/golang/kos/pkg/cache"
|
||||
xxhash "github.com/cespare/xxhash/v2"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DisableCache = "DISABLE_CACHE"
|
||||
DurationKey = "gorm:cache_duration"
|
||||
)
|
||||
|
||||
type Cacher struct {
|
||||
rawQuery func(db *gorm.DB)
|
||||
}
|
||||
|
||||
func (c *Cacher) Name() string {
|
||||
return "gorm:cache"
|
||||
}
|
||||
|
||||
func (c *Cacher) Initialize(db *gorm.DB) (err error) {
|
||||
c.rawQuery = db.Callback().Query().Get("gorm:query")
|
||||
err = db.Callback().Query().Replace("gorm:query", c.Query)
|
||||
return
|
||||
}
|
||||
|
||||
// buildCacheKey 构建一个缓存的KEY
|
||||
func (c *Cacher) buildCacheKey(db *gorm.DB) string {
|
||||
s := strconv.FormatUint(xxhash.Sum64String(db.Statement.SQL.String()+fmt.Sprintf("%v", db.Statement.Vars)), 10)
|
||||
return s
|
||||
}
|
||||
|
||||
// getDuration 获取缓存时长
|
||||
func (c *Cacher) getDuration(db *gorm.DB) time.Duration {
|
||||
var (
|
||||
ok bool
|
||||
v any
|
||||
duration time.Duration
|
||||
)
|
||||
if v, ok = db.InstanceGet(DurationKey); !ok {
|
||||
return 0
|
||||
}
|
||||
if duration, ok = v.(time.Duration); !ok {
|
||||
return 0
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
// tryLoad 尝试从缓存读取数据
|
||||
func (c *Cacher) tryLoad(key string, db *gorm.DB) (err error) {
|
||||
var (
|
||||
ok bool
|
||||
buf []byte
|
||||
)
|
||||
if buf, ok = cache.Get(db.Statement.Context, key); ok {
|
||||
err = json.Unmarshal(buf, db.Statement.Dest)
|
||||
} else {
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// storeCache 存储缓存数据
|
||||
func (c *Cacher) storeCache(key string, db *gorm.DB, duration time.Duration) (err error) {
|
||||
var (
|
||||
buf []byte
|
||||
)
|
||||
if buf, err = json.Marshal(db.Statement.Dest); err == nil {
|
||||
cache.SetEx(db.Statement.Context, key, buf, duration)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Cacher) Query(db *gorm.DB) {
|
||||
var (
|
||||
err error
|
||||
cacheKey string
|
||||
duration time.Duration
|
||||
)
|
||||
duration = c.getDuration(db)
|
||||
if duration <= 0 {
|
||||
c.rawQuery(db)
|
||||
return
|
||||
}
|
||||
callbacks.BuildQuerySQL(db)
|
||||
cacheKey = c.buildCacheKey(db)
|
||||
if err = c.tryLoad(cacheKey, db); err == nil {
|
||||
return
|
||||
}
|
||||
c.rawQuery(db)
|
||||
if db.Error == nil {
|
||||
//store cache
|
||||
if err = c.storeCache(cacheKey, db, duration); err != nil {
|
||||
_ = db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func New() *Cacher {
|
||||
return &Cacher{}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package identity
|
||||
|
||||
import (
|
||||
"github.com/rs/xid"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Identify struct {
|
||||
}
|
||||
|
||||
func (identity *Identify) Name() string {
|
||||
return "gorm:identity"
|
||||
}
|
||||
|
||||
func (identity *Identify) Initialize(db *gorm.DB) (err error) {
|
||||
err = db.Callback().Create().Before("gorm:create").Register("auto_identified", identity.Grant)
|
||||
return
|
||||
}
|
||||
|
||||
func (identity *Identify) NextID() string {
|
||||
return xid.New().String()
|
||||
}
|
||||
|
||||
func (identity *Identify) Grant(db *gorm.DB) {
|
||||
var (
|
||||
err error
|
||||
field *schema.Field
|
||||
)
|
||||
if db.Statement.Schema == nil {
|
||||
return
|
||||
}
|
||||
if field = db.Statement.Schema.LookUpField("ID"); field == nil {
|
||||
return
|
||||
}
|
||||
if field.DataType != schema.String {
|
||||
return
|
||||
}
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue.Index(i)); zero {
|
||||
if err = field.Set(db.Statement.Context, db.Statement.ReflectValue.Index(i), identity.NextID()); err != nil {
|
||||
_ = db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); zero {
|
||||
db.Statement.SetColumn("ID", identity.NextID())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func New() *Identify {
|
||||
return &Identify{}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
# 分表实现
|
||||
|
||||
首先定义一个`gorm`的模型,然后实现`shadring.Model`接口,比如如下示例
|
||||
|
||||
```go
|
||||
// ShardingTable 返回增删改时候操作的数据表
|
||||
func (model *CdrLog) ShardingTable(scene string) string {
|
||||
return model.TableName()
|
||||
}
|
||||
|
||||
// ShardingTables 返回查询时候一个范围内的表
|
||||
func (model *CdrLog) ShardingTables(ctx *sharding.Context) []string {
|
||||
var (
|
||||
timestamp int64
|
||||
)
|
||||
timeRange := make([]int64, 0)
|
||||
values := ctx.FindColumnValues("start_stamp")
|
||||
if len(values) == 0 {
|
||||
values = ctx.FindColumnValues("create_stamp")
|
||||
}
|
||||
if len(values) > 0 {
|
||||
for _, v := range values {
|
||||
timestamp, _ = strconv.ParseInt(fmt.Sprint(v), 10, 64)
|
||||
timeRange = append(timeRange, timestamp)
|
||||
}
|
||||
}
|
||||
return shard.DateTableNames(ctx.Context(), "cdr_logs", shard.ShardTypeDateMonth, timeRange)
|
||||
}
|
||||
```
|
||||
|
||||
`ShardingTable`方法是操作增删改的是回调具体表名的方法
|
||||
|
||||
`ShardingTables`方法是操作查询的时候,通过查询条件返回的表名的方法
|
|
@ -0,0 +1,120 @@
|
|||
package sharding
|
||||
|
||||
const (
|
||||
ValueOperaEqual = iota + 0x10
|
||||
ValueOperaGreater
|
||||
ValueOperaLess
|
||||
ValueOperaRange
|
||||
|
||||
ValueTypeString = iota + 0x30
|
||||
ValueTypeNumber
|
||||
ValueTypeBoolean
|
||||
ValueTypeNull
|
||||
ValueTypeAny
|
||||
)
|
||||
|
||||
type (
|
||||
ColumnCondition struct {
|
||||
Name string `json:"name"`
|
||||
Value CondValue `json:"value"`
|
||||
}
|
||||
|
||||
CondValue interface {
|
||||
Type() int
|
||||
Opera() int
|
||||
Value() any
|
||||
}
|
||||
|
||||
equalValue struct {
|
||||
vType int
|
||||
vData any
|
||||
}
|
||||
|
||||
rangeValue struct {
|
||||
vType int
|
||||
vData any
|
||||
}
|
||||
|
||||
greaterValue struct {
|
||||
vType int
|
||||
vData any
|
||||
}
|
||||
|
||||
lessValue struct {
|
||||
vType int
|
||||
vData any
|
||||
}
|
||||
)
|
||||
|
||||
func (v *lessValue) Type() int {
|
||||
return v.vType
|
||||
}
|
||||
|
||||
func (v *lessValue) Opera() int {
|
||||
return ValueOperaLess
|
||||
}
|
||||
|
||||
func (v *lessValue) Value() any {
|
||||
return v.vData
|
||||
}
|
||||
|
||||
func (v *greaterValue) Type() int {
|
||||
return v.vType
|
||||
}
|
||||
|
||||
func (v *greaterValue) Opera() int {
|
||||
return ValueOperaGreater
|
||||
}
|
||||
|
||||
func (v *greaterValue) Value() any {
|
||||
return v.vData
|
||||
}
|
||||
|
||||
func (v *rangeValue) Type() int {
|
||||
return v.vType
|
||||
}
|
||||
|
||||
func (v *rangeValue) Opera() int {
|
||||
return ValueOperaRange
|
||||
}
|
||||
|
||||
func (v *rangeValue) Value() any {
|
||||
return v.vData
|
||||
}
|
||||
|
||||
func (v *equalValue) Type() int {
|
||||
return v.vType
|
||||
}
|
||||
|
||||
func (v *equalValue) Opera() int {
|
||||
return ValueOperaEqual
|
||||
}
|
||||
|
||||
func (v *equalValue) Value() any {
|
||||
return v.vData
|
||||
}
|
||||
|
||||
func newCondValue(vType int, op int, value any) CondValue {
|
||||
switch op {
|
||||
case ValueOperaGreater:
|
||||
return &greaterValue{
|
||||
vType: vType,
|
||||
vData: value,
|
||||
}
|
||||
case ValueOperaLess:
|
||||
return &lessValue{
|
||||
vType: vType,
|
||||
vData: value,
|
||||
}
|
||||
case ValueOperaRange:
|
||||
return &rangeValue{
|
||||
vType: vType,
|
||||
vData: value,
|
||||
}
|
||||
default:
|
||||
return &equalValue{
|
||||
vType: vType,
|
||||
vData: value,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,287 @@
|
|||
package sharding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/longbridgeapp/sqlparser"
|
||||
sqlparserX "github.com/uole/sqlparser"
|
||||
"gorm.io/gorm"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Scope struct {
|
||||
db *gorm.DB
|
||||
stmt *sqlparser.SelectStatement
|
||||
stmtX *sqlparserX.Select
|
||||
}
|
||||
|
||||
func (scope *Scope) findValue(express sqlparser.Expr) (int, any) {
|
||||
var (
|
||||
vType int
|
||||
vData any
|
||||
)
|
||||
switch expr := express.(type) {
|
||||
case *sqlparser.BindExpr:
|
||||
vType = ValueTypeAny
|
||||
if len(scope.db.Statement.Vars) > expr.Pos {
|
||||
vData = scope.db.Statement.Vars[expr.Pos]
|
||||
} else {
|
||||
vType = ValueTypeNull
|
||||
}
|
||||
case *sqlparser.NumberLit:
|
||||
vType = ValueTypeNumber
|
||||
vData = expr.Value
|
||||
case *sqlparser.StringLit:
|
||||
vType = ValueTypeString
|
||||
vData = expr.Value
|
||||
case *sqlparser.BoolLit:
|
||||
vType = ValueTypeBoolean
|
||||
vData = expr.Value
|
||||
case *sqlparser.BlobLit:
|
||||
vType = ValueTypeString
|
||||
vData = expr.Value
|
||||
case *sqlparser.NullLit:
|
||||
vType = ValueTypeNull
|
||||
case *sqlparser.Range:
|
||||
arr := make([]any, 2)
|
||||
vType, arr[0] = scope.findValue(expr.X)
|
||||
vType, arr[1] = scope.findValue(expr.Y)
|
||||
vData = arr
|
||||
}
|
||||
return vType, vData
|
||||
}
|
||||
|
||||
func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) {
|
||||
var (
|
||||
vType int
|
||||
vData any
|
||||
)
|
||||
switch express.Type {
|
||||
case sqlparserX.IntVal:
|
||||
vType = ValueTypeNumber
|
||||
vData, _ = strconv.Atoi(string(express.Val))
|
||||
case sqlparserX.FloatVal:
|
||||
vType = ValueTypeNumber
|
||||
vData, _ = strconv.ParseFloat(string(express.Val), 64)
|
||||
case sqlparserX.ValArg:
|
||||
vType = ValueTypeAny
|
||||
pos, _ := strconv.Atoi(string(express.Val[2:]))
|
||||
if pos > 0 {
|
||||
pos = pos - 1
|
||||
if len(scope.db.Statement.Vars) > pos {
|
||||
vData = scope.db.Statement.Vars[pos]
|
||||
} else {
|
||||
vType = ValueTypeNull
|
||||
}
|
||||
} else {
|
||||
vType = ValueTypeNull
|
||||
}
|
||||
default:
|
||||
vType = ValueTypeString
|
||||
vData = string(express.Val)
|
||||
}
|
||||
return vType, vData
|
||||
}
|
||||
|
||||
func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) {
|
||||
var (
|
||||
ok bool
|
||||
andExpr *sqlparserX.AndExpr
|
||||
orExpr *sqlparserX.OrExpr
|
||||
parentExpr *sqlparserX.ParenExpr
|
||||
comparisonExpr *sqlparserX.ComparisonExpr
|
||||
rangeExpr *sqlparserX.RangeCond
|
||||
coumnExpr *sqlparserX.ColName
|
||||
valueExpr *sqlparserX.SQLVal
|
||||
)
|
||||
conditions = make([]*ColumnCondition, 0, 2)
|
||||
if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok {
|
||||
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
|
||||
return
|
||||
}
|
||||
if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok {
|
||||
return
|
||||
}
|
||||
if coumnExpr.Name.EqualString(column) {
|
||||
cond := &ColumnCondition{
|
||||
Name: coumnExpr.Name.String(),
|
||||
}
|
||||
vType, vData := scope.findValueX(valueExpr)
|
||||
switch comparisonExpr.Operator {
|
||||
case sqlparserX.LessThanStr, sqlparserX.LessEqualStr:
|
||||
cond.Value = newCondValue(vType, ValueOperaLess, vData)
|
||||
case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr:
|
||||
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
|
||||
case sqlparserX.EqualStr:
|
||||
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
|
||||
}
|
||||
if cond.Value != nil {
|
||||
conditions = append(conditions, cond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok {
|
||||
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
|
||||
return
|
||||
}
|
||||
if coumnExpr.Name.EqualString(column) {
|
||||
vType := 0
|
||||
arr := make([]any, 2)
|
||||
if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok {
|
||||
vType, arr[0] = scope.findValueX(valueExpr)
|
||||
}
|
||||
if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok {
|
||||
vType, arr[1] = scope.findValueX(valueExpr)
|
||||
}
|
||||
conditions = append(conditions, &ColumnCondition{
|
||||
Name: coumnExpr.Name.String(),
|
||||
Value: newCondValue(vType, ValueOperaRange, arr),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if andExpr, ok = expr.(*sqlparserX.AndExpr); ok {
|
||||
if andExpr.Left != nil {
|
||||
conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...)
|
||||
}
|
||||
if andExpr.Right != nil {
|
||||
conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...)
|
||||
}
|
||||
}
|
||||
if orExpr, ok = expr.(*sqlparserX.OrExpr); ok {
|
||||
if orExpr.Left != nil {
|
||||
conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...)
|
||||
}
|
||||
if orExpr.Right != nil {
|
||||
conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...)
|
||||
}
|
||||
}
|
||||
if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok {
|
||||
if parentExpr.Expr != nil {
|
||||
conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...)
|
||||
}
|
||||
}
|
||||
return conditions
|
||||
}
|
||||
|
||||
func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition {
|
||||
var (
|
||||
ok bool
|
||||
identExpr *sqlparser.Ident
|
||||
binaryExpr *sqlparser.BinaryExpr
|
||||
parentExpr *sqlparser.ParenExpr
|
||||
conditions []*ColumnCondition
|
||||
)
|
||||
conditions = make([]*ColumnCondition, 0, 2)
|
||||
if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok {
|
||||
if parentExpr.X != nil {
|
||||
if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok {
|
||||
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
|
||||
}
|
||||
if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok {
|
||||
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok {
|
||||
if binaryExpr.X != nil {
|
||||
if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok {
|
||||
if identExpr.Name == column {
|
||||
cond := &ColumnCondition{
|
||||
Name: identExpr.Name,
|
||||
}
|
||||
vType, vData := scope.findValue(binaryExpr.Y)
|
||||
switch binaryExpr.Op {
|
||||
case sqlparser.LT, sqlparser.LE:
|
||||
cond.Value = newCondValue(vType, ValueOperaLess, vData)
|
||||
case sqlparser.GT, sqlparser.GE:
|
||||
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
|
||||
case sqlparser.RANGE, sqlparser.BETWEEN:
|
||||
cond.Value = newCondValue(vType, ValueOperaRange, vData)
|
||||
case sqlparser.EQ:
|
||||
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
|
||||
}
|
||||
if cond.Value != nil {
|
||||
conditions = append(conditions, cond)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok {
|
||||
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
|
||||
}
|
||||
if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok {
|
||||
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if binaryExpr.Y != nil {
|
||||
if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok {
|
||||
conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return conditions
|
||||
}
|
||||
|
||||
func (scope *Scope) DB() *gorm.DB {
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (scope *Scope) Context() context.Context {
|
||||
return scope.db.Statement.Context
|
||||
}
|
||||
|
||||
func (scope *Scope) FindCondition(column string) []*ColumnCondition {
|
||||
if scope.stmtX != nil {
|
||||
if scope.stmtX.Where == nil {
|
||||
return []*ColumnCondition{}
|
||||
}
|
||||
return scope.recursiveFindX(scope.stmtX.Where.Expr, column)
|
||||
}
|
||||
return scope.recursiveFind(scope.stmt.Condition, column)
|
||||
}
|
||||
|
||||
func (scope *Scope) FindColumnValues(column string) []any {
|
||||
result := make([]any, 0)
|
||||
conditions := scope.FindCondition(column)
|
||||
if len(conditions) == 0 {
|
||||
return result
|
||||
}
|
||||
for _, cond := range conditions {
|
||||
if cond.Value.Opera() == ValueOperaGreater {
|
||||
if len(result) == 0 {
|
||||
result = make([]any, 2)
|
||||
}
|
||||
result[0] = cond.Value.Value()
|
||||
}
|
||||
|
||||
if cond.Value.Opera() == ValueOperaLess {
|
||||
if len(result) == 0 {
|
||||
result = make([]any, 2)
|
||||
}
|
||||
result[1] = cond.Value.Value()
|
||||
}
|
||||
|
||||
if cond.Value.Opera() == ValueOperaEqual {
|
||||
result = append(result, cond.Value.Value())
|
||||
break
|
||||
}
|
||||
|
||||
if cond.Value.Opera() == ValueOperaRange {
|
||||
if vs, ok := cond.Value.Value().([]any); ok {
|
||||
if len(vs) == 2 {
|
||||
if len(result) == 0 {
|
||||
result = make([]any, 2)
|
||||
}
|
||||
result[0] = vs[0]
|
||||
result[1] = vs[1]
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,476 @@
|
|||
package sharding
|
||||
|
||||
import (
|
||||
"github.com/longbridgeapp/sqlparser"
|
||||
sqlparserX "github.com/uole/sqlparser"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Sharding struct {
|
||||
UnionAll bool
|
||||
QuoteChar byte
|
||||
}
|
||||
|
||||
func (plugin *Sharding) Name() string {
|
||||
return "gorm:sharding"
|
||||
}
|
||||
|
||||
func (plugin *Sharding) Initialize(db *gorm.DB) (err error) {
|
||||
if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil {
|
||||
return
|
||||
}
|
||||
if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil {
|
||||
return
|
||||
}
|
||||
if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil {
|
||||
return
|
||||
}
|
||||
if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (plugin *Sharding) Create(db *gorm.DB) {
|
||||
var (
|
||||
ok bool
|
||||
scopeModel Model
|
||||
refValue reflect.Value
|
||||
modelValue any
|
||||
)
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||||
if db.Statement.ReflectValue.Len() > 0 {
|
||||
refValue = db.Statement.ReflectValue.Index(0)
|
||||
modelValue = refValue.Interface()
|
||||
}
|
||||
} else {
|
||||
if db.Statement.Model != nil {
|
||||
modelValue = db.Statement.Model
|
||||
} else {
|
||||
modelValue = db.Statement.ReflectValue.Interface()
|
||||
}
|
||||
}
|
||||
if modelValue != nil {
|
||||
if scopeModel, ok = modelValue.(Model); ok {
|
||||
db.Table(scopeModel.ShardingTable(sceneCreate))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (plugin *Sharding) Update(db *gorm.DB) {
|
||||
var (
|
||||
ok bool
|
||||
scopeModel Model
|
||||
refValue reflect.Value
|
||||
modelValue any
|
||||
)
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||||
if db.Statement.ReflectValue.Len() > 0 {
|
||||
refValue = db.Statement.ReflectValue.Index(0)
|
||||
modelValue = refValue.Interface()
|
||||
}
|
||||
} else {
|
||||
if db.Statement.Model != nil {
|
||||
modelValue = db.Statement.Model
|
||||
} else {
|
||||
modelValue = db.Statement.ReflectValue.Interface()
|
||||
}
|
||||
}
|
||||
if modelValue != nil {
|
||||
if scopeModel, ok = modelValue.(Model); ok {
|
||||
db.Table(scopeModel.ShardingTable(sceneUpdate))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (plugin *Sharding) Delete(db *gorm.DB) {
|
||||
var (
|
||||
ok bool
|
||||
scopeModel Model
|
||||
refValue reflect.Value
|
||||
modelValue any
|
||||
)
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||||
if db.Statement.ReflectValue.Len() > 0 {
|
||||
refValue = db.Statement.ReflectValue.Index(0)
|
||||
modelValue = refValue.Interface()
|
||||
}
|
||||
} else {
|
||||
if db.Statement.Model != nil {
|
||||
modelValue = db.Statement.Model
|
||||
} else {
|
||||
modelValue = db.Statement.ReflectValue.Interface()
|
||||
}
|
||||
}
|
||||
if modelValue != nil {
|
||||
if scopeModel, ok = modelValue.(Model); ok {
|
||||
db.Table(scopeModel.ShardingTable(sceneDelete))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (plugin *Sharding) Query(db *gorm.DB) {
|
||||
var (
|
||||
err error
|
||||
ok bool
|
||||
shardingModel Model
|
||||
modelValue any
|
||||
tables []string
|
||||
rawVars []any
|
||||
refValue reflect.Value
|
||||
tableName *sqlparser.TableName
|
||||
selectStmt *sqlparser.SelectStatement
|
||||
stmt sqlparser.Statement
|
||||
parser *sqlparser.Parser
|
||||
numOfTable int
|
||||
orderByExpr []*sqlparser.OrderingTerm
|
||||
limitExpr sqlparser.Expr
|
||||
offsetExpr sqlparser.Expr
|
||||
groupingExpr []sqlparser.Expr
|
||||
havingExpr sqlparser.Expr
|
||||
isCountStatement bool
|
||||
countField string
|
||||
)
|
||||
if db.Statement.Model != nil {
|
||||
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
||||
} else {
|
||||
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
||||
}
|
||||
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
||||
refValue = reflect.Indirect(refValue)
|
||||
}
|
||||
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
||||
elemType := refValue.Type().Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
modelValue = reflect.New(elemType).Interface()
|
||||
} else {
|
||||
modelValue = refValue.Interface()
|
||||
}
|
||||
if shardingModel, ok = modelValue.(Model); !ok {
|
||||
return
|
||||
}
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
callbacks.BuildQuerySQL(db)
|
||||
}
|
||||
parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String()))
|
||||
if stmt, err = parser.ParseStatement(); err != nil {
|
||||
return
|
||||
}
|
||||
if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok {
|
||||
return
|
||||
}
|
||||
tables = shardingModel.ShardingTables(&Scope{
|
||||
db: db,
|
||||
stmt: selectStmt,
|
||||
})
|
||||
numOfTable = len(tables)
|
||||
if numOfTable <= 1 {
|
||||
return
|
||||
}
|
||||
rawVars = make([]any, 0, len(db.Statement.Vars))
|
||||
for _, v := range db.Statement.Vars {
|
||||
rawVars = append(rawVars, v)
|
||||
}
|
||||
//是否是查询count语句
|
||||
//如果不是count的语句,添加order和group的支持
|
||||
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
||||
if v == true {
|
||||
isCountStatement = true
|
||||
}
|
||||
}
|
||||
if !isCountStatement && len(*selectStmt.Columns) == 1 {
|
||||
for _, column := range *selectStmt.Columns {
|
||||
if expr, ok := column.Expr.(*sqlparser.Call); ok {
|
||||
if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword {
|
||||
isCountStatement = true
|
||||
countField = expr.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(selectStmt.OrderBy) > 0 {
|
||||
orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy))
|
||||
for _, row := range selectStmt.OrderBy {
|
||||
orderByExpr = append(orderByExpr, row)
|
||||
}
|
||||
selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0)
|
||||
}
|
||||
if len(selectStmt.GroupingElements) > 0 {
|
||||
groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements))
|
||||
for _, row := range selectStmt.GroupingElements {
|
||||
groupingExpr = append(groupingExpr, row)
|
||||
}
|
||||
if selectStmt.HavingCondition != nil {
|
||||
havingExpr = selectStmt.HavingCondition
|
||||
}
|
||||
}
|
||||
if selectStmt.Limit != nil {
|
||||
limitExpr = selectStmt.Limit
|
||||
selectStmt.Limit = nil
|
||||
}
|
||||
if selectStmt.Offset != nil {
|
||||
offsetExpr = selectStmt.Offset
|
||||
selectStmt.Offset = nil
|
||||
}
|
||||
db.Statement.SQL.Reset()
|
||||
if isCountStatement {
|
||||
db.Statement.SQL.WriteString("SELECT SUM(")
|
||||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||
db.Statement.SQL.WriteString(strings.Trim(countField, "`"))
|
||||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||
db.Statement.SQL.WriteString(") FROM (")
|
||||
} else {
|
||||
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||||
}
|
||||
for i, name := range tables {
|
||||
db.Statement.SQL.WriteByte('(')
|
||||
if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok {
|
||||
tableName.Name.Name = name
|
||||
}
|
||||
db.Statement.SQL.WriteString(selectStmt.String())
|
||||
db.Statement.SQL.WriteByte(')')
|
||||
if i < numOfTable-1 {
|
||||
if plugin.UnionAll {
|
||||
db.Statement.SQL.WriteString(" UNION ALL ")
|
||||
} else {
|
||||
db.Statement.SQL.WriteString(" UNION ")
|
||||
}
|
||||
}
|
||||
if i > 0 {
|
||||
//copy vars
|
||||
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
||||
}
|
||||
}
|
||||
db.Statement.SQL.WriteString(") tbl ")
|
||||
if !isCountStatement {
|
||||
if len(groupingExpr) > 0 {
|
||||
db.Statement.SQL.WriteString(" GROUP BY ")
|
||||
for i, expr := range groupingExpr {
|
||||
if i != 0 {
|
||||
db.Statement.SQL.WriteString(", ")
|
||||
}
|
||||
db.Statement.SQL.WriteString(expr.String())
|
||||
}
|
||||
|
||||
if havingExpr != nil {
|
||||
db.Statement.SQL.WriteString(" HAVING ")
|
||||
db.Statement.SQL.WriteString(havingExpr.String())
|
||||
}
|
||||
}
|
||||
|
||||
if orderByExpr != nil && len(orderByExpr) > 0 {
|
||||
db.Statement.SQL.WriteString(" ORDER BY ")
|
||||
for i, term := range orderByExpr {
|
||||
if i != 0 {
|
||||
db.Statement.SQL.WriteString(", ")
|
||||
}
|
||||
db.Statement.SQL.WriteString(term.String())
|
||||
}
|
||||
}
|
||||
if limitExpr != nil {
|
||||
db.Statement.SQL.WriteString(" LIMIT ")
|
||||
db.Statement.SQL.WriteString(limitExpr.String())
|
||||
if offsetExpr != nil {
|
||||
db.Statement.SQL.WriteString(" OFFSET ")
|
||||
db.Statement.SQL.WriteString(offsetExpr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (plugin *Sharding) QueryX(db *gorm.DB) {
|
||||
var (
|
||||
err error
|
||||
ok bool
|
||||
shardingModel Model
|
||||
modelValue any
|
||||
tables []string
|
||||
rawVars []any
|
||||
refValue reflect.Value
|
||||
selectStmt *sqlparserX.Select
|
||||
stmt sqlparserX.Statement
|
||||
numOfTable int
|
||||
isCountStatement bool
|
||||
isPureCountStatement bool
|
||||
trackedBuffer *sqlparserX.TrackedBuffer
|
||||
funcExpr *sqlparserX.FuncExpr
|
||||
aliasedExpr *sqlparserX.AliasedExpr
|
||||
orderByExpr sqlparserX.OrderBy
|
||||
groupByExpr sqlparserX.GroupBy
|
||||
havingExpr *sqlparserX.Where
|
||||
limitExpr *sqlparserX.Limit
|
||||
)
|
||||
if db.Statement.Model != nil {
|
||||
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
||||
} else {
|
||||
if !db.Statement.ReflectValue.IsValid() {
|
||||
return
|
||||
}
|
||||
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
||||
}
|
||||
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
||||
refValue = reflect.Indirect(refValue)
|
||||
}
|
||||
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
||||
elemType := refValue.Type().Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
modelValue = reflect.New(elemType).Interface()
|
||||
} else {
|
||||
modelValue = refValue.Interface()
|
||||
}
|
||||
if shardingModel, ok = modelValue.(Model); !ok {
|
||||
return
|
||||
}
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
callbacks.BuildQuerySQL(db)
|
||||
}
|
||||
if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil {
|
||||
return
|
||||
}
|
||||
if selectStmt, ok = stmt.(*sqlparserX.Select); !ok {
|
||||
return
|
||||
}
|
||||
tables = shardingModel.ShardingTables(&Scope{
|
||||
db: db,
|
||||
stmtX: selectStmt,
|
||||
})
|
||||
numOfTable = len(tables)
|
||||
if numOfTable <= 1 {
|
||||
return
|
||||
}
|
||||
// 保存值
|
||||
rawVars = make([]any, 0, len(db.Statement.Vars))
|
||||
for _, v := range db.Statement.Vars {
|
||||
rawVars = append(rawVars, v)
|
||||
}
|
||||
// 替换语句
|
||||
if selectStmt.OrderBy != nil {
|
||||
orderByExpr = selectStmt.OrderBy
|
||||
selectStmt.OrderBy = nil
|
||||
}
|
||||
if selectStmt.GroupBy != nil {
|
||||
groupByExpr = selectStmt.GroupBy
|
||||
//selectStmt.GroupBy = nil
|
||||
}
|
||||
if selectStmt.Having != nil {
|
||||
havingExpr = selectStmt.Having
|
||||
//selectStmt.Having = nil
|
||||
}
|
||||
if selectStmt.Limit != nil {
|
||||
limitExpr = selectStmt.Limit
|
||||
selectStmt.Limit = nil
|
||||
}
|
||||
// 检查是否为COUNT语句
|
||||
//如果不是count的语句,添加order和group的支持
|
||||
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
||||
//这里处理的是报表的情况,再报表里面需要重写COUNT语句才能获取到正确的数量
|
||||
if v == true {
|
||||
isCountStatement = true
|
||||
isPureCountStatement = true
|
||||
}
|
||||
}
|
||||
//常规的COUNT逻辑
|
||||
if !isCountStatement && len(selectStmt.SelectExprs) == 1 {
|
||||
for _, expr := range selectStmt.SelectExprs {
|
||||
if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok {
|
||||
if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok {
|
||||
if funcExpr.Name.EqualString("count") {
|
||||
isCountStatement = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 重写SQL
|
||||
db.Statement.SQL.Reset()
|
||||
trackedBuffer = sqlparserX.NewTrackedBuffer(nil)
|
||||
|
||||
if isCountStatement {
|
||||
if isPureCountStatement {
|
||||
db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (")
|
||||
} else {
|
||||
db.Statement.SQL.WriteString("SELECT SUM(")
|
||||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||
db.Statement.SQL.WriteString(strings.Trim("count(*)", "`"))
|
||||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||
db.Statement.SQL.WriteString(") FROM (")
|
||||
}
|
||||
} else {
|
||||
if bs, ok := modelValue.(SelectBuilder); ok {
|
||||
columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs)
|
||||
if len(columns) > 0 {
|
||||
db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (")
|
||||
} else {
|
||||
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||||
}
|
||||
} else {
|
||||
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||||
}
|
||||
}
|
||||
for i, name := range tables {
|
||||
trackedBuffer.Reset()
|
||||
db.Statement.SQL.WriteByte('(')
|
||||
//赋值新的表名称
|
||||
selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{
|
||||
Expr: sqlparserX.TableName{
|
||||
Name: sqlparserX.NewTableIdent(name),
|
||||
},
|
||||
}}
|
||||
selectStmt.Format(trackedBuffer)
|
||||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||
db.Statement.SQL.WriteByte(')')
|
||||
if i < numOfTable-1 {
|
||||
if plugin.UnionAll {
|
||||
db.Statement.SQL.WriteString(" UNION ALL ")
|
||||
} else {
|
||||
db.Statement.SQL.WriteString(" UNION ")
|
||||
}
|
||||
}
|
||||
if i > 0 {
|
||||
//copy vars
|
||||
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
||||
}
|
||||
}
|
||||
db.Statement.SQL.WriteString(") tbl ")
|
||||
if !isCountStatement {
|
||||
//node.GroupBy, node.Having, node.OrderBy, node.Limit
|
||||
if groupByExpr != nil {
|
||||
trackedBuffer.Reset()
|
||||
groupByExpr.Format(trackedBuffer)
|
||||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||
}
|
||||
if havingExpr != nil {
|
||||
trackedBuffer.Reset()
|
||||
havingExpr.Format(trackedBuffer)
|
||||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||
}
|
||||
if orderByExpr != nil {
|
||||
trackedBuffer.Reset()
|
||||
orderByExpr.Format(trackedBuffer)
|
||||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||
}
|
||||
if limitExpr != nil {
|
||||
trackedBuffer.Reset()
|
||||
limitExpr.Format(trackedBuffer)
|
||||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func New() *Sharding {
|
||||
return &Sharding{
|
||||
UnionAll: true,
|
||||
QuoteChar: '`',
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package sharding
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/uole/sqlparser"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFormat(t *testing.T) {
|
||||
t.Log(fmt.Sprintf("%.2f%%", 0.5851888*100))
|
||||
}
|
||||
|
||||
func TestSharding_Query(t *testing.T) {
|
||||
//sql := "SELECT COUNT(*) FROM aaa"
|
||||
sql := "SELECT uid,SUM(IF(direction='inbound',1,0)) AS inbound_times,SUM(IF(direction='inbound',IF(answer_duration>0,1,0),0)) AS inbound_answer_times,SUM(IF(direction='outbound',1,0)) AS outbound_times,SUM(IF(direction='outbound',IF(answer_duration>0,1,0),0)) AS outbound_answer_times FROM `cdr_logs` WHERE ((`domain` = 'test.cc.echo.me' OR `domain` = 'default') AND `name` <> '') AND (`create_stamp` BETWEEN 1712505600 AND 1712505608) AND `name` IN ('a','b','c') GROUP BY `uid` having uid != '' ORDER BY create_stamp DESC LIMIT 15"
|
||||
stmt, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
selectStmt, ok := stmt.(*sqlparser.Select)
|
||||
if !ok {
|
||||
t.Error("not select stmt")
|
||||
return
|
||||
}
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
sqlparser.NewTableIdent("test").Format(buf)
|
||||
buf.Reset()
|
||||
selectStmt.Format(buf)
|
||||
t.Log("SQL")
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
t.Log("SELECT")
|
||||
selectStmt.SelectExprs.Format(buf)
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
t.Log("FROM")
|
||||
selectStmt.From.Format(buf)
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
t.Log("WHERE")
|
||||
selectStmt.Where.Format(buf)
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
t.Log("ORDER BY")
|
||||
selectStmt.OrderBy.Format(buf)
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
t.Log("GROUP BY")
|
||||
selectStmt.GroupBy.Format(buf)
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
t.Log("LIMIT")
|
||||
selectStmt.Limit.Format(buf)
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
t.Log("Having")
|
||||
selectStmt.Having.Format(buf)
|
||||
t.Log(buf.String())
|
||||
buf.Reset()
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package sharding
|
||||
|
||||
import (
|
||||
"context"
|
||||
sqlparserX "github.com/uole/sqlparser"
|
||||
)
|
||||
|
||||
const (
|
||||
TypeDatetime = "datetime"
|
||||
TypeHash = "hash"
|
||||
)
|
||||
|
||||
const (
|
||||
DateTypeYear = iota + 5
|
||||
DateTypeMonth
|
||||
DateTypeWeek
|
||||
DateTypeDay
|
||||
)
|
||||
|
||||
const (
|
||||
sceneCreate = "create"
|
||||
sceneUpdate = "update"
|
||||
sceneDelete = "delete"
|
||||
)
|
||||
|
||||
const (
|
||||
stmtCountKeyword = "count"
|
||||
)
|
||||
|
||||
type (
|
||||
Rule struct {
|
||||
Type string `json:"type"`
|
||||
Args int `json:"args"`
|
||||
}
|
||||
|
||||
Model interface {
|
||||
ShardingRule() Rule
|
||||
ShardingTable(scene string) string
|
||||
ShardingTables(scope *Scope) []string
|
||||
}
|
||||
|
||||
SelectBuilder interface {
|
||||
BuildSelect(ctx context.Context, expr sqlparserX.SelectExprs) []string
|
||||
}
|
||||
)
|
|
@ -0,0 +1,107 @@
|
|||
package validate
|
||||
|
||||
import (
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
"gorm.io/gorm"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
SkipValidations = "validations:skip_validations"
|
||||
)
|
||||
|
||||
var (
|
||||
scopeCtxKey = &validateScope{}
|
||||
telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
|
||||
)
|
||||
|
||||
type (
|
||||
validateScope struct {
|
||||
DB *gorm.DB
|
||||
Column string
|
||||
Domain string
|
||||
MultiDomain bool
|
||||
Model interface{}
|
||||
}
|
||||
|
||||
StructError struct {
|
||||
Tag string `json:"tag"`
|
||||
Column string `json:"column"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
validateRule struct {
|
||||
Rule string
|
||||
Value string
|
||||
Valid bool
|
||||
}
|
||||
)
|
||||
|
||||
func (err *StructError) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func newRule(ss ...string) *validateRule {
|
||||
v := &validateRule{
|
||||
Valid: true,
|
||||
}
|
||||
if len(ss) == 1 {
|
||||
v.Rule = ss[0]
|
||||
} else if len(ss) >= 2 {
|
||||
v.Rule = ss[0]
|
||||
v.Value = ss[1]
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// formatError 格式化错误消息
|
||||
func formatError(rule types.Rule, scm *types.Schema, tag string) string {
|
||||
var s string
|
||||
switch tag {
|
||||
case "db_unique":
|
||||
s = scm.Label + "值已经存在."
|
||||
break
|
||||
case "required":
|
||||
s = scm.Label + "值不能为空."
|
||||
case "max":
|
||||
if scm.Type == "string" {
|
||||
s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
|
||||
} else {
|
||||
s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
|
||||
}
|
||||
case "min":
|
||||
if scm.Type == "string" {
|
||||
s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
|
||||
} else {
|
||||
s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// isEmpty 判断值是否为空
|
||||
func isEmpty(val any) bool {
|
||||
if val == nil {
|
||||
return true
|
||||
}
|
||||
v := reflect.ValueOf(val)
|
||||
switch v.Kind() {
|
||||
case reflect.String, reflect.Array:
|
||||
return v.Len() == 0
|
||||
case reflect.Map, reflect.Slice:
|
||||
return v.IsNil() || v.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return v.IsNil()
|
||||
}
|
||||
return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface())
|
||||
}
|
|
@ -0,0 +1,275 @@
|
|||
package validate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"git.nobla.cn/golang/rest"
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
validator "github.com/go-playground/validator/v10"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Validate struct {
|
||||
validator *validator.Validate
|
||||
}
|
||||
|
||||
func (validate *Validate) Name() string {
|
||||
return "gorm:validate"
|
||||
}
|
||||
|
||||
func (validate *Validate) telephoneValidate(ctx context.Context, fl validator.FieldLevel) bool {
|
||||
val := fmt.Sprint(fl.Field().Interface())
|
||||
return telephoneRegex.MatchString(val)
|
||||
}
|
||||
|
||||
func (validate *Validate) uniqueValidate(ctx context.Context, fl validator.FieldLevel) bool {
|
||||
var (
|
||||
scope *validateScope
|
||||
ok bool
|
||||
count int64
|
||||
field *schema.Field
|
||||
primaryKeyValue reflect.Value
|
||||
)
|
||||
val := fl.Field().Interface()
|
||||
if scope, ok = ctx.Value(scopeCtxKey).(*validateScope); !ok {
|
||||
return true
|
||||
}
|
||||
if len(scope.DB.Statement.Schema.PrimaryFields) > 0 {
|
||||
field = scope.DB.Statement.Schema.PrimaryFields[0]
|
||||
primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model))
|
||||
for _, n := range field.BindNames {
|
||||
primaryKeyValue = primaryKeyValue.FieldByName(n)
|
||||
}
|
||||
}
|
||||
sess := scope.DB.Session(&gorm.Session{NewDB: true})
|
||||
if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil {
|
||||
//多域校验
|
||||
if scope.MultiDomain && scope.Domain != "" {
|
||||
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ? AND domain = ?", val, primaryKeyValue.Interface(), scope.Domain).Count(&count)
|
||||
} else {
|
||||
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count)
|
||||
}
|
||||
} else {
|
||||
if scope.MultiDomain && scope.Domain != "" {
|
||||
sess.Model(scope.Model).Where(scope.Column+"=? AND domain = ?", val, scope.Domain).Count(&count)
|
||||
} else {
|
||||
sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count)
|
||||
}
|
||||
}
|
||||
if count > 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (validate *Validate) grantRules(scm *types.Schema, scenario string, rule types.Rule) []*validateRule {
|
||||
rules := make([]*validateRule, 0, 5)
|
||||
if rule.Min != 0 {
|
||||
rules = append(rules, newRule("min", strconv.Itoa(rule.Min)))
|
||||
}
|
||||
if rule.Max != 0 {
|
||||
rules = append(rules, newRule("max", strconv.Itoa(rule.Max)))
|
||||
}
|
||||
//主键不做唯一判断
|
||||
if rule.Unique && !scm.Attribute.PrimaryKey {
|
||||
rules = append(rules, newRule("db_unique"))
|
||||
}
|
||||
if rule.Type != "" {
|
||||
rules = append(rules, newRule(rule.Type))
|
||||
}
|
||||
if rule.Required != nil && len(rule.Required) > 0 {
|
||||
for _, v := range rule.Required {
|
||||
if v == scenario {
|
||||
rules = append(rules, newRule("required"))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
func (validate *Validate) buildRules(rs []*validateRule) string {
|
||||
var sb strings.Builder
|
||||
for _, r := range rs {
|
||||
if !r.Valid {
|
||||
continue
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
if r.Value == "" {
|
||||
sb.WriteString(r.Rule)
|
||||
} else {
|
||||
sb.WriteString(r.Rule + "=" + r.Value)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (validate *Validate) findRule(name string, rules []*validateRule) *validateRule {
|
||||
for _, r := range rules {
|
||||
if r.Rule == name {
|
||||
return r
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validate *Validate) Initialize(db *gorm.DB) (err error) {
|
||||
validate.validator = validator.New()
|
||||
if err = db.Callback().Create().Before("gorm:before_create").Register("model_validate", validate.Validate); err != nil {
|
||||
return
|
||||
}
|
||||
if err = db.Callback().Create().Before("gorm:before_update").Register("model_validate", validate.Validate); err != nil {
|
||||
return
|
||||
}
|
||||
if err = validate.validator.RegisterValidationCtx("telephone", validate.telephoneValidate); err != nil {
|
||||
return
|
||||
}
|
||||
if err = validate.validator.RegisterValidationCtx("db_unique", validate.uniqueValidate); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (validate *Validate) inArray(v any, vs []any) bool {
|
||||
sv := fmt.Sprint(v)
|
||||
for _, s := range vs {
|
||||
if fmt.Sprint(s) == sv {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isVisible 判断字段是否需要显示
|
||||
func (validate *Validate) isVisible(stmt *gorm.Statement, scm *types.Schema) bool {
|
||||
if len(scm.Attribute.Visible) <= 0 {
|
||||
return true
|
||||
}
|
||||
for _, row := range scm.Attribute.Visible {
|
||||
if len(row.Values) == 0 {
|
||||
continue
|
||||
}
|
||||
targetField := stmt.Schema.LookUpField(row.Column)
|
||||
if targetField == nil {
|
||||
continue
|
||||
}
|
||||
targetValue := stmt.ReflectValue.FieldByName(targetField.Name)
|
||||
if !targetValue.IsValid() {
|
||||
return false
|
||||
}
|
||||
if !validate.inArray(targetValue.Interface(), row.Values) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Validate 校验字段
|
||||
func (validate *Validate) Validate(db *gorm.DB) {
|
||||
var (
|
||||
ok bool
|
||||
err error
|
||||
rules []*validateRule
|
||||
stmt *gorm.Statement
|
||||
skipValidate bool
|
||||
multiDomain bool
|
||||
value reflect.Value
|
||||
runtimeScope *types.RuntimeScope
|
||||
)
|
||||
if result, ok := db.Get(SkipValidations); ok && result.(bool) {
|
||||
return
|
||||
}
|
||||
stmt = db.Statement
|
||||
if stmt.Model == nil {
|
||||
return
|
||||
}
|
||||
if db.Statement.Context == nil {
|
||||
return
|
||||
}
|
||||
if val := db.Statement.Context.Value(rest.RuntimeScopeKey); val != nil {
|
||||
if runtimeScope, ok = val.(*types.RuntimeScope); !ok {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
return
|
||||
}
|
||||
if runtimeScope.Schemas == nil {
|
||||
return
|
||||
}
|
||||
if stmt.Schema.LookUpField("domain") != nil {
|
||||
multiDomain = true
|
||||
}
|
||||
for _, row := range runtimeScope.Schemas {
|
||||
//如果字段隐藏,那么就不进行校验
|
||||
if !validate.isVisible(stmt, row) {
|
||||
continue
|
||||
}
|
||||
if rules = validate.grantRules(row, runtimeScope.Scenario, row.Rule); len(rules) <= 0 {
|
||||
continue
|
||||
}
|
||||
field := stmt.Schema.LookUpField(row.Column)
|
||||
if field == nil {
|
||||
continue
|
||||
}
|
||||
value = stmt.ReflectValue.FieldByName(field.Name)
|
||||
if !value.IsValid() {
|
||||
continue
|
||||
}
|
||||
skipValidate = false
|
||||
if r := validate.findRule("required", rules); r != nil {
|
||||
if value.Interface() != nil {
|
||||
vType := reflect.ValueOf(value.Interface())
|
||||
switch vType.Kind() {
|
||||
case reflect.Bool:
|
||||
skipValidate = true
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
skipValidate = true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
skipValidate = true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
skipValidate = true
|
||||
default:
|
||||
skipValidate = false
|
||||
}
|
||||
}
|
||||
if skipValidate {
|
||||
r.Valid = false
|
||||
}
|
||||
} else {
|
||||
if isEmpty(value.Interface()) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
ctx := context.WithValue(db.Statement.Context, scopeCtxKey, &validateScope{
|
||||
DB: db,
|
||||
Column: row.Column,
|
||||
Model: stmt.Model,
|
||||
Domain: runtimeScope.Domain,
|
||||
MultiDomain: multiDomain,
|
||||
})
|
||||
if err = validate.validator.VarCtx(ctx, value.Interface(), validate.buildRules(rules)); err != nil {
|
||||
if errs, ok := err.(validator.ValidationErrors); ok {
|
||||
for _, e := range errs {
|
||||
_ = db.AddError(&StructError{
|
||||
Tag: e.Tag(),
|
||||
Column: row.Column,
|
||||
Message: formatError(row.Rule, row, e.Tag()),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
_ = db.AddError(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func New() *Validate {
|
||||
return &Validate{}
|
||||
}
|
|
@ -0,0 +1,405 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
"gorm.io/gorm"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type (
|
||||
Query struct {
|
||||
db *gorm.DB
|
||||
condition string
|
||||
fields []string
|
||||
params []interface{}
|
||||
table string
|
||||
joins []join
|
||||
orderBy []string
|
||||
groupBy []string
|
||||
modelValue any
|
||||
limit int
|
||||
offset int
|
||||
}
|
||||
|
||||
condition struct {
|
||||
Field string `json:"field"`
|
||||
Value interface{} `json:"value"`
|
||||
Expr string `json:"expr"`
|
||||
}
|
||||
|
||||
join struct {
|
||||
Table string
|
||||
Direction string
|
||||
Conditions []*condition
|
||||
}
|
||||
)
|
||||
|
||||
func (cond *condition) WithExpr(v string) *condition {
|
||||
cond.Expr = v
|
||||
return cond
|
||||
}
|
||||
|
||||
func (query *Query) Model() any {
|
||||
return query.modelValue
|
||||
}
|
||||
|
||||
func (query *Query) compile() (*gorm.DB, error) {
|
||||
db := query.db
|
||||
if query.condition != "" {
|
||||
db = db.Where(query.condition, query.params...)
|
||||
}
|
||||
if query.fields != nil {
|
||||
db = db.Select(strings.Join(query.fields, ","))
|
||||
}
|
||||
if query.table != "" {
|
||||
db = db.Table(query.table)
|
||||
}
|
||||
if query.joins != nil && len(query.joins) > 0 {
|
||||
for _, joinEntity := range query.joins {
|
||||
cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...)
|
||||
db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
|
||||
}
|
||||
}
|
||||
if query.orderBy != nil && len(query.orderBy) > 0 {
|
||||
db = db.Order(strings.Join(query.orderBy, ","))
|
||||
}
|
||||
if query.groupBy != nil && len(query.groupBy) > 0 {
|
||||
db = db.Group(strings.Join(query.groupBy, ","))
|
||||
}
|
||||
if query.offset > 0 {
|
||||
db = db.Offset(query.offset)
|
||||
}
|
||||
if query.limit > 0 {
|
||||
db = db.Limit(query.limit)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (query *Query) decodeValue(v any) string {
|
||||
refVal := reflect.Indirect(reflect.ValueOf(v))
|
||||
switch refVal.Kind() {
|
||||
case reflect.Bool:
|
||||
if refVal.Bool() {
|
||||
return "1"
|
||||
} else {
|
||||
return "0"
|
||||
}
|
||||
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
return strconv.FormatInt(refVal.Int(), 10)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
|
||||
case reflect.String:
|
||||
return "'" + refVal.String() + "'"
|
||||
default:
|
||||
return fmt.Sprint(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) {
|
||||
var (
|
||||
sb strings.Builder
|
||||
)
|
||||
params = make([]interface{}, 0)
|
||||
for _, cond := range conditions {
|
||||
if filter {
|
||||
if isEmpty(cond.Value) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if cond.Expr == "" {
|
||||
cond.Expr = "="
|
||||
}
|
||||
switch strings.ToUpper(cond.Expr) {
|
||||
case "=", "<>", ">", "<", ">=", "<=", "!=":
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString(" " + operator + " ")
|
||||
}
|
||||
if cond.Expr == "=" && cond.Value == nil {
|
||||
sb.WriteString("`" + cond.Field + "` IS NULL")
|
||||
} else {
|
||||
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
|
||||
params = append(params, cond.Value)
|
||||
}
|
||||
case "LIKE":
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString(" " + operator + " ")
|
||||
}
|
||||
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
|
||||
sb.WriteString("`" + cond.Field + "` LIKE ?")
|
||||
params = append(params, cond.Value)
|
||||
case "IN":
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString(" " + operator + " ")
|
||||
}
|
||||
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
|
||||
switch refVal.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
ss := make([]string, refVal.Len())
|
||||
for i := 0; i < refVal.Len(); i++ {
|
||||
ss[i] = query.decodeValue(refVal.Index(i))
|
||||
}
|
||||
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
|
||||
case reflect.String:
|
||||
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
|
||||
default:
|
||||
}
|
||||
case "BETWEEN":
|
||||
refVal := reflect.ValueOf(cond.Value)
|
||||
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
|
||||
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
|
||||
params = append(params, refVal.Index(0), refVal.Index(1))
|
||||
}
|
||||
}
|
||||
}
|
||||
str = sb.String()
|
||||
return
|
||||
}
|
||||
|
||||
func (query *Query) Select(fields ...string) *Query {
|
||||
if query.fields == nil {
|
||||
query.fields = fields
|
||||
} else {
|
||||
query.fields = append(query.fields, fields...)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) From(table string) *Query {
|
||||
query.table = table
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) LeftJoin(table string, conditions ...*condition) *Query {
|
||||
query.joins = append(query.joins, join{
|
||||
Table: table,
|
||||
Direction: "LEFT",
|
||||
Conditions: conditions,
|
||||
})
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) RightJoin(table string, conditions ...*condition) *Query {
|
||||
query.joins = append(query.joins, join{
|
||||
Table: table,
|
||||
Direction: "RIGHT",
|
||||
Conditions: conditions,
|
||||
})
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) InnerJoin(table string, conditions ...*condition) *Query {
|
||||
query.joins = append(query.joins, join{
|
||||
Table: table,
|
||||
Direction: "INNER",
|
||||
Conditions: conditions,
|
||||
})
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) AddCondition(expr, column string, val any) {
|
||||
if expr == "" {
|
||||
expr = "="
|
||||
}
|
||||
query.AndWhere(newConditionWithOperator(expr, column, val))
|
||||
}
|
||||
|
||||
func (query *Query) AndWhere(conditions ...*condition) *Query {
|
||||
length := len(conditions)
|
||||
if length == 0 {
|
||||
return query
|
||||
}
|
||||
cs, ps := query.buildConditions("AND", false, conditions...)
|
||||
if cs == "" {
|
||||
return query
|
||||
}
|
||||
query.params = append(query.params, ps...)
|
||||
if query.condition == "" {
|
||||
query.condition = cs
|
||||
} else {
|
||||
query.condition += " AND (" + cs + ")"
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) AndFilterWhere(conditions ...*condition) *Query {
|
||||
length := len(conditions)
|
||||
if length == 0 {
|
||||
return query
|
||||
}
|
||||
cs, ps := query.buildConditions("AND", true, conditions...)
|
||||
if cs == "" {
|
||||
return query
|
||||
}
|
||||
query.params = append(query.params, ps...)
|
||||
if query.condition == "" {
|
||||
query.condition = cs
|
||||
} else {
|
||||
query.condition += " AND " + cs
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) OrWhere(conditions ...*condition) *Query {
|
||||
length := len(conditions)
|
||||
if length == 0 {
|
||||
return query
|
||||
}
|
||||
cs, ps := query.buildConditions("OR", false, conditions...)
|
||||
if cs == "" {
|
||||
return query
|
||||
}
|
||||
query.params = append(query.params, ps...)
|
||||
if query.condition == "" {
|
||||
query.condition = cs
|
||||
} else {
|
||||
query.condition += " AND (" + cs + ")"
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) OrFilterWhere(conditions ...*condition) *Query {
|
||||
length := len(conditions)
|
||||
if length == 0 {
|
||||
return query
|
||||
}
|
||||
cs, ps := query.buildConditions("OR", true, conditions...)
|
||||
if cs == "" {
|
||||
return query
|
||||
}
|
||||
query.params = append(query.params, ps...)
|
||||
if query.condition == "" {
|
||||
query.condition = cs
|
||||
} else {
|
||||
query.condition += " AND (" + cs + ")"
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) GroupBy(cols ...string) *Query {
|
||||
query.groupBy = append(query.groupBy, cols...)
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) OrderBy(col, direction string) *Query {
|
||||
direction = strings.ToUpper(direction)
|
||||
if direction == "" || !(direction == "ASC" || direction == "DESC") {
|
||||
direction = "ASC"
|
||||
}
|
||||
query.orderBy = append(query.orderBy, col+" "+direction)
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) Offset(i int) *Query {
|
||||
query.offset = i
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) Limit(i int) *Query {
|
||||
query.limit = i
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) ResetSelect() *Query {
|
||||
query.fields = make([]string, 0)
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *Query) Count(v interface{}) (i int64) {
|
||||
var (
|
||||
db *gorm.DB
|
||||
err error
|
||||
)
|
||||
if db, err = query.compile(); err != nil {
|
||||
return
|
||||
} else {
|
||||
if v != nil {
|
||||
refVal := reflect.ValueOf(v)
|
||||
switch refVal.Kind() {
|
||||
case reflect.String:
|
||||
if query.table == "" {
|
||||
err = db.Table(refVal.String()).Count(&i).Error
|
||||
} else {
|
||||
err = db.Table(query.table).Count(&i).Error
|
||||
}
|
||||
default:
|
||||
//如果是报表的模型,这的话手动构建一条SQL语句
|
||||
if reporter, ok := v.(types.Reporter); ok {
|
||||
sqlRes := &sqlCountResponse{}
|
||||
childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true)
|
||||
db.WithContext(childCtx).Model(reporter).First(sqlRes)
|
||||
i = sqlRes.Count
|
||||
} else {
|
||||
err = db.Model(v).Count(&i).Error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (query *Query) One(v interface{}) (err error) {
|
||||
var (
|
||||
db *gorm.DB
|
||||
)
|
||||
if db, err = query.compile(); err != nil {
|
||||
return
|
||||
} else {
|
||||
err = db.First(v).Error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (query *Query) All(v interface{}) (err error) {
|
||||
var (
|
||||
db *gorm.DB
|
||||
)
|
||||
if db, err = query.compile(); err != nil {
|
||||
return
|
||||
} else {
|
||||
err = db.Find(v).Error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func NewCondition(column, opera string, value any) *condition {
|
||||
if opera == "" {
|
||||
opera = "="
|
||||
}
|
||||
return &condition{
|
||||
Field: column,
|
||||
Value: value,
|
||||
Expr: opera,
|
||||
}
|
||||
}
|
||||
|
||||
func newCondition(field string, value interface{}) *condition {
|
||||
return &condition{
|
||||
Field: field,
|
||||
Value: value,
|
||||
Expr: "=",
|
||||
}
|
||||
}
|
||||
|
||||
func newConditionWithOperator(operator, field string, value interface{}) *condition {
|
||||
cond := &condition{
|
||||
Field: field,
|
||||
Value: value,
|
||||
Expr: operator,
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func NewQuery(db *gorm.DB, model any) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
modelValue: model,
|
||||
params: make([]interface{}, 0),
|
||||
orderBy: make([]string, 0),
|
||||
groupBy: make([]string, 0),
|
||||
joins: make([]join, 0),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,753 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.nobla.cn/golang/kos/util/arrays"
|
||||
"git.nobla.cn/golang/kos/util/reflection"
|
||||
"git.nobla.cn/golang/rest/inflector"
|
||||
"git.nobla.cn/golang/rest/types"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
modelEntities []*Model
|
||||
httpRouter types.HttpRouter
|
||||
hookMgr *hookManager
|
||||
timeKind = reflect.TypeOf(time.Time{}).Kind()
|
||||
timePtrKind = reflect.TypeOf(&time.Time{}).Kind()
|
||||
|
||||
matchEnums = []string{types.MatchExactly, types.MatchFuzzy}
|
||||
)
|
||||
|
||||
var (
|
||||
allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
|
||||
)
|
||||
|
||||
func init() {
|
||||
hookMgr = &hookManager{}
|
||||
modelEntities = make([]*Model, 0)
|
||||
}
|
||||
|
||||
// cloneStmt 从指定的db克隆一个 Statement 对象
|
||||
func cloneStmt(db *gorm.DB) *gorm.Statement {
|
||||
return &gorm.Statement{
|
||||
DB: db,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
}
|
||||
|
||||
// dataTypeOf 推断数据的类型
|
||||
func dataTypeOf(field *schema.Field) string {
|
||||
var dataType string
|
||||
reflectType := field.FieldType
|
||||
for reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
if dataType = field.Tag.Get("type"); dataType != "" {
|
||||
return dataType
|
||||
}
|
||||
dataValue := reflect.Indirect(reflect.New(reflectType))
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
dataType = types.TypeBoolean
|
||||
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
dataType = types.TypeInteger
|
||||
case reflect.Float32, reflect.Float64:
|
||||
dataType = types.TypeFloat
|
||||
default:
|
||||
dataType = types.TypeString
|
||||
}
|
||||
return dataType
|
||||
}
|
||||
|
||||
// dataFormatOf 推断数据的格式
|
||||
func dataFormatOf(field *schema.Field) string {
|
||||
var format string
|
||||
format = field.Tag.Get("format")
|
||||
if format != "" {
|
||||
return format
|
||||
}
|
||||
//如果有枚举值,直接设置为下拉类型
|
||||
enum := field.Tag.Get("enum")
|
||||
if enum != "" {
|
||||
return types.FormatDropdown
|
||||
}
|
||||
reflectType := field.FieldType
|
||||
for reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
//时间处理
|
||||
dataValue := reflect.Indirect(reflect.New(reflectType))
|
||||
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return types.FormatTimestamp
|
||||
default:
|
||||
return types.FormatDatetime
|
||||
}
|
||||
}
|
||||
if strings.Contains(strings.ToLower(field.Name), "pass") {
|
||||
return types.FormatPassword
|
||||
}
|
||||
switch dataValue.Kind() {
|
||||
case timeKind, timePtrKind:
|
||||
format = types.FormatDatetime
|
||||
case reflect.Bool:
|
||||
format = types.FormatBoolean
|
||||
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
format = types.FormatInteger
|
||||
case reflect.Float32, reflect.Float64:
|
||||
format = types.FormatFloat
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
format = types.FormatDatetime
|
||||
}
|
||||
default:
|
||||
if field.Size >= 1024 {
|
||||
format = types.FormatText
|
||||
} else {
|
||||
format = types.FormatString
|
||||
}
|
||||
}
|
||||
return format
|
||||
}
|
||||
|
||||
// fieldName 生成字段名称
|
||||
func fieldName(name string) string {
|
||||
tokens := strings.Split(name, "_")
|
||||
for i, s := range tokens {
|
||||
tokens[i] = strings.Title(s)
|
||||
}
|
||||
return strings.Join(tokens, " ")
|
||||
}
|
||||
|
||||
// fieldNative 判断是否为原始字段
|
||||
func fieldNative(field *schema.Field) uint8 {
|
||||
if _, ok := field.Tag.Lookup("virtual"); ok {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// fieldRule 返回字段规则
|
||||
func fieldRule(field *schema.Field) types.Rule {
|
||||
r := types.Rule{
|
||||
Required: []string{},
|
||||
}
|
||||
if field.GORMDataType == schema.String {
|
||||
r.Max = field.Size
|
||||
}
|
||||
if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint {
|
||||
r.Max = field.Scale
|
||||
}
|
||||
|
||||
rs := field.Tag.Get("rule")
|
||||
if rs != "" {
|
||||
ss := strings.Split(rs, ";")
|
||||
for _, s := range ss {
|
||||
vs := strings.SplitN(s, ":", 2)
|
||||
ls := len(vs)
|
||||
if ls == 0 {
|
||||
continue
|
||||
}
|
||||
switch vs[0] {
|
||||
case "required", "require":
|
||||
if ls > 1 {
|
||||
bs := strings.Split(vs[1], ",")
|
||||
for _, i := range bs {
|
||||
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
|
||||
r.Required = append(r.Required, i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate}
|
||||
}
|
||||
case "unique":
|
||||
r.Unique = true
|
||||
case "regexp":
|
||||
if ls > 1 {
|
||||
r.Regular = vs[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if field.PrimaryKey {
|
||||
r.Unique = true
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// fieldScenario 字段Scenarios
|
||||
func fieldScenario(index int, field *schema.Field) types.Scenarios {
|
||||
var ss types.Scenarios
|
||||
if v, ok := field.Tag.Lookup("scenarios"); ok {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
ss = strings.Split(v, ";")
|
||||
}
|
||||
} else {
|
||||
if field.PrimaryKey {
|
||||
ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport}
|
||||
} else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
|
||||
ss = []string{types.ScenarioList}
|
||||
} else if field.Name == "DeletedAt" || field.Name == "Namespace" {
|
||||
//不添加任何显示场景
|
||||
ss = []string{}
|
||||
} else {
|
||||
if index < 10 {
|
||||
//高级字段只配置一些简单的场景
|
||||
ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
|
||||
} else {
|
||||
//高级字段只配置一些简单的场景
|
||||
ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// fieldPosition 字段的排序位置
|
||||
func fieldPosition(field *schema.Field, i int) int {
|
||||
s := field.Tag.Get("position")
|
||||
n, _ := strconv.Atoi(s)
|
||||
if n > 0 {
|
||||
return n
|
||||
}
|
||||
return i + 100
|
||||
}
|
||||
|
||||
// fieldAttribute 字段属性
|
||||
func fieldAttribute(field *schema.Field) types.Attribute {
|
||||
attr := types.Attribute{
|
||||
Match: types.MatchFuzzy,
|
||||
PrimaryKey: field.PrimaryKey,
|
||||
DefaultValue: field.DefaultValue,
|
||||
Readonly: []string{},
|
||||
Disable: []string{},
|
||||
Visible: make([]types.VisibleCondition, 0),
|
||||
Values: make([]types.EnumValue, 0),
|
||||
Live: types.LiveValue{},
|
||||
}
|
||||
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
|
||||
attr.EndOfNow = true
|
||||
}
|
||||
//赋值属性
|
||||
props := field.Tag.Get("props")
|
||||
if props != "" {
|
||||
vs := strings.Split(props, ";")
|
||||
for _, str := range vs {
|
||||
kv := strings.SplitN(str, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
sv := strings.TrimSpace(kv[1])
|
||||
switch strings.ToLower(strings.TrimSpace(kv[0])) {
|
||||
case "icon":
|
||||
attr.Icon = sv
|
||||
case "match":
|
||||
if arrays.Exists(sv, matchEnums) {
|
||||
attr.Match = sv
|
||||
}
|
||||
case "endofnow", "end_of_now":
|
||||
if ok, _ := strconv.ParseBool(sv); ok {
|
||||
attr.EndOfNow = true
|
||||
}
|
||||
case "invisible":
|
||||
if ok, _ := strconv.ParseBool(sv); ok {
|
||||
attr.Invisible = true
|
||||
}
|
||||
case "suffix":
|
||||
attr.Suffix = sv
|
||||
case "tag":
|
||||
attr.Tag = sv
|
||||
case "tooltip":
|
||||
attr.Tooltip = sv
|
||||
case "uploadurl", "uploaduri", "upload_url", "upload_uri":
|
||||
attr.UploadUrl = sv
|
||||
case "description":
|
||||
attr.Description = sv
|
||||
case "readonly":
|
||||
bs := strings.Split(sv, ",")
|
||||
for _, i := range bs {
|
||||
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
|
||||
attr.Readonly = append(attr.Readonly, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//live的赋值
|
||||
live := field.Tag.Get("live")
|
||||
if live != "" {
|
||||
attr.Live.Enable = true
|
||||
vs := strings.Split(live, ";")
|
||||
for _, str := range vs {
|
||||
kv := strings.SplitN(str, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
switch kv[0] {
|
||||
case "method":
|
||||
attr.Live.Method = kv[1]
|
||||
case "type":
|
||||
if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader {
|
||||
attr.Live.Type = kv[1]
|
||||
} else {
|
||||
attr.Live.Type = types.LiveTypeDropdown
|
||||
}
|
||||
case "url", "uri":
|
||||
attr.Live.Url = kv[1]
|
||||
case "columns":
|
||||
attr.Live.Columns = strings.Split(kv[1], ",")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dropdown := field.Tag.Get("dropdown")
|
||||
if dropdown != "" {
|
||||
attr.DropdownOption = &types.DropdownOption{}
|
||||
vs := strings.Split(dropdown, ";")
|
||||
for _, str := range vs {
|
||||
kv := strings.SplitN(str, ":", 2)
|
||||
if len(kv) == 0 {
|
||||
continue
|
||||
}
|
||||
switch kv[0] {
|
||||
case "created":
|
||||
attr.DropdownOption.Created = true
|
||||
case "filterable":
|
||||
attr.DropdownOption.Filterable = true
|
||||
case "autocomplete":
|
||||
attr.DropdownOption.Autocomplete = true
|
||||
case "default_first":
|
||||
attr.DropdownOption.DefaultFirst = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//显示条件
|
||||
conditions := field.Tag.Get("condition")
|
||||
if conditions != "" {
|
||||
vs := strings.Split(conditions, ";")
|
||||
for _, str := range vs {
|
||||
kv := strings.SplitN(str, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
cond := types.VisibleCondition{
|
||||
Column: kv[0],
|
||||
Values: make([]any, 0),
|
||||
}
|
||||
vv := strings.Split(kv[1], ",")
|
||||
for _, x := range vv {
|
||||
x = strings.TrimSpace(x)
|
||||
if x == "" {
|
||||
continue
|
||||
}
|
||||
cond.Values = append(cond.Values, x)
|
||||
}
|
||||
attr.Visible = append(attr.Visible, cond)
|
||||
}
|
||||
}
|
||||
//赋值枚举值
|
||||
enumns := field.Tag.Get("enum")
|
||||
if enumns != "" {
|
||||
vs := strings.Split(enumns, ";")
|
||||
for _, str := range vs {
|
||||
kv := strings.SplitN(str, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
fv := types.EnumValue{Value: kv[0]}
|
||||
//颜色分隔符
|
||||
if pos := strings.IndexByte(kv[1], '#'); pos > -1 {
|
||||
fv.Label = kv[1][:pos]
|
||||
fv.Color = kv[1][pos:]
|
||||
} else {
|
||||
fv.Label = kv[1]
|
||||
}
|
||||
attr.Values = append(attr.Values, fv)
|
||||
}
|
||||
}
|
||||
if !field.Creatable {
|
||||
attr.Disable = append(attr.Disable, types.ScenarioCreate)
|
||||
}
|
||||
if !field.Updatable {
|
||||
attr.Disable = append(attr.Disable, types.ScenarioUpdate)
|
||||
}
|
||||
attr.Tooltip = field.Comment
|
||||
return attr
|
||||
}
|
||||
|
||||
// autoMigrate 自动合并字段
|
||||
func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) {
|
||||
var (
|
||||
pos int
|
||||
columnName string
|
||||
columnIsExists bool
|
||||
columnLabel string
|
||||
schemas []*types.Schema
|
||||
models []*types.Schema
|
||||
stmt *gorm.Statement
|
||||
)
|
||||
|
||||
stmt = cloneStmt(db)
|
||||
if err = stmt.Parse(model); err != nil {
|
||||
return
|
||||
}
|
||||
if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil {
|
||||
return
|
||||
}
|
||||
if len(schemas) > 0 {
|
||||
pos = len(schemas)
|
||||
}
|
||||
models = make([]*types.Schema, 0)
|
||||
for index, field := range stmt.Schema.Fields {
|
||||
columnName = field.DBName
|
||||
if columnName == "-" {
|
||||
continue
|
||||
}
|
||||
if columnName == "" {
|
||||
columnName = field.Name
|
||||
}
|
||||
columnIsExists = false
|
||||
for _, sm := range schemas {
|
||||
if sm.Column == columnName {
|
||||
columnIsExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if columnIsExists {
|
||||
continue
|
||||
}
|
||||
columnLabel = field.Tag.Get("comment")
|
||||
if columnLabel == "" {
|
||||
columnLabel = fieldName(field.DBName)
|
||||
}
|
||||
isPrimaryKey := uint8(0)
|
||||
if field.PrimaryKey {
|
||||
isPrimaryKey = 1
|
||||
}
|
||||
schemaModel := &types.Schema{
|
||||
Domain: defaultDomain,
|
||||
ModuleName: module,
|
||||
TableName: stmt.Table,
|
||||
Enable: 1,
|
||||
Column: columnName,
|
||||
Label: columnLabel,
|
||||
Type: strings.ToLower(dataTypeOf(field)),
|
||||
Format: strings.ToLower(dataFormatOf(field)),
|
||||
Native: fieldNative(field),
|
||||
IsPrimaryKey: isPrimaryKey,
|
||||
Rule: fieldRule(field),
|
||||
Scenarios: fieldScenario(index, field),
|
||||
Attribute: fieldAttribute(field),
|
||||
Position: fieldPosition(field, pos),
|
||||
}
|
||||
//如果启用了在线调取接口功能,那么设置一下字段的format格式
|
||||
if schemaModel.Attribute.Live.Enable {
|
||||
if schemaModel.Attribute.Live.Type != "" {
|
||||
schemaModel.Format = schemaModel.Attribute.Live.Type
|
||||
}
|
||||
}
|
||||
models = append(models, schemaModel)
|
||||
pos++
|
||||
}
|
||||
if len(models) > 0 {
|
||||
err = db.Create(models).Error
|
||||
}
|
||||
naming = stmt.Table
|
||||
return
|
||||
}
|
||||
|
||||
// SetHttpRouter 设置HTTP路由
|
||||
func SetHttpRouter(router types.HttpRouter) {
|
||||
httpRouter = router
|
||||
}
|
||||
|
||||
// AutoMigrate 自动合并表的schema
|
||||
func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) {
|
||||
var (
|
||||
opts *Options
|
||||
table string
|
||||
router types.HttpRouter
|
||||
)
|
||||
opts = &Options{}
|
||||
for _, cb := range cbs {
|
||||
cb(opts)
|
||||
}
|
||||
if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil {
|
||||
return
|
||||
}
|
||||
//路由模块处理
|
||||
modelValue := newModel(model, db, types.Naming{
|
||||
Pluralize: inflector.Pluralize(table),
|
||||
Singular: inflector.Singularize(table),
|
||||
ModuleName: opts.moduleName,
|
||||
TableName: table,
|
||||
})
|
||||
modelValue.hookMgr = hookMgr
|
||||
modelValue.schemaLookup = VisibleSchemas
|
||||
if opts.router != nil {
|
||||
router = opts.router
|
||||
}
|
||||
if router == nil && httpRouter != nil {
|
||||
router = httpRouter
|
||||
}
|
||||
if opts.urlPrefix != "" {
|
||||
modelValue.urlPrefix = opts.urlPrefix
|
||||
}
|
||||
//路由绑定操作
|
||||
if router != nil {
|
||||
if modelValue.hasScenario(types.ScenarioList) {
|
||||
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search)
|
||||
}
|
||||
if modelValue.hasScenario(types.ScenarioCreate) {
|
||||
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create)
|
||||
}
|
||||
if modelValue.hasScenario(types.ScenarioUpdate) {
|
||||
router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update)
|
||||
}
|
||||
if modelValue.hasScenario(types.ScenarioDelete) {
|
||||
router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete)
|
||||
}
|
||||
if modelValue.hasScenario(types.ScenarioView) {
|
||||
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View)
|
||||
}
|
||||
if modelValue.hasScenario(types.ScenarioExport) {
|
||||
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export)
|
||||
}
|
||||
if modelValue.hasScenario(types.ScenarioImport) {
|
||||
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import)
|
||||
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import)
|
||||
}
|
||||
}
|
||||
if opts.writer != nil {
|
||||
modelValue.response = opts.writer
|
||||
}
|
||||
if opts.formatter != nil {
|
||||
modelValue.formatter = opts.formatter
|
||||
}
|
||||
modelValue.disableDomain = opts.disableDomain
|
||||
modelEntities = append(modelEntities, modelValue)
|
||||
return
|
||||
}
|
||||
|
||||
// CloneSchemas 克隆schemas
|
||||
func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) {
|
||||
var (
|
||||
values []*types.Schema
|
||||
schemas []*types.Schema
|
||||
models []*types.Schema
|
||||
)
|
||||
tx := db.WithContext(ctx)
|
||||
if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil {
|
||||
return fmt.Errorf("schema not found")
|
||||
}
|
||||
tx.Where("domain=?", domain).Find(&schemas)
|
||||
hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool {
|
||||
for _, row := range values {
|
||||
if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
models = make([]*types.Schema, 0)
|
||||
for _, row := range values {
|
||||
if !hasSchemaFunc(schemas, row) {
|
||||
row.Id = 0
|
||||
row.CreatedAt = 0
|
||||
row.UpdatedAt = 0
|
||||
row.Domain = domain
|
||||
models = append(models, row)
|
||||
}
|
||||
}
|
||||
if len(models) > 0 {
|
||||
err = tx.Save(models).Error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetSchemas 获取schemas
|
||||
func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) {
|
||||
var (
|
||||
err error
|
||||
values []*types.Schema
|
||||
tx *gorm.DB
|
||||
)
|
||||
values = make([]*types.Schema, 0)
|
||||
if domain == "" {
|
||||
domain = defaultDomain
|
||||
}
|
||||
if moduleName == "" || tableName == "" {
|
||||
return nil, gorm.ErrInvalidField
|
||||
}
|
||||
if ctx != nil {
|
||||
tx = db.WithContext(ctx)
|
||||
} else {
|
||||
tx = db
|
||||
}
|
||||
err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
err = nil
|
||||
}
|
||||
return values, err
|
||||
}
|
||||
|
||||
// VisibleSchemas 获取某个场景下面的schema
|
||||
func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) {
|
||||
schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]*types.Schema, 0, len(schemas))
|
||||
for _, row := range schemas {
|
||||
if row.IsPrimaryKey == 1 {
|
||||
result = append(result, row)
|
||||
continue
|
||||
}
|
||||
if row.Scenarios.Has(scenario) {
|
||||
result = append(result, row)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ModelTypes 查询指定模型的类型
|
||||
func ModelTypes(ctx context.Context, db *gorm.DB, model any, domainName, labelColumn, valueColumn string) (values []*types.TypeValue) {
|
||||
tx := db.WithContext(ctx)
|
||||
result := make([]map[string]any, 0, 10)
|
||||
tx.Model(model).Select(labelColumn, valueColumn).Where("domain=?", domainName).Scan(&result)
|
||||
values = make([]*types.TypeValue, 0, len(result))
|
||||
for _, pairs := range result {
|
||||
feed := &types.TypeValue{}
|
||||
for k, v := range pairs {
|
||||
if k == labelColumn {
|
||||
feed.Label = v
|
||||
}
|
||||
if k == valueColumn {
|
||||
feed.Value = v
|
||||
}
|
||||
}
|
||||
values = append(values, feed)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// GetFieldValue 获取模型某个字段的值
|
||||
func GetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string) interface{} {
|
||||
var (
|
||||
idx = -1
|
||||
)
|
||||
refVal := reflect.Indirect(refValue)
|
||||
for i, field := range stmt.Schema.Fields {
|
||||
if field.DBName == column || field.Name == column {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx > -1 {
|
||||
return refVal.Field(idx).Interface()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetFieldValue 设置模型某个字段的值
|
||||
func SetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) {
|
||||
var (
|
||||
idx = -1
|
||||
)
|
||||
refVal := reflect.Indirect(refValue)
|
||||
for i, field := range stmt.Schema.Fields {
|
||||
if field.DBName == column || field.Name == column {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx > -1 {
|
||||
refVal.Field(idx).Set(reflect.ValueOf(value))
|
||||
}
|
||||
}
|
||||
|
||||
func SafeSetFileValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) {
|
||||
var (
|
||||
idx = -1
|
||||
)
|
||||
refVal := reflect.Indirect(refValue)
|
||||
for i, field := range stmt.Schema.Fields {
|
||||
if field.DBName == column || field.Name == column {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx > -1 {
|
||||
_ = reflection.Assign(refVal.Field(idx), value)
|
||||
}
|
||||
}
|
||||
|
||||
// GetModels 获取所有注册了的模块
|
||||
func GetModels() []*Model {
|
||||
return modelEntities
|
||||
}
|
||||
|
||||
// OnBeforeCreate 创建前的回调
|
||||
func OnBeforeCreate(cb BeforeCreate) {
|
||||
hookMgr.BeforeCreate(cb)
|
||||
}
|
||||
|
||||
// OnAfterCreate 创建后的回调
|
||||
func OnAfterCreate(cb AfterCreate) {
|
||||
hookMgr.AfterCreate(cb)
|
||||
}
|
||||
|
||||
// OnBeforeUpdate 更新前的回调
|
||||
func OnBeforeUpdate(cb BeforeUpdate) {
|
||||
hookMgr.BeforeUpdate(cb)
|
||||
}
|
||||
|
||||
// OnAfterUpdate 更新后的回调
|
||||
func OnAfterUpdate(cb AfterUpdate) {
|
||||
hookMgr.AfterUpdate(cb)
|
||||
}
|
||||
|
||||
// OnBeforeSave 保存前的回调
|
||||
func OnBeforeSave(cb BeforeSave) {
|
||||
hookMgr.BeforeSave(cb)
|
||||
}
|
||||
|
||||
// OnAfterSave 保存后的回调
|
||||
func OnAfterSave(cb AfterSave) {
|
||||
hookMgr.AfterSave(cb)
|
||||
}
|
||||
|
||||
// OnBeforeDelete 删除前的回调
|
||||
func OnBeforeDelete(cb BeforeDelete) {
|
||||
hookMgr.BeforeDelete(cb)
|
||||
}
|
||||
|
||||
// OnAfterDelete 删除后的回调
|
||||
func OnAfterDelete(cb AfterDelete) {
|
||||
hookMgr.AfterDelete(cb)
|
||||
}
|
||||
|
||||
// OnAfterExport 导出后的回调
|
||||
func OnAfterExport(cb AfterExport) {
|
||||
hookMgr.AfterExport(cb)
|
||||
}
|
||||
|
||||
// OnAfterImport 导入后的回调
|
||||
func OnAfterImport(cb AfterImport) {
|
||||
hookMgr.AfterImport(cb)
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPageSize = 15
|
||||
defaultDomain = "localhost"
|
||||
)
|
||||
|
||||
type (
|
||||
httpWriter struct {
|
||||
}
|
||||
|
||||
multiValue struct {
|
||||
Text any `json:"label"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
stdResponse struct {
|
||||
Code int `json:"code"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
tableNamer interface {
|
||||
HttpTableName(req *http.Request) string
|
||||
}
|
||||
|
||||
//创建后的回调,这个回调不在事物内
|
||||
afterCreated interface {
|
||||
AfterCreated(ctx context.Context, tx *gorm.DB)
|
||||
}
|
||||
//更新后的回调,这个回调不在事物内
|
||||
afterUpdated interface {
|
||||
AfterUpdated(ctx context.Context, tx *gorm.DB)
|
||||
}
|
||||
//保存后的回调,这个回调不在事物内
|
||||
afterSaved interface {
|
||||
AfterSaved(ctx context.Context, tx *gorm.DB)
|
||||
}
|
||||
|
||||
sqlCountResponse struct {
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
)
|
||||
|
||||
func (h *httpWriter) Success(w http.ResponseWriter, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(stdResponse{
|
||||
Code: 0,
|
||||
Data: data,
|
||||
}); err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *httpWriter) Failure(w http.ResponseWriter, reason int, message string, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(stdResponse{
|
||||
Code: reason,
|
||||
Reason: message,
|
||||
Data: data,
|
||||
}); err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func defaultValueLookup(field string, w http.ResponseWriter, r *http.Request) string {
|
||||
var (
|
||||
domainName string
|
||||
)
|
||||
domainName = r.Header.Get(field)
|
||||
if domainName == "" {
|
||||
return defaultDomain
|
||||
}
|
||||
return domainName
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type (
|
||||
Attribute struct {
|
||||
Match string `json:"match"` //匹配模式
|
||||
Tag string `json:"tag,omitempty"` //字段标签
|
||||
PrimaryKey bool `json:"primary_key"` //是否为主键
|
||||
DefaultValue string `json:"default_value"` //默认值
|
||||
Readonly []string `json:"readonly"` //只读场景
|
||||
Disable []string `json:"disable"` //禁用场景
|
||||
Visible []VisibleCondition `json:"visible"` //可见条
|
||||
Invisible bool `json:"invisible"` //不可见的字段,表示在UI界面看不到这个字段,但是这个字段需要
|
||||
EndOfNow bool `json:"end_of_now"` //最大时间为当前时间
|
||||
Values []EnumValue `json:"values"` //值
|
||||
Live LiveValue `json:"live"` //延时加载配置
|
||||
UploadUrl string `json:"upload_url,omitempty"` //上传地址
|
||||
Icon string `json:"icon,omitempty"` //显示图标
|
||||
Sort bool `json:"sort"` //是否允许排序
|
||||
Suffix string `json:"suffix,omitempty"` //追加内容
|
||||
Tooltip string `json:"tooltip,omitempty"` //字段提示信息
|
||||
Description string `json:"description,omitempty"` //字段说明信息
|
||||
DropdownOption *DropdownOption `json:"dropdown,omitempty"` //下拉选项
|
||||
}
|
||||
)
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *Attribute) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
switch s := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(s), n)
|
||||
case []byte:
|
||||
return json.Unmarshal(s, n)
|
||||
}
|
||||
return ErrDbTypeUnsupported
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n Attribute) Value() (driver.Value, error) {
|
||||
return json.Marshal(n)
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
package types
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDbTypeUnsupported = errors.New("database type unsupported")
|
||||
)
|
|
@ -0,0 +1,39 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
Min int `json:"min"`
|
||||
Max int `json:"max"`
|
||||
Type string `json:"type"`
|
||||
Unique bool `json:"unique"`
|
||||
Required []string `json:"required"`
|
||||
Regular string `json:"regular"`
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n Scenarios) Value() (driver.Value, error) {
|
||||
return json.Marshal(n)
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *Rule) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
switch s := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(s), n)
|
||||
case []byte:
|
||||
return json.Unmarshal(s, n)
|
||||
}
|
||||
return ErrDbTypeUnsupported
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n Rule) Value() (driver.Value, error) {
|
||||
return json.Marshal(n)
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type (
|
||||
LiveValue struct {
|
||||
Enable bool `json:"enable"`
|
||||
Type string `json:"type"`
|
||||
Url string `json:"url"`
|
||||
Method string `json:"method"`
|
||||
Body string `json:"body"`
|
||||
ContentType string `json:"content_type"`
|
||||
Columns []string `json:"columns"`
|
||||
}
|
||||
|
||||
EnumValue struct {
|
||||
Label string `json:"label"`
|
||||
Value string `json:"value"`
|
||||
Color string `json:"color"`
|
||||
}
|
||||
|
||||
VisibleCondition struct {
|
||||
Column string `json:"column"`
|
||||
Values []interface{} `json:"values"`
|
||||
}
|
||||
|
||||
DropdownOption struct {
|
||||
Created bool `json:"created,omitempty"`
|
||||
Filterable bool `json:"filterable,omitempty"`
|
||||
Autocomplete bool `json:"autocomplete,omitempty"`
|
||||
DefaultFirst bool `json:"default_first,omitempty"`
|
||||
}
|
||||
|
||||
Scenarios []string
|
||||
|
||||
Schema struct {
|
||||
Id uint64 `json:"id" gorm:"primary_key"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"autoCreateTime"` //创建时间
|
||||
UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime"` //更新时间
|
||||
Domain string `json:"domain" gorm:"column:domain;type:char(60);index"` //域
|
||||
ModuleName string `json:"module_name" gorm:"column:module_name;type:varchar(60);index"` //模块名称
|
||||
TableName string `json:"table_name" gorm:"column:table_name;type:varchar(120);index"` //表名称
|
||||
Enable uint8 `json:"enable" gorm:"column:enable;type:int(1)"` //是否启用
|
||||
Column string `json:"column" gorm:"type:varchar(120)"` //字段名称
|
||||
Label string `json:"label" gorm:"type:varchar(120)"` //显示名称
|
||||
Type string `json:"type" gorm:"type:varchar(120)"` //字段类型
|
||||
Format string `json:"format" gorm:"type:varchar(120)"` //字段格式
|
||||
Native uint8 `json:"native" gorm:"type:int(1)"` //是否为原生字段
|
||||
IsPrimaryKey uint8 `json:"is_primary_key" gorm:"type:int(1)"` //是否为主键
|
||||
Expression string `json:"expression" gorm:"type:varchar(526)"` //计算规则
|
||||
Scenarios Scenarios `json:"scenarios" gorm:"type:varchar(120)"` //场景
|
||||
Rule Rule `json:"rule" gorm:"type:varchar(2048)"` //字段规则
|
||||
Attribute Attribute `json:"attribute" gorm:"type:varchar(4096)"` //字段属性
|
||||
Position int `json:"position"` //字段排序位置
|
||||
}
|
||||
)
|
||||
|
||||
func (n Scenarios) Has(str string) bool {
|
||||
for _, v := range n {
|
||||
if v == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *Scenarios) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
switch s := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(s), n)
|
||||
case []byte:
|
||||
return json.Unmarshal(s, n)
|
||||
}
|
||||
return ErrDbTypeUnsupported
|
||||
}
|
|
@ -0,0 +1,231 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
SceneCreate = "create"
|
||||
SceneUpdate = "update"
|
||||
SceneDelete = "delete"
|
||||
)
|
||||
|
||||
const (
|
||||
ScenarioCreate = "create"
|
||||
ScenarioUpdate = "update"
|
||||
ScenarioDelete = "delete"
|
||||
ScenarioSearch = "search"
|
||||
ScenarioExport = "export"
|
||||
ScenarioImport = "import"
|
||||
ScenarioList = "list"
|
||||
ScenarioView = "view"
|
||||
ScenarioMapping = "mapping"
|
||||
)
|
||||
|
||||
const (
|
||||
MatchExactly = "exactly" //精确匹配
|
||||
MatchFuzzy = "fuzzy" //模糊匹配
|
||||
)
|
||||
|
||||
const (
|
||||
LiveTypeDropdown = "dropdown"
|
||||
LiveTypeCascader = "cascader"
|
||||
)
|
||||
|
||||
const (
|
||||
TypeInteger = "integer"
|
||||
TypeFloat = "float"
|
||||
TypeBoolean = "boolean"
|
||||
TypeString = "string"
|
||||
)
|
||||
|
||||
const (
|
||||
FormatInteger = "integer"
|
||||
FormatFloat = "float"
|
||||
FormatBoolean = "boolean"
|
||||
FormatString = "string"
|
||||
FormatText = "text"
|
||||
FormatDropdown = "dropdown"
|
||||
FormatDatetime = "datetime"
|
||||
FormatDate = "date"
|
||||
FormatTime = "time"
|
||||
FormatTimestamp = "timestamp"
|
||||
FormatPassword = "password"
|
||||
)
|
||||
|
||||
const (
|
||||
OperatorEqual = "eq"
|
||||
OperatorGreaterThan = "gt"
|
||||
OperatorGreaterEqual = "ge"
|
||||
OperatorLessThan = "lt"
|
||||
OperatorLessEqual = "le"
|
||||
OperatorLike = "like"
|
||||
OperatorBetween = "between"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrImportFileNotExists = 1004
|
||||
ErrImportFileUnavailable = 1005
|
||||
ErrImportColumnNotMatch = 1006
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPageSize = 15
|
||||
defaultDomain = "localhost"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestDenied = 8005
|
||||
RequestRecordNotFound = 8004
|
||||
RequestPayloadInvalid = 8006
|
||||
RequestCreateFailure = 8007
|
||||
RequestUpdateFailure = 8008
|
||||
RequestDeleteFailure = 8009
|
||||
)
|
||||
|
||||
const (
|
||||
FormatRaw = "raw"
|
||||
FormatBoth = "both"
|
||||
)
|
||||
|
||||
const (
|
||||
FieldDomain = "domain"
|
||||
)
|
||||
|
||||
type (
|
||||
// HttpWriter http 响应接口
|
||||
HttpWriter interface {
|
||||
Success(w http.ResponseWriter, data any)
|
||||
Failure(w http.ResponseWriter, reason int, message string, data any)
|
||||
}
|
||||
|
||||
// HttpRouter http 路由管理工具
|
||||
HttpRouter interface {
|
||||
Handle(method string, uri string, handler http.HandlerFunc)
|
||||
}
|
||||
|
||||
// TypeValue 键值对数据
|
||||
TypeValue struct {
|
||||
Label any `json:"label"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
// NestedValue 层级数据
|
||||
NestedValue struct {
|
||||
Label any `json:"label"`
|
||||
Value any `json:"value"`
|
||||
Children []*NestedValue `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
//ValueLookupFunc 查找域的函数
|
||||
ValueLookupFunc func(column string, w http.ResponseWriter, r *http.Request) string
|
||||
|
||||
//SchemaLookupFunc 查找schema的回调函数
|
||||
SchemaLookupFunc func(ctx context.Context, db *gorm.DB, domain, module, table, scenario string) ([]*Schema, error)
|
||||
)
|
||||
|
||||
type (
|
||||
multiValue struct {
|
||||
Text any `json:"label"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
Tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
Reporter interface {
|
||||
TableName() string
|
||||
RealTable() string
|
||||
QuoteColumn(ctx context.Context, column string) string
|
||||
GroupBy(ctx context.Context) []string
|
||||
}
|
||||
|
||||
FormatModel interface {
|
||||
Format(ctx context.Context, scene string, schemas []*Schema)
|
||||
}
|
||||
|
||||
SelectColumn struct {
|
||||
Name string `json:"name"`
|
||||
Expr string `json:"expr"`
|
||||
Native bool `json:"native"`
|
||||
Callback string `json:"callback"`
|
||||
}
|
||||
|
||||
DiffAttr struct {
|
||||
Column string `json:"column"`
|
||||
Label string `json:"label"`
|
||||
OldValue any `json:"old_value"`
|
||||
NewValue any `json:"new_value"`
|
||||
}
|
||||
|
||||
Condition struct {
|
||||
Column string `json:"column"`
|
||||
Expr string `json:"expr"`
|
||||
Value any `json:"value,omitempty"`
|
||||
Values []any `json:"values,omitempty"`
|
||||
}
|
||||
|
||||
Naming struct {
|
||||
Pluralize string
|
||||
Singular string
|
||||
ModuleName string
|
||||
TableName string
|
||||
}
|
||||
|
||||
RuntimeScope struct {
|
||||
Domain string //域
|
||||
User string //用户
|
||||
ModuleName string //模块名称
|
||||
TableName string //表名称
|
||||
Scenario string //场景名称
|
||||
Schemas []*Schema //字段schema
|
||||
Request *http.Request //HTTP请求结构
|
||||
PrimaryKeyValue any //主键
|
||||
}
|
||||
|
||||
ImportResult struct {
|
||||
Code int `json:"code"`
|
||||
TotalCount int `json:"total_count"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
UploadFile string `json:"upload_file"`
|
||||
FailureFile string `json:"failure_file"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
}
|
||||
|
||||
sqlCountResponse struct {
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
ListResponse struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"pagesize"`
|
||||
TotalCount int64 `json:"totalCount"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
CreateResponse struct {
|
||||
ID any `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
UpdateResponse struct {
|
||||
ID any `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
DeleteResponse struct {
|
||||
ID any `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
ImportResponse struct {
|
||||
UID string `json:"uid"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
)
|
|
@ -0,0 +1,45 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"git.nobla.cn/golang/kos/util/arrays"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func hasToken(hack string, need string) bool {
|
||||
if len(need) == 0 || len(hack) == 0 {
|
||||
return false
|
||||
}
|
||||
char := []byte{',', ';'}
|
||||
for _, c := range char {
|
||||
if strings.IndexByte(need, c) > -1 {
|
||||
return arrays.Exists(hack, strings.Split(need, string(c)))
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isEmpty(val any) bool {
|
||||
if val == nil {
|
||||
return true
|
||||
}
|
||||
v := reflect.ValueOf(val)
|
||||
switch v.Kind() {
|
||||
case reflect.String, reflect.Array:
|
||||
return v.Len() == 0
|
||||
case reflect.Map, reflect.Slice:
|
||||
return v.IsNil() || v.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return v.IsNil()
|
||||
default:
|
||||
return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue