初始化仓库

This commit is contained in:
fancl 2024-12-11 17:29:01 +08:00
commit 61ffcec858
29 changed files with 5475 additions and 0 deletions

60
.gitignore vendored 100644
View File

@ -0,0 +1,60 @@
bin/
.svn/
.godeps
./build
.cover/
dist
_site
_posts
*.dat
.vscode
vendor
# Go.gitignore
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
storage
.idea
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.local
.DS_Store
profile
# vim stuff
*.sw[op]
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
.vscode/*
!.vscode/extensions.json
node_modules

1
README.md 100644
View File

@ -0,0 +1 @@
# 数据库组件

144
condition.go 100644
View File

@ -0,0 +1,144 @@
package rest
import (
"context"
"encoding/json"
"git.nobla.cn/golang/kos/util/arrays"
"git.nobla.cn/golang/rest/types"
"net/http"
"strings"
)
func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition {
for _, cond := range conditions {
if cond.Column == schema.Column {
return cond
}
}
return nil
}
func BuildConditions(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
var (
ok bool
skip bool
formValue string
activeQuery ActiveQuery
)
if activeQuery, ok = query.Model().(ActiveQuery); ok {
if err = activeQuery.BeforeQuery(ctx, query); err != nil {
return
}
}
if arrays.Exists(r.Method, []string{http.MethodPut, http.MethodPost}) {
conditions := make([]*types.Condition, 0)
if err = json.NewDecoder(r.Body).Decode(&conditions); err != nil {
return
}
for _, row := range schemas {
if row.Native == 0 {
continue
}
cond := findCondition(row, conditions)
if cond == nil {
continue
}
switch row.Format {
case types.FormatInteger, types.FormatFloat, types.FormatTimestamp, types.FormatDatetime, types.FormatDate, types.FormatTime:
switch cond.Expr {
case types.OperatorBetween:
if len(cond.Values) == 2 {
query.AndFilterWhere(newCondition(row.Column, cond.Values[0]).WithExpr(">="))
query.AndFilterWhere(newCondition(row.Column, cond.Values[1]).WithExpr("<="))
}
case types.OperatorGreaterThan:
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">"))
case types.OperatorGreaterEqual:
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr(">="))
case types.OperatorLessThan:
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<"))
case types.OperatorLessEqual:
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("<="))
default:
query.AndFilterWhere(newCondition(row.Column, cond.Value))
}
default:
switch cond.Expr {
case types.OperatorLike:
query.AndFilterWhere(newCondition(row.Column, cond.Value).WithExpr("LIKE"))
default:
query.AndFilterWhere(newCondition(row.Column, cond.Value))
}
}
}
} else {
qs := r.URL.Query()
for _, row := range schemas {
skip = false
if skip {
continue
}
if row.Native == 0 {
continue
}
formValue = qs.Get(row.Column)
switch row.Format {
case types.FormatString, types.FormatText:
if row.Attribute.Match == types.MatchExactly {
query.AndFilterWhere(newCondition(row.Column, formValue))
} else {
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
}
case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp:
var sep string
seps := []byte{',', '/'}
for _, s := range seps {
if strings.IndexByte(formValue, s) > -1 {
sep = string(s)
}
}
if ss := strings.Split(formValue, sep); len(ss) == 2 {
query.AndFilterWhere(
newCondition(row.Column, strings.TrimSpace(ss[0])).WithExpr(">="),
newCondition(row.Column, strings.TrimSpace(ss[1])).WithExpr("<="),
)
} else {
query.AndFilterWhere(newCondition(row.Column, formValue))
}
case types.FormatInteger, types.FormatFloat:
query.AndFilterWhere(newCondition(row.Column, formValue))
default:
if row.Type == types.TypeString {
if row.Attribute.Match == types.MatchExactly {
query.AndFilterWhere(newCondition(row.Column, formValue))
} else {
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
}
} else {
query.AndFilterWhere(newCondition(row.Column, formValue))
}
}
}
}
sortPar := r.FormValue("sort")
if sortPar != "" {
sorts := strings.Split(sortPar, ",")
for _, s := range sorts {
if s[0] == '-' {
query.OrderBy(s[1:], "DESC")
} else {
if s[0] == '+' {
query.OrderBy(s[1:], "ASC")
} else {
query.OrderBy(s, "ASC")
}
}
}
}
if activeQuery, ok = query.Model().(ActiveQuery); ok {
if err = activeQuery.AfterQuery(ctx, query); err != nil {
return
}
}
return
}

240
formatter.go 100644
View File

@ -0,0 +1,240 @@
package rest
import (
"context"
"database/sql"
"fmt"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"reflect"
"strconv"
"sync"
"time"
)
var (
DefaultFormatter = NewFormatter()
DefaultNullDisplay = ""
)
func init() {
DefaultFormatter.Register("string", stringFormat)
DefaultFormatter.Register("integer", integerFormat)
DefaultFormatter.Register("decimal", decimalFormat)
DefaultFormatter.Register("date", dateFormat)
DefaultFormatter.Register("time", timeFormat)
DefaultFormatter.Register("datetime", datetimeFormat)
DefaultFormatter.Register("duration", durationFormat)
DefaultFormatter.Register("dropdown", dropdownFormat)
DefaultFormatter.Register("timestamp", datetimeFormat)
DefaultFormatter.Register("percentage", percentageFormat)
}
type FormatFunc func(ctx context.Context, value any, model any, scm *types.Schema) any
type Formatter struct {
callbacks sync.Map
}
func (formatter *Formatter) Register(f string, fun FormatFunc) {
formatter.callbacks.Store(f, fun)
}
func (formatter *Formatter) Format(ctx context.Context, format string, value any, model any, scm *types.Schema) any {
v, ok := formatter.callbacks.Load(format)
if ok {
return v.(FormatFunc)(ctx, value, model, scm)
}
return value
}
func (formatter *Formatter) getModelValue(refValue reflect.Value, schema *types.Schema, stmt *gorm.Statement) any {
if stmt.Schema == nil {
return nil
}
field := stmt.Schema.LookUpField(schema.Column)
if field == nil {
return nil
}
return refValue.FieldByName(field.Name).Interface()
}
func (formatter *Formatter) formatModel(ctx context.Context, refValue reflect.Value, schemas []*types.Schema, stmt *gorm.Statement, format string) any {
values := make(map[string]any)
multiValues := make(map[string]multiValue)
modelValue := refValue.Interface()
refValue = reflect.Indirect(refValue)
for _, scm := range schemas {
switch format {
case types.FormatRaw:
values[scm.Column] = formatter.getModelValue(refValue, scm, stmt)
case types.FormatBoth:
v := multiValue{
Value: formatter.getModelValue(refValue, scm, stmt),
}
v.Text = formatter.Format(ctx, scm.Format, v.Value, modelValue, scm)
multiValues[scm.Column] = v
default:
values[scm.Column] = formatter.Format(ctx, scm.Format, formatter.getModelValue(refValue, scm, stmt), modelValue, scm)
}
}
if format == types.FormatBoth {
return multiValues
} else {
return values
}
}
func (formatter *Formatter) formatModels(ctx context.Context, val any, schemas []*types.Schema, stmt *gorm.Statement, format string) any {
refValue := reflect.Indirect(reflect.ValueOf(val))
if refValue.Kind() != reflect.Slice {
return []any{}
}
length := refValue.Len()
values := make([]any, length)
for i := 0; i < length; i++ {
rowValue := refValue.Index(i)
modelValue := rowValue.Interface()
if formatModel, ok := modelValue.(types.FormatModel); ok {
formatModel.Format(ctx, format, schemas)
}
values[i] = formatter.formatModel(ctx, rowValue, schemas, stmt, format)
}
return values
}
func stringFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
return fmt.Sprint(value)
}
func integerFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
var (
n int
)
switch value.(type) {
case float32, float64:
n = int(reflect.ValueOf(value).Float())
case int, int8, int16, int32, int64:
n = int(reflect.ValueOf(value).Int())
case uint, uint8, uint16, uint32, uint64:
n = int(reflect.ValueOf(value).Uint())
case string:
n, _ = strconv.Atoi(reflect.ValueOf(value).String())
case []byte:
n, _ = strconv.Atoi(string(reflect.ValueOf(value).Bytes()))
}
return n
}
func decimalFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
var (
n float64
)
switch value.(type) {
case float32, float64:
n = reflect.ValueOf(value).Float()
case int, int8, int16, int32, int64:
n = float64(reflect.ValueOf(value).Int())
case uint, uint8, uint16, uint32, uint64:
n = float64(reflect.ValueOf(value).Uint())
case string:
n, _ = strconv.ParseFloat(reflect.ValueOf(value).String(), 64)
case []byte:
n, _ = strconv.ParseFloat(string(reflect.ValueOf(value).Bytes()), 64)
}
return n
}
func dateFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
if t, ok := value.(time.Time); ok {
return t.Format("2006-01-02")
}
if t, ok := value.(*sql.NullTime); ok {
if t != nil && t.Valid {
return t.Time.Format("2006-01-02")
}
}
if t, ok := value.(int64); ok {
tm := time.Unix(t, 0)
return tm.Format("2006-01-02")
}
return DefaultNullDisplay
}
func timeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
if t, ok := value.(time.Time); ok {
return t.Format("15:04:05")
}
if t, ok := value.(*sql.NullTime); ok {
if t != nil && t.Valid {
return t.Time.Format("15:04:05")
}
}
if t, ok := value.(int64); ok {
tm := time.Unix(t, 0)
return tm.Format("15:04:05")
}
return value
}
func datetimeFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
if t, ok := value.(time.Time); ok {
return t.Format("2006-01-02 15:04:05")
}
if t, ok := value.(*sql.NullTime); ok {
if t != nil && t.Valid {
return t.Time.Format("2006-01-02 15:04:05")
}
}
if t, ok := value.(int64); ok {
if t > 0 {
tm := time.Unix(t, 0)
return tm.Format("2006-01-02 15:04:05")
}
}
return DefaultNullDisplay
}
func percentageFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
n := decimalFormat(ctx, value, model, schema).(float64)
if n <= 1 {
return fmt.Sprintf("%.2f%%", n*100)
} else {
return fmt.Sprintf("%.2f%%", n)
}
}
func durationFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
var (
hour int
minVal int
sec int
)
n := integerFormat(ctx, value, model, schema).(int)
hour = n / 3600
minVal = (n - hour*3600) / 60
sec = n - hour*3600 - minVal*60
return fmt.Sprintf("%02d:%02d:%02d", hour, minVal, sec)
}
func dropdownFormat(ctx context.Context, value interface{}, model any, schema *types.Schema) interface{} {
attributes := schema.Attribute
if attributes.Values != nil {
for _, v := range attributes.Values {
if v.Value == value {
return v.Label
}
}
}
return value
}
func NewFormatter() *Formatter {
formatter := &Formatter{}
return formatter
}
func RegisterFormat(f string, cb FormatFunc) {
DefaultFormatter.Register(f, cb)
}

26
go.mod 100644
View File

@ -0,0 +1,26 @@
module git.nobla.cn/golang/rest
go 1.22.9
require (
git.nobla.cn/golang/kos v0.1.32
github.com/cespare/xxhash/v2 v2.3.0
github.com/go-playground/validator/v10 v10.23.0
github.com/longbridgeapp/sqlparser v0.3.2
github.com/rs/xid v1.6.0
github.com/uole/sqlparser v0.0.1
gorm.io/gorm v1.25.12
)
require (
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
)

46
go.sum 100644
View File

@ -0,0 +1,46 @@
git.nobla.cn/golang/kos v0.1.32 h1:sFVCA7vKc8dPUd0cxzwExOSPX2mmMh2IuwL6cYS1pBc=
git.nobla.cn/golang/kos v0.1.32/go.mod h1:35Z070+5oB39WcVrh5DDlnVeftL/Ccmscw2MZFe9fUg=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL8sThn8IHr/sO+o=
github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M=
github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y=
github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/uole/sqlparser v0.0.1 h1:LLUklg6Ne5MypXQuo53QcJv/xKdxtEKM9iUuEBN/lt8=
github.com/uole/sqlparser v0.0.1/go.mod h1:CRYFz2PTm9oHM0j9GFKi1VzPy70r6GsF0b9vpnwJ4yI=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=

251
hook.go 100644
View File

@ -0,0 +1,251 @@
package rest
import (
"context"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
)
const (
beforeCreate = "beforeCreate"
afterCreate = "afterCreate"
beforeUpdate = "beforeUpdate"
afterUpdate = "afterUpdate"
beforeSave = "beforeSave"
afterSave = "afterSave"
beforeDelete = "beforeDelete"
afterDelete = "afterDelete"
afterExport = "afterExport"
afterImport = "afterImport"
)
type (
BeforeCreate func(ctx context.Context, tx *gorm.DB, model any) (err error)
AfterCreate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
BeforeUpdate func(ctx context.Context, tx *gorm.DB, model any) (err error)
AfterUpdate func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
BeforeSave func(ctx context.Context, tx *gorm.DB, model any) (err error)
AfterSave func(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr)
BeforeDelete func(ctx context.Context, tx *gorm.DB, model any) (err error)
AfterDelete func(ctx context.Context, tx *gorm.DB, model any)
AfterExport func(ctx context.Context, filename string) //导出回调
AfterImport func(ctx context.Context, result *types.ImportResult) //导入回调
hookManager struct {
callbacks map[string][]any
}
)
type (
ActiveQuery interface {
BeforeQuery(ctx context.Context, query *Query) (err error)
AfterQuery(ctx context.Context, query *Query) (err error)
}
)
func (hook *hookManager) register(spec string, cb any) {
if hook.callbacks == nil {
hook.callbacks = make(map[string][]any)
}
if _, ok := hook.callbacks[spec]; !ok {
hook.callbacks[spec] = make([]any, 0)
}
hook.callbacks[spec] = append(hook.callbacks[spec], cb)
}
func (hook *hookManager) beforeCreate(ctx context.Context, tx *gorm.DB, model any) (err error) {
callbacks, ok := hook.callbacks[beforeCreate]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(BeforeCreate); ok {
if err = cb(ctx, tx, model); err != nil {
return err
}
}
}
return
}
func (hook *hookManager) afterCreate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
callbacks, ok := hook.callbacks[afterCreate]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(AfterCreate); ok {
cb(ctx, tx, model, diff)
}
}
return
}
func (hook *hookManager) beforeUpdate(ctx context.Context, tx *gorm.DB, model any) (err error) {
callbacks, ok := hook.callbacks[beforeUpdate]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(BeforeUpdate); ok {
if err = cb(ctx, tx, model); err != nil {
return err
}
}
}
return
}
func (hook *hookManager) afterUpdate(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
callbacks, ok := hook.callbacks[afterUpdate]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(AfterUpdate); ok {
cb(ctx, tx, model, diff)
}
}
return
}
func (hook *hookManager) beforeSave(ctx context.Context, tx *gorm.DB, model any) (err error) {
callbacks, ok := hook.callbacks[beforeSave]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(BeforeSave); ok {
if err = cb(ctx, tx, model); err != nil {
return err
}
}
}
return
}
func (hook *hookManager) afterSave(ctx context.Context, tx *gorm.DB, model any, diff []*types.DiffAttr) {
callbacks, ok := hook.callbacks[afterSave]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(AfterSave); ok {
cb(ctx, tx, model, diff)
}
}
return
}
func (hook *hookManager) beforeDelete(ctx context.Context, tx *gorm.DB, model any) (err error) {
callbacks, ok := hook.callbacks[beforeDelete]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(BeforeDelete); ok {
if err = cb(ctx, tx, model); err != nil {
return err
}
}
}
return
}
func (hook *hookManager) afterDelete(ctx context.Context, tx *gorm.DB, model any) {
callbacks, ok := hook.callbacks[afterDelete]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(AfterDelete); ok {
cb(ctx, tx, model)
}
}
return
}
func (hook *hookManager) afterExport(ctx context.Context, filename string) {
callbacks, ok := hook.callbacks[afterExport]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(AfterExport); ok {
cb(ctx, filename)
}
}
return
}
func (hook *hookManager) afterImport(ctx context.Context, ret *types.ImportResult) {
callbacks, ok := hook.callbacks[afterImport]
if !ok {
return
}
for _, callback := range callbacks {
if cb, ok := callback.(AfterImport); ok {
cb(ctx, ret)
}
}
return
}
func (hook *hookManager) BeforeCreate(cb BeforeCreate) {
if cb != nil {
hook.register(beforeCreate, cb)
}
}
func (hook *hookManager) AfterCreate(cb AfterCreate) {
if cb != nil {
hook.register(afterCreate, cb)
}
}
func (hook *hookManager) BeforeUpdate(cb BeforeUpdate) {
if cb != nil {
hook.register(beforeUpdate, cb)
}
}
func (hook *hookManager) AfterUpdate(cb AfterUpdate) {
if cb != nil {
hook.register(afterUpdate, cb)
}
}
func (hook *hookManager) BeforeSave(cb BeforeSave) {
if cb != nil {
hook.register(beforeSave, cb)
}
}
func (hook *hookManager) AfterSave(cb AfterSave) {
if cb != nil {
hook.register(afterSave, cb)
}
}
func (hook *hookManager) BeforeDelete(cb BeforeDelete) {
if cb != nil {
hook.register(beforeDelete, cb)
}
}
func (hook *hookManager) AfterDelete(cb AfterDelete) {
if cb != nil {
hook.register(afterDelete, cb)
}
}
func (hook *hookManager) AfterExport(cb AfterExport) {
if cb != nil {
hook.register(afterExport, cb)
}
}
func (hook *hookManager) AfterImport(cb AfterImport) {
if cb != nil {
hook.register(afterImport, cb)
}
}

View File

@ -0,0 +1,408 @@
package inflector
import (
"bytes"
"fmt"
"regexp"
"strings"
"sync"
)
// Rule represents name of the inflector rule, can be
// Plural or Singular
type Rule int
const (
Plural = iota
Singular
)
// InflectorRule represents inflector rule
type InflectorRule struct {
Rules []*ruleItem
Irregular []*irregularItem
Uninflected []string
compiledIrregular *regexp.Regexp
compiledUninflected *regexp.Regexp
compiledRules []*compiledRule
}
type ruleItem struct {
pattern string
replacement string
}
type irregularItem struct {
word string
replacement string
}
// compiledRule represents compiled version of Inflector.Rules.
type compiledRule struct {
replacement string
*regexp.Regexp
}
// threadsafe access to rules and caches
var mutex sync.Mutex
var rules = make(map[Rule]*InflectorRule)
// Words that should not be inflected
var uninflected = []string{
`Amoyese`, `bison`, `Borghese`, `bream`, `breeches`, `britches`, `buffalo`,
`cantus`, `carp`, `chassis`, `clippers`, `cod`, `coitus`, `Congoese`,
`contretemps`, `corps`, `debris`, `diabetes`, `djinn`, `eland`, `elk`,
`equipment`, `Faroese`, `flounder`, `Foochowese`, `gallows`, `Genevese`,
`Genoese`, `Gilbertese`, `graffiti`, `headquarters`, `herpes`, `hijinks`,
`Hottentotese`, `information`, `innings`, `jackanapes`, `Kiplingese`,
`Kongoese`, `Lucchese`, `mackerel`, `Maltese`, `.*?media`, `mews`, `moose`,
`mumps`, `Nankingese`, `news`, `nexus`, `Niasese`, `Pekingese`,
`Piedmontese`, `pincers`, `Pistoiese`, `pliers`, `Portuguese`, `proceedings`,
`rabies`, `rice`, `rhinoceros`, `salmon`, `Sarawakese`, `scissors`,
`sea[- ]bass`, `series`, `Shavese`, `shears`, `siemens`, `species`, `swine`,
`testes`, `trousers`, `trout`, `tuna`, `Vermontese`, `Wenchowese`, `whiting`,
`wildebeest`, `Yengeese`,
}
// Plural words that should not be inflected
var uninflectedPlurals = []string{
`.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`,
`people`,
}
// Singular words that should not be inflected
var uninflectedSingulars = []string{
`.*[nrlm]ese`, `.*deer`, `.*fish`, `.*measles`, `.*ois`, `.*pox`, `.*sheep`,
`.*ss`,
}
type cache map[string]string
// Inflected words that already cached for immediate retrieval from a given Rule
var caches = make(map[Rule]cache)
// map of irregular words where its key is a word and its value is the replacement
var irregularMaps = make(map[Rule]cache)
var (
// https://github.com/golang/lint/blob/master/lint.go#L770
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
commonInitialismsReplacer *strings.Replacer
)
func init() {
rules[Plural] = &InflectorRule{
Rules: []*ruleItem{
{`(?i)(s)tatus$`, `${1}${2}tatuses`},
{`(?i)(quiz)$`, `${1}zes`},
{`(?i)^(ox)$`, `${1}${2}en`},
{`(?i)([m|l])ouse$`, `${1}ice`},
{`(?i)(matr|vert|ind)(ix|ex)$`, `${1}ices`},
{`(?i)(x|ch|ss|sh)$`, `${1}es`},
{`(?i)([^aeiouy]|qu)y$`, `${1}ies`},
{`(?i)(hive)$`, `$1s`},
{`(?i)(?:([^f])fe|([lre])f)$`, `${1}${2}ves`},
{`(?i)sis$`, `ses`},
{`(?i)([ti])um$`, `${1}a`},
{`(?i)(p)erson$`, `${1}eople`},
{`(?i)(m)an$`, `${1}en`},
{`(?i)(c)hild$`, `${1}hildren`},
{`(?i)(buffal|tomat)o$`, `${1}${2}oes`},
{`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|vir)us$`, `${1}i`},
{`(?i)us$`, `uses`},
{`(?i)(alias)$`, `${1}es`},
{`(?i)(ax|cris|test)is$`, `${1}es`},
{`s$`, `s`},
{`^$`, ``},
{`$`, `s`},
},
Irregular: []*irregularItem{
{`atlas`, `atlases`},
{`beef`, `beefs`},
{`brother`, `brothers`},
{`cafe`, `cafes`},
{`child`, `children`},
{`cookie`, `cookies`},
{`corpus`, `corpuses`},
{`cow`, `cows`},
{`ganglion`, `ganglions`},
{`genie`, `genies`},
{`genus`, `genera`},
{`graffito`, `graffiti`},
{`hoof`, `hoofs`},
{`loaf`, `loaves`},
{`man`, `men`},
{`money`, `monies`},
{`mongoose`, `mongooses`},
{`move`, `moves`},
{`mythos`, `mythoi`},
{`niche`, `niches`},
{`numen`, `numina`},
{`occiput`, `occiputs`},
{`octopus`, `octopuses`},
{`opus`, `opuses`},
{`ox`, `oxen`},
{`penis`, `penises`},
{`person`, `people`},
{`sex`, `sexes`},
{`soliloquy`, `soliloquies`},
{`testis`, `testes`},
{`trilby`, `trilbys`},
{`turf`, `turfs`},
{`potato`, `potatoes`},
{`hero`, `heroes`},
{`tooth`, `teeth`},
{`goose`, `geese`},
{`foot`, `feet`},
},
}
prepare(Plural)
rules[Singular] = &InflectorRule{
Rules: []*ruleItem{
{`(?i)(s)tatuses$`, `${1}${2}tatus`},
{`(?i)^(.*)(menu)s$`, `${1}${2}`},
{`(?i)(quiz)zes$`, `$1`},
{`(?i)(matr)ices$`, `${1}ix`},
{`(?i)(vert|ind)ices$`, `${1}ex`},
{`(?i)^(ox)en`, `$1`},
{`(?i)(alias)(es)*$`, `$1`},
{`(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|viri?)i$`, `${1}us`},
{`(?i)([ftw]ax)es`, `$1`},
{`(?i)(cris|ax|test)es$`, `${1}is`},
{`(?i)(shoe|slave)s$`, `$1`},
{`(?i)(o)es$`, `$1`},
{`ouses$`, `ouse`},
{`([^a])uses$`, `${1}us`},
{`(?i)([m|l])ice$`, `${1}ouse`},
{`(?i)(x|ch|ss|sh)es$`, `$1`},
{`(?i)(m)ovies$`, `${1}${2}ovie`},
{`(?i)(s)eries$`, `${1}${2}eries`},
{`(?i)([^aeiouy]|qu)ies$`, `${1}y`},
{`(?i)(tive)s$`, `$1`},
{`(?i)([lre])ves$`, `${1}f`},
{`(?i)([^fo])ves$`, `${1}fe`},
{`(?i)(hive)s$`, `$1`},
{`(?i)(drive)s$`, `$1`},
{`(?i)(^analy)ses$`, `${1}sis`},
{`(?i)(analy|diagno|^ba|(p)arenthe|(p)rogno|(s)ynop|(t)he)ses$`, `${1}${2}sis`},
{`(?i)([ti])a$`, `${1}um`},
{`(?i)(p)eople$`, `${1}${2}erson`},
{`(?i)(m)en$`, `${1}an`},
{`(?i)(c)hildren$`, `${1}${2}hild`},
{`(?i)(n)ews$`, `${1}${2}ews`},
{`eaus$`, `eau`},
{`^(.*us)$`, `$1`},
{`(?i)s$`, ``},
},
Irregular: []*irregularItem{
{`foes`, `foe`},
{`waves`, `wave`},
{`curves`, `curve`},
{`atlases`, `atlas`},
{`beefs`, `beef`},
{`brothers`, `brother`},
{`cafes`, `cafe`},
{`children`, `child`},
{`cookies`, `cookie`},
{`corpuses`, `corpus`},
{`cows`, `cow`},
{`ganglions`, `ganglion`},
{`genies`, `genie`},
{`genera`, `genus`},
{`graffiti`, `graffito`},
{`hoofs`, `hoof`},
{`loaves`, `loaf`},
{`men`, `man`},
{`monies`, `money`},
{`mongooses`, `mongoose`},
{`moves`, `move`},
{`mythoi`, `mythos`},
{`niches`, `niche`},
{`numina`, `numen`},
{`occiputs`, `occiput`},
{`octopuses`, `octopus`},
{`opuses`, `opus`},
{`oxen`, `ox`},
{`penises`, `penis`},
{`people`, `person`},
{`sexes`, `sex`},
{`soliloquies`, `soliloquy`},
{`testes`, `testis`},
{`trilbys`, `trilby`},
{`turfs`, `turf`},
{`potatoes`, `potato`},
{`heroes`, `hero`},
{`teeth`, `tooth`},
{`geese`, `goose`},
{`feet`, `foot`},
},
}
prepare(Singular)
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
// prepare rule, e.g., compile the pattern.
func prepare(r Rule) error {
var reString string
switch r {
case Plural:
// Merge global uninflected with singularsUninflected
rules[r].Uninflected = merge(uninflected, uninflectedPlurals)
case Singular:
// Merge global uninflected with singularsUninflected
rules[r].Uninflected = merge(uninflected, uninflectedSingulars)
}
// Set InflectorRule.compiledUninflected by joining InflectorRule.Uninflected into
// a single string then compile it.
reString = fmt.Sprintf(`(?i)(^(?:%s))$`, strings.Join(rules[r].Uninflected, `|`))
rules[r].compiledUninflected = regexp.MustCompile(reString)
// Prepare irregularMaps
irregularMaps[r] = make(cache, len(rules[r].Irregular))
// Set InflectorRule.compiledIrregular by joining the irregularItem.word of Inflector.Irregular
// into a single string then compile it.
vIrregulars := make([]string, len(rules[r].Irregular))
for i, item := range rules[r].Irregular {
vIrregulars[i] = item.word
irregularMaps[r][item.word] = item.replacement
}
reString = fmt.Sprintf(`(?i)(.*)\b((?:%s))$`, strings.Join(vIrregulars, `|`))
rules[r].compiledIrregular = regexp.MustCompile(reString)
// Compile all patterns in InflectorRule.Rules
rules[r].compiledRules = make([]*compiledRule, len(rules[r].Rules))
for i, item := range rules[r].Rules {
rules[r].compiledRules[i] = &compiledRule{item.replacement, regexp.MustCompile(item.pattern)}
}
// Prepare caches
caches[r] = make(cache)
return nil
}
// merge slice a and slice b
func merge(a []string, b []string) []string {
result := make([]string, len(a)+len(b))
copy(result, a)
copy(result[len(a):], b)
return result
}
func getInflected(r Rule, s string) string {
mutex.Lock()
defer mutex.Unlock()
if v, ok := caches[r][s]; ok {
return v
}
// Check for irregular words
if res := rules[r].compiledIrregular.FindStringSubmatch(s); len(res) >= 3 {
var buf bytes.Buffer
buf.WriteString(res[1])
buf.WriteString(s[0:1])
buf.WriteString(irregularMaps[r][strings.ToLower(res[2])][1:])
// Cache it then returns
caches[r][s] = buf.String()
return caches[r][s]
}
// Check for uninflected words
if rules[r].compiledUninflected.MatchString(s) {
caches[r][s] = s
return caches[r][s]
}
// Check each rule
for _, re := range rules[r].compiledRules {
if re.MatchString(s) {
caches[r][s] = re.ReplaceAllString(s, re.replacement)
return caches[r][s]
}
}
// Returns unaltered
caches[r][s] = s
return caches[r][s]
}
// Pluralize returns string s in plural form.
func Pluralize(s string) string {
return getInflected(Plural, s)
}
// Singularize returns string s in singular form.
func Singularize(s string) string {
return getInflected(Singular, s)
}
var (
camelizeReg = regexp.MustCompile(`[^A-Za-z0-9]+`)
)
// Camelize Converts a word like "send_email" to "SendEmail"
func Camelize(s string) string {
s = camelizeReg.ReplaceAllString(s, " ")
return strings.Replace(strings.Title(s), " ", "", -1)
}
// Camel2id Converts a word like "SendEmail" to "send_email"
func Camel2id(name string) string {
var (
value = commonInitialismsReplacer.Replace(name)
buf strings.Builder
lastCase, nextCase, nextNumber bool // upper case == true
curCase = value[0] <= 'Z' && value[0] >= 'A'
)
for i, v := range value[:len(value)-1] {
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
if curCase {
if lastCase && (nextCase || nextNumber) {
buf.WriteRune(v + 32)
} else {
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
buf.WriteByte('_')
}
buf.WriteRune(v + 32)
}
} else {
buf.WriteRune(v)
}
lastCase = curCase
curCase = nextCase
}
if curCase {
if !lastCase && len(value) > 1 {
buf.WriteByte('_')
}
buf.WriteByte(value[len(value)-1] + 32)
} else {
buf.WriteByte(value[len(value)-1])
}
ret := buf.String()
return ret
}
// Camel2words Converts a CamelCase name into space-separated words.
// For example, 'send_email' will be converted to 'Send Email'.
func Camel2words(s string) string {
s = camelizeReg.ReplaceAllString(s, " ")
return strings.Title(s)
}

987
model.go 100644
View File

@ -0,0 +1,987 @@
package rest
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"git.nobla.cn/golang/kos/util/arrays"
"git.nobla.cn/golang/kos/util/pool"
"git.nobla.cn/golang/rest/inflector"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path"
"reflect"
"strconv"
"strings"
"time"
)
type Model struct {
naming types.Naming //命名规则
value reflect.Value //模块值
db *gorm.DB //数据库
primaryKey string //主键
urlPrefix string //url前缀
disableDomain bool //禁用域
schemaLookup types.SchemaLookupFunc //获取schema的函数
valueLookup types.ValueLookupFunc //查看域
statement *gorm.Statement //字段声明
formatter *Formatter //格式化
response types.HttpWriter //HTTP响应
hookMgr *hookManager //钩子管理器
dirname string //存放文件目录
}
var (
RuntimeScopeKey = &types.RuntimeScope{}
)
// getDB 获取数据库连接对象
func (m *Model) getDB() *gorm.DB {
return m.db
}
// getFormatter 获取格式化组件
func (m *Model) getFormatter() *Formatter {
if m.formatter != nil {
return m.formatter
}
return DefaultFormatter
}
// getHook 获取钩子
func (m *Model) getHook() *hookManager {
return m.hookMgr
}
// hasScenario 判断是否有该场景
func (m *Model) hasScenario(s string) bool {
return true
}
// setValue 设置字段的值
func (m *Model) setValue(refValue reflect.Value, column string, value any) {
SetFieldValue(m.statement, refValue, column, value)
}
func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) {
SafeSetFileValue(m.statement, refValue, column, value)
}
// getValue 获取字段的值
func (m *Model) getValue(refValue reflect.Value, column string) interface{} {
return GetFieldValue(m.statement, refValue, column)
}
// hasColumn 判断指定的列是否存在
func (m *Model) hasColumn(column string) bool {
for _, field := range m.statement.Schema.Fields {
if field.DBName == column || field.Name == column {
return true
}
}
return false
}
// getFilename 获取文件存放目录
func (m *Model) getFilename(domain string, spec string, name string) string {
if m.dirname == "" {
m.dirname = os.TempDir()
}
filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name)
if _, err := os.Stat(path.Dir(filename)); err != nil {
_ = os.MkdirAll(path.Dir(filename), 0755)
}
return filename
}
// findPrimaryKey 查找主键的值
func (m *Model) findPrimaryKey(uri string, r *http.Request) string {
var (
pos int
)
urlPath := r.URL.Path
pos = strings.IndexByte(uri, ':')
if pos > 0 {
return urlPath[pos:]
}
return ""
}
// parseReportColumn 解析报表的列
func (m *Model) parseReportColumn(name, props string) *types.SelectColumn {
var (
key string
value string
)
column := &types.SelectColumn{
Name: inflector.Camel2id(name),
Native: false,
}
tokens := strings.Split(props, ";")
for _, token := range tokens {
pair := strings.SplitN(token, ":", 2)
if len(pair) == 0 {
continue
}
if len(pair) == 1 {
key = strings.TrimSpace(pair[0])
value = ""
} else {
key = strings.TrimSpace(pair[0])
value = strings.TrimSpace(pair[1])
}
switch key {
case "native":
column.Native = true
case "name":
column.Name = value
case "expr":
column.Expr = value
}
}
return column
}
func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) {
modelType := reflect.ValueOf(dest).Type()
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
columns := make([]string, 0)
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
scenarios := field.Tag.Get("scenarios")
if !hasToken(types.ScenarioList, scenarios) {
continue
}
isPrimary := field.Tag.Get("is_primary")
if isPrimary != "true" {
continue
}
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
if !column.Native {
continue
}
if column.Expr == "" {
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
} else {
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
}
}
columns = append(columns, "COUNT(*) AS count")
query.Select(columns...)
}
func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) {
modelType := reflect.ValueOf(dest).Type()
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
columns := make([]string, 0)
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
scenarios := field.Tag.Get("scenarios")
if !hasToken(types.ScenarioList, scenarios) {
continue
}
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
if !column.Native {
continue
}
if column.Expr == "" {
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
} else {
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
}
}
query.Select(columns...)
}
// buildCondition 构建sql条件
func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
return BuildConditions(ctx, r, query, schemas)
}
// ModuleName 模块名称
func (m *Model) ModuleName() string {
return m.naming.ModuleName
}
// TableName 表的名称
func (m *Model) TableName() string {
return m.naming.ModuleName
}
// Fields 返回搜索的模型的字段
func (m *Model) Fields() []*schema.Field {
return m.statement.Schema.Fields
}
// Uri 获取请求的uri
func (m *Model) Uri(scenario string) string {
ss := make([]string, 4)
if m.urlPrefix != "" {
ss = append(ss, m.urlPrefix)
}
switch scenario {
case types.ScenarioList:
ss = append(ss, m.naming.ModuleName, m.naming.Pluralize)
case types.ScenarioView:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioCreate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular)
case types.ScenarioUpdate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioDelete:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioExport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export")
case types.ScenarioImport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import")
}
uri := path.Join(ss...)
if !strings.HasPrefix(uri, "/") {
uri = "/" + uri
}
return uri
}
// Method 获取HTTP请求的方法
func (m *Model) Method(scenario string) string {
var (
method = http.MethodGet
)
switch scenario {
case types.ScenarioCreate:
method = http.MethodPost
case types.ScenarioUpdate:
method = http.MethodPut
case types.ScenarioDelete:
method = http.MethodDelete
}
return method
}
// Search 实现通过HTTP方法查找数据
func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
var (
ok bool
err error
qs url.Values
page int
pageSize int
pageIndex int
query *Query
domainName string
modelSlices reflect.Value
modelValues reflect.Value
searchSchemas []*types.Schema
listSchemas []*types.Schema
modelValue reflect.Value
scenario string
reporter types.Reporter
namerTable tableNamer
)
qs = r.URL.Query()
page, _ = strconv.Atoi(qs.Get("page"))
pageSize, _ = strconv.Atoi(qs.Get("pagesize"))
if pageSize <= 0 {
pageSize = defaultPageSize
}
pageIndex = page
if pageIndex > 0 {
pageIndex--
}
modelValue = reflect.New(m.value.Type())
//这里创建指针类型这样的话就能在format里面调用函数
if m.value.Kind() != reflect.Ptr {
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
} else {
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
}
modelValues = reflect.New(modelSlices.Type())
modelValues.Elem().Set(modelSlices)
query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface())
domainName = m.valueLookup(types.FieldDomain, w, r)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioList,
})
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
scenario = types.ScenarioList
if arrays.Exists(r.FormValue("scenario"), allowScenario) {
scenario = r.FormValue("scenario")
}
if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
query.AndWhere(newCondition(types.FieldDomain, domainName))
}
}
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
// 处理表名逻辑
if namerTable, ok = query.Model().(tableNamer); ok {
query.From(namerTable.HttpTableName(r))
}
//处理报表逻辑
if reporter, ok = modelValue.Interface().(types.Reporter); ok {
query.From(reporter.RealTable())
}
res := &types.ListResponse{
Page: page,
PageSize: pageSize,
}
if reporter == nil {
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
} else {
//如果是报表的情况需要手动指定COUNT的雨具逻辑才能生效
m.buildReporterCountColumns(childCtx, reporter, query)
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
//这里需要重置一下选项,不然会出问题
query.ResetSelect()
query.GroupBy(reporter.GroupBy(childCtx)...)
}
query.Offset(pageIndex * pageSize).Limit(pageSize)
if res.TotalCount > 0 {
if reporter != nil {
m.buildReporterQueryColumns(childCtx, reporter, query)
}
if err = query.All(modelValues.Interface()); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
// 不进行格式化输出
res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format"))
} else {
res.Data = make([]string, 0)
}
m.response.Success(w, res)
}
// Create 实现通过HTTP方法创建模型
func (m *Model) Create(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
schemas []*types.Schema
diffAttrs []*types.DiffAttr
domainName string
modelValue reflect.Value
)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
m.setValue(modelValue, types.FieldDomain, domainName)
}
}
diffAttrs = make([]*types.DiffAttr, 0, 10)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
User: m.valueLookup("user", w, r),
Request: r,
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioCreate,
Schemas: schemas,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
return
}
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Save(model).Error
} else {
errTx = tx.Save(model).Error
}
if errTx != nil {
return
}
for _, row := range schemas {
diffAttrs = append(diffAttrs, &types.DiffAttr{
Column: row.Column,
Label: row.Label,
OldValue: nil,
NewValue: m.getValue(modelValue, row.Column),
})
}
return
}); err == nil {
res := &types.CreateResponse{
ID: m.getValue(modelValue, m.primaryKey),
Status: "created",
}
if creator, ok := model.(afterCreated); ok {
creator.AfterCreated(childCtx, dbSess)
}
if preserver, ok := model.(afterSaved); ok {
preserver.AfterSaved(childCtx, dbSess)
}
m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs)
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
m.response.Success(w, res)
} else {
m.response.Failure(w, types.RequestCreateFailure, err.Error(), err)
}
}
// Update 实现通过HTTP方法更新模型
func (m *Model) Update(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
schemas []*types.Schema
diffAttrs []*types.DiffAttr
domainName string
modelValue reflect.Value
oldValues map[string]any
)
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
conditions := map[string]any{
m.primaryKey: idStr,
}
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
oldValues = make(map[string]any)
for _, row := range schemas {
oldValues[row.Column] = m.getValue(modelValue, row.Column)
}
if err = json.NewDecoder(r.Body).Decode(model); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
diffAttrs = make([]*types.DiffAttr, 0, 10)
updates := make(map[string]any)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioUpdate,
Schemas: schemas,
PrimaryKeyValue: idStr,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
return
}
for _, row := range schemas {
v := m.getValue(modelValue, row.Column)
if oldValues[row.Column] != v {
updates[row.Column] = v
diffAttrs = append(diffAttrs, &types.DiffAttr{
Column: row.Column,
Label: row.Label,
OldValue: oldValues[row.Column],
NewValue: v,
})
}
}
if len(updates) > 0 {
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error
} else {
errTx = tx.Model(model).Updates(updates).Error
}
if errTx != nil {
return
}
}
return
}); err == nil {
if updater, ok := model.(afterUpdated); ok {
updater.AfterUpdated(childCtx, dbSess)
}
if preserver, ok := model.(afterSaved); ok {
preserver.AfterSaved(childCtx, dbSess)
}
m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs)
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
m.response.Success(w, types.UpdateResponse{
ID: idStr,
Status: "updated",
})
} else {
m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil)
}
}
// Delete 实现通过HTTP方法删除模型
func (m *Model) Delete(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
modelValue reflect.Value
)
idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
conditions := map[string]any{
m.primaryKey: idStr,
}
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: m.valueLookup(types.FieldDomain, w, r),
User: m.valueLookup("user", w, r),
Request: r,
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioDelete,
PrimaryKeyValue: idStr,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil {
return
}
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Delete(model).Error
} else {
errTx = tx.Delete(model).Error
}
if errTx != nil {
return
}
m.getHook().afterDelete(childCtx, tx, model)
return
}); err == nil {
m.response.Success(w, types.DeleteResponse{
ID: idStr,
Status: "deleted",
})
} else {
m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil)
}
}
// View 查看数据详情
func (m *Model) View(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
modelValue reflect.Value
qs url.Values
schemas []*types.Schema
scenario string
domainName string
)
qs = r.URL.Query()
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
conditions := map[string]any{
m.primaryKey: idStr,
}
domainName = m.valueLookup(types.FieldDomain, w, r)
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
scenario = qs.Get("scenario")
if scenario == "" {
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView)
} else {
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario)
}
if err == nil {
m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format")))
} else {
m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil)
}
}
// Export 实现通过HTTP方法导出模型
func (m *Model) Export(w http.ResponseWriter, r *http.Request) {
var (
err error
query *Query
modelSlices reflect.Value
modelValues reflect.Value
searchSchemas []*types.Schema
exportSchemas []*types.Schema
domainName string
fp *os.File
modelValue reflect.Value
)
if !m.hasScenario(types.ScenarioList) {
m.response.Failure(w, types.RequestDenied, "request denied", nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil)
return
}
defer func() {
_ = fp.Close()
}()
modelValue = reflect.New(m.value.Type())
//这里创建指针类型这样的话就能在format里面调用函数
if m.value.Kind() != reflect.Ptr {
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
} else {
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
}
modelValues = reflect.New(modelSlices.Type())
modelValues.Elem().Set(modelSlices)
query = NewQuery(m.getDB(), modelValue.Interface())
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioExport,
})
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
query.AndWhere(newCondition(types.FieldDomain, domainName))
}
}
// 处理表名逻辑
if namerTable, ok := query.Model().(tableNamer); ok {
query.From(namerTable.HttpTableName(r))
}
//处理报表逻辑
if reporter, ok := modelValue.Interface().(types.Reporter); ok {
query.From(reporter.RealTable())
query.GroupBy(reporter.GroupBy(childCtx)...)
m.buildReporterQueryColumns(childCtx, reporter, query)
}
if err = query.All(modelValues.Interface()); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "")
writer := csv.NewWriter(fp)
rows := make([]string, len(exportSchemas))
for i, field := range exportSchemas {
rows[i] = field.Label
}
_ = writer.Write(rows)
if values, ok := value.([]any); ok {
for _, val := range values {
row, ok2 := val.(map[string]any)
if !ok2 {
continue
}
for i, field := range exportSchemas {
if v, ok := row[field.Column]; ok {
rows[i] = fmt.Sprint(v)
} else {
rows[i] = ""
}
}
_ = writer.Write(rows)
}
}
writer.Flush()
m.getHook().afterExport(childCtx, filename)
http.ServeContent(w, r, path.Base(filename), time.Now(), fp)
}
// findSchema 查找指定的schema
func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema {
for _, row := range schemas {
if row.Label == label {
return row
}
}
return nil
}
// importInternal 文件上传方法
func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) {
var (
err error
rows []string
fp *os.File
tm time.Time
fields []string
sess *gorm.DB
csvReader *csv.Reader
csvWriter *csv.Writer
modelValue reflect.Value
modelEntity any
diffAttrs []*types.DiffAttr
result *types.ImportResult
failureFp *os.File
failureFile string
)
tm = time.Now()
result = &types.ImportResult{}
if fp, err = os.Open(filename); err != nil {
result.Code = types.ErrImportFileNotExists
goto __end
}
defer func() {
_ = fp.Close()
}()
csvReader = csv.NewReader(fp)
if rows, err = csvReader.Read(); err != nil {
result.Code = types.ErrImportFileUnavailable
goto __end
}
fields = make([]string, 0, len(rows))
for _, s := range rows {
v := m.findSchema(s, schemas)
if v == nil {
result.Code = types.ErrImportColumnNotMatch
goto __end
}
fields = append(fields, v.Column)
}
sess = m.getDB().WithContext(ctx)
//失败文件指针
failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix()))
if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil {
return
}
defer func() {
_ = failureFp.Close()
}()
csvWriter = csv.NewWriter(failureFp)
rows = append(rows, "Error")
_ = csvWriter.Write(rows)
diffAttrs = make([]*types.DiffAttr, len(schemas))
for {
if rows, err = csvReader.Read(); err != nil {
break
}
result.TotalCount++
if len(rows) != len(fields) {
continue
}
modelValue = reflect.New(m.value.Type())
for idx, field := range fields {
m.safeSetValue(modelValue, field, rows[idx])
}
if len(extraFields) > 0 {
for k, v := range extraFields {
m.safeSetValue(modelValue, k, v)
}
}
modelEntity = modelValue.Interface()
//写入数据
if fast {
//如果是快速模式,直接存储数据
if err = sess.Save(modelEntity).Error; err == nil {
result.SuccessCount++
} else {
rows = append(rows, err.Error())
_ = csvWriter.Write(rows)
}
} else {
if err = sess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil {
return
}
if tabler, ok := modelEntity.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error
} else {
errTx = tx.Save(modelEntity).Error
}
if errTx != nil {
return
}
for idx, row := range schemas {
diffAttrs[idx] = &types.DiffAttr{
Column: row.Column,
Label: row.Label,
NewValue: m.getValue(modelValue, row.Column),
}
}
m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs)
m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs)
return
}); err == nil {
result.SuccessCount++
} else {
rows = append(rows, err.Error())
_ = csvWriter.Write(rows)
}
}
}
csvWriter.Flush()
__end:
result.UploadFile = filename
if result.TotalCount > result.SuccessCount {
result.FailureFile = failureFile
}
result.Duration = time.Now().Sub(tm)
m.getHook().afterImport(ctx, result)
}
// Import 实现通过HTTP方法导入
func (m *Model) Import(w http.ResponseWriter, r *http.Request) {
var (
err error
fast bool
schemas []*types.Schema
rows []string
domainName string
dst *os.File
fp multipart.File
csvWriter *csv.Writer
qs url.Values
extraFields map[string]string
)
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
//这里用background的context
childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioImport,
Schemas: schemas,
})
if r.Method == http.MethodGet {
//下载导入模板
csvWriter = csv.NewWriter(w)
rows = make([]string, 0, len(schemas))
for _, row := range schemas {
//主键不需要导入
if row.IsPrimaryKey == 1 {
continue
}
rows = append(rows, row.Label)
}
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
err = csvWriter.Write(rows)
csvWriter.Flush()
return
}
filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
if fp, _, err = r.FormFile("file"); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil)
return
}
defer func() {
_ = fp.Close()
}()
if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil {
buf := pool.GetBytes(32 * 1024)
_, err = io.CopyBuffer(dst, fp, buf)
pool.PutBytes(buf)
_ = dst.Close()
} else {
m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil)
return
}
qs = r.URL.Query()
if qs != nil {
extraFields = make(map[string]string)
for k, _ := range qs {
if strings.HasPrefix(k, "_attr_") {
extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k)
}
}
}
fast, _ = strconv.ParseBool(qs.Get("__fast"))
go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields)
m.response.Success(w, types.ImportResponse{
UID: m.valueLookup("user", w, r),
Status: "committed",
})
}
// newModel 创建一个模型
func newModel(v any, db *gorm.DB, naming types.Naming) *Model {
model := &Model{
db: db,
naming: naming,
response: &httpWriter{},
value: reflect.Indirect(reflect.ValueOf(v)),
valueLookup: defaultValueLookup,
}
model.statement = &gorm.Statement{
DB: model.getDB(),
ConnPool: model.getDB().ConnPool,
Clauses: map[string]clause.Clause{},
}
if err := model.statement.Parse(v); err == nil {
if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 {
model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0]
}
}
return model
}

52
options.go 100644
View File

@ -0,0 +1,52 @@
package rest
import "git.nobla.cn/golang/rest/types"
type Options struct {
urlPrefix string
moduleName string
disableDomain bool
router types.HttpRouter
writer types.HttpWriter
formatter *Formatter
dirname string //文件目录
}
type Option func(o *Options)
func WithUriPrefix(s string) Option {
return func(o *Options) {
o.urlPrefix = s
}
}
func WithModuleName(s string) Option {
return func(o *Options) {
o.moduleName = s
}
}
// WithoutDomain 禁用域
func WithoutDomain() Option {
return func(o *Options) {
o.disableDomain = true
}
}
func WithHttpRouter(s types.HttpRouter) Option {
return func(o *Options) {
o.router = s
}
}
func WithHttpWriter(s types.HttpWriter) Option {
return func(o *Options) {
o.writer = s
}
}
func WithFormatter(s *Formatter) Option {
return func(o *Options) {
o.formatter = s
}
}

108
plugins/cache/cache.go vendored 100644
View File

@ -0,0 +1,108 @@
package cache
import (
"encoding/json"
"fmt"
"git.nobla.cn/golang/kos/pkg/cache"
xxhash "github.com/cespare/xxhash/v2"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"os"
"strconv"
"time"
)
const (
DisableCache = "DISABLE_CACHE"
DurationKey = "gorm:cache_duration"
)
type Cacher struct {
rawQuery func(db *gorm.DB)
}
func (c *Cacher) Name() string {
return "gorm:cache"
}
func (c *Cacher) Initialize(db *gorm.DB) (err error) {
c.rawQuery = db.Callback().Query().Get("gorm:query")
err = db.Callback().Query().Replace("gorm:query", c.Query)
return
}
// buildCacheKey 构建一个缓存的KEY
func (c *Cacher) buildCacheKey(db *gorm.DB) string {
s := strconv.FormatUint(xxhash.Sum64String(db.Statement.SQL.String()+fmt.Sprintf("%v", db.Statement.Vars)), 10)
return s
}
// getDuration 获取缓存时长
func (c *Cacher) getDuration(db *gorm.DB) time.Duration {
var (
ok bool
v any
duration time.Duration
)
if v, ok = db.InstanceGet(DurationKey); !ok {
return 0
}
if duration, ok = v.(time.Duration); !ok {
return 0
}
return duration
}
// tryLoad 尝试从缓存读取数据
func (c *Cacher) tryLoad(key string, db *gorm.DB) (err error) {
var (
ok bool
buf []byte
)
if buf, ok = cache.Get(db.Statement.Context, key); ok {
err = json.Unmarshal(buf, db.Statement.Dest)
} else {
err = os.ErrNotExist
}
return
}
// storeCache 存储缓存数据
func (c *Cacher) storeCache(key string, db *gorm.DB, duration time.Duration) (err error) {
var (
buf []byte
)
if buf, err = json.Marshal(db.Statement.Dest); err == nil {
cache.SetEx(db.Statement.Context, key, buf, duration)
}
return
}
func (c *Cacher) Query(db *gorm.DB) {
var (
err error
cacheKey string
duration time.Duration
)
duration = c.getDuration(db)
if duration <= 0 {
c.rawQuery(db)
return
}
callbacks.BuildQuerySQL(db)
cacheKey = c.buildCacheKey(db)
if err = c.tryLoad(cacheKey, db); err == nil {
return
}
c.rawQuery(db)
if db.Error == nil {
//store cache
if err = c.storeCache(cacheKey, db, duration); err != nil {
_ = db.AddError(err)
}
}
}
func New() *Cacher {
return &Cacher{}
}

View File

@ -0,0 +1,57 @@
package identity
import (
"github.com/rs/xid"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"reflect"
)
type Identify struct {
}
func (identity *Identify) Name() string {
return "gorm:identity"
}
func (identity *Identify) Initialize(db *gorm.DB) (err error) {
err = db.Callback().Create().Before("gorm:create").Register("auto_identified", identity.Grant)
return
}
func (identity *Identify) NextID() string {
return xid.New().String()
}
func (identity *Identify) Grant(db *gorm.DB) {
var (
err error
field *schema.Field
)
if db.Statement.Schema == nil {
return
}
if field = db.Statement.Schema.LookUpField("ID"); field == nil {
return
}
if field.DataType != schema.String {
return
}
if db.Statement.ReflectValue.Kind() == reflect.Array || db.Statement.ReflectValue.Kind() == reflect.Slice {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue.Index(i)); zero {
if err = field.Set(db.Statement.Context, db.Statement.ReflectValue.Index(i), identity.NextID()); err != nil {
_ = db.AddError(err)
}
}
}
} else {
if _, zero := field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); zero {
db.Statement.SetColumn("ID", identity.NextID())
}
}
}
func New() *Identify {
return &Identify{}
}

View File

@ -0,0 +1,33 @@
# 分表实现
首先定义一个`gorm`的模型,然后实现`shadring.Model`接口,比如如下示例
```go
// ShardingTable 返回增删改时候操作的数据表
func (model *CdrLog) ShardingTable(scene string) string {
return model.TableName()
}
// ShardingTables 返回查询时候一个范围内的表
func (model *CdrLog) ShardingTables(ctx *sharding.Context) []string {
var (
timestamp int64
)
timeRange := make([]int64, 0)
values := ctx.FindColumnValues("start_stamp")
if len(values) == 0 {
values = ctx.FindColumnValues("create_stamp")
}
if len(values) > 0 {
for _, v := range values {
timestamp, _ = strconv.ParseInt(fmt.Sprint(v), 10, 64)
timeRange = append(timeRange, timestamp)
}
}
return shard.DateTableNames(ctx.Context(), "cdr_logs", shard.ShardTypeDateMonth, timeRange)
}
```
`ShardingTable`方法是操作增删改的是回调具体表名的方法
`ShardingTables`方法是操作查询的时候,通过查询条件返回的表名的方法

View File

@ -0,0 +1,120 @@
package sharding
const (
ValueOperaEqual = iota + 0x10
ValueOperaGreater
ValueOperaLess
ValueOperaRange
ValueTypeString = iota + 0x30
ValueTypeNumber
ValueTypeBoolean
ValueTypeNull
ValueTypeAny
)
type (
ColumnCondition struct {
Name string `json:"name"`
Value CondValue `json:"value"`
}
CondValue interface {
Type() int
Opera() int
Value() any
}
equalValue struct {
vType int
vData any
}
rangeValue struct {
vType int
vData any
}
greaterValue struct {
vType int
vData any
}
lessValue struct {
vType int
vData any
}
)
func (v *lessValue) Type() int {
return v.vType
}
func (v *lessValue) Opera() int {
return ValueOperaLess
}
func (v *lessValue) Value() any {
return v.vData
}
func (v *greaterValue) Type() int {
return v.vType
}
func (v *greaterValue) Opera() int {
return ValueOperaGreater
}
func (v *greaterValue) Value() any {
return v.vData
}
func (v *rangeValue) Type() int {
return v.vType
}
func (v *rangeValue) Opera() int {
return ValueOperaRange
}
func (v *rangeValue) Value() any {
return v.vData
}
func (v *equalValue) Type() int {
return v.vType
}
func (v *equalValue) Opera() int {
return ValueOperaEqual
}
func (v *equalValue) Value() any {
return v.vData
}
func newCondValue(vType int, op int, value any) CondValue {
switch op {
case ValueOperaGreater:
return &greaterValue{
vType: vType,
vData: value,
}
case ValueOperaLess:
return &lessValue{
vType: vType,
vData: value,
}
case ValueOperaRange:
return &rangeValue{
vType: vType,
vData: value,
}
default:
return &equalValue{
vType: vType,
vData: value,
}
}
}

View File

@ -0,0 +1,287 @@
package sharding
import (
"context"
"github.com/longbridgeapp/sqlparser"
sqlparserX "github.com/uole/sqlparser"
"gorm.io/gorm"
"strconv"
)
type Scope struct {
db *gorm.DB
stmt *sqlparser.SelectStatement
stmtX *sqlparserX.Select
}
func (scope *Scope) findValue(express sqlparser.Expr) (int, any) {
var (
vType int
vData any
)
switch expr := express.(type) {
case *sqlparser.BindExpr:
vType = ValueTypeAny
if len(scope.db.Statement.Vars) > expr.Pos {
vData = scope.db.Statement.Vars[expr.Pos]
} else {
vType = ValueTypeNull
}
case *sqlparser.NumberLit:
vType = ValueTypeNumber
vData = expr.Value
case *sqlparser.StringLit:
vType = ValueTypeString
vData = expr.Value
case *sqlparser.BoolLit:
vType = ValueTypeBoolean
vData = expr.Value
case *sqlparser.BlobLit:
vType = ValueTypeString
vData = expr.Value
case *sqlparser.NullLit:
vType = ValueTypeNull
case *sqlparser.Range:
arr := make([]any, 2)
vType, arr[0] = scope.findValue(expr.X)
vType, arr[1] = scope.findValue(expr.Y)
vData = arr
}
return vType, vData
}
func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) {
var (
vType int
vData any
)
switch express.Type {
case sqlparserX.IntVal:
vType = ValueTypeNumber
vData, _ = strconv.Atoi(string(express.Val))
case sqlparserX.FloatVal:
vType = ValueTypeNumber
vData, _ = strconv.ParseFloat(string(express.Val), 64)
case sqlparserX.ValArg:
vType = ValueTypeAny
pos, _ := strconv.Atoi(string(express.Val[2:]))
if pos > 0 {
pos = pos - 1
if len(scope.db.Statement.Vars) > pos {
vData = scope.db.Statement.Vars[pos]
} else {
vType = ValueTypeNull
}
} else {
vType = ValueTypeNull
}
default:
vType = ValueTypeString
vData = string(express.Val)
}
return vType, vData
}
func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) {
var (
ok bool
andExpr *sqlparserX.AndExpr
orExpr *sqlparserX.OrExpr
parentExpr *sqlparserX.ParenExpr
comparisonExpr *sqlparserX.ComparisonExpr
rangeExpr *sqlparserX.RangeCond
coumnExpr *sqlparserX.ColName
valueExpr *sqlparserX.SQLVal
)
conditions = make([]*ColumnCondition, 0, 2)
if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok {
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
return
}
if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok {
return
}
if coumnExpr.Name.EqualString(column) {
cond := &ColumnCondition{
Name: coumnExpr.Name.String(),
}
vType, vData := scope.findValueX(valueExpr)
switch comparisonExpr.Operator {
case sqlparserX.LessThanStr, sqlparserX.LessEqualStr:
cond.Value = newCondValue(vType, ValueOperaLess, vData)
case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr:
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
case sqlparserX.EqualStr:
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
}
if cond.Value != nil {
conditions = append(conditions, cond)
}
}
}
if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok {
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
return
}
if coumnExpr.Name.EqualString(column) {
vType := 0
arr := make([]any, 2)
if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok {
vType, arr[0] = scope.findValueX(valueExpr)
}
if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok {
vType, arr[1] = scope.findValueX(valueExpr)
}
conditions = append(conditions, &ColumnCondition{
Name: coumnExpr.Name.String(),
Value: newCondValue(vType, ValueOperaRange, arr),
})
}
}
if andExpr, ok = expr.(*sqlparserX.AndExpr); ok {
if andExpr.Left != nil {
conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...)
}
if andExpr.Right != nil {
conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...)
}
}
if orExpr, ok = expr.(*sqlparserX.OrExpr); ok {
if orExpr.Left != nil {
conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...)
}
if orExpr.Right != nil {
conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...)
}
}
if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok {
if parentExpr.Expr != nil {
conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...)
}
}
return conditions
}
func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition {
var (
ok bool
identExpr *sqlparser.Ident
binaryExpr *sqlparser.BinaryExpr
parentExpr *sqlparser.ParenExpr
conditions []*ColumnCondition
)
conditions = make([]*ColumnCondition, 0, 2)
if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok {
if parentExpr.X != nil {
if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok {
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
}
if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok {
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
}
}
}
if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok {
if binaryExpr.X != nil {
if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok {
if identExpr.Name == column {
cond := &ColumnCondition{
Name: identExpr.Name,
}
vType, vData := scope.findValue(binaryExpr.Y)
switch binaryExpr.Op {
case sqlparser.LT, sqlparser.LE:
cond.Value = newCondValue(vType, ValueOperaLess, vData)
case sqlparser.GT, sqlparser.GE:
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
case sqlparser.RANGE, sqlparser.BETWEEN:
cond.Value = newCondValue(vType, ValueOperaRange, vData)
case sqlparser.EQ:
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
}
if cond.Value != nil {
conditions = append(conditions, cond)
}
}
} else {
if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok {
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
}
if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok {
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
}
}
}
if binaryExpr.Y != nil {
if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok {
conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...)
}
}
}
return conditions
}
func (scope *Scope) DB() *gorm.DB {
return scope.db
}
func (scope *Scope) Context() context.Context {
return scope.db.Statement.Context
}
func (scope *Scope) FindCondition(column string) []*ColumnCondition {
if scope.stmtX != nil {
if scope.stmtX.Where == nil {
return []*ColumnCondition{}
}
return scope.recursiveFindX(scope.stmtX.Where.Expr, column)
}
return scope.recursiveFind(scope.stmt.Condition, column)
}
func (scope *Scope) FindColumnValues(column string) []any {
result := make([]any, 0)
conditions := scope.FindCondition(column)
if len(conditions) == 0 {
return result
}
for _, cond := range conditions {
if cond.Value.Opera() == ValueOperaGreater {
if len(result) == 0 {
result = make([]any, 2)
}
result[0] = cond.Value.Value()
}
if cond.Value.Opera() == ValueOperaLess {
if len(result) == 0 {
result = make([]any, 2)
}
result[1] = cond.Value.Value()
}
if cond.Value.Opera() == ValueOperaEqual {
result = append(result, cond.Value.Value())
break
}
if cond.Value.Opera() == ValueOperaRange {
if vs, ok := cond.Value.Value().([]any); ok {
if len(vs) == 2 {
if len(result) == 0 {
result = make([]any, 2)
}
result[0] = vs[0]
result[1] = vs[1]
}
}
break
}
}
return result
}

View File

@ -0,0 +1,476 @@
package sharding
import (
"github.com/longbridgeapp/sqlparser"
sqlparserX "github.com/uole/sqlparser"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"reflect"
"strings"
)
type Sharding struct {
UnionAll bool
QuoteChar byte
}
func (plugin *Sharding) Name() string {
return "gorm:sharding"
}
func (plugin *Sharding) Initialize(db *gorm.DB) (err error) {
if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil {
return
}
if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil {
return
}
if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil {
return
}
if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil {
return err
}
return
}
func (plugin *Sharding) Create(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneCreate))
}
}
}
func (plugin *Sharding) Update(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneUpdate))
}
}
}
func (plugin *Sharding) Delete(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneDelete))
}
}
}
func (plugin *Sharding) Query(db *gorm.DB) {
var (
err error
ok bool
shardingModel Model
modelValue any
tables []string
rawVars []any
refValue reflect.Value
tableName *sqlparser.TableName
selectStmt *sqlparser.SelectStatement
stmt sqlparser.Statement
parser *sqlparser.Parser
numOfTable int
orderByExpr []*sqlparser.OrderingTerm
limitExpr sqlparser.Expr
offsetExpr sqlparser.Expr
groupingExpr []sqlparser.Expr
havingExpr sqlparser.Expr
isCountStatement bool
countField string
)
if db.Statement.Model != nil {
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
} else {
refValue = reflect.New(db.Statement.ReflectValue.Type())
}
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
refValue = reflect.Indirect(refValue)
}
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
elemType := refValue.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
modelValue = reflect.New(elemType).Interface()
} else {
modelValue = refValue.Interface()
}
if shardingModel, ok = modelValue.(Model); !ok {
return
}
if db.Statement.SQL.Len() == 0 {
callbacks.BuildQuerySQL(db)
}
parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String()))
if stmt, err = parser.ParseStatement(); err != nil {
return
}
if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok {
return
}
tables = shardingModel.ShardingTables(&Scope{
db: db,
stmt: selectStmt,
})
numOfTable = len(tables)
if numOfTable <= 1 {
return
}
rawVars = make([]any, 0, len(db.Statement.Vars))
for _, v := range db.Statement.Vars {
rawVars = append(rawVars, v)
}
//是否是查询count语句
//如果不是count的语句添加order和group的支持
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
if v == true {
isCountStatement = true
}
}
if !isCountStatement && len(*selectStmt.Columns) == 1 {
for _, column := range *selectStmt.Columns {
if expr, ok := column.Expr.(*sqlparser.Call); ok {
if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword {
isCountStatement = true
countField = expr.String()
break
}
}
}
}
if len(selectStmt.OrderBy) > 0 {
orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy))
for _, row := range selectStmt.OrderBy {
orderByExpr = append(orderByExpr, row)
}
selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0)
}
if len(selectStmt.GroupingElements) > 0 {
groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements))
for _, row := range selectStmt.GroupingElements {
groupingExpr = append(groupingExpr, row)
}
if selectStmt.HavingCondition != nil {
havingExpr = selectStmt.HavingCondition
}
}
if selectStmt.Limit != nil {
limitExpr = selectStmt.Limit
selectStmt.Limit = nil
}
if selectStmt.Offset != nil {
offsetExpr = selectStmt.Offset
selectStmt.Offset = nil
}
db.Statement.SQL.Reset()
if isCountStatement {
db.Statement.SQL.WriteString("SELECT SUM(")
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(strings.Trim(countField, "`"))
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(") FROM (")
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
for i, name := range tables {
db.Statement.SQL.WriteByte('(')
if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok {
tableName.Name.Name = name
}
db.Statement.SQL.WriteString(selectStmt.String())
db.Statement.SQL.WriteByte(')')
if i < numOfTable-1 {
if plugin.UnionAll {
db.Statement.SQL.WriteString(" UNION ALL ")
} else {
db.Statement.SQL.WriteString(" UNION ")
}
}
if i > 0 {
//copy vars
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
}
}
db.Statement.SQL.WriteString(") tbl ")
if !isCountStatement {
if len(groupingExpr) > 0 {
db.Statement.SQL.WriteString(" GROUP BY ")
for i, expr := range groupingExpr {
if i != 0 {
db.Statement.SQL.WriteString(", ")
}
db.Statement.SQL.WriteString(expr.String())
}
if havingExpr != nil {
db.Statement.SQL.WriteString(" HAVING ")
db.Statement.SQL.WriteString(havingExpr.String())
}
}
if orderByExpr != nil && len(orderByExpr) > 0 {
db.Statement.SQL.WriteString(" ORDER BY ")
for i, term := range orderByExpr {
if i != 0 {
db.Statement.SQL.WriteString(", ")
}
db.Statement.SQL.WriteString(term.String())
}
}
if limitExpr != nil {
db.Statement.SQL.WriteString(" LIMIT ")
db.Statement.SQL.WriteString(limitExpr.String())
if offsetExpr != nil {
db.Statement.SQL.WriteString(" OFFSET ")
db.Statement.SQL.WriteString(offsetExpr.String())
}
}
}
return
}
func (plugin *Sharding) QueryX(db *gorm.DB) {
var (
err error
ok bool
shardingModel Model
modelValue any
tables []string
rawVars []any
refValue reflect.Value
selectStmt *sqlparserX.Select
stmt sqlparserX.Statement
numOfTable int
isCountStatement bool
isPureCountStatement bool
trackedBuffer *sqlparserX.TrackedBuffer
funcExpr *sqlparserX.FuncExpr
aliasedExpr *sqlparserX.AliasedExpr
orderByExpr sqlparserX.OrderBy
groupByExpr sqlparserX.GroupBy
havingExpr *sqlparserX.Where
limitExpr *sqlparserX.Limit
)
if db.Statement.Model != nil {
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
} else {
if !db.Statement.ReflectValue.IsValid() {
return
}
refValue = reflect.New(db.Statement.ReflectValue.Type())
}
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
refValue = reflect.Indirect(refValue)
}
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
elemType := refValue.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
modelValue = reflect.New(elemType).Interface()
} else {
modelValue = refValue.Interface()
}
if shardingModel, ok = modelValue.(Model); !ok {
return
}
if db.Statement.SQL.Len() == 0 {
callbacks.BuildQuerySQL(db)
}
if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil {
return
}
if selectStmt, ok = stmt.(*sqlparserX.Select); !ok {
return
}
tables = shardingModel.ShardingTables(&Scope{
db: db,
stmtX: selectStmt,
})
numOfTable = len(tables)
if numOfTable <= 1 {
return
}
// 保存值
rawVars = make([]any, 0, len(db.Statement.Vars))
for _, v := range db.Statement.Vars {
rawVars = append(rawVars, v)
}
// 替换语句
if selectStmt.OrderBy != nil {
orderByExpr = selectStmt.OrderBy
selectStmt.OrderBy = nil
}
if selectStmt.GroupBy != nil {
groupByExpr = selectStmt.GroupBy
//selectStmt.GroupBy = nil
}
if selectStmt.Having != nil {
havingExpr = selectStmt.Having
//selectStmt.Having = nil
}
if selectStmt.Limit != nil {
limitExpr = selectStmt.Limit
selectStmt.Limit = nil
}
// 检查是否为COUNT语句
//如果不是count的语句添加order和group的支持
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
//这里处理的是报表的情况再报表里面需要重写COUNT语句才能获取到正确的数量
if v == true {
isCountStatement = true
isPureCountStatement = true
}
}
//常规的COUNT逻辑
if !isCountStatement && len(selectStmt.SelectExprs) == 1 {
for _, expr := range selectStmt.SelectExprs {
if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok {
if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok {
if funcExpr.Name.EqualString("count") {
isCountStatement = true
break
}
}
}
}
}
// 重写SQL
db.Statement.SQL.Reset()
trackedBuffer = sqlparserX.NewTrackedBuffer(nil)
if isCountStatement {
if isPureCountStatement {
db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (")
} else {
db.Statement.SQL.WriteString("SELECT SUM(")
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(strings.Trim("count(*)", "`"))
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(") FROM (")
}
} else {
if bs, ok := modelValue.(SelectBuilder); ok {
columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs)
if len(columns) > 0 {
db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (")
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
}
for i, name := range tables {
trackedBuffer.Reset()
db.Statement.SQL.WriteByte('(')
//赋值新的表名称
selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{
Expr: sqlparserX.TableName{
Name: sqlparserX.NewTableIdent(name),
},
}}
selectStmt.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
db.Statement.SQL.WriteByte(')')
if i < numOfTable-1 {
if plugin.UnionAll {
db.Statement.SQL.WriteString(" UNION ALL ")
} else {
db.Statement.SQL.WriteString(" UNION ")
}
}
if i > 0 {
//copy vars
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
}
}
db.Statement.SQL.WriteString(") tbl ")
if !isCountStatement {
//node.GroupBy, node.Having, node.OrderBy, node.Limit
if groupByExpr != nil {
trackedBuffer.Reset()
groupByExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if havingExpr != nil {
trackedBuffer.Reset()
havingExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if orderByExpr != nil {
trackedBuffer.Reset()
orderByExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if limitExpr != nil {
trackedBuffer.Reset()
limitExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
}
}
func New() *Sharding {
return &Sharding{
UnionAll: true,
QuoteChar: '`',
}
}

View File

@ -0,0 +1,61 @@
package sharding
import (
"fmt"
"github.com/uole/sqlparser"
"testing"
)
func TestFormat(t *testing.T) {
t.Log(fmt.Sprintf("%.2f%%", 0.5851888*100))
}
func TestSharding_Query(t *testing.T) {
//sql := "SELECT COUNT(*) FROM aaa"
sql := "SELECT uid,SUM(IF(direction='inbound',1,0)) AS inbound_times,SUM(IF(direction='inbound',IF(answer_duration>0,1,0),0)) AS inbound_answer_times,SUM(IF(direction='outbound',1,0)) AS outbound_times,SUM(IF(direction='outbound',IF(answer_duration>0,1,0),0)) AS outbound_answer_times FROM `cdr_logs` WHERE ((`domain` = 'test.cc.echo.me' OR `domain` = 'default') AND `name` <> '') AND (`create_stamp` BETWEEN 1712505600 AND 1712505608) AND `name` IN ('a','b','c') GROUP BY `uid` having uid != '' ORDER BY create_stamp DESC LIMIT 15"
stmt, err := sqlparser.Parse(sql)
if err != nil {
t.Error(err)
return
}
selectStmt, ok := stmt.(*sqlparser.Select)
if !ok {
t.Error("not select stmt")
return
}
buf := sqlparser.NewTrackedBuffer(nil)
sqlparser.NewTableIdent("test").Format(buf)
buf.Reset()
selectStmt.Format(buf)
t.Log("SQL")
t.Log(buf.String())
buf.Reset()
t.Log("SELECT")
selectStmt.SelectExprs.Format(buf)
t.Log(buf.String())
buf.Reset()
t.Log("FROM")
selectStmt.From.Format(buf)
t.Log(buf.String())
buf.Reset()
t.Log("WHERE")
selectStmt.Where.Format(buf)
t.Log(buf.String())
buf.Reset()
t.Log("ORDER BY")
selectStmt.OrderBy.Format(buf)
t.Log(buf.String())
buf.Reset()
t.Log("GROUP BY")
selectStmt.GroupBy.Format(buf)
t.Log(buf.String())
buf.Reset()
t.Log("LIMIT")
selectStmt.Limit.Format(buf)
t.Log(buf.String())
buf.Reset()
t.Log("Having")
selectStmt.Having.Format(buf)
t.Log(buf.String())
buf.Reset()
}

View File

@ -0,0 +1,45 @@
package sharding
import (
"context"
sqlparserX "github.com/uole/sqlparser"
)
const (
TypeDatetime = "datetime"
TypeHash = "hash"
)
const (
DateTypeYear = iota + 5
DateTypeMonth
DateTypeWeek
DateTypeDay
)
const (
sceneCreate = "create"
sceneUpdate = "update"
sceneDelete = "delete"
)
const (
stmtCountKeyword = "count"
)
type (
Rule struct {
Type string `json:"type"`
Args int `json:"args"`
}
Model interface {
ShardingRule() Rule
ShardingTable(scene string) string
ShardingTables(scope *Scope) []string
}
SelectBuilder interface {
BuildSelect(ctx context.Context, expr sqlparserX.SelectExprs) []string
}
)

View File

@ -0,0 +1,107 @@
package validate
import (
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"reflect"
"regexp"
"strconv"
)
const (
SkipValidations = "validations:skip_validations"
)
var (
scopeCtxKey = &validateScope{}
telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
)
type (
validateScope struct {
DB *gorm.DB
Column string
Domain string
MultiDomain bool
Model interface{}
}
StructError struct {
Tag string `json:"tag"`
Column string `json:"column"`
Message string `json:"message"`
}
validateRule struct {
Rule string
Value string
Valid bool
}
)
func (err *StructError) Error() string {
return err.Message
}
func newRule(ss ...string) *validateRule {
v := &validateRule{
Valid: true,
}
if len(ss) == 1 {
v.Rule = ss[0]
} else if len(ss) >= 2 {
v.Rule = ss[0]
v.Value = ss[1]
}
return v
}
// formatError 格式化错误消息
func formatError(rule types.Rule, scm *types.Schema, tag string) string {
var s string
switch tag {
case "db_unique":
s = scm.Label + "值已经存在."
break
case "required":
s = scm.Label + "值不能为空."
case "max":
if scm.Type == "string" {
s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
} else {
s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
}
case "min":
if scm.Type == "string" {
s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
} else {
s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
}
}
return s
}
// isEmpty 判断值是否为空
func isEmpty(val any) bool {
if val == nil {
return true
}
v := reflect.ValueOf(val)
switch v.Kind() {
case reflect.String, reflect.Array:
return v.Len() == 0
case reflect.Map, reflect.Slice:
return v.IsNil() || v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
}
return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface())
}

View File

@ -0,0 +1,275 @@
package validate
import (
"context"
"fmt"
"git.nobla.cn/golang/rest"
"git.nobla.cn/golang/rest/types"
validator "github.com/go-playground/validator/v10"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"reflect"
"strconv"
"strings"
)
type Validate struct {
validator *validator.Validate
}
func (validate *Validate) Name() string {
return "gorm:validate"
}
func (validate *Validate) telephoneValidate(ctx context.Context, fl validator.FieldLevel) bool {
val := fmt.Sprint(fl.Field().Interface())
return telephoneRegex.MatchString(val)
}
func (validate *Validate) uniqueValidate(ctx context.Context, fl validator.FieldLevel) bool {
var (
scope *validateScope
ok bool
count int64
field *schema.Field
primaryKeyValue reflect.Value
)
val := fl.Field().Interface()
if scope, ok = ctx.Value(scopeCtxKey).(*validateScope); !ok {
return true
}
if len(scope.DB.Statement.Schema.PrimaryFields) > 0 {
field = scope.DB.Statement.Schema.PrimaryFields[0]
primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model))
for _, n := range field.BindNames {
primaryKeyValue = primaryKeyValue.FieldByName(n)
}
}
sess := scope.DB.Session(&gorm.Session{NewDB: true})
if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil {
//多域校验
if scope.MultiDomain && scope.Domain != "" {
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ? AND domain = ?", val, primaryKeyValue.Interface(), scope.Domain).Count(&count)
} else {
sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count)
}
} else {
if scope.MultiDomain && scope.Domain != "" {
sess.Model(scope.Model).Where(scope.Column+"=? AND domain = ?", val, scope.Domain).Count(&count)
} else {
sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count)
}
}
if count > 0 {
return false
}
return true
}
func (validate *Validate) grantRules(scm *types.Schema, scenario string, rule types.Rule) []*validateRule {
rules := make([]*validateRule, 0, 5)
if rule.Min != 0 {
rules = append(rules, newRule("min", strconv.Itoa(rule.Min)))
}
if rule.Max != 0 {
rules = append(rules, newRule("max", strconv.Itoa(rule.Max)))
}
//主键不做唯一判断
if rule.Unique && !scm.Attribute.PrimaryKey {
rules = append(rules, newRule("db_unique"))
}
if rule.Type != "" {
rules = append(rules, newRule(rule.Type))
}
if rule.Required != nil && len(rule.Required) > 0 {
for _, v := range rule.Required {
if v == scenario {
rules = append(rules, newRule("required"))
break
}
}
}
return rules
}
func (validate *Validate) buildRules(rs []*validateRule) string {
var sb strings.Builder
for _, r := range rs {
if !r.Valid {
continue
}
if sb.Len() > 0 {
sb.WriteString(",")
}
if r.Value == "" {
sb.WriteString(r.Rule)
} else {
sb.WriteString(r.Rule + "=" + r.Value)
}
}
return sb.String()
}
func (validate *Validate) findRule(name string, rules []*validateRule) *validateRule {
for _, r := range rules {
if r.Rule == name {
return r
}
}
return nil
}
func (validate *Validate) Initialize(db *gorm.DB) (err error) {
validate.validator = validator.New()
if err = db.Callback().Create().Before("gorm:before_create").Register("model_validate", validate.Validate); err != nil {
return
}
if err = db.Callback().Create().Before("gorm:before_update").Register("model_validate", validate.Validate); err != nil {
return
}
if err = validate.validator.RegisterValidationCtx("telephone", validate.telephoneValidate); err != nil {
return
}
if err = validate.validator.RegisterValidationCtx("db_unique", validate.uniqueValidate); err != nil {
return
}
return
}
func (validate *Validate) inArray(v any, vs []any) bool {
sv := fmt.Sprint(v)
for _, s := range vs {
if fmt.Sprint(s) == sv {
return true
}
}
return false
}
// isVisible 判断字段是否需要显示
func (validate *Validate) isVisible(stmt *gorm.Statement, scm *types.Schema) bool {
if len(scm.Attribute.Visible) <= 0 {
return true
}
for _, row := range scm.Attribute.Visible {
if len(row.Values) == 0 {
continue
}
targetField := stmt.Schema.LookUpField(row.Column)
if targetField == nil {
continue
}
targetValue := stmt.ReflectValue.FieldByName(targetField.Name)
if !targetValue.IsValid() {
return false
}
if !validate.inArray(targetValue.Interface(), row.Values) {
return false
}
}
return true
}
// Validate 校验字段
func (validate *Validate) Validate(db *gorm.DB) {
var (
ok bool
err error
rules []*validateRule
stmt *gorm.Statement
skipValidate bool
multiDomain bool
value reflect.Value
runtimeScope *types.RuntimeScope
)
if result, ok := db.Get(SkipValidations); ok && result.(bool) {
return
}
stmt = db.Statement
if stmt.Model == nil {
return
}
if db.Statement.Context == nil {
return
}
if val := db.Statement.Context.Value(rest.RuntimeScopeKey); val != nil {
if runtimeScope, ok = val.(*types.RuntimeScope); !ok {
return
}
} else {
return
}
if runtimeScope.Schemas == nil {
return
}
if stmt.Schema.LookUpField("domain") != nil {
multiDomain = true
}
for _, row := range runtimeScope.Schemas {
//如果字段隐藏,那么就不进行校验
if !validate.isVisible(stmt, row) {
continue
}
if rules = validate.grantRules(row, runtimeScope.Scenario, row.Rule); len(rules) <= 0 {
continue
}
field := stmt.Schema.LookUpField(row.Column)
if field == nil {
continue
}
value = stmt.ReflectValue.FieldByName(field.Name)
if !value.IsValid() {
continue
}
skipValidate = false
if r := validate.findRule("required", rules); r != nil {
if value.Interface() != nil {
vType := reflect.ValueOf(value.Interface())
switch vType.Kind() {
case reflect.Bool:
skipValidate = true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
skipValidate = true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
skipValidate = true
case reflect.Float32, reflect.Float64:
skipValidate = true
default:
skipValidate = false
}
}
if skipValidate {
r.Valid = false
}
} else {
if isEmpty(value.Interface()) {
continue
}
}
ctx := context.WithValue(db.Statement.Context, scopeCtxKey, &validateScope{
DB: db,
Column: row.Column,
Model: stmt.Model,
Domain: runtimeScope.Domain,
MultiDomain: multiDomain,
})
if err = validate.validator.VarCtx(ctx, value.Interface(), validate.buildRules(rules)); err != nil {
if errs, ok := err.(validator.ValidationErrors); ok {
for _, e := range errs {
_ = db.AddError(&StructError{
Tag: e.Tag(),
Column: row.Column,
Message: formatError(row.Rule, row, e.Tag()),
})
}
} else {
_ = db.AddError(err)
}
break
}
}
}
func New() *Validate {
return &Validate{}
}

405
query.go 100644
View File

@ -0,0 +1,405 @@
package rest
import (
"context"
"fmt"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"reflect"
"strconv"
"strings"
)
type (
Query struct {
db *gorm.DB
condition string
fields []string
params []interface{}
table string
joins []join
orderBy []string
groupBy []string
modelValue any
limit int
offset int
}
condition struct {
Field string `json:"field"`
Value interface{} `json:"value"`
Expr string `json:"expr"`
}
join struct {
Table string
Direction string
Conditions []*condition
}
)
func (cond *condition) WithExpr(v string) *condition {
cond.Expr = v
return cond
}
func (query *Query) Model() any {
return query.modelValue
}
func (query *Query) compile() (*gorm.DB, error) {
db := query.db
if query.condition != "" {
db = db.Where(query.condition, query.params...)
}
if query.fields != nil {
db = db.Select(strings.Join(query.fields, ","))
}
if query.table != "" {
db = db.Table(query.table)
}
if query.joins != nil && len(query.joins) > 0 {
for _, joinEntity := range query.joins {
cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...)
db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
}
}
if query.orderBy != nil && len(query.orderBy) > 0 {
db = db.Order(strings.Join(query.orderBy, ","))
}
if query.groupBy != nil && len(query.groupBy) > 0 {
db = db.Group(strings.Join(query.groupBy, ","))
}
if query.offset > 0 {
db = db.Offset(query.offset)
}
if query.limit > 0 {
db = db.Limit(query.limit)
}
return db, nil
}
func (query *Query) decodeValue(v any) string {
refVal := reflect.Indirect(reflect.ValueOf(v))
switch refVal.Kind() {
case reflect.Bool:
if refVal.Bool() {
return "1"
} else {
return "0"
}
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
return strconv.FormatInt(refVal.Int(), 10)
case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
case reflect.String:
return "'" + refVal.String() + "'"
default:
return fmt.Sprint(v)
}
}
func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) {
var (
sb strings.Builder
)
params = make([]interface{}, 0)
for _, cond := range conditions {
if filter {
if isEmpty(cond.Value) {
continue
}
}
if cond.Expr == "" {
cond.Expr = "="
}
switch strings.ToUpper(cond.Expr) {
case "=", "<>", ">", "<", ">=", "<=", "!=":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
if cond.Expr == "=" && cond.Value == nil {
sb.WriteString("`" + cond.Field + "` IS NULL")
} else {
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
params = append(params, cond.Value)
}
case "LIKE":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
sb.WriteString("`" + cond.Field + "` LIKE ?")
params = append(params, cond.Value)
case "IN":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
switch refVal.Kind() {
case reflect.Slice, reflect.Array:
ss := make([]string, refVal.Len())
for i := 0; i < refVal.Len(); i++ {
ss[i] = query.decodeValue(refVal.Index(i))
}
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
case reflect.String:
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
default:
}
case "BETWEEN":
refVal := reflect.ValueOf(cond.Value)
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
params = append(params, refVal.Index(0), refVal.Index(1))
}
}
}
str = sb.String()
return
}
func (query *Query) Select(fields ...string) *Query {
if query.fields == nil {
query.fields = fields
} else {
query.fields = append(query.fields, fields...)
}
return query
}
func (query *Query) From(table string) *Query {
query.table = table
return query
}
func (query *Query) LeftJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "LEFT",
Conditions: conditions,
})
return query
}
func (query *Query) RightJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "RIGHT",
Conditions: conditions,
})
return query
}
func (query *Query) InnerJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "INNER",
Conditions: conditions,
})
return query
}
func (query *Query) AddCondition(expr, column string, val any) {
if expr == "" {
expr = "="
}
query.AndWhere(newConditionWithOperator(expr, column, val))
}
func (query *Query) AndWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("AND", false, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) AndFilterWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("AND", true, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND " + cs
}
return query
}
func (query *Query) OrWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("OR", false, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) OrFilterWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("OR", true, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) GroupBy(cols ...string) *Query {
query.groupBy = append(query.groupBy, cols...)
return query
}
func (query *Query) OrderBy(col, direction string) *Query {
direction = strings.ToUpper(direction)
if direction == "" || !(direction == "ASC" || direction == "DESC") {
direction = "ASC"
}
query.orderBy = append(query.orderBy, col+" "+direction)
return query
}
func (query *Query) Offset(i int) *Query {
query.offset = i
return query
}
func (query *Query) Limit(i int) *Query {
query.limit = i
return query
}
func (query *Query) ResetSelect() *Query {
query.fields = make([]string, 0)
return query
}
func (query *Query) Count(v interface{}) (i int64) {
var (
db *gorm.DB
err error
)
if db, err = query.compile(); err != nil {
return
} else {
if v != nil {
refVal := reflect.ValueOf(v)
switch refVal.Kind() {
case reflect.String:
if query.table == "" {
err = db.Table(refVal.String()).Count(&i).Error
} else {
err = db.Table(query.table).Count(&i).Error
}
default:
//如果是报表的模型这的话手动构建一条SQL语句
if reporter, ok := v.(types.Reporter); ok {
sqlRes := &sqlCountResponse{}
childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true)
db.WithContext(childCtx).Model(reporter).First(sqlRes)
i = sqlRes.Count
} else {
err = db.Model(v).Count(&i).Error
}
}
}
}
return
}
func (query *Query) One(v interface{}) (err error) {
var (
db *gorm.DB
)
if db, err = query.compile(); err != nil {
return
} else {
err = db.First(v).Error
}
return
}
func (query *Query) All(v interface{}) (err error) {
var (
db *gorm.DB
)
if db, err = query.compile(); err != nil {
return
} else {
err = db.Find(v).Error
}
return
}
func NewCondition(column, opera string, value any) *condition {
if opera == "" {
opera = "="
}
return &condition{
Field: column,
Value: value,
Expr: opera,
}
}
func newCondition(field string, value interface{}) *condition {
return &condition{
Field: field,
Value: value,
Expr: "=",
}
}
func newConditionWithOperator(operator, field string, value interface{}) *condition {
cond := &condition{
Field: field,
Value: value,
Expr: operator,
}
return cond
}
func NewQuery(db *gorm.DB, model any) *Query {
return &Query{
db: db,
modelValue: model,
params: make([]interface{}, 0),
orderBy: make([]string, 0),
groupBy: make([]string, 0),
joins: make([]join, 0),
}
}

753
rest.go 100644
View File

@ -0,0 +1,753 @@
package rest
import (
"context"
"errors"
"fmt"
"git.nobla.cn/golang/kos/util/arrays"
"git.nobla.cn/golang/kos/util/reflection"
"git.nobla.cn/golang/rest/inflector"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"net/http"
"reflect"
"strconv"
"strings"
"time"
)
var (
modelEntities []*Model
httpRouter types.HttpRouter
hookMgr *hookManager
timeKind = reflect.TypeOf(time.Time{}).Kind()
timePtrKind = reflect.TypeOf(&time.Time{}).Kind()
matchEnums = []string{types.MatchExactly, types.MatchFuzzy}
)
var (
allowScenario = []string{types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
)
func init() {
hookMgr = &hookManager{}
modelEntities = make([]*Model, 0)
}
// cloneStmt 从指定的db克隆一个 Statement 对象
func cloneStmt(db *gorm.DB) *gorm.Statement {
return &gorm.Statement{
DB: db,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
}
}
// dataTypeOf 推断数据的类型
func dataTypeOf(field *schema.Field) string {
var dataType string
reflectType := field.FieldType
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
if dataType = field.Tag.Get("type"); dataType != "" {
return dataType
}
dataValue := reflect.Indirect(reflect.New(reflectType))
switch dataValue.Kind() {
case reflect.Bool:
dataType = types.TypeBoolean
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
dataType = types.TypeInteger
case reflect.Float32, reflect.Float64:
dataType = types.TypeFloat
default:
dataType = types.TypeString
}
return dataType
}
// dataFormatOf 推断数据的格式
func dataFormatOf(field *schema.Field) string {
var format string
format = field.Tag.Get("format")
if format != "" {
return format
}
//如果有枚举值,直接设置为下拉类型
enum := field.Tag.Get("enum")
if enum != "" {
return types.FormatDropdown
}
reflectType := field.FieldType
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
//时间处理
dataValue := reflect.Indirect(reflect.New(reflectType))
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" || field.Name == "DeletedAt" {
switch dataValue.Kind() {
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return types.FormatTimestamp
default:
return types.FormatDatetime
}
}
if strings.Contains(strings.ToLower(field.Name), "pass") {
return types.FormatPassword
}
switch dataValue.Kind() {
case timeKind, timePtrKind:
format = types.FormatDatetime
case reflect.Bool:
format = types.FormatBoolean
case reflect.Int8, reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
format = types.FormatInteger
case reflect.Float32, reflect.Float64:
format = types.FormatFloat
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
format = types.FormatDatetime
}
default:
if field.Size >= 1024 {
format = types.FormatText
} else {
format = types.FormatString
}
}
return format
}
// fieldName 生成字段名称
func fieldName(name string) string {
tokens := strings.Split(name, "_")
for i, s := range tokens {
tokens[i] = strings.Title(s)
}
return strings.Join(tokens, " ")
}
// fieldNative 判断是否为原始字段
func fieldNative(field *schema.Field) uint8 {
if _, ok := field.Tag.Lookup("virtual"); ok {
return 0
}
return 1
}
// fieldRule 返回字段规则
func fieldRule(field *schema.Field) types.Rule {
r := types.Rule{
Required: []string{},
}
if field.GORMDataType == schema.String {
r.Max = field.Size
}
if field.GORMDataType == schema.Int || field.GORMDataType == schema.Float || field.GORMDataType == schema.Uint {
r.Max = field.Scale
}
rs := field.Tag.Get("rule")
if rs != "" {
ss := strings.Split(rs, ";")
for _, s := range ss {
vs := strings.SplitN(s, ":", 2)
ls := len(vs)
if ls == 0 {
continue
}
switch vs[0] {
case "required", "require":
if ls > 1 {
bs := strings.Split(vs[1], ",")
for _, i := range bs {
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
r.Required = append(r.Required, i)
}
}
} else {
r.Required = []string{types.ScenarioCreate, types.ScenarioUpdate}
}
case "unique":
r.Unique = true
case "regexp":
if ls > 1 {
r.Regular = vs[1]
}
}
}
}
if field.PrimaryKey {
r.Unique = true
}
return r
}
// fieldScenario 字段Scenarios
func fieldScenario(index int, field *schema.Field) types.Scenarios {
var ss types.Scenarios
if v, ok := field.Tag.Lookup("scenarios"); ok {
v = strings.TrimSpace(v)
if v != "" {
ss = strings.Split(v, ";")
}
} else {
if field.PrimaryKey {
ss = []string{types.ScenarioList, types.ScenarioView, types.ScenarioExport}
} else if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
ss = []string{types.ScenarioList}
} else if field.Name == "DeletedAt" || field.Name == "Namespace" {
//不添加任何显示场景
ss = []string{}
} else {
if index < 10 {
//高级字段只配置一些简单的场景
ss = []string{types.ScenarioSearch, types.ScenarioList, types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
} else {
//高级字段只配置一些简单的场景
ss = []string{types.ScenarioCreate, types.ScenarioUpdate, types.ScenarioView, types.ScenarioExport}
}
}
}
return ss
}
// fieldPosition 字段的排序位置
func fieldPosition(field *schema.Field, i int) int {
s := field.Tag.Get("position")
n, _ := strconv.Atoi(s)
if n > 0 {
return n
}
return i + 100
}
// fieldAttribute 字段属性
func fieldAttribute(field *schema.Field) types.Attribute {
attr := types.Attribute{
Match: types.MatchFuzzy,
PrimaryKey: field.PrimaryKey,
DefaultValue: field.DefaultValue,
Readonly: []string{},
Disable: []string{},
Visible: make([]types.VisibleCondition, 0),
Values: make([]types.EnumValue, 0),
Live: types.LiveValue{},
}
if field.Name == "CreatedAt" || field.Name == "UpdatedAt" {
attr.EndOfNow = true
}
//赋值属性
props := field.Tag.Get("props")
if props != "" {
vs := strings.Split(props, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
sv := strings.TrimSpace(kv[1])
switch strings.ToLower(strings.TrimSpace(kv[0])) {
case "icon":
attr.Icon = sv
case "match":
if arrays.Exists(sv, matchEnums) {
attr.Match = sv
}
case "endofnow", "end_of_now":
if ok, _ := strconv.ParseBool(sv); ok {
attr.EndOfNow = true
}
case "invisible":
if ok, _ := strconv.ParseBool(sv); ok {
attr.Invisible = true
}
case "suffix":
attr.Suffix = sv
case "tag":
attr.Tag = sv
case "tooltip":
attr.Tooltip = sv
case "uploadurl", "uploaduri", "upload_url", "upload_uri":
attr.UploadUrl = sv
case "description":
attr.Description = sv
case "readonly":
bs := strings.Split(sv, ",")
for _, i := range bs {
if arrays.Exists(i, []string{types.ScenarioCreate, types.ScenarioUpdate}) {
attr.Readonly = append(attr.Readonly, i)
}
}
}
}
}
//live的赋值
live := field.Tag.Get("live")
if live != "" {
attr.Live.Enable = true
vs := strings.Split(live, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
switch kv[0] {
case "method":
attr.Live.Method = kv[1]
case "type":
if kv[1] == types.LiveTypeDropdown || kv[1] == types.LiveTypeCascader {
attr.Live.Type = kv[1]
} else {
attr.Live.Type = types.LiveTypeDropdown
}
case "url", "uri":
attr.Live.Url = kv[1]
case "columns":
attr.Live.Columns = strings.Split(kv[1], ",")
}
}
}
dropdown := field.Tag.Get("dropdown")
if dropdown != "" {
attr.DropdownOption = &types.DropdownOption{}
vs := strings.Split(dropdown, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) == 0 {
continue
}
switch kv[0] {
case "created":
attr.DropdownOption.Created = true
case "filterable":
attr.DropdownOption.Filterable = true
case "autocomplete":
attr.DropdownOption.Autocomplete = true
case "default_first":
attr.DropdownOption.DefaultFirst = true
}
}
}
//显示条件
conditions := field.Tag.Get("condition")
if conditions != "" {
vs := strings.Split(conditions, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
cond := types.VisibleCondition{
Column: kv[0],
Values: make([]any, 0),
}
vv := strings.Split(kv[1], ",")
for _, x := range vv {
x = strings.TrimSpace(x)
if x == "" {
continue
}
cond.Values = append(cond.Values, x)
}
attr.Visible = append(attr.Visible, cond)
}
}
//赋值枚举值
enumns := field.Tag.Get("enum")
if enumns != "" {
vs := strings.Split(enumns, ";")
for _, str := range vs {
kv := strings.SplitN(str, ":", 2)
if len(kv) != 2 {
continue
}
fv := types.EnumValue{Value: kv[0]}
//颜色分隔符
if pos := strings.IndexByte(kv[1], '#'); pos > -1 {
fv.Label = kv[1][:pos]
fv.Color = kv[1][pos:]
} else {
fv.Label = kv[1]
}
attr.Values = append(attr.Values, fv)
}
}
if !field.Creatable {
attr.Disable = append(attr.Disable, types.ScenarioCreate)
}
if !field.Updatable {
attr.Disable = append(attr.Disable, types.ScenarioUpdate)
}
attr.Tooltip = field.Comment
return attr
}
// autoMigrate 自动合并字段
func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (naming string, err error) {
var (
pos int
columnName string
columnIsExists bool
columnLabel string
schemas []*types.Schema
models []*types.Schema
stmt *gorm.Statement
)
stmt = cloneStmt(db)
if err = stmt.Parse(model); err != nil {
return
}
if schemas, err = GetSchemas(ctx, db, defaultDomain, module, stmt.Table); err != nil {
return
}
if len(schemas) > 0 {
pos = len(schemas)
}
models = make([]*types.Schema, 0)
for index, field := range stmt.Schema.Fields {
columnName = field.DBName
if columnName == "-" {
continue
}
if columnName == "" {
columnName = field.Name
}
columnIsExists = false
for _, sm := range schemas {
if sm.Column == columnName {
columnIsExists = true
break
}
}
if columnIsExists {
continue
}
columnLabel = field.Tag.Get("comment")
if columnLabel == "" {
columnLabel = fieldName(field.DBName)
}
isPrimaryKey := uint8(0)
if field.PrimaryKey {
isPrimaryKey = 1
}
schemaModel := &types.Schema{
Domain: defaultDomain,
ModuleName: module,
TableName: stmt.Table,
Enable: 1,
Column: columnName,
Label: columnLabel,
Type: strings.ToLower(dataTypeOf(field)),
Format: strings.ToLower(dataFormatOf(field)),
Native: fieldNative(field),
IsPrimaryKey: isPrimaryKey,
Rule: fieldRule(field),
Scenarios: fieldScenario(index, field),
Attribute: fieldAttribute(field),
Position: fieldPosition(field, pos),
}
//如果启用了在线调取接口功能那么设置一下字段的format格式
if schemaModel.Attribute.Live.Enable {
if schemaModel.Attribute.Live.Type != "" {
schemaModel.Format = schemaModel.Attribute.Live.Type
}
}
models = append(models, schemaModel)
pos++
}
if len(models) > 0 {
err = db.Create(models).Error
}
naming = stmt.Table
return
}
// SetHttpRouter 设置HTTP路由
func SetHttpRouter(router types.HttpRouter) {
httpRouter = router
}
// AutoMigrate 自动合并表的schema
func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) {
var (
opts *Options
table string
router types.HttpRouter
)
opts = &Options{}
for _, cb := range cbs {
cb(opts)
}
if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil {
return
}
//路由模块处理
modelValue := newModel(model, db, types.Naming{
Pluralize: inflector.Pluralize(table),
Singular: inflector.Singularize(table),
ModuleName: opts.moduleName,
TableName: table,
})
modelValue.hookMgr = hookMgr
modelValue.schemaLookup = VisibleSchemas
if opts.router != nil {
router = opts.router
}
if router == nil && httpRouter != nil {
router = httpRouter
}
if opts.urlPrefix != "" {
modelValue.urlPrefix = opts.urlPrefix
}
//路由绑定操作
if router != nil {
if modelValue.hasScenario(types.ScenarioList) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioList), modelValue.Search)
}
if modelValue.hasScenario(types.ScenarioCreate) {
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioCreate), modelValue.Create)
}
if modelValue.hasScenario(types.ScenarioUpdate) {
router.Handle(http.MethodPut, modelValue.Uri(types.ScenarioUpdate), modelValue.Update)
}
if modelValue.hasScenario(types.ScenarioDelete) {
router.Handle(http.MethodDelete, modelValue.Uri(types.ScenarioDelete), modelValue.Delete)
}
if modelValue.hasScenario(types.ScenarioView) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioView), modelValue.View)
}
if modelValue.hasScenario(types.ScenarioExport) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioExport), modelValue.Export)
}
if modelValue.hasScenario(types.ScenarioImport) {
router.Handle(http.MethodGet, modelValue.Uri(types.ScenarioImport), modelValue.Import)
router.Handle(http.MethodPost, modelValue.Uri(types.ScenarioImport), modelValue.Import)
}
}
if opts.writer != nil {
modelValue.response = opts.writer
}
if opts.formatter != nil {
modelValue.formatter = opts.formatter
}
modelValue.disableDomain = opts.disableDomain
modelEntities = append(modelEntities, modelValue)
return
}
// CloneSchemas 克隆schemas
func CloneSchemas(ctx context.Context, db *gorm.DB, domain string) (err error) {
var (
values []*types.Schema
schemas []*types.Schema
models []*types.Schema
)
tx := db.WithContext(ctx)
if err = tx.Where("domain=?", defaultDomain).Find(&values).Error; err != nil {
return fmt.Errorf("schema not found")
}
tx.Where("domain=?", domain).Find(&schemas)
hasSchemaFunc := func(values []*types.Schema, hack *types.Schema) bool {
for _, row := range values {
if row.ModuleName == hack.ModuleName && row.TableName == hack.TableName && row.Column == hack.Column {
return true
}
}
return false
}
models = make([]*types.Schema, 0)
for _, row := range values {
if !hasSchemaFunc(schemas, row) {
row.Id = 0
row.CreatedAt = 0
row.UpdatedAt = 0
row.Domain = domain
models = append(models, row)
}
}
if len(models) > 0 {
err = tx.Save(models).Error
}
return
}
// GetSchemas 获取schemas
func GetSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName string) ([]*types.Schema, error) {
var (
err error
values []*types.Schema
tx *gorm.DB
)
values = make([]*types.Schema, 0)
if domain == "" {
domain = defaultDomain
}
if moduleName == "" || tableName == "" {
return nil, gorm.ErrInvalidField
}
if ctx != nil {
tx = db.WithContext(ctx)
} else {
tx = db
}
err = tx.Where("domain=? AND module_name=? AND table_name=?", domain, moduleName, tableName).Order("position ASC").Find(&values).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
err = nil
}
return values, err
}
// VisibleSchemas 获取某个场景下面的schema
func VisibleSchemas(ctx context.Context, db *gorm.DB, domain, moduleName, tableName, scenario string) ([]*types.Schema, error) {
schemas, err := GetSchemas(ctx, db, domain, moduleName, tableName)
if err != nil {
return nil, err
}
result := make([]*types.Schema, 0, len(schemas))