初始化仓库
This commit is contained in:
commit
61ffcec858
|
@ -0,0 +1,60 @@
|
||||||
|
bin/
|
||||||
|
|
||||||
|
.svn/
|
||||||
|
.godeps
|
||||||
|
./build
|
||||||
|
.cover/
|
||||||
|
dist
|
||||||
|
_site
|
||||||
|
_posts
|
||||||
|
*.dat
|
||||||
|
.vscode
|
||||||
|
vendor
|
||||||
|
|
||||||
|
# Go.gitignore
|
||||||
|
|
||||||
|
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||||
|
*.o
|
||||||
|
*.a
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Folders
|
||||||
|
_obj
|
||||||
|
_test
|
||||||
|
storage
|
||||||
|
.idea
|
||||||
|
|
||||||
|
# Architecture specific extensions/prefixes
|
||||||
|
*.[568vq]
|
||||||
|
[568vq].out
|
||||||
|
|
||||||
|
*.cgo1.go
|
||||||
|
*.cgo2.c
|
||||||
|
_cgo_defun.c
|
||||||
|
_cgo_gotypes.go
|
||||||
|
_cgo_export.*
|
||||||
|
|
||||||
|
_testmain.go
|
||||||
|
|
||||||
|
*.exe
|
||||||
|
*.local
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
profile
|
||||||
|
|
||||||
|
# vim stuff
|
||||||
|
*.sw[op]
|
||||||
|
|
||||||
|
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
|
||||||
|
node_modules
|
|
@ -0,0 +1,144 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"git.nobla.cn/golang/kos/util/arrays"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition {
|
||||||
|
for _, cond := range conditions {
|
||||||
|
if cond.Column == schema.Column {
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildConditions(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
skip bool
|
||||||
|
formValue string
|
||||||
|
activeQuery ActiveQuery
|
||||||
|
)
|
||||||
|
if activeQuery, ok = query.Model().(ActiveQuery); ok {
|
||||||
|
if err = activeQuery.BeforeQuery(ctx, query); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if arrays.Exists(r.Method, []string{http.MethodPut, http.MethodPost}) {
|
||||||
|
conditions := make([]*types.Condition, 0)
|
||||||
|
if err = json.NewDecoder(r.Body).Decode(&conditions); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, row := range schemas {
|
||||||
|
if row.Native == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cond := findCondition(row, conditions)
|
||||||
|
if cond == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch row.Format {
|
||||||
|
case types.FormatInteger, types.FormatFloat, types.FormatTimestamp, types.FormatDatetime, types.FormatDate, types.FormatTime:
|
||||||
|
switch cond.Expr {
|
||||||
|
case types.OperatorBetween:
|
||||||
|
if len(cond.Values) == 2 {
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Values[0]).WithExpr(">="))
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Values[1]).WithExpr("<="))
|
||||||
|
}
|
||||||
|
case types.OperatorGreaterThan:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">"))
|
||||||
|
case types.OperatorGreaterEqual:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">="))
|
||||||
|
case types.OperatorLessThan:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<"))
|
||||||
|
case types.OperatorLessEqual:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<="))
|
||||||
|
default:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Value))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
switch cond.Expr {
|
||||||
|
case types.OperatorLike:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("LIKE"))
|
||||||
|
default:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, cond.Value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
qs := r.URL.Query()
|
||||||
|
for _, row := range schemas {
|
||||||
|
skip = false
|
||||||
|
if skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if row.Native == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
formValue = qs.Get(row.Column)
|
||||||
|
switch row.Format {
|
||||||
|
case types.FormatString, types.FormatText:
|
||||||
|
if row.Attribute.Match == types.MatchExactly {
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||||
|
} else {
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
|
||||||
|
}
|
||||||
|
case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp:
|
||||||
|
var sep string
|
||||||
|
seps := []byte{',', '/'}
|
||||||
|
for _, s := range seps {
|
||||||
|
if strings.IndexByte(formValue, s) > -1 {
|
||||||
|
sep = string(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ss := strings.Split(formValue, sep); len(ss) == 2 {
|
||||||
|
query.AndFilterWhere(
|
||||||
|
newCondition(row.Column, strings.TrimSpace(ss[0])).WithExpr(">="),
|
||||||
|
newCondition(row.Column, strings.TrimSpace(ss[1])).WithExpr("<="),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||||
|
}
|
||||||
|
case types.FormatInteger, types.FormatFloat:
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||||
|
default:
|
||||||
|
if row.Type == types.TypeString {
|
||||||
|
if row.Attribute.Match == types.MatchExactly {
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||||
|
} else {
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sortPar := r.FormValue("sort")
|
||||||
|
if sortPar != "" {
|
||||||
|
sorts := strings.Split(sortPar, ",")
|
||||||
|
for _, s := range sorts {
|
||||||
|
if s[0] == '-' {
|
||||||
|
query.OrderBy(s[1:], "DESC")
|
||||||
|
} else {
|
||||||
|
if s[0] == '+' {
|
||||||
|
query.OrderBy(s[1:], "ASC")
|
||||||
|
} else {
|
||||||
|
query.OrderBy(s, "ASC")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if activeQuery, ok = query.Model().(ActiveQuery); ok {
|
||||||
|
if err = activeQuery.AfterQuery(ctx, query); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,240 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
DefaultFormatter = NewFormatter()
|
||||||
|
|
||||||
|
DefaultNullDisplay = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
DefaultFormatter.Register("string", stringFormat)
|
||||||
|
DefaultFormatter.Register("integer", integerFormat)
|
||||||
|
DefaultFormatter.Register("decimal", decimalFormat)
|
||||||
|
DefaultFormatter.Register("date", dateFormat)
|
||||||
|
DefaultFormatter.Register("time", timeFormat)
|
||||||
|
DefaultFormatter.Register("datetime", datetimeFormat)
|
||||||
|
DefaultFormatter.Register("duration", durationFormat)
|
||||||
|
DefaultFormatter.Register("dropdown", dropdownFormat)
|
||||||
|
DefaultFormatter.Register("timestamp", datetimeFormat)
|
||||||
|
DefaultFormatter.Register("percentage", percentageFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FormatFunc func(ctx context.Context, value any, model any, scm *types.Schema) any
|
||||||
|
|
||||||
|
type Formatter struct {
|
||||||
|
callbacks sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func (formatter *Formatter) Register(f string, fun FormatFunc) {
|
||||||
|
formatter.callbacks.Store(f, fun)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (formatter *Formatter) Format(ctx context.Context, format string, value any, model any, scm *types.Schema) any {
|
||||||
|
v, ok := formatter.callbacks.Load(format)
|
||||||
|
if ok {
|
||||||
|
return v.(FormatFunc)(ctx, value, model, scm)
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (formatter *Formatter) getModelValue(refValue reflect.Value, schema *types.Schema, stmt *gorm.Statement) any {
|
||||||
|
if stmt.Schema == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
field := stmt.Schema.LookUpField(schema.Column)
|
||||||
|
if field == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return refValue.FieldByName(field.Name).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (formatter *Formatter) formatModel(ctx context.Context, refValue reflect.Value, schemas []*types.Schema, stmt *gorm.Statement, format string) any {
|
||||||
|
values := make(map[string]any)
|
||||||
|
multiValues := make(map[string]multiValue)
|
||||||
|
modelValue := refValue.Interface()
|
||||||
|
refValue = reflect.Indirect(refValue)
|
||||||
|
for _, scm := range schemas {
|
||||||
|
switch format {
|
||||||
|
case types.FormatRaw:
|
||||||
|
values[scm.Column] = formatter.getModelValue(refValue, scm, stmt)
|
||||||
|
case types.FormatBoth:
|
||||||
|
v := multiValue{
|
||||||
|
Value: formatter.getModelValue(refValue, scm, stmt),
|
||||||
|
}
|
||||||
|
v.Text = formatter.Format(ctx, scm.Format, v.Value, modelValue, scm)
|
||||||
|
multiValues[scm.Column] = v
|
||||||
|
default:
|
||||||
|
values[scm.Column] = formatter.Format(ctx, scm.Format, formatter.getModelValue(refValue, scm, stmt), modelValue, scm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if format == types.FormatBoth {
|
||||||
|
return multiValues
|
||||||
|
} else {
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (formatter *Formatter) formatModels(ctx context.Context, val any, schemas []*types.Schema, stmt *gorm.Statement, format string) any {
|
||||||
|
refValue := reflect.Indirect(reflect.ValueOf(val))
|
||||||
|
if refValue.Kind() != reflect.Slice {
|
||||||
|
return []any{}
|
||||||
|
}
|
||||||
|
length := refValue.Len()
|
||||||
|
values := make([]any, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
rowValue := refValue.Index(i)
|
||||||
|
modelValue := rowValue.Interface()
|
||||||
|
if formatModel, ok := modelValue.(types.FormatModel); ok {
|
||||||
|
formatModel.Format(ctx, format, schemas)
|
||||||
|
}
|
||||||
|
values[i] = formatter.formatModel(ctx, rowValue, schemas, stmt, format)
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
return fmt.Sprint(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func integerFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
)
|
||||||
|
switch value.(type) {
|
||||||
|
case float32, float64:
|
||||||
|
n = int(reflect.ValueOf(value).Float())
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
n = int(reflect.ValueOf(value).Int())
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
n = int(reflect.ValueOf(value).Uint())
|
||||||
|
case string:
|
||||||
|
n, _ = strconv.Atoi(reflect.ValueOf(value).String())
|
||||||
|
case []byte:
|
||||||
|
n, _ = strconv.Atoi(string(reflect.ValueOf(value).Bytes()))
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func decimalFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
var (
|
||||||
|
n float64
|
||||||
|
)
|
||||||
|
switch value.(type) {
|
||||||
|
case float32, float64:
|
||||||
|
n = reflect.ValueOf(value).Float()
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
n = float64(reflect.ValueOf(value).Int())
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
n = float64(reflect.ValueOf(value).Uint())
|
||||||
|
case string:
|
||||||
|
n, _ = strconv.ParseFloat(reflect.ValueOf(value).String(), 64)
|
||||||
|
case []byte:
|
||||||
|
n, _ = strconv.ParseFloat(string(reflect.ValueOf(value).Bytes()), 64)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func dateFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
if t, ok := value.(time.Time); ok {
|
||||||
|
return t.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
if t, ok := value.(*sql.NullTime); ok {
|
||||||
|
if t != nil && t.Valid {
|
||||||
|
return t.Time.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t, ok := value.(int64); ok {
|
||||||
|
tm := time.Unix(t, 0)
|
||||||
|
return tm.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
return DefaultNullDisplay
|
||||||
|
}
|
||||||
|
|
||||||
|
func timeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
if t, ok := value.(time.Time); ok {
|
||||||
|
return t.Format("15:04:05")
|
||||||
|
}
|
||||||
|
if t, ok := value.(*sql.NullTime); ok {
|
||||||
|
if t != nil && t.Valid {
|
||||||
|
return t.Time.Format("15:04:05")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t, ok := value.(int64); ok {
|
||||||
|
tm := time.Unix(t, 0)
|
||||||
|
return tm.Format("15:04:05")
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func datetimeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
if t, ok := value.(time.Time); ok {
|
||||||
|
return t.Format("2006-01-02 15:04:05")
|
||||||
|
}
|
||||||
|
if t, ok := value.(*sql.NullTime); ok {
|
||||||
|
if t != nil && t.Valid {
|
||||||
|
return t.Time.Format("2006-01-02 15:04:05")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t, ok := value.(int64); ok {
|
||||||
|
if t > 0 {
|
||||||
|
tm := time.Unix(t, 0)
|
||||||
|
return tm.Format("2006-01-02 15:04:05")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return DefaultNullDisplay
|
||||||
|
}
|
||||||
|
|
||||||
|
func percentageFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
n := decimalFormat(ctx, value, model, schema).(float64)
|
||||||
|
if n <= 1 {
|
||||||
|
return fmt.Sprintf("%.2f%%", n*100)
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%.2f%%", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func durationFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
var (
|
||||||
|
hour int
|
||||||
|
minVal int
|
||||||
|
sec int
|
||||||
|
)
|
||||||
|
n := integerFormat(ctx, value, model, schema).(int)
|
||||||
|
hour = n / 3600
|
||||||
|
minVal = (n - hour*3600) / 60
|
||||||
|
sec = n - hour*3600 - minVal*60
|
||||||
|
return fmt.Sprintf("%02d:%02d:%02d", hour, minVal, sec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dropdownFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
|
||||||
|
attributes := schema.Attribute
|
||||||
|
if attributes.Values != nil {
|
||||||
|
for _, v := range attributes.Values {
|
||||||
|
if v.Value == value {
|
||||||
|
return v.Label
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFormatter() *Formatter {
|
||||||
|
formatter := &Formatter{}
|
||||||
|
return formatter
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterFormat(f string, cb FormatFunc) {
|
||||||
|
DefaultFormatter.Register(f, cb)
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
module git.nobla.cn/golang/rest
|
||||||
|
|
||||||
|
go 1.22.9
|
||||||
|
|
||||||
|
require (
|
||||||
|
git.nobla.cn/golang/kos v0.1.32
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0
|
||||||
|
github.com/go-playground/validator/v10 v10.23.0
|
||||||
|
github.com/longbridgeapp/sqlparser v0.3.2
|
||||||
|
github.com/rs/xid v1.6.0
|
||||||
|
github.com/uole/sqlparser v0.0.1
|
||||||
|
gorm.io/gorm v1.25.12
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
golang.org/x/crypto v0.19.0 // indirect
|
||||||
|
golang.org/x/net v0.21.0 // indirect
|
||||||
|
golang.org/x/sys v0.17.0 // indirect
|
||||||
|
golang.org/x/text v0.14.0 // indirect
|
||||||
|
)
|
|
@ -0,0 +1,46 @@
|
||||||
|
git.nobla.cn/golang/kos v0.1.32 h1:sFVCA7vKc8dPUd0cxzwExOSPX2mmMh2IuwL6cYS1pBc=
|
||||||
|
git.nobla.cn/golang/kos v0.1.32/go.mod h1:35Z070+5oB39WcVrh5DDlnVeftL/Ccmscw2MZFe9fUg=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
|
github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL8sThn8IHr/sO+o=
|
||||||
|
github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||||
|
github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M=
|
||||||
|
github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8=
|
||||||
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y=
|
||||||
|
github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/uole/sqlparser v0.0.1 h1:LLUklg6Ne5MypXQuo53QcJv/xKdxtEKM9iUuEBN/lt8=
|
||||||
|
github.com/uole/sqlparser v0.0.1/go.mod h1:CRYFz2PTm9oHM0j9GFKi1VzPy70r6GsF0b9vpnwJ4yI=
|
||||||
|
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||||
|
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||||
|
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||||
|
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||||
|
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||||
|
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||||
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
|
@ -0,0 +1,251 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
beforeCreate = "beforeCreate"
|
||||||
|
afterCreate = "afterCreate"
|
||||||
|
beforeUpdate = "beforeUpdate"
|
||||||
|
afterUpdate = "afterUpdate"
|
||||||
|
beforeSave = "beforeSave"
|
||||||
|
afterSave = "afterSave"
|
||||||
|
beforeDelete = "beforeDelete"
|
||||||
|
afterDelete = "afterDelete"
|
||||||
|
afterExport = "afterExport"
|
||||||
|
afterImport = "afterImport"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
BeforeCreate func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||||
|
AfterCreate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
|
||||||
|
BeforeUpdate func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||||
|
AfterUpdate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
|
||||||
|
BeforeSave func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||||
|
AfterSave func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
|
||||||
|
BeforeDelete func(ctx context.Context, tx *gorm.DB, model any) (err error)
|
||||||
|
AfterDelete func(ctx context.Context, tx *gorm.DB, model any)
|
||||||
|
AfterExport func(ctx context.Context, filename string) //导出回调
|
||||||
|
AfterImport func(ctx context.Context, result *types.ImportResult) //导入回调
|
||||||
|
hookManager struct {
|
||||||
|
callbacks map[string][]any
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
ActiveQuery interface {
|
||||||
|
BeforeQuery(ctx context.Context, query *Query) (err error)
|
||||||
|
AfterQuery(ctx context.Context, query *Query) (err error)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (hook *hookManager) register(spec string, cb any) {
|
||||||
|
if hook.callbacks == nil {
|
||||||
|
hook.callbacks = make(map[string][]any)
|
||||||
|
}
|
||||||
|
if _, ok := hook.callbacks[spec]; !ok {
|
||||||
|
hook.callbacks[spec] = make([]any, 0)
|
||||||
|
}
|
||||||
|
hook.callbacks[spec] = append(hook.callbacks[spec], cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) beforeCreate(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||||
|
callbacks, ok := hook.callbacks[beforeCreate]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(BeforeCreate); ok {
|
||||||
|
if err = cb(ctx, tx, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) afterCreate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
|
||||||
|
callbacks, ok := hook.callbacks[afterCreate]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(AfterCreate); ok {
|
||||||
|
cb(ctx, tx, model, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) beforeUpdate(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||||
|
callbacks, ok := hook.callbacks[beforeUpdate]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(BeforeUpdate); ok {
|
||||||
|
if err = cb(ctx, tx, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) afterUpdate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
|
||||||
|
callbacks, ok := hook.callbacks[afterUpdate]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(AfterUpdate); ok {
|
||||||
|
cb(ctx, tx, model, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) beforeSave(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||||
|
callbacks, ok := hook.callbacks[beforeSave]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(BeforeSave); ok {
|
||||||
|
if err = cb(ctx, tx, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) afterSave(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
|
||||||
|
callbacks, ok := hook.callbacks[afterSave]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(AfterSave); ok {
|
||||||
|
cb(ctx, tx, model, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) beforeDelete(ctx context.Context, tx *gorm.DB, model any) (err error) {
|
||||||
|
callbacks, ok := hook.callbacks[beforeDelete]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(BeforeDelete); ok {
|
||||||
|
if err = cb(ctx, tx, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) afterDelete(ctx context.Context, tx *gorm.DB, model any) {
|
||||||
|
callbacks, ok := hook.callbacks[afterDelete]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(AfterDelete); ok {
|
||||||
|
cb(ctx, tx, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) afterExport(ctx context.Context, filename string) {
|
||||||
|
callbacks, ok := hook.callbacks[afterExport]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(AfterExport); ok {
|
||||||
|
cb(ctx, filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) afterImport(ctx context.Context, ret *types.ImportResult) {
|
||||||
|
callbacks, ok := hook.callbacks[afterImport]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, callback := range callbacks {
|
||||||
|
if cb, ok := callback.(AfterImport); ok {
|
||||||
|
cb(ctx, ret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) BeforeCreate(cb BeforeCreate) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(beforeCreate, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) AfterCreate(cb AfterCreate) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(afterCreate, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) BeforeUpdate(cb BeforeUpdate) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(beforeUpdate, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) AfterUpdate(cb AfterUpdate) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(afterUpdate, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) BeforeSave(cb BeforeSave) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(beforeSave, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) AfterSave(cb AfterSave) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(afterSave, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) BeforeDelete(cb BeforeDelete) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(beforeDelete, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) AfterDelete(cb AfterDelete) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(afterDelete, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) AfterExport(cb AfterExport) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(afterExport, cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook *hookManager) AfterImport(cb AfterImport) {
|
||||||
|
if cb != nil {
|
||||||
|
hook.register(afterImport, cb)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,408 @@
|
||||||
|
package inflector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rule represents name of the inflector rule, can be
|
||||||
|
// Plural or Singular
|
||||||
|
type Rule int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Plural = iota
|
||||||
|
Singular
|
||||||
|
)
|
||||||
|
|
||||||
|
// InflectorRule represents inflector rule
|
||||||
|
type InflectorRule struct {
|
||||||
|
Rules []*ruleItem
|
||||||
|
Irregular []*irregularItem
|
||||||
|
Uninflected []string
|
||||||
|
compiledIrregular *regexp.Regexp
|
||||||
|
compiledUninflected *regexp.Regexp
|
||||||
|
compiledRules []*compiledRule
|
||||||
|
}
|
||||||
|
|
||||||
|
type ruleItem struct {
|
||||||
|
pattern string
|
||||||
|
replacement string
|
||||||
|
}
|
||||||
|
|
||||||
|
type irregularItem struct {
|
||||||
|
word string
|
||||||
|
replacement string
|
||||||
|
}
|
||||||
|
|
||||||
|
// compiledRule represents compiled version of Inflector.Rules.
|
||||||
|
type compiledRule struct {
|
||||||
|
replacement string
|
||||||
|
*regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// threadsafe access to rules and caches
|
||||||
|
var mutex sync.Mutex
|
||||||
|
var rules = make(map[Rule]*InflectorRule)
|
||||||
|
|
||||||
|
// Words that should not be inflected
|
||||||
|
var uninflected = []string{
|
||||||
|
`Amoyese`, `bison`, `Borghese`, `bream`, `breeches`, `britches`, `buffalo`,
|
||||||
|
`cantus`, `carp`, `chassis`, `clippers`, `cod`, `coitus`, `Congoese`,
|
||||||
|
`contretemps`, `corps`, `debris`, `diabetes`, `djinn`, `eland`, `elk`,
|
||||||
|
`equipment`, `Faroese`, `flounder`, `Foochowese`, `gallows`, `Genevese`,
|
||||||
|
`Genoese`, `Gilbertese`, `graffiti`, `headquarters`, `herpes`, `hijinks`,
|
||||||
|
`Hottentotese`, `information`, `innings`, `jackanapes`, `Kiplingese`,
|
||||||
|
`Kongoese`, `Lucchese`, `mackerel`, `Maltese`, `.*?media`, `mews`, `moose`,
|
||||||
|
`mumps`, `Nankingese`, `news`, `nexus`, `Niasese`, `Pekingese`,
|
||||||
|
`Piedmontese`, `pincers`, `Pistoiese`, `pliers`, `Portuguese`, `proceedings`,
|
||||||
|
`rabies`, `rice`, `rhinoceros`, `salmon`, `Sarawakese`, `scissors`,
|
||||||
|
`sea[- ]bass`, `series`, `Shavese`, `shears`, `siemens`, `species`, `swine`,
|
||||||
|
`testes`, `trousers`, `trout`, `tuna`, `Vermontese`, `Wenchowese`, `whiting`,
|
||||||
|
`wildebeest`, `Yengeese`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plural words that should not be inflected
|
||||||
|
var uninflectedPlurals = []string{
|
||||||
|
`.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`,
|
||||||
|
`people`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Singular words that should not be inflected
|
||||||
|
var uninflectedSingulars = []string{
|
||||||
|
`.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`,
|
||||||
|
`.*ss`,
|
||||||
|
}
|
||||||
|
|
||||||
|
type cache map[string]string
|
||||||
|
|
||||||
|
// Inflected words that already cached for immediate retrieval from a given Rule
|
||||||
|
var caches = make(map[Rule]cache)
|
||||||
|
|
||||||
|
// map of irregular words where its key is a word and its value is the replacement
|
||||||
|
var irregularMaps = make(map[Rule]cache)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// https://github.com/golang/lint/blob/master/lint.go#L770
|
||||||
|
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
|
commonInitialismsReplacer *strings.Replacer
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rules[Plural] = &InflectorRule{
|
||||||
|
Rules: []*ruleItem{
|
||||||
|
{`(?i)(s)tatus$`, `${1}${2}tatuses`},
|
||||||
|
{`(?i)(quiz)$`, `${1}zes`},
|
||||||
|
{`(?i)^(ox)$`, `${1}${2}en`},
|
||||||
|
{`(?i)([m|l])ouse$`, `${1}ice`},
|
||||||
|
{`(?i)(matr|vert|ind)(ix|ex)$`, `${1}ices`},
|
||||||
|
{`(?i)(x|ch|ss|sh)$`, `${1}es`},
|
||||||
|
{`(?i)([^aeiouy]|qu)y$`, `${1}ies`},
|
||||||
|
{`(?i)(hive)$`, `$1s`},
|
||||||
|
{`(?i)(?:([^f])fe|([lre])f)$`, `${1}${2}ves`},
|
||||||
|
{`(?i)sis$`, `ses`},
|
||||||
|
{`(?i)([ti])um$`, `${1}a`},
|
||||||
|
{`(?i)(p)erson$`, `${1}eople`},
|
||||||
|
{`(?i)(m)an$`, `${1}en`},
|
||||||
|
{`(?i)(c)hild$`, `${1}hildren`},
|
||||||
|
{`(?i)(buffal|tomat)o$`, `${1}${2}oes`},
|
||||||
|
{`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|vir)us$`, `${1}i`},
|
||||||
|
{`(?i)us$`, `uses`},
|
||||||
|
{`(?i)(alias)$`, `${1}es`},
|
||||||
|
{`(?i)(ax|cris|test)is$`, `${1}es`},
|
||||||
|
{`s$`, `s`},
|
||||||
|
{`^$`, ``},
|
||||||
|
{`$`, `s`},
|
||||||
|
},
|
||||||
|
Irregular: []*irregularItem{
|
||||||
|
{`atlas`, `atlases`},
|
||||||
|
{`beef`, `beefs`},
|
||||||
|
{`brother`, `brothers`},
|
||||||
|
{`cafe`, `cafes`},
|
||||||
|
{`child`, `children`},
|
||||||
|
{`cookie`, `cookies`},
|
||||||
|
{`corpus`, `corpuses`},
|
||||||
|
{`cow`, `cows`},
|
||||||
|
{`ganglion`, `ganglions`},
|
||||||
|
{`genie`, `genies`},
|
||||||
|
{`genus`, `genera`},
|
||||||
|
{`graffito`, `graffiti`},
|
||||||
|
{`hoof`, `hoofs`},
|
||||||
|
{`loaf`, `loaves`},
|
||||||
|
{`man`, `men`},
|
||||||
|
{`money`, `monies`},
|
||||||
|
{`mongoose`, `mongooses`},
|
||||||
|
{`move`, `moves`},
|
||||||
|
{`mythos`, `mythoi`},
|
||||||
|
{`niche`, `niches`},
|
||||||
|
{`numen`, `numina`},
|
||||||
|
{`occiput`, `occiputs`},
|
||||||
|
{`octopus`, `octopuses`},
|
||||||
|
{`opus`, `opuses`},
|
||||||
|
{`ox`, `oxen`},
|
||||||
|
{`penis`, `penises`},
|
||||||
|
{`person`, `people`},
|
||||||
|
{`sex`, `sexes`},
|
||||||
|
{`soliloquy`, `soliloquies`},
|
||||||
|
{`testis`, `testes`},
|
||||||
|
{`trilby`, `trilbys`},
|
||||||
|
{`turf`, `turfs`},
|
||||||
|
{`potato`, `potatoes`},
|
||||||
|
{`hero`, `heroes`},
|
||||||
|
{`tooth`, `teeth`},
|
||||||
|
{`goose`, `geese`},
|
||||||
|
{`foot`, `feet`},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prepare(Plural)
|
||||||
|
|
||||||
|
rules[Singular] = &InflectorRule{
|
||||||
|
Rules: []*ruleItem{
|
||||||
|
{`(?i)(s)tatuses$`, `${1}${2}tatus`},
|
||||||
|
{`(?i)^(.*)(menu)s$`, `${1}${2}`},
|
||||||
|
{`(?i)(quiz)zes$`, `$1`},
|
||||||
|
{`(?i)(matr)ices$`, `${1}ix`},
|
||||||
|
{`(?i)(vert|ind)ices$`, `${1}ex`},
|
||||||
|
{`(?i)^(ox)en`, `$1`},
|
||||||
|
{`(?i)(alias)(es)*$`, `$1`},
|
||||||
|
{`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|viri?)i$`, `${1}us`},
|
||||||
|
{`(?i)([ftw]ax)es`, `$1`},
|
||||||
|
{`(?i)(cris|ax|test)es$`, `${1}is`},
|
||||||
|
{`(?i)(shoe|slave)s$`, `$1`},
|
||||||
|
{`(?i)(o)es$`, `$1`},
|
||||||
|
{`ouses$`, `ouse`},
|
||||||
|
{`([^a])uses$`, `${1}us`},
|
||||||
|
{`(?i)([m|l])ice$`, `${1}ouse`},
|
||||||
|
{`(?i)(x|ch|ss|sh)es$`, `$1`},
|
||||||
|
{`(?i)(m)ovies$`, `${1}${2}ovie`},
|
||||||
|
{`(?i)(s)eries$`, `${1}${2}eries`},
|
||||||
|
{`(?i)([^aeiouy]|qu)ies$`, `${1}y`},
|
||||||
|
{`(?i)(tive)s$`, `$1`},
|
||||||
|
{`(?i)([lre])ves$`, `${1}f`},
|
||||||
|
{`(?i)([^fo])ves$`, `${1}fe`},
|
||||||
|
{`(?i)(hive)s$`, `$1`},
|
||||||
|
{`(?i)(drive)s$`, `$1`},
|
||||||
|
{`(?i)(^analy)ses$`, `${1}sis`},
|
||||||
|
{`(?i)(analy|diagno|^ba|(p)arenthe|(p)rogno|(s)ynop|(t)he)ses$`, `${1}${2}sis`},
|
||||||
|
{`(?i)([ti])a$`, `${1}um`},
|
||||||
|
{`(?i)(p)eople$`, `${1}${2}erson`},
|
||||||
|
{`(?i)(m)en$`, `${1}an`},
|
||||||
|
{`(?i)(c)hildren$`, `${1}${2}hild`},
|
||||||
|
{`(?i)(n)ews$`, `${1}${2}ews`},
|
||||||
|
{`eaus$`, `eau`},
|
||||||
|
{`^(.*us)$`, `$1`},
|
||||||
|
{`(?i)s$`, ``},
|
||||||
|
},
|
||||||
|
Irregular: []*irregularItem{
|
||||||
|
{`foes`, `foe`},
|
||||||
|
{`waves`, `wave`},
|
||||||
|
{`curves`, `curve`},
|
||||||
|
{`atlases`, `atlas`},
|
||||||
|
{`beefs`, `beef`},
|
||||||
|
{`brothers`, `brother`},
|
||||||
|
{`cafes`, `cafe`},
|
||||||
|
{`children`, `child`},
|
||||||
|
{`cookies`, `cookie`},
|
||||||
|
{`corpuses`, `corpus`},
|
||||||
|
{`cows`, `cow`},
|
||||||
|
{`ganglions`, `ganglion`},
|
||||||
|
{`genies`, `genie`},
|
||||||
|
{`genera`, `genus`},
|
||||||
|
{`graffiti`, `graffito`},
|
||||||
|
{`hoofs`, `hoof`},
|
||||||
|
{`loaves`, `loaf`},
|
||||||
|
{`men`, `man`},
|
||||||
|
{`monies`, `money`},
|
||||||
|
{`mongooses`, `mongoose`},
|
||||||
|
{`moves`, `move`},
|
||||||
|
{`mythoi`, `mythos`},
|
||||||
|
{`niches`, `niche`},
|
||||||
|
{`numina`, `numen`},
|
||||||
|
{`occiputs`, `occiput`},
|
||||||
|
{`octopuses`, `octopus`},
|
||||||
|
{`opuses`, `opus`},
|
||||||
|
{`oxen`, `ox`},
|
||||||
|
{`penises`, `penis`},
|
||||||
|
{`people`, `person`},
|
||||||
|
{`sexes`, `sex`},
|
||||||
|
{`soliloquies`, `soliloquy`},
|
||||||
|
{`testes`, `testis`},
|
||||||
|
{`trilbys`, `trilby`},
|
||||||
|
{`turfs`, `turf`},
|
||||||
|
{`potatoes`, `potato`},
|
||||||
|
{`heroes`, `hero`},
|
||||||
|
{`teeth`, `tooth`},
|
||||||
|
{`geese`, `goose`},
|
||||||
|
{`feet`, `foot`},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prepare(Singular)
|
||||||
|
|
||||||
|
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
|
||||||
|
for _, initialism := range commonInitialisms {
|
||||||
|
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||||
|
}
|
||||||
|
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare rule, e.g., compile the pattern.
|
||||||
|
func prepare(r Rule) error {
|
||||||
|
var reString string
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case Plural:
|
||||||
|
// Merge global uninflected with singularsUninflected
|
||||||
|
rules[r].Uninflected = merge(uninflected, uninflectedPlurals)
|
||||||
|
case Singular:
|
||||||
|
// Merge global uninflected with singularsUninflected
|
||||||
|
rules[r].Uninflected = merge(uninflected, uninflectedSingulars)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set InflectorRule.compiledUninflected by joining InflectorRule.Uninflected into
|
||||||
|
// a single string then compile it.
|
||||||
|
reString = fmt.Sprintf(`(?i)(^(?:%s))$`, strings.Join(rules[r].Uninflected, `|`))
|
||||||
|
rules[r].compiledUninflected = regexp.MustCompile(reString)
|
||||||
|
|
||||||
|
// Prepare irregularMaps
|
||||||
|
irregularMaps[r] = make(cache, len(rules[r].Irregular))
|
||||||
|
|
||||||
|
// Set InflectorRule.compiledIrregular by joining the irregularItem.word of Inflector.Irregular
|
||||||
|
// into a single string then compile it.
|
||||||
|
vIrregulars := make([]string, len(rules[r].Irregular))
|
||||||
|
for i, item := range rules[r].Irregular {
|
||||||
|
vIrregulars[i] = item.word
|
||||||
|
irregularMaps[r][item.word] = item.replacement
|
||||||
|
}
|
||||||
|
reString = fmt.Sprintf(`(?i)(.*)\b((?:%s))$`, strings.Join(vIrregulars, `|`))
|
||||||
|
rules[r].compiledIrregular = regexp.MustCompile(reString)
|
||||||
|
|
||||||
|
// Compile all patterns in InflectorRule.Rules
|
||||||
|
rules[r].compiledRules = make([]*compiledRule, len(rules[r].Rules))
|
||||||
|
for i, item := range rules[r].Rules {
|
||||||
|
rules[r].compiledRules[i] = &compiledRule{item.replacement, regexp.MustCompile(item.pattern)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare caches
|
||||||
|
caches[r] = make(cache)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge slice a and slice b
|
||||||
|
func merge(a []string, b []string) []string {
|
||||||
|
result := make([]string, len(a)+len(b))
|
||||||
|
copy(result, a)
|
||||||
|
copy(result[len(a):], b)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInflected(r Rule, s string) string {
|
||||||
|
mutex.Lock()
|
||||||
|
defer mutex.Unlock()
|
||||||
|
if v, ok := caches[r][s]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for irregular words
|
||||||
|
if res := rules[r].compiledIrregular.FindStringSubmatch(s); len(res) >= 3 {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
buf.WriteString(res[1])
|
||||||
|
buf.WriteString(s[0:1])
|
||||||
|
buf.WriteString(irregularMaps[r][strings.ToLower(res[2])][1:])
|
||||||
|
|
||||||
|
// Cache it then returns
|
||||||
|
caches[r][s] = buf.String()
|
||||||
|
return caches[r][s]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for uninflected words
|
||||||
|
if rules[r].compiledUninflected.MatchString(s) {
|
||||||
|
caches[r][s] = s
|
||||||
|
return caches[r][s]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check each rule
|
||||||
|
for _, re := range rules[r].compiledRules {
|
||||||
|
if re.MatchString(s) {
|
||||||
|
caches[r][s] = re.ReplaceAllString(s, re.replacement)
|
||||||
|
return caches[r][s]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns unaltered
|
||||||
|
caches[r][s] = s
|
||||||
|
return caches[r][s]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pluralize returns string s in plural form.
|
||||||
|
func Pluralize(s string) string {
|
||||||
|
return getInflected(Plural, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Singularize returns string s in singular form.
|
||||||
|
func Singularize(s string) string {
|
||||||
|
return getInflected(Singular, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
camelizeReg = regexp.MustCompile(`[^A-Za-z0-9]+`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Camelize Converts a word like "send_email" to "SendEmail"
|
||||||
|
func Camelize(s string) string {
|
||||||
|
s = camelizeReg.ReplaceAllString(s, " ")
|
||||||
|
return strings.Replace(strings.Title(s), " ", "", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Camel2id Converts a word like "SendEmail" to "send_email"
|
||||||
|
func Camel2id(name string) string {
|
||||||
|
var (
|
||||||
|
value = commonInitialismsReplacer.Replace(name)
|
||||||
|
buf strings.Builder
|
||||||
|
lastCase, nextCase, nextNumber bool // upper case == true
|
||||||
|
curCase = value[0] <= 'Z' && value[0] >= 'A'
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, v := range value[:len(value)-1] {
|
||||||
|
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
|
||||||
|
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
|
||||||
|
|
||||||
|
if curCase {
|
||||||
|
if lastCase && (nextCase || nextNumber) {
|
||||||
|
buf.WriteRune(v + 32)
|
||||||
|
} else {
|
||||||
|
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
|
||||||
|
buf.WriteByte('_')
|
||||||
|
}
|
||||||
|
buf.WriteRune(v + 32)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastCase = curCase
|
||||||
|
curCase = nextCase
|
||||||
|
}
|
||||||
|
|
||||||
|
if curCase {
|
||||||
|
if !lastCase && len(value) > 1 {
|
||||||
|
buf.WriteByte('_')
|
||||||
|
}
|
||||||
|
buf.WriteByte(value[len(value)-1] + 32)
|
||||||
|
} else {
|
||||||
|
buf.WriteByte(value[len(value)-1])
|
||||||
|
}
|
||||||
|
ret := buf.String()
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// Camel2words Converts a CamelCase name into space-separated words.
|
||||||
|
// For example, 'send_email' will be converted to 'Send Email'.
|
||||||
|
func Camel2words(s string) string {
|
||||||
|
s = camelizeReg.ReplaceAllString(s, " ")
|
||||||
|
return strings.Title(s)
|
||||||
|
}
|
|
@ -0,0 +1,987 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/csv"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"git.nobla.cn/golang/kos/util/arrays"
|
||||||
|
"git.nobla.cn/golang/kos/util/pool"
|
||||||
|
"git.nobla.cn/golang/rest/inflector"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
naming types.Naming //命名规则
|
||||||
|
value reflect.Value //模块值
|
||||||
|
db *gorm.DB //数据库
|
||||||
|
primaryKey string //主键
|
||||||
|
urlPrefix string //url前缀
|
||||||
|
disableDomain bool //禁用域
|
||||||
|
schemaLookup types.SchemaLookupFunc //获取schema的函数
|
||||||
|
valueLookup types.ValueLookupFunc //查看域
|
||||||
|
statement *gorm.Statement //字段声明
|
||||||
|
formatter *Formatter //格式化
|
||||||
|
response types.HttpWriter //HTTP响应
|
||||||
|
hookMgr *hookManager //钩子管理器
|
||||||
|
dirname string //存放文件目录
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
RuntimeScopeKey = &types.RuntimeScope{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// getDB 获取数据库连接对象
|
||||||
|
func (m *Model) getDB() *gorm.DB {
|
||||||
|
return m.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFormatter 获取格式化组件
|
||||||
|
func (m *Model) getFormatter() *Formatter {
|
||||||
|
if m.formatter != nil {
|
||||||
|
return m.formatter
|
||||||
|
}
|
||||||
|
return DefaultFormatter
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHook 获取钩子
|
||||||
|
func (m *Model) getHook() *hookManager {
|
||||||
|
return m.hookMgr
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasScenario 判断是否有该场景
|
||||||
|
func (m *Model) hasScenario(s string) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// setValue 设置字段的值
|
||||||
|
func (m *Model) setValue(refValue reflect.Value, column string, value any) {
|
||||||
|
SetFieldValue(m.statement, refValue, column, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) {
|
||||||
|
SafeSetFileValue(m.statement, refValue, column, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getValue 获取字段的值
|
||||||
|
func (m *Model) getValue(refValue reflect.Value, column string) interface{} {
|
||||||
|
return GetFieldValue(m.statement, refValue, column)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasColumn 判断指定的列是否存在
|
||||||
|
func (m *Model) hasColumn(column string) bool {
|
||||||
|
for _, field := range m.statement.Schema.Fields {
|
||||||
|
if field.DBName == column || field.Name == column {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFilename 获取文件存放目录
|
||||||
|
func (m *Model) getFilename(domain string, spec string, name string) string {
|
||||||
|
if m.dirname == "" {
|
||||||
|
m.dirname = os.TempDir()
|
||||||
|
}
|
||||||
|
filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name)
|
||||||
|
if _, err := os.Stat(path.Dir(filename)); err != nil {
|
||||||
|
_ = os.MkdirAll(path.Dir(filename), 0755)
|
||||||
|
}
|
||||||
|
return filename
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPrimaryKey 查找主键的值
|
||||||
|
func (m *Model) findPrimaryKey(uri string, r *http.Request) string {
|
||||||
|
var (
|
||||||
|
pos int
|
||||||
|
)
|
||||||
|
urlPath := r.URL.Path
|
||||||
|
pos = strings.IndexByte(uri, ':')
|
||||||
|
if pos > 0 {
|
||||||
|
return urlPath[pos:]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseReportColumn 解析报表的列
|
||||||
|
func (m *Model) parseReportColumn(name, props string) *types.SelectColumn {
|
||||||
|
var (
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
)
|
||||||
|
column := &types.SelectColumn{
|
||||||
|
Name: inflector.Camel2id(name),
|
||||||
|
Native: false,
|
||||||
|
}
|
||||||
|
tokens := strings.Split(props, ";")
|
||||||
|
for _, token := range tokens {
|
||||||
|
pair := strings.SplitN(token, ":", 2)
|
||||||
|
if len(pair) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(pair) == 1 {
|
||||||
|
key = strings.TrimSpace(pair[0])
|
||||||
|
value = ""
|
||||||
|
} else {
|
||||||
|
key = strings.TrimSpace(pair[0])
|
||||||
|
value = strings.TrimSpace(pair[1])
|
||||||
|
}
|
||||||
|
switch key {
|
||||||
|
case "native":
|
||||||
|
column.Native = true
|
||||||
|
case "name":
|
||||||
|
column.Name = value
|
||||||
|
case "expr":
|
||||||
|
column.Expr = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return column
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) {
|
||||||
|
modelType := reflect.ValueOf(dest).Type()
|
||||||
|
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
columns := make([]string, 0)
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
scenarios := field.Tag.Get("scenarios")
|
||||||
|
if !hasToken(types.ScenarioList, scenarios) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
isPrimary := field.Tag.Get("is_primary")
|
||||||
|
if isPrimary != "true" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
|
||||||
|
if !column.Native {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if column.Expr == "" {
|
||||||
|
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
|
||||||
|
} else {
|
||||||
|
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
columns = append(columns, "COUNT(*) AS count")
|
||||||
|
query.Select(columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) {
|
||||||
|
modelType := reflect.ValueOf(dest).Type()
|
||||||
|
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
columns := make([]string, 0)
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
scenarios := field.Tag.Get("scenarios")
|
||||||
|
if !hasToken(types.ScenarioList, scenarios) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
|
||||||
|
if !column.Native {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if column.Expr == "" {
|
||||||
|
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
|
||||||
|
} else {
|
||||||
|
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
query.Select(columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCondition 构建sql条件
|
||||||
|
func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
|
||||||
|
return BuildConditions(ctx, r, query, schemas)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModuleName 模块名称
|
||||||
|
func (m *Model) ModuleName() string {
|
||||||
|
return m.naming.ModuleName
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName 表的名称
|
||||||
|
func (m *Model) TableName() string {
|
||||||
|
return m.naming.ModuleName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields 返回搜索的模型的字段
|
||||||
|
func (m *Model) Fields() []*schema.Field {
|
||||||
|
return m.statement.Schema.Fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uri 获取请求的uri
|
||||||
|
func (m *Model) Uri(scenario string) string {
|
||||||
|
ss := make([]string, 4)
|
||||||
|
if m.urlPrefix != "" {
|
||||||
|
ss = append(ss, m.urlPrefix)
|
||||||
|
}
|
||||||
|
switch scenario {
|
||||||
|
case types.ScenarioList:
|
||||||
|
ss = append(ss, m.naming.ModuleName, m.naming.Pluralize)
|
||||||
|
case types.ScenarioView:
|
||||||
|
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||||||
|
case types.ScenarioCreate:
|
||||||
|
ss = append(ss, m.naming.ModuleName, m.naming.Singular)
|
||||||
|
case types.ScenarioUpdate:
|
||||||
|
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||||||
|
case types.ScenarioDelete:
|
||||||
|
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||||||
|
case types.ScenarioExport:
|
||||||
|
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export")
|
||||||
|
case types.ScenarioImport:
|
||||||
|
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import")
|
||||||
|
}
|
||||||
|
uri := path.Join(ss...)
|
||||||
|
if !strings.HasPrefix(uri, "/") {
|
||||||
|
uri = "/" + uri
|
||||||
|
}
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method 获取HTTP请求的方法
|
||||||
|
func (m *Model) Method(scenario string) string {
|
||||||
|
var (
|
||||||
|
method = http.MethodGet
|
||||||
|
)
|
||||||
|
switch scenario {
|
||||||
|
case types.ScenarioCreate:
|
||||||
|
method = http.MethodPost
|
||||||
|
case types.ScenarioUpdate:
|
||||||
|
method = http.MethodPut
|
||||||
|
case types.ScenarioDelete:
|
||||||
|
method = http.MethodDelete
|
||||||
|
}
|
||||||
|
return method
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search 实现通过HTTP方法查找数据
|
||||||
|
func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
err error
|
||||||
|
qs url.Values
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
pageIndex int
|
||||||
|
query *Query
|
||||||
|
domainName string
|
||||||
|
modelSlices reflect.Value
|
||||||
|
modelValues reflect.Value
|
||||||
|
searchSchemas []*types.Schema
|
||||||
|
listSchemas []*types.Schema
|
||||||
|
modelValue reflect.Value
|
||||||
|
scenario string
|
||||||
|
reporter types.Reporter
|
||||||
|
namerTable tableNamer
|
||||||
|
)
|
||||||
|
qs = r.URL.Query()
|
||||||
|
page, _ = strconv.Atoi(qs.Get("page"))
|
||||||
|
pageSize, _ = strconv.Atoi(qs.Get("pagesize"))
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = defaultPageSize
|
||||||
|
}
|
||||||
|
pageIndex = page
|
||||||
|
if pageIndex > 0 {
|
||||||
|
pageIndex--
|
||||||
|
}
|
||||||
|
modelValue = reflect.New(m.value.Type())
|
||||||
|
//这里创建指针类型,这样的话就能在format里面调用函数
|
||||||
|
if m.value.Kind() != reflect.Ptr {
|
||||||
|
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
|
||||||
|
} else {
|
||||||
|
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
|
||||||
|
}
|
||||||
|
modelValues = reflect.New(modelSlices.Type())
|
||||||
|
modelValues.Elem().Set(modelSlices)
|
||||||
|
query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface())
|
||||||
|
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||||
|
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||||
|
Domain: domainName,
|
||||||
|
Request: r,
|
||||||
|
User: m.valueLookup("user", w, r),
|
||||||
|
ModuleName: m.naming.ModuleName,
|
||||||
|
TableName: m.naming.TableName,
|
||||||
|
Scenario: types.ScenarioList,
|
||||||
|
})
|
||||||
|
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scenario = types.ScenarioList
|
||||||
|
if arrays.Exists(r.FormValue("scenario"), allowScenario) {
|
||||||
|
scenario = r.FormValue("scenario")
|
||||||
|
}
|
||||||
|
if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !m.disableDomain {
|
||||||
|
if m.hasColumn(types.FieldDomain) {
|
||||||
|
query.AndWhere(newCondition(types.FieldDomain, domainName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 处理表名逻辑
|
||||||
|
if namerTable, ok = query.Model().(tableNamer); ok {
|
||||||
|
query.From(namerTable.HttpTableName(r))
|
||||||
|
}
|
||||||
|
//处理报表逻辑
|
||||||
|
if reporter, ok = modelValue.Interface().(types.Reporter); ok {
|
||||||
|
query.From(reporter.RealTable())
|
||||||
|
}
|
||||||
|
res := &types.ListResponse{
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
}
|
||||||
|
if reporter == nil {
|
||||||
|
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
|
||||||
|
} else {
|
||||||
|
//如果是报表的情况,需要手动指定COUNT的雨具逻辑才能生效
|
||||||
|
m.buildReporterCountColumns(childCtx, reporter, query)
|
||||||
|
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
|
||||||
|
|
||||||
|
//这里需要重置一下选项,不然会出问题
|
||||||
|
query.ResetSelect()
|
||||||
|
query.GroupBy(reporter.GroupBy(childCtx)...)
|
||||||
|
}
|
||||||
|
query.Offset(pageIndex * pageSize).Limit(pageSize)
|
||||||
|
if res.TotalCount > 0 {
|
||||||
|
if reporter != nil {
|
||||||
|
m.buildReporterQueryColumns(childCtx, reporter, query)
|
||||||
|
}
|
||||||
|
if err = query.All(modelValues.Interface()); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 不进行格式化输出
|
||||||
|
res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format"))
|
||||||
|
} else {
|
||||||
|
res.Data = make([]string, 0)
|
||||||
|
}
|
||||||
|
m.response.Success(w, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 实现通过HTTP方法创建模型
|
||||||
|
func (m *Model) Create(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
model any
|
||||||
|
schemas []*types.Schema
|
||||||
|
diffAttrs []*types.DiffAttr
|
||||||
|
domainName string
|
||||||
|
modelValue reflect.Value
|
||||||
|
)
|
||||||
|
modelValue = reflect.New(m.value.Type())
|
||||||
|
model = modelValue.Interface()
|
||||||
|
if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||||
|
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !m.disableDomain {
|
||||||
|
if m.hasColumn(types.FieldDomain) {
|
||||||
|
m.setValue(modelValue, types.FieldDomain, domainName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
diffAttrs = make([]*types.DiffAttr, 0, 10)
|
||||||
|
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||||
|
Domain: domainName,
|
||||||
|
User: m.valueLookup("user", w, r),
|
||||||
|
Request: r,
|
||||||
|
ModuleName: m.naming.ModuleName,
|
||||||
|
TableName: m.naming.TableName,
|
||||||
|
Scenario: types.ScenarioCreate,
|
||||||
|
Schemas: schemas,
|
||||||
|
})
|
||||||
|
dbSess := m.getDB().WithContext(childCtx)
|
||||||
|
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||||
|
if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tabler, ok := model.(types.Tabler); ok {
|
||||||
|
errTx = tx.Table(tabler.TableName()).Save(model).Error
|
||||||
|
} else {
|
||||||
|
errTx = tx.Save(model).Error
|
||||||
|
}
|
||||||
|
if errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, row := range schemas {
|
||||||
|
diffAttrs = append(diffAttrs, &types.DiffAttr{
|
||||||
|
Column: row.Column,
|
||||||
|
Label: row.Label,
|
||||||
|
OldValue: nil,
|
||||||
|
NewValue: m.getValue(modelValue, row.Column),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}); err == nil {
|
||||||
|
res := &types.CreateResponse{
|
||||||
|
ID: m.getValue(modelValue, m.primaryKey),
|
||||||
|
Status: "created",
|
||||||
|
}
|
||||||
|
if creator, ok := model.(afterCreated); ok {
|
||||||
|
creator.AfterCreated(childCtx, dbSess)
|
||||||
|
}
|
||||||
|
if preserver, ok := model.(afterSaved); ok {
|
||||||
|
preserver.AfterSaved(childCtx, dbSess)
|
||||||
|
}
|
||||||
|
m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs)
|
||||||
|
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
|
||||||
|
m.response.Success(w, res)
|
||||||
|
} else {
|
||||||
|
m.response.Failure(w, types.RequestCreateFailure, err.Error(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 实现通过HTTP方法更新模型
|
||||||
|
func (m *Model) Update(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
model any
|
||||||
|
schemas []*types.Schema
|
||||||
|
diffAttrs []*types.DiffAttr
|
||||||
|
domainName string
|
||||||
|
modelValue reflect.Value
|
||||||
|
oldValues map[string]any
|
||||||
|
)
|
||||||
|
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
|
||||||
|
modelValue = reflect.New(m.value.Type())
|
||||||
|
model = modelValue.Interface()
|
||||||
|
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||||
|
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conditions := map[string]any{
|
||||||
|
m.primaryKey: idStr,
|
||||||
|
}
|
||||||
|
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oldValues = make(map[string]any)
|
||||||
|
for _, row := range schemas {
|
||||||
|
oldValues[row.Column] = m.getValue(modelValue, row.Column)
|
||||||
|
}
|
||||||
|
if err = json.NewDecoder(r.Body).Decode(model); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
diffAttrs = make([]*types.DiffAttr, 0, 10)
|
||||||
|
updates := make(map[string]any)
|
||||||
|
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||||
|
Domain: domainName,
|
||||||
|
Request: r,
|
||||||
|
User: m.valueLookup("user", w, r),
|
||||||
|
ModuleName: m.naming.ModuleName,
|
||||||
|
TableName: m.naming.TableName,
|
||||||
|
Scenario: types.ScenarioUpdate,
|
||||||
|
Schemas: schemas,
|
||||||
|
PrimaryKeyValue: idStr,
|
||||||
|
})
|
||||||
|
dbSess := m.getDB().WithContext(childCtx)
|
||||||
|
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||||
|
if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, row := range schemas {
|
||||||
|
v := m.getValue(modelValue, row.Column)
|
||||||
|
if oldValues[row.Column] != v {
|
||||||
|
updates[row.Column] = v
|
||||||
|
diffAttrs = append(diffAttrs, &types.DiffAttr{
|
||||||
|
Column: row.Column,
|
||||||
|
Label: row.Label,
|
||||||
|
OldValue: oldValues[row.Column],
|
||||||
|
NewValue: v,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(updates) > 0 {
|
||||||
|
if tabler, ok := model.(types.Tabler); ok {
|
||||||
|
errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error
|
||||||
|
} else {
|
||||||
|
errTx = tx.Model(model).Updates(updates).Error
|
||||||
|
}
|
||||||
|
if errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}); err == nil {
|
||||||
|
if updater, ok := model.(afterUpdated); ok {
|
||||||
|
updater.AfterUpdated(childCtx, dbSess)
|
||||||
|
}
|
||||||
|
if preserver, ok := model.(afterSaved); ok {
|
||||||
|
preserver.AfterSaved(childCtx, dbSess)
|
||||||
|
}
|
||||||
|
m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs)
|
||||||
|
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
|
||||||
|
m.response.Success(w, types.UpdateResponse{
|
||||||
|
ID: idStr,
|
||||||
|
Status: "updated",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 实现通过HTTP方法删除模型
|
||||||
|
func (m *Model) Delete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
model any
|
||||||
|
modelValue reflect.Value
|
||||||
|
)
|
||||||
|
idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r)
|
||||||
|
modelValue = reflect.New(m.value.Type())
|
||||||
|
model = modelValue.Interface()
|
||||||
|
conditions := map[string]any{
|
||||||
|
m.primaryKey: idStr,
|
||||||
|
}
|
||||||
|
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||||
|
Domain: m.valueLookup(types.FieldDomain, w, r),
|
||||||
|
User: m.valueLookup("user", w, r),
|
||||||
|
Request: r,
|
||||||
|
ModuleName: m.naming.ModuleName,
|
||||||
|
TableName: m.naming.TableName,
|
||||||
|
Scenario: types.ScenarioDelete,
|
||||||
|
PrimaryKeyValue: idStr,
|
||||||
|
})
|
||||||
|
dbSess := m.getDB().WithContext(childCtx)
|
||||||
|
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||||
|
if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tabler, ok := model.(types.Tabler); ok {
|
||||||
|
errTx = tx.Table(tabler.TableName()).Delete(model).Error
|
||||||
|
} else {
|
||||||
|
errTx = tx.Delete(model).Error
|
||||||
|
}
|
||||||
|
if errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.getHook().afterDelete(childCtx, tx, model)
|
||||||
|
return
|
||||||
|
}); err == nil {
|
||||||
|
m.response.Success(w, types.DeleteResponse{
|
||||||
|
ID: idStr,
|
||||||
|
Status: "deleted",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// View 查看数据详情
|
||||||
|
func (m *Model) View(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
model any
|
||||||
|
modelValue reflect.Value
|
||||||
|
qs url.Values
|
||||||
|
schemas []*types.Schema
|
||||||
|
scenario string
|
||||||
|
domainName string
|
||||||
|
)
|
||||||
|
qs = r.URL.Query()
|
||||||
|
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
|
||||||
|
modelValue = reflect.New(m.value.Type())
|
||||||
|
model = modelValue.Interface()
|
||||||
|
conditions := map[string]any{
|
||||||
|
m.primaryKey: idStr,
|
||||||
|
}
|
||||||
|
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||||
|
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scenario = qs.Get("scenario")
|
||||||
|
if scenario == "" {
|
||||||
|
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView)
|
||||||
|
} else {
|
||||||
|
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format")))
|
||||||
|
} else {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export 实现通过HTTP方法导出模型
|
||||||
|
func (m *Model) Export(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
query *Query
|
||||||
|
modelSlices reflect.Value
|
||||||
|
modelValues reflect.Value
|
||||||
|
searchSchemas []*types.Schema
|
||||||
|
exportSchemas []*types.Schema
|
||||||
|
domainName string
|
||||||
|
fp *os.File
|
||||||
|
modelValue reflect.Value
|
||||||
|
)
|
||||||
|
if !m.hasScenario(types.ScenarioList) {
|
||||||
|
m.response.Failure(w, types.RequestDenied, "request denied", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||||
|
filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
|
||||||
|
if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = fp.Close()
|
||||||
|
}()
|
||||||
|
modelValue = reflect.New(m.value.Type())
|
||||||
|
//这里创建指针类型,这样的话就能在format里面调用函数
|
||||||
|
if m.value.Kind() != reflect.Ptr {
|
||||||
|
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
|
||||||
|
} else {
|
||||||
|
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
|
||||||
|
}
|
||||||
|
modelValues = reflect.New(modelSlices.Type())
|
||||||
|
modelValues.Elem().Set(modelSlices)
|
||||||
|
query = NewQuery(m.getDB(), modelValue.Interface())
|
||||||
|
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||||||
|
Domain: domainName,
|
||||||
|
Request: r,
|
||||||
|
User: m.valueLookup("user", w, r),
|
||||||
|
ModuleName: m.naming.ModuleName,
|
||||||
|
TableName: m.naming.TableName,
|
||||||
|
Scenario: types.ScenarioExport,
|
||||||
|
})
|
||||||
|
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !m.disableDomain {
|
||||||
|
if m.hasColumn(types.FieldDomain) {
|
||||||
|
query.AndWhere(newCondition(types.FieldDomain, domainName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 处理表名逻辑
|
||||||
|
if namerTable, ok := query.Model().(tableNamer); ok {
|
||||||
|
query.From(namerTable.HttpTableName(r))
|
||||||
|
}
|
||||||
|
//处理报表逻辑
|
||||||
|
if reporter, ok := modelValue.Interface().(types.Reporter); ok {
|
||||||
|
query.From(reporter.RealTable())
|
||||||
|
query.GroupBy(reporter.GroupBy(childCtx)...)
|
||||||
|
m.buildReporterQueryColumns(childCtx, reporter, query)
|
||||||
|
}
|
||||||
|
if err = query.All(modelValues.Interface()); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/csv")
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
|
||||||
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
|
||||||
|
value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "")
|
||||||
|
writer := csv.NewWriter(fp)
|
||||||
|
rows := make([]string, len(exportSchemas))
|
||||||
|
for i, field := range exportSchemas {
|
||||||
|
rows[i] = field.Label
|
||||||
|
}
|
||||||
|
_ = writer.Write(rows)
|
||||||
|
if values, ok := value.([]any); ok {
|
||||||
|
for _, val := range values {
|
||||||
|
row, ok2 := val.(map[string]any)
|
||||||
|
if !ok2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for i, field := range exportSchemas {
|
||||||
|
if v, ok := row[field.Column]; ok {
|
||||||
|
rows[i] = fmt.Sprint(v)
|
||||||
|
} else {
|
||||||
|
rows[i] = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = writer.Write(rows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writer.Flush()
|
||||||
|
m.getHook().afterExport(childCtx, filename)
|
||||||
|
http.ServeContent(w, r, path.Base(filename), time.Now(), fp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findSchema 查找指定的schema
|
||||||
|
func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema {
|
||||||
|
for _, row := range schemas {
|
||||||
|
if row.Label == label {
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// importInternal 文件上传方法
|
||||||
|
func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
rows []string
|
||||||
|
fp *os.File
|
||||||
|
tm time.Time
|
||||||
|
fields []string
|
||||||
|
sess *gorm.DB
|
||||||
|
csvReader *csv.Reader
|
||||||
|
csvWriter *csv.Writer
|
||||||
|
modelValue reflect.Value
|
||||||
|
modelEntity any
|
||||||
|
diffAttrs []*types.DiffAttr
|
||||||
|
result *types.ImportResult
|
||||||
|
failureFp *os.File
|
||||||
|
failureFile string
|
||||||
|
)
|
||||||
|
tm = time.Now()
|
||||||
|
result = &types.ImportResult{}
|
||||||
|
if fp, err = os.Open(filename); err != nil {
|
||||||
|
result.Code = types.ErrImportFileNotExists
|
||||||
|
goto __end
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = fp.Close()
|
||||||
|
}()
|
||||||
|
csvReader = csv.NewReader(fp)
|
||||||
|
if rows, err = csvReader.Read(); err != nil {
|
||||||
|
result.Code = types.ErrImportFileUnavailable
|
||||||
|
goto __end
|
||||||
|
}
|
||||||
|
fields = make([]string, 0, len(rows))
|
||||||
|
for _, s := range rows {
|
||||||
|
v := m.findSchema(s, schemas)
|
||||||
|
if v == nil {
|
||||||
|
result.Code = types.ErrImportColumnNotMatch
|
||||||
|
goto __end
|
||||||
|
}
|
||||||
|
fields = append(fields, v.Column)
|
||||||
|
}
|
||||||
|
sess = m.getDB().WithContext(ctx)
|
||||||
|
//失败文件指针
|
||||||
|
failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix()))
|
||||||
|
if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = failureFp.Close()
|
||||||
|
}()
|
||||||
|
csvWriter = csv.NewWriter(failureFp)
|
||||||
|
rows = append(rows, "Error")
|
||||||
|
_ = csvWriter.Write(rows)
|
||||||
|
diffAttrs = make([]*types.DiffAttr, len(schemas))
|
||||||
|
for {
|
||||||
|
if rows, err = csvReader.Read(); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result.TotalCount++
|
||||||
|
if len(rows) != len(fields) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modelValue = reflect.New(m.value.Type())
|
||||||
|
for idx, field := range fields {
|
||||||
|
m.safeSetValue(modelValue, field, rows[idx])
|
||||||
|
}
|
||||||
|
if len(extraFields) > 0 {
|
||||||
|
for k, v := range extraFields {
|
||||||
|
m.safeSetValue(modelValue, k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelEntity = modelValue.Interface()
|
||||||
|
//写入数据
|
||||||
|
if fast {
|
||||||
|
//如果是快速模式,直接存储数据
|
||||||
|
if err = sess.Save(modelEntity).Error; err == nil {
|
||||||
|
result.SuccessCount++
|
||||||
|
} else {
|
||||||
|
rows = append(rows, err.Error())
|
||||||
|
_ = csvWriter.Write(rows)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err = sess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||||||
|
if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tabler, ok := modelEntity.(types.Tabler); ok {
|
||||||
|
errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error
|
||||||
|
} else {
|
||||||
|
errTx = tx.Save(modelEntity).Error
|
||||||
|
}
|
||||||
|
if errTx != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for idx, row := range schemas {
|
||||||
|
diffAttrs[idx] = &types.DiffAttr{
|
||||||
|
Column: row.Column,
|
||||||
|
Label: row.Label,
|
||||||
|
NewValue: m.getValue(modelValue, row.Column),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs)
|
||||||
|
m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs)
|
||||||
|
return
|
||||||
|
}); err == nil {
|
||||||
|
result.SuccessCount++
|
||||||
|
} else {
|
||||||
|
rows = append(rows, err.Error())
|
||||||
|
_ = csvWriter.Write(rows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
csvWriter.Flush()
|
||||||
|
__end:
|
||||||
|
result.UploadFile = filename
|
||||||
|
if result.TotalCount > result.SuccessCount {
|
||||||
|
result.FailureFile = failureFile
|
||||||
|
}
|
||||||
|
result.Duration = time.Now().Sub(tm)
|
||||||
|
m.getHook().afterImport(ctx, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import 实现通过HTTP方法导入
|
||||||
|
func (m *Model) Import(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
fast bool
|
||||||
|
schemas []*types.Schema
|
||||||
|
rows []string
|
||||||
|
domainName string
|
||||||
|
dst *os.File
|
||||||
|
fp multipart.File
|
||||||
|
csvWriter *csv.Writer
|
||||||
|
qs url.Values
|
||||||
|
extraFields map[string]string
|
||||||
|
)
|
||||||
|
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||||||
|
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//这里用background的context
|
||||||
|
childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{
|
||||||
|
Domain: domainName,
|
||||||
|
User: m.valueLookup("user", w, r),
|
||||||
|
ModuleName: m.naming.ModuleName,
|
||||||
|
TableName: m.naming.TableName,
|
||||||
|
Scenario: types.ScenarioImport,
|
||||||
|
Schemas: schemas,
|
||||||
|
})
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
//下载导入模板
|
||||||
|
csvWriter = csv.NewWriter(w)
|
||||||
|
rows = make([]string, 0, len(schemas))
|
||||||
|
for _, row := range schemas {
|
||||||
|
//主键不需要导入
|
||||||
|
if row.IsPrimaryKey == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rows = append(rows, row.Label)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/csv")
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
|
||||||
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
|
||||||
|
err = csvWriter.Write(rows)
|
||||||
|
csvWriter.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
|
||||||
|
if fp, _, err = r.FormFile("file"); err != nil {
|
||||||
|
m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = fp.Close()
|
||||||
|
}()
|
||||||
|
if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil {
|
||||||
|
buf := pool.GetBytes(32 * 1024)
|
||||||
|
_, err = io.CopyBuffer(dst, fp, buf)
|
||||||
|
pool.PutBytes(buf)
|
||||||
|
_ = dst.Close()
|
||||||
|
} else {
|
||||||
|
m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
qs = r.URL.Query()
|
||||||
|
if qs != nil {
|
||||||
|
extraFields = make(map[string]string)
|
||||||
|
for k, _ := range qs {
|
||||||
|
if strings.HasPrefix(k, "_attr_") {
|
||||||
|
extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fast, _ = strconv.ParseBool(qs.Get("__fast"))
|
||||||
|
go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields)
|
||||||
|
m.response.Success(w, types.ImportResponse{
|
||||||
|
UID: m.valueLookup("user", w, r),
|
||||||
|
Status: "committed",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// newModel 创建一个模型
|
||||||
|
func newModel(v any, db *gorm.DB, naming types.Naming) *Model {
|
||||||
|
model := &Model{
|
||||||
|
db: db,
|
||||||
|
naming: naming,
|
||||||
|
response: &httpWriter{},
|
||||||
|
value: reflect.Indirect(reflect.ValueOf(v)),
|
||||||
|
valueLookup: defaultValueLookup,
|
||||||
|
}
|
||||||
|
model.statement = &gorm.Statement{
|
||||||
|
DB: model.getDB(),
|
||||||
|
ConnPool: model.getDB().ConnPool,
|
||||||
|
Clauses: map[string]clause.Clause{},
|
||||||
|
}
|
||||||
|
if err := model.statement.Parse(v); err == nil {
|
||||||
|
if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 {
|
||||||
|
model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return model
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import "git.nobla.cn/golang/rest/types"
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
urlPrefix string
|
||||||
|
moduleName string
|
||||||
|
disableDomain bool
|
||||||
|
router types.HttpRouter
|
||||||
|
writer types.HttpWriter
|
||||||
|
formatter *Formatter
|
||||||
|
dirname string //文件目录
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(o *Options)
|
||||||
|
|
||||||
|
func WithUriPrefix(s string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.urlPrefix = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithModuleName(s string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.moduleName = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithoutDomain 禁用域
|
||||||
|
func WithoutDomain() Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.disableDomain = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHttpRouter(s types.HttpRouter) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.router = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHttpWriter(s types.HttpWriter) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.writer = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithFormatter(s *Formatter) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.formatter = s
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"git.nobla.cn/golang/kos/pkg/cache"
|
||||||
|
xxhash "github.com/cespare/xxhash/v2"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/callbacks"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DisableCache = "DISABLE_CACHE"
|
||||||
|
DurationKey = "gorm:cache_duration"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cacher struct {
|
||||||
|
rawQuery func(db *gorm.DB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cacher) Name() string {
|
||||||
|
return "gorm:cache"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cacher) Initialize(db *gorm.DB) (err error) {
|
||||||
|
c.rawQuery = db.Callback().Query().Get("gorm:query")
|
||||||
|
err = db.Callback().Query().Replace("gorm:query", c.Query)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCacheKey 构建一个缓存的KEY
|
||||||
|
func (c *Cacher) buildCacheKey(db *gorm.DB) string {
|
||||||
|
s := strconv.FormatUint(xxhash.Sum64String(db.Statement.SQL.String()+fmt.Sprintf("%v", db.Statement.Vars)), 10)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDuration 获取缓存时长
|
||||||
|
func (c *Cacher) getDuration(db *gorm.DB) time.Duration {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
v any
|
||||||
|
duration time.Duration
|
||||||
|
)
|
||||||
|
if v, ok = db.InstanceGet(DurationKey); !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if duration, ok = v.(time.Duration); !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryLoad 尝试从缓存读取数据
|
||||||
|
func (c *Cacher) tryLoad(key string, db *gorm.DB) (err error) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
buf []byte
|
||||||
|
)
|
||||||
|
if buf, ok = cache.Get(db.Statement.Context, key); ok {
|
||||||
|
err = json.Unmarshal(buf, db.Statement.Dest)
|
||||||
|
} else {
|
||||||
|
err = os.ErrNotExist
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeCache 存储缓存数据
|
||||||
|
func (c *Cacher) storeCache(key string, db *gorm.DB, duration time.Duration) (err error) {
|
||||||
|
var (
|
||||||
|
buf []byte
|
||||||
|
)
|
||||||
|
if buf, err = json.Marshal(db.Statement.Dest); err == nil {
|
||||||
|
cache.SetEx(db.Statement.Context, key, buf, duration)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cacher) Query(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
cacheKey string
|
||||||
|
duration time.Duration
|
||||||
|
)
|
||||||
|
duration = c.getDuration(db)
|
||||||
|
if duration <= 0 {
|
||||||
|
c.rawQuery(db)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callbacks.BuildQuerySQL(db)
|
||||||
|
cacheKey = c.buildCacheKey(db)
|
||||||
|
if err = c.tryLoad(cacheKey, db); err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.rawQuery(db)
|
||||||
|
if db.Error == nil {
|
||||||
|
//store cache
|
||||||
|
if err = c.storeCache(cacheKey, db, duration); err != nil {
|
||||||
|
_ = db.AddError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Cacher {
|
||||||
|
return &Cacher{}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package identity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/rs/xid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Identify struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (identity *Identify) Name() string {
|
||||||
|
return "gorm:identity"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (identity *Identify) Initialize(db *gorm.DB) (err error) {
|
||||||
|
err = db.Callback().Create().Before("gorm:create").Register("auto_identified", identity.Grant)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (identity *Identify) NextID() string {
|
||||||
|
return xid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (identity *Identify) Grant(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
field *schema.Field
|
||||||
|
)
|
||||||
|
if db.Statement.Schema == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if field = db.Statement.Schema.LookUpField("ID"); field == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if field.DataType != schema.String {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
|
if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue.Index(i)); zero {
|
||||||
|
if err = field.Set(db.Statement.Context, db.Statement.ReflectValue.Index(i), identity.NextID()); err != nil {
|
||||||
|
_ = db.AddError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); zero {
|
||||||
|
db.Statement.SetColumn("ID", identity.NextID())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Identify {
|
||||||
|
return &Identify{}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
# 分表实现
|
||||||
|
|
||||||
|
首先定义一个`gorm`的模型,然后实现`shadring.Model`接口,比如如下示例
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ShardingTable 返回增删改时候操作的数据表
|
||||||
|
func (model *CdrLog) ShardingTable(scene string) string {
|
||||||
|
return model.TableName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShardingTables 返回查询时候一个范围内的表
|
||||||
|
func (model *CdrLog) ShardingTables(ctx *sharding.Context) []string {
|
||||||
|
var (
|
||||||
|
timestamp int64
|
||||||
|
)
|
||||||
|
timeRange := make([]int64, 0)
|
||||||
|
values := ctx.FindColumnValues("start_stamp")
|
||||||
|
if len(values) == 0 {
|
||||||
|
values = ctx.FindColumnValues("create_stamp")
|
||||||
|
}
|
||||||
|
if len(values) > 0 {
|
||||||
|
for _, v := range values {
|
||||||
|
timestamp, _ = strconv.ParseInt(fmt.Sprint(v), 10, 64)
|
||||||
|
timeRange = append(timeRange, timestamp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return shard.DateTableNames(ctx.Context(), "cdr_logs", shard.ShardTypeDateMonth, timeRange)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`ShardingTable`方法是操作增删改的是回调具体表名的方法
|
||||||
|
|
||||||
|
`ShardingTables`方法是操作查询的时候,通过查询条件返回的表名的方法
|
|
@ -0,0 +1,120 @@
|
||||||
|
package sharding
|
||||||
|
|
||||||
|
const (
|
||||||
|
ValueOperaEqual = iota + 0x10
|
||||||
|
ValueOperaGreater
|
||||||
|
ValueOperaLess
|
||||||
|
ValueOperaRange
|
||||||
|
|
||||||
|
ValueTypeString = iota + 0x30
|
||||||
|
ValueTypeNumber
|
||||||
|
ValueTypeBoolean
|
||||||
|
ValueTypeNull
|
||||||
|
ValueTypeAny
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
ColumnCondition struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Value CondValue `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
CondValue interface {
|
||||||
|
Type() int
|
||||||
|
Opera() int
|
||||||
|
Value() any
|
||||||
|
}
|
||||||
|
|
||||||
|
equalValue struct {
|
||||||
|
vType int
|
||||||
|
vData any
|
||||||
|
}
|
||||||
|
|
||||||
|
rangeValue struct {
|
||||||
|
vType int
|
||||||
|
vData any
|
||||||
|
}
|
||||||
|
|
||||||
|
greaterValue struct {
|
||||||
|
vType int
|
||||||
|
vData any
|
||||||
|
}
|
||||||
|
|
||||||
|
lessValue struct {
|
||||||
|
vType int
|
||||||
|
vData any
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (v *lessValue) Type() int {
|
||||||
|
return v.vType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *lessValue) Opera() int {
|
||||||
|
return ValueOperaLess
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *lessValue) Value() any {
|
||||||
|
return v.vData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *greaterValue) Type() int {
|
||||||
|
return v.vType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *greaterValue) Opera() int {
|
||||||
|
return ValueOperaGreater
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *greaterValue) Value() any {
|
||||||
|
return v.vData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *rangeValue) Type() int {
|
||||||
|
return v.vType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *rangeValue) Opera() int {
|
||||||
|
return ValueOperaRange
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *rangeValue) Value() any {
|
||||||
|
return v.vData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *equalValue) Type() int {
|
||||||
|
return v.vType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *equalValue) Opera() int {
|
||||||
|
return ValueOperaEqual
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *equalValue) Value() any {
|
||||||
|
return v.vData
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCondValue(vType int, op int, value any) CondValue {
|
||||||
|
switch op {
|
||||||
|
case ValueOperaGreater:
|
||||||
|
return &greaterValue{
|
||||||
|
vType: vType,
|
||||||
|
vData: value,
|
||||||
|
}
|
||||||
|
case ValueOperaLess:
|
||||||
|
return &lessValue{
|
||||||
|
vType: vType,
|
||||||
|
vData: value,
|
||||||
|
}
|
||||||
|
case ValueOperaRange:
|
||||||
|
return &rangeValue{
|
||||||
|
vType: vType,
|
||||||
|
vData: value,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return &equalValue{
|
||||||
|
vType: vType,
|
||||||
|
vData: value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,287 @@
|
||||||
|
package sharding
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/longbridgeapp/sqlparser"
|
||||||
|
sqlparserX "github.com/uole/sqlparser"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Scope struct {
|
||||||
|
db *gorm.DB
|
||||||
|
stmt *sqlparser.SelectStatement
|
||||||
|
stmtX *sqlparserX.Select
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) findValue(express sqlparser.Expr) (int, any) {
|
||||||
|
var (
|
||||||
|
vType int
|
||||||
|
vData any
|
||||||
|
)
|
||||||
|
switch expr := express.(type) {
|
||||||
|
case *sqlparser.BindExpr:
|
||||||
|
vType = ValueTypeAny
|
||||||
|
if len(scope.db.Statement.Vars) > expr.Pos {
|
||||||
|
vData = scope.db.Statement.Vars[expr.Pos]
|
||||||
|
} else {
|
||||||
|
vType = ValueTypeNull
|
||||||
|
}
|
||||||
|
case *sqlparser.NumberLit:
|
||||||
|
vType = ValueTypeNumber
|
||||||
|
vData = expr.Value
|
||||||
|
case *sqlparser.StringLit:
|
||||||
|
vType = ValueTypeString
|
||||||
|
vData = expr.Value
|
||||||
|
case *sqlparser.BoolLit:
|
||||||
|
vType = ValueTypeBoolean
|
||||||
|
vData = expr.Value
|
||||||
|
case *sqlparser.BlobLit:
|
||||||
|
vType = ValueTypeString
|
||||||
|
vData = expr.Value
|
||||||
|
case *sqlparser.NullLit:
|
||||||
|
vType = ValueTypeNull
|
||||||
|
case *sqlparser.Range:
|
||||||
|
arr := make([]any, 2)
|
||||||
|
vType, arr[0] = scope.findValue(expr.X)
|
||||||
|
vType, arr[1] = scope.findValue(expr.Y)
|
||||||
|
vData = arr
|
||||||
|
}
|
||||||
|
return vType, vData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) {
|
||||||
|
var (
|
||||||
|
vType int
|
||||||
|
vData any
|
||||||
|
)
|
||||||
|
switch express.Type {
|
||||||
|
case sqlparserX.IntVal:
|
||||||
|
vType = ValueTypeNumber
|
||||||
|
vData, _ = strconv.Atoi(string(express.Val))
|
||||||
|
case sqlparserX.FloatVal:
|
||||||
|
vType = ValueTypeNumber
|
||||||
|
vData, _ = strconv.ParseFloat(string(express.Val), 64)
|
||||||
|
case sqlparserX.ValArg:
|
||||||
|
vType = ValueTypeAny
|
||||||
|
pos, _ := strconv.Atoi(string(express.Val[2:]))
|
||||||
|
if pos > 0 {
|
||||||
|
pos = pos - 1
|
||||||
|
if len(scope.db.Statement.Vars) > pos {
|
||||||
|
vData = scope.db.Statement.Vars[pos]
|
||||||
|
} else {
|
||||||
|
vType = ValueTypeNull
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vType = ValueTypeNull
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
vType = ValueTypeString
|
||||||
|
vData = string(express.Val)
|
||||||
|
}
|
||||||
|
return vType, vData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
andExpr *sqlparserX.AndExpr
|
||||||
|
orExpr *sqlparserX.OrExpr
|
||||||
|
parentExpr *sqlparserX.ParenExpr
|
||||||
|
comparisonExpr *sqlparserX.ComparisonExpr
|
||||||
|
rangeExpr *sqlparserX.RangeCond
|
||||||
|
coumnExpr *sqlparserX.ColName
|
||||||
|
valueExpr *sqlparserX.SQLVal
|
||||||
|
)
|
||||||
|
conditions = make([]*ColumnCondition, 0, 2)
|
||||||
|
if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok {
|
||||||
|
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if coumnExpr.Name.EqualString(column) {
|
||||||
|
cond := &ColumnCondition{
|
||||||
|
Name: coumnExpr.Name.String(),
|
||||||
|
}
|
||||||
|
vType, vData := scope.findValueX(valueExpr)
|
||||||
|
switch comparisonExpr.Operator {
|
||||||
|
case sqlparserX.LessThanStr, sqlparserX.LessEqualStr:
|
||||||
|
cond.Value = newCondValue(vType, ValueOperaLess, vData)
|
||||||
|
case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr:
|
||||||
|
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
|
||||||
|
case sqlparserX.EqualStr:
|
||||||
|
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
|
||||||
|
}
|
||||||
|
if cond.Value != nil {
|
||||||
|
conditions = append(conditions, cond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok {
|
||||||
|
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if coumnExpr.Name.EqualString(column) {
|
||||||
|
vType := 0
|
||||||
|
arr := make([]any, 2)
|
||||||
|
if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok {
|
||||||
|
vType, arr[0] = scope.findValueX(valueExpr)
|
||||||
|
}
|
||||||
|
if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok {
|
||||||
|
vType, arr[1] = scope.findValueX(valueExpr)
|
||||||
|
}
|
||||||
|
conditions = append(conditions, &ColumnCondition{
|
||||||
|
Name: coumnExpr.Name.String(),
|
||||||
|
Value: newCondValue(vType, ValueOperaRange, arr),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if andExpr, ok = expr.(*sqlparserX.AndExpr); ok {
|
||||||
|
if andExpr.Left != nil {
|
||||||
|
conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...)
|
||||||
|
}
|
||||||
|
if andExpr.Right != nil {
|
||||||
|
conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if orExpr, ok = expr.(*sqlparserX.OrExpr); ok {
|
||||||
|
if orExpr.Left != nil {
|
||||||
|
conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...)
|
||||||
|
}
|
||||||
|
if orExpr.Right != nil {
|
||||||
|
conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok {
|
||||||
|
if parentExpr.Expr != nil {
|
||||||
|
conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conditions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
identExpr *sqlparser.Ident
|
||||||
|
binaryExpr *sqlparser.BinaryExpr
|
||||||
|
parentExpr *sqlparser.ParenExpr
|
||||||
|
conditions []*ColumnCondition
|
||||||
|
)
|
||||||
|
conditions = make([]*ColumnCondition, 0, 2)
|
||||||
|
if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok {
|
||||||
|
if parentExpr.X != nil {
|
||||||
|
if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok {
|
||||||
|
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
|
||||||
|
}
|
||||||
|
if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok {
|
||||||
|
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok {
|
||||||
|
if binaryExpr.X != nil {
|
||||||
|
if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok {
|
||||||
|
if identExpr.Name == column {
|
||||||
|
cond := &ColumnCondition{
|
||||||
|
Name: identExpr.Name,
|
||||||
|
}
|
||||||
|
vType, vData := scope.findValue(binaryExpr.Y)
|
||||||
|
switch binaryExpr.Op {
|
||||||
|
case sqlparser.LT, sqlparser.LE:
|
||||||
|
cond.Value = newCondValue(vType, ValueOperaLess, vData)
|
||||||
|
case sqlparser.GT, sqlparser.GE:
|
||||||
|
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
|
||||||
|
case sqlparser.RANGE, sqlparser.BETWEEN:
|
||||||
|
cond.Value = newCondValue(vType, ValueOperaRange, vData)
|
||||||
|
case sqlparser.EQ:
|
||||||
|
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
|
||||||
|
}
|
||||||
|
if cond.Value != nil {
|
||||||
|
conditions = append(conditions, cond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok {
|
||||||
|
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
|
||||||
|
}
|
||||||
|
if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok {
|
||||||
|
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if binaryExpr.Y != nil {
|
||||||
|
if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok {
|
||||||
|
conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return conditions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) DB() *gorm.DB {
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) Context() context.Context {
|
||||||
|
return scope.db.Statement.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) FindCondition(column string) []*ColumnCondition {
|
||||||
|
if scope.stmtX != nil {
|
||||||
|
if scope.stmtX.Where == nil {
|
||||||
|
return []*ColumnCondition{}
|
||||||
|
}
|
||||||
|
return scope.recursiveFindX(scope.stmtX.Where.Expr, column)
|
||||||
|
}
|
||||||
|
return scope.recursiveFind(scope.stmt.Condition, column)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) FindColumnValues(column string) []any {
|
||||||
|
result := make([]any, 0)
|
||||||
|
conditions := scope.FindCondition(column)
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
for _, cond := range conditions {
|
||||||
|
if cond.Value.Opera() == ValueOperaGreater {
|
||||||
|
if len(result) == 0 {
|
||||||
|
result = make([]any, 2)
|
||||||
|
}
|
||||||
|
result[0] = cond.Value.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
if cond.Value.Opera() == ValueOperaLess {
|
||||||
|
if len(result) == 0 {
|
||||||
|
result = make([]any, 2)
|
||||||
|
}
|
||||||
|
result[1] = cond.Value.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
if cond.Value.Opera() == ValueOperaEqual {
|
||||||
|
result = append(result, cond.Value.Value())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if cond.Value.Opera() == ValueOperaRange {
|
||||||
|
if vs, ok := cond.Value.Value().([]any); ok {
|
||||||
|
if len(vs) == 2 {
|
||||||
|
if len(result) == 0 {
|
||||||
|
result = make([]any, 2)
|
||||||
|
}
|
||||||
|
result[0] = vs[0]
|
||||||
|
result[1] = vs[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
|
@ -0,0 +1,476 @@
|
||||||
|
package sharding
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/longbridgeapp/sqlparser"
|
||||||
|
sqlparserX "github.com/uole/sqlparser"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/callbacks"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sharding struct {
|
||||||
|
UnionAll bool
|
||||||
|
QuoteChar byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plugin *Sharding) Name() string {
|
||||||
|
return "gorm:sharding"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plugin *Sharding) Initialize(db *gorm.DB) (err error) {
|
||||||
|
if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plugin *Sharding) Create(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
scopeModel Model
|
||||||
|
refValue reflect.Value
|
||||||
|
modelValue any
|
||||||
|
)
|
||||||
|
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||||||
|
if db.Statement.ReflectValue.Len() > 0 {
|
||||||
|
refValue = db.Statement.ReflectValue.Index(0)
|
||||||
|
modelValue = refValue.Interface()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if db.Statement.Model != nil {
|
||||||
|
modelValue = db.Statement.Model
|
||||||
|
} else {
|
||||||
|
modelValue = db.Statement.ReflectValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if modelValue != nil {
|
||||||
|
if scopeModel, ok = modelValue.(Model); ok {
|
||||||
|
db.Table(scopeModel.ShardingTable(sceneCreate))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plugin *Sharding) Update(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
scopeModel Model
|
||||||
|
refValue reflect.Value
|
||||||
|
modelValue any
|
||||||
|
)
|
||||||
|
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||||||
|
if db.Statement.ReflectValue.Len() > 0 {
|
||||||
|
refValue = db.Statement.ReflectValue.Index(0)
|
||||||
|
modelValue = refValue.Interface()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if db.Statement.Model != nil {
|
||||||
|
modelValue = db.Statement.Model
|
||||||
|
} else {
|
||||||
|
modelValue = db.Statement.ReflectValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if modelValue != nil {
|
||||||
|
if scopeModel, ok = modelValue.(Model); ok {
|
||||||
|
db.Table(scopeModel.ShardingTable(sceneUpdate))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plugin *Sharding) Delete(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
scopeModel Model
|
||||||
|
refValue reflect.Value
|
||||||
|
modelValue any
|
||||||
|
)
|
||||||
|
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||||||
|
if db.Statement.ReflectValue.Len() > 0 {
|
||||||
|
refValue = db.Statement.ReflectValue.Index(0)
|
||||||
|
modelValue = refValue.Interface()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if db.Statement.Model != nil {
|
||||||
|
modelValue = db.Statement.Model
|
||||||
|
} else {
|
||||||
|
modelValue = db.Statement.ReflectValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if modelValue != nil {
|
||||||
|
if scopeModel, ok = modelValue.(Model); ok {
|
||||||
|
db.Table(scopeModel.ShardingTable(sceneDelete))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plugin *Sharding) Query(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
ok bool
|
||||||
|
shardingModel Model
|
||||||
|
modelValue any
|
||||||
|
tables []string
|
||||||
|
rawVars []any
|
||||||
|
refValue reflect.Value
|
||||||
|
tableName *sqlparser.TableName
|
||||||
|
selectStmt *sqlparser.SelectStatement
|
||||||
|
stmt sqlparser.Statement
|
||||||
|
parser *sqlparser.Parser
|
||||||
|
numOfTable int
|
||||||
|
orderByExpr []*sqlparser.OrderingTerm
|
||||||
|
limitExpr sqlparser.Expr
|
||||||
|
offsetExpr sqlparser.Expr
|
||||||
|
groupingExpr []sqlparser.Expr
|
||||||
|
havingExpr sqlparser.Expr
|
||||||
|
isCountStatement bool
|
||||||
|
countField string
|
||||||
|
)
|
||||||
|
if db.Statement.Model != nil {
|
||||||
|
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
||||||
|
} else {
|
||||||
|
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
||||||
|
}
|
||||||
|
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
||||||
|
refValue = reflect.Indirect(refValue)
|
||||||
|
}
|
||||||
|
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
||||||
|
elemType := refValue.Type().Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
modelValue = reflect.New(elemType).Interface()
|
||||||
|
} else {
|
||||||
|
modelValue = refValue.Interface()
|
||||||
|
}
|
||||||
|
if shardingModel, ok = modelValue.(Model); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if db.Statement.SQL.Len() == 0 {
|
||||||
|
callbacks.BuildQuerySQL(db)
|
||||||
|
}
|
||||||
|
parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String()))
|
||||||
|
if stmt, err = parser.ParseStatement(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tables = shardingModel.ShardingTables(&Scope{
|
||||||
|
db: db,
|
||||||
|
stmt: selectStmt,
|
||||||
|
})
|
||||||
|
numOfTable = len(tables)
|
||||||
|
if numOfTable <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rawVars = make([]any, 0, len(db.Statement.Vars))
|
||||||
|
for _, v := range db.Statement.Vars {
|
||||||
|
rawVars = append(rawVars, v)
|
||||||
|
}
|
||||||
|
//是否是查询count语句
|
||||||
|
//如果不是count的语句,添加order和group的支持
|
||||||
|
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
||||||
|
if v == true {
|
||||||
|
isCountStatement = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isCountStatement && len(*selectStmt.Columns) == 1 {
|
||||||
|
for _, column := range *selectStmt.Columns {
|
||||||
|
if expr, ok := column.Expr.(*sqlparser.Call); ok {
|
||||||
|
if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword {
|
||||||
|
isCountStatement = true
|
||||||
|
countField = expr.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(selectStmt.OrderBy) > 0 {
|
||||||
|
orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy))
|
||||||
|
for _, row := range selectStmt.OrderBy {
|
||||||
|
orderByExpr = append(orderByExpr, row)
|
||||||
|
}
|
||||||
|
selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0)
|
||||||
|
}
|
||||||
|
if len(selectStmt.GroupingElements) > 0 {
|
||||||
|
groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements))
|
||||||
|
for _, row := range selectStmt.GroupingElements {
|
||||||
|
groupingExpr = append(groupingExpr, row)
|
||||||
|
}
|
||||||
|
if selectStmt.HavingCondition != nil {
|
||||||
|
havingExpr = selectStmt.HavingCondition
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if selectStmt.Limit != nil {
|
||||||
|
limitExpr = selectStmt.Limit
|
||||||
|
selectStmt.Limit = nil
|
||||||
|
}
|
||||||
|
if selectStmt.Offset != nil {
|
||||||
|
offsetExpr = selectStmt.Offset
|
||||||
|
selectStmt.Offset = nil
|
||||||
|
}
|
||||||
|
db.Statement.SQL.Reset()
|
||||||
|
if isCountStatement {
|
||||||
|
db.Statement.SQL.WriteString("SELECT SUM(")
|
||||||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||||
|
db.Statement.SQL.WriteString(strings.Trim(countField, "`"))
|
||||||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||||
|
db.Statement.SQL.WriteString(") FROM (")
|
||||||
|
} else {
|
||||||
|
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||||||
|
}
|
||||||
|
for i, name := range tables {
|
||||||
|
db.Statement.SQL.WriteByte('(')
|
||||||
|
if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok {
|
||||||
|
tableName.Name.Name = name
|
||||||
|
}
|
||||||
|
db.Statement.SQL.WriteString(selectStmt.String())
|
||||||
|
db.Statement.SQL.WriteByte(')')
|
||||||
|
if i < numOfTable-1 {
|
||||||
|
if plugin.UnionAll {
|
||||||
|
db.Statement.SQL.WriteString(" UNION ALL ")
|
||||||
|
} else {
|
||||||
|
db.Statement.SQL.WriteString(" UNION ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
//copy vars
|
||||||
|
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db.Statement.SQL.WriteString(") tbl ")
|
||||||
|
if !isCountStatement {
|
||||||
|
if len(groupingExpr) > 0 {
|
||||||
|
db.Statement.SQL.WriteString(" GROUP BY ")
|
||||||
|
for i, expr := range groupingExpr {
|
||||||
|
if i != 0 {
|
||||||
|
db.Statement.SQL.WriteString(", ")
|
||||||
|
}
|
||||||
|
db.Statement.SQL.WriteString(expr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if havingExpr != nil {
|
||||||
|
db.Statement.SQL.WriteString(" HAVING ")
|
||||||
|
db.Statement.SQL.WriteString(havingExpr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if orderByExpr != nil && len(orderByExpr) > 0 {
|
||||||
|
db.Statement.SQL.WriteString(" ORDER BY ")
|
||||||
|
for i, term := range orderByExpr {
|
||||||
|
if i != 0 {
|
||||||
|
db.Statement.SQL.WriteString(", ")
|
||||||
|
}
|
||||||
|
db.Statement.SQL.WriteString(term.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if limitExpr != nil {
|
||||||
|
db.Statement.SQL.WriteString(" LIMIT ")
|
||||||
|
db.Statement.SQL.WriteString(limitExpr.String())
|
||||||
|
if offsetExpr != nil {
|
||||||
|
db.Statement.SQL.WriteString(" OFFSET ")
|
||||||
|
db.Statement.SQL.WriteString(offsetExpr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (plugin *Sharding) QueryX(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
ok bool
|
||||||
|
shardingModel Model
|
||||||
|
modelValue any
|
||||||
|
tables []string
|
||||||
|
rawVars []any
|
||||||
|
refValue reflect.Value
|
||||||
|
selectStmt *sqlparserX.Select
|
||||||
|
stmt sqlparserX.Statement
|
||||||
|
numOfTable int
|
||||||
|
isCountStatement bool
|
||||||
|
isPureCountStatement bool
|
||||||
|
trackedBuffer *sqlparserX.TrackedBuffer
|
||||||
|
funcExpr *sqlparserX.FuncExpr
|
||||||
|
aliasedExpr *sqlparserX.AliasedExpr
|
||||||
|
orderByExpr sqlparserX.OrderBy
|
||||||
|
groupByExpr sqlparserX.GroupBy
|
||||||
|
havingExpr *sqlparserX.Where
|
||||||
|
limitExpr *sqlparserX.Limit
|
||||||
|
)
|
||||||
|
if db.Statement.Model != nil {
|
||||||
|
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
||||||
|
} else {
|
||||||
|
if !db.Statement.ReflectValue.IsValid() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
||||||
|
}
|
||||||
|
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
||||||
|
refValue = reflect.Indirect(refValue)
|
||||||
|
}
|
||||||
|
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
||||||
|
elemType := refValue.Type().Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
modelValue = reflect.New(elemType).Interface()
|
||||||
|
} else {
|
||||||
|
modelValue = refValue.Interface()
|
||||||
|
}
|
||||||
|
if shardingModel, ok = modelValue.(Model); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if db.Statement.SQL.Len() == 0 {
|
||||||
|
callbacks.BuildQuerySQL(db)
|
||||||
|
}
|
||||||
|
if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if selectStmt, ok = stmt.(*sqlparserX.Select); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tables = shardingModel.ShardingTables(&Scope{
|
||||||
|
db: db,
|
||||||
|
stmtX: selectStmt,
|
||||||
|
})
|
||||||
|
numOfTable = len(tables)
|
||||||
|
if numOfTable <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 保存值
|
||||||
|
rawVars = make([]any, 0, len(db.Statement.Vars))
|
||||||
|
for _, v := range db.Statement.Vars {
|
||||||
|
rawVars = append(rawVars, v)
|
||||||
|
}
|
||||||
|
// 替换语句
|
||||||
|
if selectStmt.OrderBy != nil {
|
||||||
|
orderByExpr = selectStmt.OrderBy
|
||||||
|
selectStmt.OrderBy = nil
|
||||||
|
}
|
||||||
|
if selectStmt.GroupBy != nil {
|
||||||
|
groupByExpr = selectStmt.GroupBy
|
||||||
|
//selectStmt.GroupBy = nil
|
||||||
|
}
|
||||||
|
if selectStmt.Having != nil {
|
||||||
|
havingExpr = selectStmt.Having
|
||||||
|
//selectStmt.Having = nil
|
||||||
|
}
|
||||||
|
if selectStmt.Limit != nil {
|
||||||
|
limitExpr = selectStmt.Limit
|
||||||
|
selectStmt.Limit = nil
|
||||||
|
}
|
||||||
|
// 检查是否为COUNT语句
|
||||||
|
//如果不是count的语句,添加order和group的支持
|
||||||
|
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
||||||
|
//这里处理的是报表的情况,再报表里面需要重写COUNT语句才能获取到正确的数量
|
||||||
|
if v == true {
|
||||||
|
isCountStatement = true
|
||||||
|
isPureCountStatement = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//常规的COUNT逻辑
|
||||||
|
if !isCountStatement && len(selectStmt.SelectExprs) == 1 {
|
||||||
|
for _, expr := range selectStmt.SelectExprs {
|
||||||
|
if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok {
|
||||||
|
if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok {
|
||||||
|
if funcExpr.Name.EqualString("count") {
|
||||||
|
isCountStatement = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 重写SQL
|
||||||
|
db.Statement.SQL.Reset()
|
||||||
|
trackedBuffer = sqlparserX.NewTrackedBuffer(nil)
|
||||||
|
|
||||||
|
if isCountStatement {
|
||||||
|
if isPureCountStatement {
|
||||||
|
db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (")
|
||||||
|
} else {
|
||||||
|
db.Statement.SQL.WriteString("SELECT SUM(")
|
||||||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||||
|
db.Statement.SQL.WriteString(strings.Trim("count(*)", "`"))
|
||||||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||||||
|
db.Statement.SQL.WriteString(") FROM (")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if bs, ok := modelValue.(SelectBuilder); ok {
|
||||||
|
columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs)
|
||||||
|
if len(columns) > 0 {
|
||||||
|
db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (")
|
||||||
|
} else {
|
||||||
|
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, name := range tables {
|
||||||
|
trackedBuffer.Reset()
|
||||||
|
db.Statement.SQL.WriteByte('(')
|
||||||
|
//赋值新的表名称
|
||||||
|
selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{
|
||||||
|
Expr: sqlparserX.TableName{
|
||||||
|
Name: sqlparserX.NewTableIdent(name),
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
selectStmt.Format(trackedBuffer)
|
||||||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||||
|
db.Statement.SQL.WriteByte(')')
|
||||||
|
if i < numOfTable-1 {
|
||||||
|
if plugin.UnionAll {
|
||||||
|
db.Statement.SQL.WriteString(" UNION ALL ")
|
||||||
|
} else {
|
||||||
|
db.Statement.SQL.WriteString(" UNION ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
//copy vars
|
||||||
|
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db.Statement.SQL.WriteString(") tbl ")
|
||||||
|
if !isCountStatement {
|
||||||
|
//node.GroupBy, node.Having, node.OrderBy, node.Limit
|
||||||
|
if groupByExpr != nil {
|
||||||
|
trackedBuffer.Reset()
|
||||||
|
groupByExpr.Format(trackedBuffer)
|
||||||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||||
|
}
|
||||||
|
if havingExpr != nil {
|
||||||
|
trackedBuffer.Reset()
|
||||||
|
havingExpr.Format(trackedBuffer)
|
||||||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||||
|
}
|
||||||
|
if orderByExpr != nil {
|
||||||
|
trackedBuffer.Reset()
|
||||||
|
orderByExpr.Format(trackedBuffer)
|
||||||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||||
|
}
|
||||||
|
if limitExpr != nil {
|
||||||
|
trackedBuffer.Reset()
|
||||||
|
limitExpr.Format(trackedBuffer)
|
||||||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Sharding {
|
||||||
|
return &Sharding{
|
||||||
|
UnionAll: true,
|
||||||
|
QuoteChar: '`',
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package sharding
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/uole/sqlparser"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormat(t *testing.T) {
|
||||||
|
t.Log(fmt.Sprintf("%.2f%%", 0.5851888*100))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSharding_Query(t *testing.T) {
|
||||||
|
//sql := "SELECT COUNT(*) FROM aaa"
|
||||||
|
sql := "SELECT uid,SUM(IF(direction='inbound',1,0)) AS inbound_times,SUM(IF(direction='inbound',IF(answer_duration>0,1,0),0)) AS inbound_answer_times,SUM(IF(direction='outbound',1,0)) AS outbound_times,SUM(IF(direction='outbound',IF(answer_duration>0,1,0),0)) AS outbound_answer_times FROM `cdr_logs` WHERE ((`domain` = 'test.cc.echo.me' OR `domain` = 'default') AND `name` <> '') AND (`create_stamp` BETWEEN 1712505600 AND 1712505608) AND `name` IN ('a','b','c') GROUP BY `uid` having uid != '' ORDER BY create_stamp DESC LIMIT 15"
|
||||||
|
stmt, err := sqlparser.Parse(sql)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
selectStmt, ok := stmt.(*sqlparser.Select)
|
||||||
|
if !ok {
|
||||||
|
t.Error("not select stmt")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buf := sqlparser.NewTrackedBuffer(nil)
|
||||||
|
sqlparser.NewTableIdent("test").Format(buf)
|
||||||
|
buf.Reset()
|
||||||
|
selectStmt.Format(buf)
|
||||||
|
t.Log("SQL")
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
t.Log("SELECT")
|
||||||
|
selectStmt.SelectExprs.Format(buf)
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
t.Log("FROM")
|
||||||
|
selectStmt.From.Format(buf)
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
t.Log("WHERE")
|
||||||
|
selectStmt.Where.Format(buf)
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
t.Log("ORDER BY")
|
||||||
|
selectStmt.OrderBy.Format(buf)
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
t.Log("GROUP BY")
|
||||||
|
selectStmt.GroupBy.Format(buf)
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
t.Log("LIMIT")
|
||||||
|
selectStmt.Limit.Format(buf)
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
t.Log("Having")
|
||||||
|
selectStmt.Having.Format(buf)
|
||||||
|
t.Log(buf.String())
|
||||||
|
buf.Reset()
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package sharding
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
sqlparserX "github.com/uole/sqlparser"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeDatetime = "datetime"
|
||||||
|
TypeHash = "hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DateTypeYear = iota + 5
|
||||||
|
DateTypeMonth
|
||||||
|
DateTypeWeek
|
||||||
|
DateTypeDay
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sceneCreate = "create"
|
||||||
|
sceneUpdate = "update"
|
||||||
|
sceneDelete = "delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
stmtCountKeyword = "count"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Rule struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Args int `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
Model interface {
|
||||||
|
ShardingRule() Rule
|
||||||
|
ShardingTable(scene string) string
|
||||||
|
ShardingTables(scope *Scope) []string
|
||||||
|
}
|
||||||
|
|
||||||
|
SelectBuilder interface {
|
||||||
|
BuildSelect(ctx context.Context, expr sqlparserX.SelectExprs) []string
|
||||||
|
}
|
||||||
|
)
|
|
@ -0,0 +1,107 @@
|
||||||
|
package validate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SkipValidations = "validations:skip_validations"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
scopeCtxKey = &validateScope{}
|
||||||
|
telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
validateScope struct {
|
||||||
|
DB *gorm.DB
|
||||||
|
Column string
|
||||||
|
Domain string
|
||||||
|
MultiDomain bool
|
||||||
|
Model interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
StructError struct {
|
||||||
|
Tag string `json:"tag"`
|
||||||
|
Column string `json:"column"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
validateRule struct {
|
||||||
|
Rule string
|
||||||
|
Value string
|
||||||
|
Valid bool
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (err *StructError) Error() string {
|
||||||
|
return err.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRule(ss ...string) *validateRule {
|
||||||
|
v := &validateRule{
|
||||||
|
Valid: true,
|
||||||
|
}
|
||||||
|
if len(ss) == 1 {
|
||||||
|
v.Rule = ss[0]
|
||||||
|
} else if len(ss) >= 2 {
|
||||||
|
v.Rule = ss[0]
|
||||||
|
v.Value = ss[1]
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatError 格式化错误消息
|
||||||
|
func formatError(rule types.Rule, scm *types.Schema, tag string) string {
|
||||||
|
var s string
|
||||||
|
switch tag {
|
||||||
|
case "db_unique":
|
||||||
|
s = scm.Label + "值已经存在."
|
||||||
|
break
|
||||||
|
case "required":
|
||||||
|
s = scm.Label + "值不能为空."
|
||||||
|
case "max":
|
||||||
|
if scm.Type == "string" {
|
||||||
|
s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
|
||||||
|
} else {
|
||||||
|
s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
|
||||||
|
}
|
||||||
|
case "min":
|
||||||
|
if scm.Type == "string" {
|
||||||
|
s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
|
||||||
|
} else {
|
||||||
|
s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEmpty 判断值是否为空
|
||||||
|
func isEmpty(val any) bool {
|
||||||
|
if val == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(val)
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.String, reflect.Array:
|
||||||
|
return v.Len() == 0
|
||||||
|
case reflect.Map, reflect.Slice:
|
||||||
|
return v.IsNil() || v.Len() == 0
|
||||||
|
case reflect.Bool:
|
||||||
|
return !v.Bool()
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return v.Int() == 0
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
return v.Uint() == 0
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return v.Float() == 0
|
||||||
|
case reflect.Interface, reflect.Ptr:
|
||||||
|
return v.IsNil()
|
||||||
|
}
|
||||||
|
return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface())
|
||||||
|
}
|
|
@ -0,0 +1,275 @@
|
||||||
|
package validate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"git.nobla.cn/golang/rest"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
validator "github.com/go-playground/validator/v10"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Validate struct {
|
||||||
|
validator *validator.Validate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) Name() string {
|
||||||
|
return "gorm:validate"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) telephoneValidate(ctx context.Context, fl validator.FieldLevel) bool {
|
||||||
|
val := fmt.Sprint(fl.Field().Interface())
|
||||||
|
return telephoneRegex.MatchString(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) uniqueValidate(ctx context.Context, fl validator.FieldLevel) bool {
|
||||||
|
var (
|
||||||
|
scope *validateScope
|
||||||
|
ok bool
|
||||||
|
count int64
|
||||||
|
field *schema.Field
|
||||||
|
primaryKeyValue reflect.Value
|
||||||
|
)
|
||||||
|
val := fl.Field().Interface()
|
||||||
|
if scope, ok = ctx.Value(scopeCtxKey).(*validateScope); !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(scope.DB.Statement.Schema.PrimaryFields) > 0 {
|
||||||
|
field = scope.DB.Statement.Schema.PrimaryFields[0]
|
||||||
|
primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model))
|
||||||
|
for _, n := range field.BindNames {
|
||||||
|
primaryKeyValue = primaryKeyValue.FieldByName(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sess := scope.DB.Session(&gorm.Session{NewDB: true})
|
||||||
|
if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil {
|
||||||
|
//多域校验
|
||||||
|
if scope.MultiDomain && scope.Domain != "" {
|
||||||
|
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ? AND domain = ?", val, primaryKeyValue.Interface(), scope.Domain).Count(&count)
|
||||||
|
} else {
|
||||||
|
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if scope.MultiDomain && scope.Domain != "" {
|
||||||
|
sess.Model(scope.Model).Where(scope.Column+"=? AND domain = ?", val, scope.Domain).Count(&count)
|
||||||
|
} else {
|
||||||
|
sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) grantRules(scm *types.Schema, scenario string, rule types.Rule) []*validateRule {
|
||||||
|
rules := make([]*validateRule, 0, 5)
|
||||||
|
if rule.Min != 0 {
|
||||||
|
rules = append(rules, newRule("min", strconv.Itoa(rule.Min)))
|
||||||
|
}
|
||||||
|
if rule.Max != 0 {
|
||||||
|
rules = append(rules, newRule("max", strconv.Itoa(rule.Max)))
|
||||||
|
}
|
||||||
|
//主键不做唯一判断
|
||||||
|
if rule.Unique && !scm.Attribute.PrimaryKey {
|
||||||
|
rules = append(rules, newRule("db_unique"))
|
||||||
|
}
|
||||||
|
if rule.Type != "" {
|
||||||
|
rules = append(rules, newRule(rule.Type))
|
||||||
|
}
|
||||||
|
if rule.Required != nil && len(rule.Required) > 0 {
|
||||||
|
for _, v := range rule.Required {
|
||||||
|
if v == scenario {
|
||||||
|
rules = append(rules, newRule("required"))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) buildRules(rs []*validateRule) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, r := range rs {
|
||||||
|
if !r.Valid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
sb.WriteString(",")
|
||||||
|
}
|
||||||
|
if r.Value == "" {
|
||||||
|
sb.WriteString(r.Rule)
|
||||||
|
} else {
|
||||||
|
sb.WriteString(r.Rule + "=" + r.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) findRule(name string, rules []*validateRule) *validateRule {
|
||||||
|
for _, r := range rules {
|
||||||
|
if r.Rule == name {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) Initialize(db *gorm.DB) (err error) {
|
||||||
|
validate.validator = validator.New()
|
||||||
|
if err = db.Callback().Create().Before("gorm:before_create").Register("model_validate", validate.Validate); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = db.Callback().Create().Before("gorm:before_update").Register("model_validate", validate.Validate); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = validate.validator.RegisterValidationCtx("telephone", validate.telephoneValidate); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = validate.validator.RegisterValidationCtx("db_unique", validate.uniqueValidate); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (validate *Validate) inArray(v any, vs []any) bool {
|
||||||
|
sv := fmt.Sprint(v)
|
||||||
|
for _, s := range vs {
|
||||||
|
if fmt.Sprint(s) == sv {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isVisible 判断字段是否需要显示
|
||||||
|
func (validate *Validate) isVisible(stmt *gorm.Statement, scm *types.Schema) bool {
|
||||||
|
if len(scm.Attribute.Visible) <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, row := range scm.Attribute.Visible {
|
||||||
|
if len(row.Values) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
targetField := stmt.Schema.LookUpField(row.Column)
|
||||||
|
if targetField == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
targetValue := stmt.ReflectValue.FieldByName(targetField.Name)
|
||||||
|
if !targetValue.IsValid() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !validate.inArray(targetValue.Interface(), row.Values) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 校验字段
|
||||||
|
func (validate *Validate) Validate(db *gorm.DB) {
|
||||||
|
var (
|
||||||
|
ok bool
|
||||||
|
err error
|
||||||
|
rules []*validateRule
|
||||||
|
stmt *gorm.Statement
|
||||||
|
skipValidate bool
|
||||||
|
multiDomain bool
|
||||||
|
value reflect.Value
|
||||||
|
runtimeScope *types.RuntimeScope
|
||||||
|
)
|
||||||
|
if result, ok := db.Get(SkipValidations); ok && result.(bool) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stmt = db.Statement
|
||||||
|
if stmt.Model == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if db.Statement.Context == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if val := db.Statement.Context.Value(rest.RuntimeScopeKey); val != nil {
|
||||||
|
if runtimeScope, ok = val.(*types.RuntimeScope); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if runtimeScope.Schemas == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if stmt.Schema.LookUpField("domain") != nil {
|
||||||
|
multiDomain = true
|
||||||
|
}
|
||||||
|
for _, row := range runtimeScope.Schemas {
|
||||||
|
//如果字段隐藏,那么就不进行校验
|
||||||
|
if !validate.isVisible(stmt, row) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rules = validate.grantRules(row, runtimeScope.Scenario, row.Rule); len(rules) <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field := stmt.Schema.LookUpField(row.Column)
|
||||||
|
if field == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
value = stmt.ReflectValue.FieldByName(field.Name)
|
||||||
|
if !value.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
skipValidate = false
|
||||||
|
if r := validate.findRule("required", rules); r != nil {
|
||||||
|
if value.Interface() != nil {
|
||||||
|
vType := reflect.ValueOf(value.Interface())
|
||||||
|
switch vType.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
skipValidate = true
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
skipValidate = true
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
skipValidate = true
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
skipValidate = true
|
||||||
|
default:
|
||||||
|
skipValidate = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if skipValidate {
|
||||||
|
r.Valid = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if isEmpty(value.Interface()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(db.Statement.Context, scopeCtxKey, &validateScope{
|
||||||
|
DB: db,
|
||||||
|
Column: row.Column,
|
||||||
|
Model: stmt.Model,
|
||||||
|
Domain: runtimeScope.Domain,
|
||||||
|
MultiDomain: multiDomain,
|
||||||
|
})
|
||||||
|
if err = validate.validator.VarCtx(ctx, value.Interface(), validate.buildRules(rules)); err != nil {
|
||||||
|
if errs, ok := err.(validator.ValidationErrors); ok {
|
||||||
|
for _, e := range errs {
|
||||||
|
_ = db.AddError(&StructError{
|
||||||
|
Tag: e.Tag(),
|
||||||
|
Column: row.Column,
|
||||||
|
Message: formatError(row.Rule, row, e.Tag()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_ = db.AddError(err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Validate {
|
||||||
|
return &Validate{}
|
||||||
|
}
|
|
@ -0,0 +1,405 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Query struct {
|
||||||
|
db *gorm.DB
|
||||||
|
condition string
|
||||||
|
fields []string
|
||||||
|
params []interface{}
|
||||||
|
table string
|
||||||
|
joins []join
|
||||||
|
orderBy []string
|
||||||
|
groupBy []string
|
||||||
|
modelValue any
|
||||||
|
limit int
|
||||||
|
offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
condition struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Value interface{} `json:"value"`
|
||||||
|
Expr string `json:"expr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
join struct {
|
||||||
|
Table string
|
||||||
|
Direction string
|
||||||
|
Conditions []*condition
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (cond *condition) WithExpr(v string) *condition {
|
||||||
|
cond.Expr = v
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) Model() any {
|
||||||
|
return query.modelValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) compile() (*gorm.DB, error) {
|
||||||
|
db := query.db
|
||||||
|
if query.condition != "" {
|
||||||
|
db = db.Where(query.condition, query.params...)
|
||||||
|
}
|
||||||
|
if query.fields != nil {
|
||||||
|
db = db.Select(strings.Join(query.fields, ","))
|
||||||
|
}
|
||||||
|
if query.table != "" {
|
||||||
|
db = db.Table(query.table)
|
||||||
|
}
|
||||||
|
if query.joins != nil && len(query.joins) > 0 {
|
||||||
|
for _, joinEntity := range query.joins {
|
||||||
|
cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...)
|
||||||
|
db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if query.orderBy != nil && len(query.orderBy) > 0 {
|
||||||
|
db = db.Order(strings.Join(query.orderBy, ","))
|
||||||
|
}
|
||||||
|
if query.groupBy != nil && len(query.groupBy) > 0 {
|
||||||
|
db = db.Group(strings.Join(query.groupBy, ","))
|
||||||
|
}
|
||||||
|
if query.offset > 0 {
|
||||||
|
db = db.Offset(query.offset)
|
||||||
|
}
|
||||||
|
if query.limit > 0 {
|
||||||
|
db = db.Limit(query.limit)
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) decodeValue(v any) string {
|
||||||
|
refVal := reflect.Indirect(reflect.ValueOf(v))
|
||||||
|
switch refVal.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
if refVal.Bool() {
|
||||||
|
return "1"
|
||||||
|
} else {
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||||
|
return strconv.FormatInt(refVal.Int(), 10)
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
|
||||||
|
case reflect.String:
|
||||||
|
return "'" + refVal.String() + "'"
|
||||||
|
default:
|
||||||
|
return fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) {
|
||||||
|
var (
|
||||||
|
sb strings.Builder
|
||||||
|
)
|
||||||
|
params = make([]interface{}, 0)
|
||||||
|
for _, cond := range conditions {
|
||||||
|
if filter {
|
||||||
|
if isEmpty(cond.Value) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cond.Expr == "" {
|
||||||
|
cond.Expr = "="
|
||||||
|
}
|
||||||
|
switch strings.ToUpper(cond.Expr) {
|
||||||
|
case "=", "<>", ">", "<", ">=", "<=", "!=":
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
sb.WriteString(" " + operator + " ")
|
||||||
|
}
|
||||||
|
if cond.Expr == "=" && cond.Value == nil {
|
||||||
|
sb.WriteString("`" + cond.Field + "` IS NULL")
|
||||||
|
} else {
|
||||||
|
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
|
||||||
|
params = append(params, cond.Value)
|
||||||
|
}
|
||||||
|
case "LIKE":
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
sb.WriteString(" " + operator + " ")
|
||||||
|
}
|
||||||
|
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
|
||||||
|
sb.WriteString("`" + cond.Field + "` LIKE ?")
|
||||||
|
params = append(params, cond.Value)
|
||||||
|
case "IN":
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
sb.WriteString(" " + operator + " ")
|
||||||
|
}
|
||||||
|
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
|
||||||
|
switch refVal.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
ss := make([]string, refVal.Len())
|
||||||
|
for i := 0; i < refVal.Len(); i++ {
|
||||||
|
ss[i] = query.decodeValue(refVal.Index(i))
|
||||||
|
}
|
||||||
|
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
|
||||||
|
case reflect.String:
|
||||||
|
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
case "BETWEEN":
|
||||||
|
refVal := reflect.ValueOf(cond.Value)
|
||||||
|
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
|
||||||
|
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
|
||||||
|
params = append(params, refVal.Index(0), refVal.Index(1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
str = sb.String()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) Select(fields ...string) *Query {
|
||||||
|
if query.fields == nil {
|
||||||
|
query.fields = fields
|
||||||
|
} else {
|
||||||
|
query.fields = append(query.fields, fields...)
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) From(table string) *Query {
|
||||||
|
query.table = table
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) LeftJoin(table string, conditions ...*condition) *Query {
|
||||||
|
query.joins = append(query.joins, join{
|
||||||
|
Table: table,
|
||||||
|
Direction: "LEFT",
|
||||||
|
Conditions: conditions,
|
||||||
|
})
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) RightJoin(table string, conditions ...*condition) *Query {
|
||||||
|
query.joins = append(query.joins, join{
|
||||||
|
Table: table,
|
||||||
|
Direction: "RIGHT",
|
||||||
|
Conditions: conditions,
|
||||||
|
})
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) InnerJoin(table string, conditions ...*condition) *Query {
|
||||||
|
query.joins = append(query.joins, join{
|
||||||
|
Table: table,
|
||||||
|
Direction: "INNER",
|
||||||
|
Conditions: conditions,
|
||||||
|
})
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) AddCondition(expr, column string, val any) {
|
||||||
|
if expr == "" {
|
||||||
|
expr = "="
|
||||||
|
}
|
||||||
|
query.AndWhere(newConditionWithOperator(expr, column, val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) AndWhere(conditions ...*condition) *Query {
|
||||||
|
length := len(conditions)
|
||||||
|
if length == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
cs, ps := query.buildConditions("AND", false, conditions...)
|
||||||
|
if cs == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
query.params = append(query.params, ps...)
|
||||||
|
if query.condition == "" {
|
||||||
|
query.condition = cs
|
||||||
|
} else {
|
||||||
|
query.condition += " AND (" + cs + ")"
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) AndFilterWhere(conditions ...*condition) *Query {
|
||||||
|
length := len(conditions)
|
||||||
|
if length == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
cs, ps := query.buildConditions("AND", true, conditions...)
|
||||||
|
if cs == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
query.params = append(query.params, ps...)
|
||||||
|
if query.condition == "" {
|
||||||
|
query.condition = cs
|
||||||
|
} else {
|
||||||
|
query.condition += " AND " + cs
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) OrWhere(conditions ...*condition) *Query {
|
||||||
|
length := len(conditions)
|
||||||
|
if length == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
cs, ps := query.buildConditions("OR", false, conditions...)
|
||||||
|
if cs == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
query.params = append(query.params, ps...)
|
||||||
|
if query.condition == "" {
|
||||||
|
query.condition = cs
|
||||||
|
} else {
|
||||||
|
query.condition += " AND (" + cs + ")"
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) OrFilterWhere(conditions ...*condition) *Query {
|
||||||
|
length := len(conditions)
|
||||||
|
if length == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
cs, ps := query.buildConditions("OR", true, conditions...)
|
||||||
|
if cs == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
query.params = append(query.params, ps...)
|
||||||
|
if query.condition == "" {
|
||||||
|
query.condition = cs
|
||||||
|
} else {
|
||||||
|
query.condition += " AND (" + cs + ")"
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) GroupBy(cols ...string) *Query {
|
||||||
|
query.groupBy = append(query.groupBy, cols...)
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) OrderBy(col, direction string) *Query {
|
||||||
|
direction = strings.ToUpper(direction)
|
||||||
|
if direction == "" || !(direction == "ASC" || direction == "DESC") {
|
||||||
|
direction = "ASC"
|
||||||
|
}
|
||||||
|
query.orderBy = append(query.orderBy, col+" "+direction)
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) Offset(i int) *Query {
|
||||||
|
query.offset = i
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) Limit(i int) *Query {
|
||||||
|
query.limit = i
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) ResetSelect() *Query {
|
||||||
|
query.fields = make([]string, 0)
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) Count(v interface{}) (i int64) {
|
||||||
|
var (
|
||||||
|
db *gorm.DB
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if db, err = query.compile(); err != nil {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
if v != nil {
|
||||||
|
refVal := reflect.ValueOf(v)
|
||||||
|
switch refVal.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
if query.table == "" {
|
||||||
|
err = db.Table(refVal.String()).Count(&i).Error
|
||||||
|
} else {
|
||||||
|
err = db.Table(query.table).Count(&i).Error
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
//如果是报表的模型,这的话手动构建一条SQL语句
|
||||||
|
if reporter, ok := v.(types.Reporter); ok {
|
||||||
|
sqlRes := &sqlCountResponse{}
|
||||||
|
childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true)
|
||||||
|
db.WithContext(childCtx).Model(reporter).First(sqlRes)
|
||||||
|
i = sqlRes.Count
|
||||||
|
} else {
|
||||||
|
err = db.Model(v).Count(&i).Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) One(v interface{}) (err error) {
|
||||||
|
var (
|
||||||
|
db *gorm.DB
|
||||||
|
)
|
||||||
|
if db, err = query.compile(); err != nil {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
err = db.First(v).Error
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) All(v interface{}) (err error) {
|
||||||
|
var (
|
||||||
|
db *gorm.DB
|
||||||
|
)
|
||||||
|
if db, err = query.compile(); err != nil {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
err = db.Find(v).Error
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCondition(column, opera string, value any) *condition {
|
||||||
|
if opera == "" {
|
||||||
|
opera = "="
|
||||||
|
}
|
||||||
|
return &condition{
|
||||||
|
Field: column,
|
||||||
|
Value: value,
|
||||||
|
Expr: opera,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCondition(field string, value interface{}) *condition {
|
||||||
|
return &condition{
|
||||||
|
Field: field,
|
||||||
|
Value: value,
|
||||||
|
Expr: "=",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConditionWithOperator(operator, field string, value interface{}) *condition {
|
||||||
|
cond := &condition{
|
||||||
|
Field: field,
|
||||||
|
Value: value,
|
||||||
|
Expr: operator,
|
||||||
|
}
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQuery(db *gorm.DB, model any) *Query {
|
||||||
|
return &Query{
|
||||||
|
db: db,
|
||||||
|
modelValue: model,
|
||||||
|
params: make([]interface{}, 0),
|
||||||
|
orderBy: make([]string, 0),
|
||||||
|
groupBy: make([]string, 0),
|
||||||
|
joins: make([]join, 0),
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,753 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"git.nobla.cn/golang/kos/util/arrays"
|
||||||
|
"git.nobla.cn/golang/kos/util/reflection"
|
||||||
|
"git.nobla.cn/golang/rest/inflector"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
modelEntities []*Model
|
||||||
|
httpRouter types.HttpRouter
|
||||||
|
hookMgr *hookManager
|
||||||
|
timeKind = reflect.TypeOf(time.Time{}).Kind()
|
||||||
|
timePtrKind = reflect.TypeOf(&time.Time{}).Kind()
|
||||||
|
|
||||||
|
matchEnums = []string{types.MatchExactly, types.MatchFuzzy}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
hookMgr = &hookManager{}
|
||||||
|
modelEntities = make([]*Model, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneStmt 从指定的db克隆一个 Statement 对象
|
||||||
|
func cloneStmt(db *gorm.DB) *gorm.Statement {
|
||||||
|
return &gorm.Statement{
|
||||||
|
DB: db,
|
||||||
|
ConnPool: db.Statement.ConnPool,
|
||||||
|
Context: db.Statement.Context,
|
||||||
|
Clauses: map[string]clause.Clause{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dataTypeOf 推断数据的类型
|
||||||
|
func dataTypeOf(field *schema.Field) string {
|
||||||
|
var dataType string
|
||||||
|
reflectType := field.FieldType
|
||||||
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
|
}
|
||||||
|
if dataType = field.Tag.Get("type"); dataType != "" {
|
||||||
|
return dataType
|
||||||
|
}
|
||||||
|
dataValue := reflect.Indirect(reflect.New(reflectType))
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
dataType = types.TypeBoolean
|
||||||
|
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
dataType = types.TypeInteger
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
dataType = types.TypeFloat
|
||||||
|
default:
|
||||||
|
dataType = types.TypeString
|
||||||
|
}
|
||||||
|
return dataType
|
||||||
|
}
|
||||||
|
|
||||||
|
// dataFormatOf 推断数据的格式
|
||||||
|
func dataFormatOf(field *schema.Field) string {
|
||||||
|
var format string
|
||||||
|
format = field.Tag.Get("format")
|
||||||
|
if format != "" {
|
||||||
|
return format
|
||||||
|
}
|
||||||
|
//如果有枚举值,直接设置为下拉类型
|
||||||
|
enum := field.Tag.Get("enum")
|
||||||
|
if enum != "" {
|
||||||
|
return types.FormatDropdown
|
||||||
|
}
|
||||||
|
reflectType := field.FieldType
|
||||||
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
|
}
|
||||||
|
//时间处理
|
||||||
|
dataValue := reflect.Indirect(reflect.New(reflectType))
|
||||||
|
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
return types.FormatTimestamp
|
||||||
|
default:
|
||||||
|
return types.FormatDatetime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(field.Name), "pass") {
|
||||||
|
return types.FormatPassword
|
||||||
|
}
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case timeKind, timePtrKind:
|
||||||
|
format = types.FormatDatetime
|
||||||
|
case reflect.Bool:
|
||||||
|
format = types.FormatBoolean
|
||||||
|
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
format = types.FormatInteger
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
format = types.FormatFloat
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
format = types.FormatDatetime
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if field.Size >= 1024 {
|
||||||
|
format = types.FormatText
|
||||||
|
} else {
|
||||||
|
format = types.FormatString
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return format
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldName 生成字段名称
|
||||||
|
func fieldName(name string) string {
|
||||||
|
tokens := strings.Split(name, "_")
|
||||||
|
for i, s := range tokens {
|
||||||
|
tokens[i] = strings.Title(s)
|
||||||
|
}
|
||||||
|
return strings.Join(tokens, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldNative 判断是否为原始字段
|
||||||
|
func fieldNative(field *schema.Field) uint8 {
|
||||||
|
if _, ok := field.Tag.Lookup("virtual"); ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldRule 返回字段规则
|
||||||
|
func fieldRule(field *schema.Field) types.Rule {
|
||||||
|
r := types.Rule{
|
||||||
|
Required: []string{},
|
||||||
|
}
|
||||||
|
if field.GORMDataType == schema.String {
|
||||||
|
r.Max = field.Size
|
||||||
|
}
|
||||||
|
if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint {
|
||||||
|
r.Max = field.Scale
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := field.Tag.Get("rule")
|
||||||
|
if rs != "" {
|
||||||
|
ss := strings.Split(rs, ";")
|
||||||
|
for _, s := range ss {
|
||||||
|
vs := strings.SplitN(s, ":", 2)
|
||||||
|
ls := len(vs)
|
||||||
|
if ls == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch vs[0] {
|
||||||
|
case "required", "require":
|
||||||
|
if ls > 1 {
|
||||||
|
bs := strings.Split(vs[1], ",")
|
||||||
|
for _, i := range bs {
|
||||||
|
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
|
||||||
|
r.Required = append(r.Required, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate}
|
||||||
|
}
|
||||||
|
case "unique":
|
||||||
|
r.Unique = true
|
||||||
|
case "regexp":
|
||||||
|
if ls > 1 {
|
||||||
|
r.Regular = vs[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if field.PrimaryKey {
|
||||||
|
r.Unique = true
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldScenario 字段Scenarios
|
||||||
|
func fieldScenario(index int, field *schema.Field) types.Scenarios {
|
||||||
|
var ss types.Scenarios
|
||||||
|
if v, ok := field.Tag.Lookup("scenarios"); ok {
|
||||||
|
v = strings.TrimSpace(v)
|
||||||
|
if v != "" {
|
||||||
|
ss = strings.Split(v, ";")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if field.PrimaryKey {
|
||||||
|
ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport}
|
||||||
|
} else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
|
||||||
|
ss = []string{types.ScenarioList}
|
||||||
|
} else if field.Name == "DeletedAt" || field.Name == "Namespace" {
|
||||||
|
//不添加任何显示场景
|
||||||
|
ss = []string{}
|
||||||
|
} else {
|
||||||
|
if index < 10 {
|
||||||
|
//高级字段只配置一些简单的场景
|
||||||
|
ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
|
||||||
|
} else {
|
||||||
|
//高级字段只配置一些简单的场景
|
||||||
|
ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldPosition 字段的排序位置
|
||||||
|
func fieldPosition(field *schema.Field, i int) int {
|
||||||
|
s := field.Tag.Get("position")
|
||||||
|
n, _ := strconv.Atoi(s)
|
||||||
|
if n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
return i + 100
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldAttribute 字段属性
|
||||||
|
func fieldAttribute(field *schema.Field) types.Attribute {
|
||||||
|
attr := types.Attribute{
|
||||||
|
Match: types.MatchFuzzy,
|
||||||
|
PrimaryKey: field.PrimaryKey,
|
||||||
|
DefaultValue: field.DefaultValue,
|
||||||
|
Readonly: []string{},
|
||||||
|
Disable: []string{},
|
||||||
|
Visible: make([]types.VisibleCondition, 0),
|
||||||
|
Values: make([]types.EnumValue, 0),
|
||||||
|
Live: types.LiveValue{},
|
||||||
|
}
|
||||||
|
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
|
||||||
|
attr.EndOfNow = true
|
||||||
|
}
|
||||||
|
//赋值属性
|
||||||
|
props := field.Tag.Get("props")
|
||||||
|
if props != "" {
|
||||||
|
vs := strings.Split(props, ";")
|
||||||
|
for _, str := range vs {
|
||||||
|
kv := strings.SplitN(str, ":", 2)
|
||||||
|
if len(kv) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sv := strings.TrimSpace(kv[1])
|
||||||
|
switch strings.ToLower(strings.TrimSpace(kv[0])) {
|
||||||
|
case "icon":
|
||||||
|
attr.Icon = sv
|
||||||
|
case "match":
|
||||||
|
if arrays.Exists(sv, matchEnums) {
|
||||||
|
attr.Match = sv
|
||||||
|
}
|
||||||
|
case "endofnow", "end_of_now":
|
||||||
|
if ok, _ := strconv.ParseBool(sv); ok {
|
||||||
|
attr.EndOfNow = true
|
||||||
|
}
|
||||||
|
case "invisible":
|
||||||
|
if ok, _ := strconv.ParseBool(sv); ok {
|
||||||
|
attr.Invisible = true
|
||||||
|
}
|
||||||
|
case "suffix":
|
||||||
|
attr.Suffix = sv
|
||||||
|
case "tag":
|
||||||
|
attr.Tag = sv
|
||||||
|
case "tooltip":
|
||||||
|
attr.Tooltip = sv
|
||||||
|
case "uploadurl", "uploaduri", "upload_url", "upload_uri":
|
||||||
|
attr.UploadUrl = sv
|
||||||
|
case "description":
|
||||||
|
attr.Description = sv
|
||||||
|
case "readonly":
|
||||||
|
bs := strings.Split(sv, ",")
|
||||||
|
for _, i := range bs {
|
||||||
|
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
|
||||||
|
attr.Readonly = append(attr.Readonly, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//live的赋值
|
||||||
|
live := field.Tag.Get("live")
|
||||||
|
if live != "" {
|
||||||
|
attr.Live.Enable = true
|
||||||
|
vs := strings.Split(live, ";")
|
||||||
|
for _, str := range vs {
|
||||||
|
kv := strings.SplitN(str, ":", 2)
|
||||||
|
if len(kv) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch kv[0] {
|
||||||
|
case "method":
|
||||||
|
attr.Live.Method = kv[1]
|
||||||
|
case "type":
|
||||||
|
if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader {
|
||||||
|
attr.Live.Type = kv[1]
|
||||||
|
} else {
|
||||||
|
attr.Live.Type = types.LiveTypeDropdown
|
||||||
|
}
|
||||||
|
case "url", "uri":
|
||||||
|
attr.Live.Url = kv[1]
|
||||||
|
case "columns":
|
||||||
|
attr.Live.Columns = strings.Split(kv[1], ",")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dropdown := field.Tag.Get("dropdown")
|
||||||
|
if dropdown != "" {
|
||||||
|
attr.DropdownOption = &types.DropdownOption{}
|
||||||
|
vs := strings.Split(dropdown, ";")
|
||||||
|
for _, str := range vs {
|
||||||
|
kv := strings.SplitN(str, ":", 2)
|
||||||
|
if len(kv) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch kv[0] {
|
||||||
|
case "created":
|
||||||
|
attr.DropdownOption.Created = true
|
||||||
|
case "filterable":
|
||||||
|
attr.DropdownOption.Filterable = true
|
||||||
|
case "autocomplete":
|
||||||
|
attr.DropdownOption.Autocomplete = true
|
||||||
|
case "default_first":
|
||||||
|
attr.DropdownOption.DefaultFirst = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//显示条件
|
||||||
|
conditions := field.Tag.Get("condition")
|
||||||
|
if conditions != "" {
|
||||||
|
vs := strings.Split(conditions, ";")
|
||||||
|
for _, str := range vs {
|
||||||
|
kv := strings.SplitN(str, ":", 2)
|
||||||
|
if len(kv) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cond := types.VisibleCondition{
|
||||||
|
Column: kv[0],
|
||||||
|
Values: make([]any, 0),
|
||||||
|
}
|
||||||
|
vv := strings.Split(kv[1], ",")
|
||||||
|
for _, x := range vv {
|
||||||
|
x = strings.TrimSpace(x)
|
||||||
|
if x == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cond.Values = append(cond.Values, x)
|
||||||
|
}
|
||||||
|
attr.Visible = append(attr.Visible, cond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//赋值枚举值
|
||||||
|
enumns := field.Tag.Get("enum")
|
||||||
|
if enumns != "" {
|
||||||
|
vs := strings.Split(enumns, ";")
|
||||||
|
for _, str := range vs {
|
||||||
|
kv := strings.SplitN(str, ":", 2)
|
||||||
|
if len(kv) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fv := types.EnumValue{Value: kv[0]}
|
||||||
|
//颜色分隔符
|
||||||
|
if pos := strings.IndexByte(kv[1], '#'); pos > -1 {
|
||||||
|
fv.Label = kv[1][:pos]
|
||||||
|
fv.Color = kv[1][pos:]
|
||||||
|
} else {
|
||||||
|
fv.Label = kv[1]
|
||||||
|
}
|
||||||
|
attr.Values = append(attr.Values, fv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !field.Creatable {
|
||||||
|
attr.Disable = append(attr.Disable, types.ScenarioCreate)
|
||||||
|
}
|
||||||
|
if !field.Updatable {
|
||||||
|
attr.Disable = append(attr.Disable, types.ScenarioUpdate)
|
||||||
|
}
|
||||||
|
attr.Tooltip = field.Comment
|
||||||
|
return attr
|
||||||
|
}
|
||||||
|
|
||||||
|
// autoMigrate 自动合并字段
|
||||||
|
func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) {
|
||||||
|
var (
|
||||||
|
pos int
|
||||||
|
columnName string
|
||||||
|
columnIsExists bool
|
||||||
|
columnLabel string
|
||||||
|
schemas []*types.Schema
|
||||||
|
models []*types.Schema
|
||||||
|
stmt *gorm.Statement
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = cloneStmt(db)
|
||||||
|
if err = stmt.Parse(model); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(schemas) > 0 {
|
||||||
|
pos = len(schemas)
|
||||||
|
}
|
||||||
|
models = make([]*types.Schema, 0)
|
||||||
|
for index, field := range stmt.Schema.Fields {
|
||||||
|
columnName = field.DBName
|
||||||
|
if columnName == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if columnName == "" {
|
||||||
|
columnName = field.Name
|
||||||
|
}
|
||||||
|
columnIsExists = false
|
||||||
|
for _, sm := range schemas {
|
||||||
|
if sm.Column == columnName {
|
||||||
|
columnIsExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if columnIsExists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
columnLabel = field.Tag.Get("comment")
|
||||||
|
if columnLabel == "" {
|
||||||
|
columnLabel = fieldName(field.DBName)
|
||||||
|
}
|
||||||
|
isPrimaryKey := uint8(0)
|
||||||
|
if field.PrimaryKey {
|
||||||
|
isPrimaryKey = 1
|
||||||
|
}
|
||||||
|
schemaModel := &types.Schema{
|
||||||
|
Domain: defaultDomain,
|
||||||
|
ModuleName: module,
|
||||||
|
TableName: stmt.Table,
|
||||||
|
Enable: 1,
|
||||||
|
Column: columnName,
|
||||||
|
Label: columnLabel,
|
||||||
|
Type: strings.ToLower(dataTypeOf(field)),
|
||||||
|
Format: strings.ToLower(dataFormatOf(field)),
|
||||||
|
Native: fieldNative(field),
|
||||||
|
IsPrimaryKey: isPrimaryKey,
|
||||||
|
Rule: fieldRule(field),
|
||||||
|
Scenarios: fieldScenario(index, field),
|
||||||
|
Attribute: fieldAttribute(field),
|
||||||
|
Position: fieldPosition(field, pos),
|
||||||
|
}
|
||||||
|
//如果启用了在线调取接口功能,那么设置一下字段的format格式
|
||||||
|
if schemaModel.Attribute.Live.Enable {
|
||||||
|
if schemaModel.Attribute.Live.Type != "" {
|
||||||
|
schemaModel.Format = schemaModel.Attribute.Live.Type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
models = append(models, schemaModel)
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
if len(models) > 0 {
|
||||||
|
err = db.Create(models).Error
|
||||||
|
}
|
||||||
|
naming = stmt.Table
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHttpRouter 设置HTTP路由
|
||||||
|
func SetHttpRouter(router types.HttpRouter) {
|
||||||
|
httpRouter = router
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutoMigrate 自动合并表的schema
|
||||||
|
func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) {
|
||||||
|
var (
|
||||||
|
opts *Options
|
||||||
|
table string
|
||||||
|
router types.HttpRouter
|
||||||
|
)
|
||||||
|
opts = &Options{}
|
||||||
|
for _, cb := range cbs {
|
||||||
|
cb(opts)
|
||||||
|
}
|
||||||
|
if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//路由模块处理
|
||||||
|
modelValue := newModel(model, db, types.Naming{
|
||||||
|
Pluralize: inflector.Pluralize(table),
|
||||||
|
Singular: inflector.Singularize(table),
|
||||||
|
ModuleName: opts.moduleName,
|
||||||
|
TableName: table,
|
||||||
|
})
|
||||||
|
modelValue.hookMgr = hookMgr
|
||||||
|
modelValue.schemaLookup = VisibleSchemas
|
||||||
|
if opts.router != nil {
|
||||||
|
router = opts.router
|
||||||
|
}
|
||||||
|
if router == nil && httpRouter != nil {
|
||||||
|
router = httpRouter
|
||||||
|
}
|
||||||
|
if opts.urlPrefix != "" {
|
||||||
|
modelValue.urlPrefix = opts.urlPrefix
|
||||||
|
}
|
||||||
|
//路由绑定操作
|
||||||
|
if router != nil {
|
||||||
|
if modelValue.hasScenario(types.ScenarioList) {
|
||||||
|
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search)
|
||||||
|
}
|
||||||
|
if modelValue.hasScenario(types.ScenarioCreate) {
|
||||||
|
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create)
|
||||||
|
}
|
||||||
|
if modelValue.hasScenario(types.ScenarioUpdate) {
|
||||||
|
router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update)
|
||||||
|
}
|
||||||
|
if modelValue.hasScenario(types.ScenarioDelete) {
|
||||||
|
router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete)
|
||||||
|
}
|
||||||
|
if modelValue.hasScenario(types.ScenarioView) {
|
||||||
|
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View)
|
||||||
|
}
|
||||||
|
if modelValue.hasScenario(types.ScenarioExport) {
|
||||||
|
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export)
|
||||||
|
}
|
||||||
|
if modelValue.hasScenario(types.ScenarioImport) {
|
||||||
|
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import)
|
||||||
|
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if opts.writer != nil {
|
||||||
|
modelValue.response = opts.writer
|
||||||
|
}
|
||||||
|
if opts.formatter != nil {
|
||||||
|
modelValue.formatter = opts.formatter
|
||||||
|
}
|
||||||
|
modelValue.disableDomain = opts.disableDomain
|
||||||
|
modelEntities = append(modelEntities, modelValue)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneSchemas 克隆schemas
|
||||||
|
func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) {
|
||||||
|
var (
|
||||||
|
values []*types.Schema
|
||||||
|
schemas []*types.Schema
|
||||||
|
models []*types.Schema
|
||||||
|
)
|
||||||
|
tx := db.WithContext(ctx)
|
||||||
|
if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil {
|
||||||
|
return fmt.Errorf("schema not found")
|
||||||
|
}
|
||||||
|
tx.Where("domain=?", domain).Find(&schemas)
|
||||||
|
hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool {
|
||||||
|
for _, row := range values {
|
||||||
|
if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
models = make([]*types.Schema, 0)
|
||||||
|
for _, row := range values {
|
||||||
|
if !hasSchemaFunc(schemas, row) {
|
||||||
|
row.Id = 0
|
||||||
|
row.CreatedAt = 0
|
||||||
|
row.UpdatedAt = 0
|
||||||
|
row.Domain = domain
|
||||||
|
models = append(models, row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(models) > 0 {
|
||||||
|
err = tx.Save(models).Error
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSchemas 获取schemas
|
||||||
|
func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
values []*types.Schema
|
||||||
|
tx *gorm.DB
|
||||||
|
)
|
||||||
|
values = make([]*types.Schema, 0)
|
||||||
|
if domain == "" {
|
||||||
|
domain = defaultDomain
|
||||||
|
}
|
||||||
|
if moduleName == "" || tableName == "" {
|
||||||
|
return nil, gorm.ErrInvalidField
|
||||||
|
}
|
||||||
|
if ctx != nil {
|
||||||
|
tx = db.WithContext(ctx)
|
||||||
|
} else {
|
||||||
|
tx = db
|
||||||
|
}
|
||||||
|
err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return values, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisibleSchemas 获取某个场景下面的schema
|
||||||
|
func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) {
|
||||||
|
schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := make([]*types.Schema, 0, len(schemas))
|
||||||
|
|