477 lines
13 KiB
Go
477 lines
13 KiB
Go
package sharding
|
||
|
||
import (
|
||
"github.com/longbridgeapp/sqlparser"
|
||
sqlparserX "github.com/uole/sqlparser"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/callbacks"
|
||
"reflect"
|
||
"strings"
|
||
)
|
||
|
||
type Sharding struct {
|
||
UnionAll bool
|
||
QuoteChar byte
|
||
}
|
||
|
||
func (plugin *Sharding) Name() string {
|
||
return "gorm:sharding"
|
||
}
|
||
|
||
func (plugin *Sharding) Initialize(db *gorm.DB) (err error) {
|
||
if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil {
|
||
return
|
||
}
|
||
if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil {
|
||
return
|
||
}
|
||
if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil {
|
||
return
|
||
}
|
||
if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil {
|
||
return err
|
||
}
|
||
return
|
||
}
|
||
|
||
func (plugin *Sharding) Create(db *gorm.DB) {
|
||
var (
|
||
ok bool
|
||
scopeModel Model
|
||
refValue reflect.Value
|
||
modelValue any
|
||
)
|
||
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||
if db.Statement.ReflectValue.Len() > 0 {
|
||
refValue = db.Statement.ReflectValue.Index(0)
|
||
modelValue = refValue.Interface()
|
||
}
|
||
} else {
|
||
if db.Statement.Model != nil {
|
||
modelValue = db.Statement.Model
|
||
} else {
|
||
modelValue = db.Statement.ReflectValue.Interface()
|
||
}
|
||
}
|
||
if modelValue != nil {
|
||
if scopeModel, ok = modelValue.(Model); ok {
|
||
db.Table(scopeModel.ShardingTable(sceneCreate))
|
||
}
|
||
}
|
||
}
|
||
|
||
func (plugin *Sharding) Update(db *gorm.DB) {
|
||
var (
|
||
ok bool
|
||
scopeModel Model
|
||
refValue reflect.Value
|
||
modelValue any
|
||
)
|
||
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||
if db.Statement.ReflectValue.Len() > 0 {
|
||
refValue = db.Statement.ReflectValue.Index(0)
|
||
modelValue = refValue.Interface()
|
||
}
|
||
} else {
|
||
if db.Statement.Model != nil {
|
||
modelValue = db.Statement.Model
|
||
} else {
|
||
modelValue = db.Statement.ReflectValue.Interface()
|
||
}
|
||
}
|
||
if modelValue != nil {
|
||
if scopeModel, ok = modelValue.(Model); ok {
|
||
db.Table(scopeModel.ShardingTable(sceneUpdate))
|
||
}
|
||
}
|
||
}
|
||
|
||
func (plugin *Sharding) Delete(db *gorm.DB) {
|
||
var (
|
||
ok bool
|
||
scopeModel Model
|
||
refValue reflect.Value
|
||
modelValue any
|
||
)
|
||
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
|
||
if db.Statement.ReflectValue.Len() > 0 {
|
||
refValue = db.Statement.ReflectValue.Index(0)
|
||
modelValue = refValue.Interface()
|
||
}
|
||
} else {
|
||
if db.Statement.Model != nil {
|
||
modelValue = db.Statement.Model
|
||
} else {
|
||
modelValue = db.Statement.ReflectValue.Interface()
|
||
}
|
||
}
|
||
if modelValue != nil {
|
||
if scopeModel, ok = modelValue.(Model); ok {
|
||
db.Table(scopeModel.ShardingTable(sceneDelete))
|
||
}
|
||
}
|
||
}
|
||
|
||
func (plugin *Sharding) Query(db *gorm.DB) {
|
||
var (
|
||
err error
|
||
ok bool
|
||
shardingModel Model
|
||
modelValue any
|
||
tables []string
|
||
rawVars []any
|
||
refValue reflect.Value
|
||
tableName *sqlparser.TableName
|
||
selectStmt *sqlparser.SelectStatement
|
||
stmt sqlparser.Statement
|
||
parser *sqlparser.Parser
|
||
numOfTable int
|
||
orderByExpr []*sqlparser.OrderingTerm
|
||
limitExpr sqlparser.Expr
|
||
offsetExpr sqlparser.Expr
|
||
groupingExpr []sqlparser.Expr
|
||
havingExpr sqlparser.Expr
|
||
isCountStatement bool
|
||
countField string
|
||
)
|
||
if db.Statement.Model != nil {
|
||
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
||
} else {
|
||
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
||
}
|
||
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
||
refValue = reflect.Indirect(refValue)
|
||
}
|
||
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
||
elemType := refValue.Type().Elem()
|
||
if elemType.Kind() == reflect.Ptr {
|
||
elemType = elemType.Elem()
|
||
}
|
||
modelValue = reflect.New(elemType).Interface()
|
||
} else {
|
||
modelValue = refValue.Interface()
|
||
}
|
||
if shardingModel, ok = modelValue.(Model); !ok {
|
||
return
|
||
}
|
||
if db.Statement.SQL.Len() == 0 {
|
||
callbacks.BuildQuerySQL(db)
|
||
}
|
||
parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String()))
|
||
if stmt, err = parser.ParseStatement(); err != nil {
|
||
return
|
||
}
|
||
if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok {
|
||
return
|
||
}
|
||
tables = shardingModel.ShardingTables(&Scope{
|
||
db: db,
|
||
stmt: selectStmt,
|
||
})
|
||
numOfTable = len(tables)
|
||
if numOfTable <= 1 {
|
||
return
|
||
}
|
||
rawVars = make([]any, 0, len(db.Statement.Vars))
|
||
for _, v := range db.Statement.Vars {
|
||
rawVars = append(rawVars, v)
|
||
}
|
||
//是否是查询count语句
|
||
//如果不是count的语句,添加order和group的支持
|
||
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
||
if v == true {
|
||
isCountStatement = true
|
||
}
|
||
}
|
||
if !isCountStatement && len(*selectStmt.Columns) == 1 {
|
||
for _, column := range *selectStmt.Columns {
|
||
if expr, ok := column.Expr.(*sqlparser.Call); ok {
|
||
if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword {
|
||
isCountStatement = true
|
||
countField = expr.String()
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(selectStmt.OrderBy) > 0 {
|
||
orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy))
|
||
for _, row := range selectStmt.OrderBy {
|
||
orderByExpr = append(orderByExpr, row)
|
||
}
|
||
selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0)
|
||
}
|
||
if len(selectStmt.GroupingElements) > 0 {
|
||
groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements))
|
||
for _, row := range selectStmt.GroupingElements {
|
||
groupingExpr = append(groupingExpr, row)
|
||
}
|
||
if selectStmt.HavingCondition != nil {
|
||
havingExpr = selectStmt.HavingCondition
|
||
}
|
||
}
|
||
if selectStmt.Limit != nil {
|
||
limitExpr = selectStmt.Limit
|
||
selectStmt.Limit = nil
|
||
}
|
||
if selectStmt.Offset != nil {
|
||
offsetExpr = selectStmt.Offset
|
||
selectStmt.Offset = nil
|
||
}
|
||
db.Statement.SQL.Reset()
|
||
if isCountStatement {
|
||
db.Statement.SQL.WriteString("SELECT SUM(")
|
||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||
db.Statement.SQL.WriteString(strings.Trim(countField, "`"))
|
||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||
db.Statement.SQL.WriteString(") FROM (")
|
||
} else {
|
||
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||
}
|
||
for i, name := range tables {
|
||
db.Statement.SQL.WriteByte('(')
|
||
if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok {
|
||
tableName.Name.Name = name
|
||
}
|
||
db.Statement.SQL.WriteString(selectStmt.String())
|
||
db.Statement.SQL.WriteByte(')')
|
||
if i < numOfTable-1 {
|
||
if plugin.UnionAll {
|
||
db.Statement.SQL.WriteString(" UNION ALL ")
|
||
} else {
|
||
db.Statement.SQL.WriteString(" UNION ")
|
||
}
|
||
}
|
||
if i > 0 {
|
||
//copy vars
|
||
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
||
}
|
||
}
|
||
db.Statement.SQL.WriteString(") tbl ")
|
||
if !isCountStatement {
|
||
if len(groupingExpr) > 0 {
|
||
db.Statement.SQL.WriteString(" GROUP BY ")
|
||
for i, expr := range groupingExpr {
|
||
if i != 0 {
|
||
db.Statement.SQL.WriteString(", ")
|
||
}
|
||
db.Statement.SQL.WriteString(expr.String())
|
||
}
|
||
|
||
if havingExpr != nil {
|
||
db.Statement.SQL.WriteString(" HAVING ")
|
||
db.Statement.SQL.WriteString(havingExpr.String())
|
||
}
|
||
}
|
||
|
||
if orderByExpr != nil && len(orderByExpr) > 0 {
|
||
db.Statement.SQL.WriteString(" ORDER BY ")
|
||
for i, term := range orderByExpr {
|
||
if i != 0 {
|
||
db.Statement.SQL.WriteString(", ")
|
||
}
|
||
db.Statement.SQL.WriteString(term.String())
|
||
}
|
||
}
|
||
if limitExpr != nil {
|
||
db.Statement.SQL.WriteString(" LIMIT ")
|
||
db.Statement.SQL.WriteString(limitExpr.String())
|
||
if offsetExpr != nil {
|
||
db.Statement.SQL.WriteString(" OFFSET ")
|
||
db.Statement.SQL.WriteString(offsetExpr.String())
|
||
}
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (plugin *Sharding) QueryX(db *gorm.DB) {
|
||
var (
|
||
err error
|
||
ok bool
|
||
shardingModel Model
|
||
modelValue any
|
||
tables []string
|
||
rawVars []any
|
||
refValue reflect.Value
|
||
selectStmt *sqlparserX.Select
|
||
stmt sqlparserX.Statement
|
||
numOfTable int
|
||
isCountStatement bool
|
||
isPureCountStatement bool
|
||
trackedBuffer *sqlparserX.TrackedBuffer
|
||
funcExpr *sqlparserX.FuncExpr
|
||
aliasedExpr *sqlparserX.AliasedExpr
|
||
orderByExpr sqlparserX.OrderBy
|
||
groupByExpr sqlparserX.GroupBy
|
||
havingExpr *sqlparserX.Where
|
||
limitExpr *sqlparserX.Limit
|
||
)
|
||
if db.Statement.Model != nil {
|
||
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
|
||
} else {
|
||
if !db.Statement.ReflectValue.IsValid() {
|
||
return
|
||
}
|
||
refValue = reflect.New(db.Statement.ReflectValue.Type())
|
||
}
|
||
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
|
||
refValue = reflect.Indirect(refValue)
|
||
}
|
||
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
|
||
elemType := refValue.Type().Elem()
|
||
if elemType.Kind() == reflect.Ptr {
|
||
elemType = elemType.Elem()
|
||
}
|
||
modelValue = reflect.New(elemType).Interface()
|
||
} else {
|
||
modelValue = refValue.Interface()
|
||
}
|
||
if shardingModel, ok = modelValue.(Model); !ok {
|
||
return
|
||
}
|
||
if db.Statement.SQL.Len() == 0 {
|
||
callbacks.BuildQuerySQL(db)
|
||
}
|
||
if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil {
|
||
return
|
||
}
|
||
if selectStmt, ok = stmt.(*sqlparserX.Select); !ok {
|
||
return
|
||
}
|
||
tables = shardingModel.ShardingTables(&Scope{
|
||
db: db,
|
||
stmtX: selectStmt,
|
||
})
|
||
numOfTable = len(tables)
|
||
if numOfTable <= 1 {
|
||
return
|
||
}
|
||
// 保存值
|
||
rawVars = make([]any, 0, len(db.Statement.Vars))
|
||
for _, v := range db.Statement.Vars {
|
||
rawVars = append(rawVars, v)
|
||
}
|
||
// 替换语句
|
||
if selectStmt.OrderBy != nil {
|
||
orderByExpr = selectStmt.OrderBy
|
||
selectStmt.OrderBy = nil
|
||
}
|
||
if selectStmt.GroupBy != nil {
|
||
groupByExpr = selectStmt.GroupBy
|
||
//selectStmt.GroupBy = nil
|
||
}
|
||
if selectStmt.Having != nil {
|
||
havingExpr = selectStmt.Having
|
||
//selectStmt.Having = nil
|
||
}
|
||
if selectStmt.Limit != nil {
|
||
limitExpr = selectStmt.Limit
|
||
selectStmt.Limit = nil
|
||
}
|
||
// 检查是否为COUNT语句
|
||
//如果不是count的语句,添加order和group的支持
|
||
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
|
||
//这里处理的是报表的情况,再报表里面需要重写COUNT语句才能获取到正确的数量
|
||
if v == true {
|
||
isCountStatement = true
|
||
isPureCountStatement = true
|
||
}
|
||
}
|
||
//常规的COUNT逻辑
|
||
if !isCountStatement && len(selectStmt.SelectExprs) == 1 {
|
||
for _, expr := range selectStmt.SelectExprs {
|
||
if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok {
|
||
if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok {
|
||
if funcExpr.Name.EqualString("count") {
|
||
isCountStatement = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// 重写SQL
|
||
db.Statement.SQL.Reset()
|
||
trackedBuffer = sqlparserX.NewTrackedBuffer(nil)
|
||
|
||
if isCountStatement {
|
||
if isPureCountStatement {
|
||
db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (")
|
||
} else {
|
||
db.Statement.SQL.WriteString("SELECT SUM(")
|
||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||
db.Statement.SQL.WriteString(strings.Trim("count(*)", "`"))
|
||
db.Statement.SQL.WriteByte(plugin.QuoteChar)
|
||
db.Statement.SQL.WriteString(") FROM (")
|
||
}
|
||
} else {
|
||
if bs, ok := modelValue.(SelectBuilder); ok {
|
||
columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs)
|
||
if len(columns) > 0 {
|
||
db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (")
|
||
} else {
|
||
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||
}
|
||
} else {
|
||
db.Statement.SQL.WriteString("SELECT * FROM (")
|
||
}
|
||
}
|
||
for i, name := range tables {
|
||
trackedBuffer.Reset()
|
||
db.Statement.SQL.WriteByte('(')
|
||
//赋值新的表名称
|
||
selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{
|
||
Expr: sqlparserX.TableName{
|
||
Name: sqlparserX.NewTableIdent(name),
|
||
},
|
||
}}
|
||
selectStmt.Format(trackedBuffer)
|
||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||
db.Statement.SQL.WriteByte(')')
|
||
if i < numOfTable-1 {
|
||
if plugin.UnionAll {
|
||
db.Statement.SQL.WriteString(" UNION ALL ")
|
||
} else {
|
||
db.Statement.SQL.WriteString(" UNION ")
|
||
}
|
||
}
|
||
if i > 0 {
|
||
//copy vars
|
||
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
|
||
}
|
||
}
|
||
db.Statement.SQL.WriteString(") tbl ")
|
||
if !isCountStatement {
|
||
//node.GroupBy, node.Having, node.OrderBy, node.Limit
|
||
if groupByExpr != nil {
|
||
trackedBuffer.Reset()
|
||
groupByExpr.Format(trackedBuffer)
|
||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||
}
|
||
if havingExpr != nil {
|
||
trackedBuffer.Reset()
|
||
havingExpr.Format(trackedBuffer)
|
||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||
}
|
||
if orderByExpr != nil {
|
||
trackedBuffer.Reset()
|
||
orderByExpr.Format(trackedBuffer)
|
||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||
}
|
||
if limitExpr != nil {
|
||
trackedBuffer.Reset()
|
||
limitExpr.Format(trackedBuffer)
|
||
db.Statement.SQL.WriteString(trackedBuffer.String())
|
||
}
|
||
}
|
||
}
|
||
|
||
func New() *Sharding {
|
||
return &Sharding{
|
||
UnionAll: true,
|
||
QuoteChar: '`',
|
||
}
|
||
}
|