477 lines
13 KiB
Go
477 lines
13 KiB
Go
|
package sharding
|
|||
|
|
|||
|
import (
|
|||
|
"github.com/longbridgeapp/sqlparser"
|
|||
|
sqlparserX "github.com/uole/sqlparser"
|
|||
|
"gorm.io/gorm"
|
|||
|
"gorm.io/gorm/callbacks"
|
|||
|
"reflect"
|
|||
|
"strings"
|
|||
|
)
|
|||
|
|
|||
|
type Sharding struct {
|
|||
|
UnionAll bool
|
|||
|
QuoteChar byte
|
|||
|
}
|
|||
|
|
|||
|
func (plugin *Sharding) Name() string {
|
|||
|
return "gorm:sharding"
|
|||
|
}
|
|||
|
|
|||
|
func (plugin *Sharding) Initialize(db *gorm.DB) (err error) {
|
|||
|
if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func (plugin *Sharding) Create(db *gorm.DB) {
|
|||
|
var (
|
|||
|
ok bool
|
|||
|
scopeModel Model
|
|||
|
refValue reflect.Value
|
|||
|
modelValue any
|
|||
|
)
|
|||
|
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
|||
|
if db.Statement.ReflectValue.Len() > 0 {
|
|||
|
refValue = db.Statement.ReflectValue.Index(0)
|
|||
|
modelValue = refValue.Interface()
|
|||
|
}
|
|||
|
} else {
|
|||
|
if db.Statement.Model != nil {
|
|||
|
modelValue = db.Statement.Model
|
|||
|
} else {
|
|||
|
modelValue = db.Statement.ReflectValue.Interface()
|
|||
|
}
|
|||
|
}
|
|||
|
if modelValue != nil {
|
|||
|
if scopeModel, ok = modelValue.(Model); ok {
|
|||
|
db.Table(scopeModel.ShardingTable(sceneCreate))
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func (plugin *Sharding) Update(db *gorm.DB) {
|
|||
|
var (
|
|||
|
ok bool
|
|||
|
scopeModel Model
|
|||
|
refValue reflect.Value
|
|||
|
modelValue any
|
|||
|
)
|
|||
|
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
|||
|
if db.Statement.ReflectValue.Len() > 0 {
|
|||
|
refValue = db.Statement.ReflectValue.Index(0)
|
|||
|
modelValue = refValue.Interface()
|
|||
|
}
|
|||
|
} else {
|
|||
|
if db.Statement.Model != nil {
|
|||
|
modelValue = db.Statement.Model
|
|||
|
} else {
|
|||
|
modelValue = db.Statement.ReflectValue.Interface()
|
|||
|
}
|
|||
|
}
|
|||
|
if modelValue != nil {
|
|||
|
if scopeModel, ok = modelValue.(Model); ok {
|
|||
|
db.Table(scopeModel.ShardingTable(sceneUpdate))
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func (plugin *Sharding) Delete(db *gorm.DB) {
|
|||
|
var (
|
|||
|
ok bool
|
|||
|
scopeModel Model
|
|||
|
refValue reflect.Value
|
|||
|
modelValue any
|
|||
|
)
|
|||
|
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
|||
|
if db.Statement.ReflectValue.Len() > 0 {
|
|||
|
refValue = db.Statement.ReflectValue.Index(0)
|
|||
|
modelValue = refValue.Interface()
|
|||
|
}
|
|||
|
} else {
|
|||
|
if db.Statement.Model != nil {
|
|||
|
modelValue = db.Statement.Model
|
|||
|
} else {
|
|||
|
modelValue = db.Statement.ReflectValue.Interface()
|
|||
|
}
|
|||
|
}
|
|||
|
if modelValue != nil {
|
|||
|
if scopeModel, ok = modelValue.(Model); ok {
|
|||
|
db.Table(scopeModel.ShardingTable(sceneDelete))
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func (plugin *Sharding) Query(db *gorm.DB) {
|
|||
|
var (
|
|||
|
err error
|
|||
|
ok bool
|
|||
|
shardingModel Model
|
|||
|
modelValue any
|
|||
|
tables []string
|
|||
|
rawVars []any
|
|||
|
refValue reflect.Value
|
|||
|
tableName *sqlparser.TableName
|
|||
|
selectStmt *sqlparser.SelectStatement
|
|||
|
stmt sqlparser.Statement
|
|||
|
parser *sqlparser.Parser
|
|||
|
numOfTable int
|
|||
|
orderByExpr []*sqlparser.OrderingTerm
|
|||
|
limitExpr sqlparser.Expr
|
|||
|
offsetExpr sqlparser.Expr
|
|||
|
groupingExpr []sqlparser.Expr
|
|||
|
havingExpr sqlparser.Expr
|
|||
|
isCountStatement bool
|
|||
|
countField string
|
|||
|
)
|
|||
|
if db.Statement.Model != nil {
|
|||
|
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
|||
|
} else {
|
|||
|
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
|||
|
}
|
|||
|
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
|||
|
refValue = reflect.Indirect(refValue)
|
|||
|
}
|
|||
|
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
|||
|
elemType := refValue.Type().Elem()
|
|||
|
if elemType.Kind() == reflect.Ptr {
|
|||
|
elemType = elemType.Elem()
|
|||
|
}
|
|||
|
modelValue = reflect.New(elemType).Interface()
|
|||
|
} else {
|
|||
|
modelValue = refValue.Interface()
|
|||
|
}
|
|||
|
if shardingModel, ok = modelValue.(Model); !ok {
|
|||
|
return
|
|||
|
}
|
|||
|
if db.Statement.SQL.Len() == 0 {
|
|||
|
callbacks.BuildQuerySQL(db)
|
|||
|
}
|
|||
|
parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String()))
|
|||
|
if stmt, err = parser.ParseStatement(); err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok {
|
|||
|
return
|
|||
|
}
|
|||
|
tables = shardingModel.ShardingTables(&Scope{
|
|||
|
db: db,
|
|||
|
stmt: selectStmt,
|
|||
|
})
|
|||
|
numOfTable = len(tables)
|
|||
|
if numOfTable <= 1 {
|
|||
|
return
|
|||
|
}
|
|||
|
rawVars = make([]any, 0, len(db.Statement.Vars))
|
|||
|
for _, v := range db.Statement.Vars {
|
|||
|
rawVars = append(rawVars, v)
|
|||
|
}
|
|||
|
//是否是查询count语句
|
|||
|
//如果不是count的语句,添加order和group的支持
|
|||
|
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
|||
|
if v == true {
|
|||
|
isCountStatement = true
|
|||
|
}
|
|||
|
}
|
|||
|
if !isCountStatement && len(*selectStmt.Columns) == 1 {
|
|||
|
for _, column := range *selectStmt.Columns {
|
|||
|
if expr, ok := column.Expr.(*sqlparser.Call); ok {
|
|||
|
if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword {
|
|||
|
isCountStatement = true
|
|||
|
countField = expr.String()
|
|||
|
break
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
if len(selectStmt.OrderBy) > 0 {
|
|||
|
orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy))
|
|||
|
for _, row := range selectStmt.OrderBy {
|
|||
|
orderByExpr = append(orderByExpr, row)
|
|||
|
}
|
|||
|
selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0)
|
|||
|
}
|
|||
|
if len(selectStmt.GroupingElements) > 0 {
|
|||
|
groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements))
|
|||
|
for _, row := range selectStmt.GroupingElements {
|
|||
|
groupingExpr = append(groupingExpr, row)
|
|||
|
}
|
|||
|
if selectStmt.HavingCondition != nil {
|
|||
|
havingExpr = selectStmt.HavingCondition
|
|||
|
}
|
|||
|
}
|
|||
|
if selectStmt.Limit != nil {
|
|||
|
limitExpr = selectStmt.Limit
|
|||
|
selectStmt.Limit = nil
|
|||
|
}
|
|||
|
if selectStmt.Offset != nil {
|
|||
|
offsetExpr = selectStmt.Offset
|
|||
|
selectStmt.Offset = nil
|
|||
|
}
|
|||
|
db.Statement.SQL.Reset()
|
|||
|
if isCountStatement {
|
|||
|
db.Statement.SQL.WriteString("SELECT SUM(")
|
|||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
|||
|
db.Statement.SQL.WriteString(strings.Trim(countField, "`"))
|
|||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
|||
|
db.Statement.SQL.WriteString(") FROM (")
|
|||
|
} else {
|
|||
|
db.Statement.SQL.WriteString("SELECT * FROM (")
|
|||
|
}
|
|||
|
for i, name := range tables {
|
|||
|
db.Statement.SQL.WriteByte('(')
|
|||
|
if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok {
|
|||
|
tableName.Name.Name = name
|
|||
|
}
|
|||
|
db.Statement.SQL.WriteString(selectStmt.String())
|
|||
|
db.Statement.SQL.WriteByte(')')
|
|||
|
if i < numOfTable-1 {
|
|||
|
if plugin.UnionAll {
|
|||
|
db.Statement.SQL.WriteString(" UNION ALL ")
|
|||
|
} else {
|
|||
|
db.Statement.SQL.WriteString(" UNION ")
|
|||
|
}
|
|||
|
}
|
|||
|
if i > 0 {
|
|||
|
//copy vars
|
|||
|
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
|||
|
}
|
|||
|
}
|
|||
|
db.Statement.SQL.WriteString(") tbl ")
|
|||
|
if !isCountStatement {
|
|||
|
if len(groupingExpr) > 0 {
|
|||
|
db.Statement.SQL.WriteString(" GROUP BY ")
|
|||
|
for i, expr := range groupingExpr {
|
|||
|
if i != 0 {
|
|||
|
db.Statement.SQL.WriteString(", ")
|
|||
|
}
|
|||
|
db.Statement.SQL.WriteString(expr.String())
|
|||
|
}
|
|||
|
|
|||
|
if havingExpr != nil {
|
|||
|
db.Statement.SQL.WriteString(" HAVING ")
|
|||
|
db.Statement.SQL.WriteString(havingExpr.String())
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
if orderByExpr != nil && len(orderByExpr) > 0 {
|
|||
|
db.Statement.SQL.WriteString(" ORDER BY ")
|
|||
|
for i, term := range orderByExpr {
|
|||
|
if i != 0 {
|
|||
|
db.Statement.SQL.WriteString(", ")
|
|||
|
}
|
|||
|
db.Statement.SQL.WriteString(term.String())
|
|||
|
}
|
|||
|
}
|
|||
|
if limitExpr != nil {
|
|||
|
db.Statement.SQL.WriteString(" LIMIT ")
|
|||
|
db.Statement.SQL.WriteString(limitExpr.String())
|
|||
|
if offsetExpr != nil {
|
|||
|
db.Statement.SQL.WriteString(" OFFSET ")
|
|||
|
db.Statement.SQL.WriteString(offsetExpr.String())
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func (plugin *Sharding) QueryX(db *gorm.DB) {
|
|||
|
var (
|
|||
|
err error
|
|||
|
ok bool
|
|||
|
shardingModel Model
|
|||
|
modelValue any
|
|||
|
tables []string
|
|||
|
rawVars []any
|
|||
|
refValue reflect.Value
|
|||
|
selectStmt *sqlparserX.Select
|
|||
|
stmt sqlparserX.Statement
|
|||
|
numOfTable int
|
|||
|
isCountStatement bool
|
|||
|
isPureCountStatement bool
|
|||
|
trackedBuffer *sqlparserX.TrackedBuffer
|
|||
|
funcExpr *sqlparserX.FuncExpr
|
|||
|
aliasedExpr *sqlparserX.AliasedExpr
|
|||
|
orderByExpr sqlparserX.OrderBy
|
|||
|
groupByExpr sqlparserX.GroupBy
|
|||
|
havingExpr *sqlparserX.Where
|
|||
|
limitExpr *sqlparserX.Limit
|
|||
|
)
|
|||
|
if db.Statement.Model != nil {
|
|||
|
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
|||
|
} else {
|
|||
|
if !db.Statement.ReflectValue.IsValid() {
|
|||
|
return
|
|||
|
}
|
|||
|
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
|||
|
}
|
|||
|
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
|||
|
refValue = reflect.Indirect(refValue)
|
|||
|
}
|
|||
|
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
|||
|
elemType := refValue.Type().Elem()
|
|||
|
if elemType.Kind() == reflect.Ptr {
|
|||
|
elemType = elemType.Elem()
|
|||
|
}
|
|||
|
modelValue = reflect.New(elemType).Interface()
|
|||
|
} else {
|
|||
|
modelValue = refValue.Interface()
|
|||
|
}
|
|||
|
if shardingModel, ok = modelValue.(Model); !ok {
|
|||
|
return
|
|||
|
}
|
|||
|
if db.Statement.SQL.Len() == 0 {
|
|||
|
callbacks.BuildQuerySQL(db)
|
|||
|
}
|
|||
|
if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
if selectStmt, ok = stmt.(*sqlparserX.Select); !ok {
|
|||
|
return
|
|||
|
}
|
|||
|
tables = shardingModel.ShardingTables(&Scope{
|
|||
|
db: db,
|
|||
|
stmtX: selectStmt,
|
|||
|
})
|
|||
|
numOfTable = len(tables)
|
|||
|
if numOfTable <= 1 {
|
|||
|
return
|
|||
|
}
|
|||
|
// 保存值
|
|||
|
rawVars = make([]any, 0, len(db.Statement.Vars))
|
|||
|
for _, v := range db.Statement.Vars {
|
|||
|
rawVars = append(rawVars, v)
|
|||
|
}
|
|||
|
// 替换语句
|
|||
|
if selectStmt.OrderBy != nil {
|
|||
|
orderByExpr = selectStmt.OrderBy
|
|||
|
selectStmt.OrderBy = nil
|
|||
|
}
|
|||
|
if selectStmt.GroupBy != nil {
|
|||
|
groupByExpr = selectStmt.GroupBy
|
|||
|
//selectStmt.GroupBy = nil
|
|||
|
}
|
|||
|
if selectStmt.Having != nil {
|
|||
|
havingExpr = selectStmt.Having
|
|||
|
//selectStmt.Having = nil
|
|||
|
}
|
|||
|
if selectStmt.Limit != nil {
|
|||
|
limitExpr = selectStmt.Limit
|
|||
|
selectStmt.Limit = nil
|
|||
|
}
|
|||
|
// 检查是否为COUNT语句
|
|||
|
//如果不是count的语句,添加order和group的支持
|
|||
|
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
|||
|
//这里处理的是报表的情况,再报表里面需要重写COUNT语句才能获取到正确的数量
|
|||
|
if v == true {
|
|||
|
isCountStatement = true
|
|||
|
isPureCountStatement = true
|
|||
|
}
|
|||
|
}
|
|||
|
//常规的COUNT逻辑
|
|||
|
if !isCountStatement && len(selectStmt.SelectExprs) == 1 {
|
|||
|
for _, expr := range selectStmt.SelectExprs {
|
|||
|
if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok {
|
|||
|
if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok {
|
|||
|
if funcExpr.Name.EqualString("count") {
|
|||
|
isCountStatement = true
|
|||
|
break
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
// 重写SQL
|
|||
|
db.Statement.SQL.Reset()
|
|||
|
trackedBuffer = sqlparserX.NewTrackedBuffer(nil)
|
|||
|
|
|||
|
if isCountStatement {
|
|||
|
if isPureCountStatement {
|
|||
|
db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (")
|
|||
|
} else {
|
|||
|
db.Statement.SQL.WriteString("SELECT SUM(")
|
|||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
|||
|
db.Statement.SQL.WriteString(strings.Trim("count(*)", "`"))
|
|||
|
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
|||
|
db.Statement.SQL.WriteString(") FROM (")
|
|||
|
}
|
|||
|
} else {
|
|||
|
if bs, ok := modelValue.(SelectBuilder); ok {
|
|||
|
columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs)
|
|||
|
if len(columns) > 0 {
|
|||
|
db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (")
|
|||
|
} else {
|
|||
|
db.Statement.SQL.WriteString("SELECT * FROM (")
|
|||
|
}
|
|||
|
} else {
|
|||
|
db.Statement.SQL.WriteString("SELECT * FROM (")
|
|||
|
}
|
|||
|
}
|
|||
|
for i, name := range tables {
|
|||
|
trackedBuffer.Reset()
|
|||
|
db.Statement.SQL.WriteByte('(')
|
|||
|
//赋值新的表名称
|
|||
|
selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{
|
|||
|
Expr: sqlparserX.TableName{
|
|||
|
Name: sqlparserX.NewTableIdent(name),
|
|||
|
},
|
|||
|
}}
|
|||
|
selectStmt.Format(trackedBuffer)
|
|||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
|||
|
db.Statement.SQL.WriteByte(')')
|
|||
|
if i < numOfTable-1 {
|
|||
|
if plugin.UnionAll {
|
|||
|
db.Statement.SQL.WriteString(" UNION ALL ")
|
|||
|
} else {
|
|||
|
db.Statement.SQL.WriteString(" UNION ")
|
|||
|
}
|
|||
|
}
|
|||
|
if i > 0 {
|
|||
|
//copy vars
|
|||
|
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
|||
|
}
|
|||
|
}
|
|||
|
db.Statement.SQL.WriteString(") tbl ")
|
|||
|
if !isCountStatement {
|
|||
|
//node.GroupBy, node.Having, node.OrderBy, node.Limit
|
|||
|
if groupByExpr != nil {
|
|||
|
trackedBuffer.Reset()
|
|||
|
groupByExpr.Format(trackedBuffer)
|
|||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
|||
|
}
|
|||
|
if havingExpr != nil {
|
|||
|
trackedBuffer.Reset()
|
|||
|
havingExpr.Format(trackedBuffer)
|
|||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
|||
|
}
|
|||
|
if orderByExpr != nil {
|
|||
|
trackedBuffer.Reset()
|
|||
|
orderByExpr.Format(trackedBuffer)
|
|||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
|||
|
}
|
|||
|
if limitExpr != nil {
|
|||
|
trackedBuffer.Reset()
|
|||
|
limitExpr.Format(trackedBuffer)
|
|||
|
db.Statement.SQL.WriteString(trackedBuffer.String())
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func New() *Sharding {
|
|||
|
return &Sharding{
|
|||
|
UnionAll: true,
|
|||
|
QuoteChar: '`',
|
|||
|
}
|
|||
|
}
|