rest/plugins/sharding/sharding.go

477 lines
13 KiB
Go
Raw Permalink Normal View History

2024-12-11 17:29:01 +08:00
package sharding
import (
"github.com/longbridgeapp/sqlparser"
sqlparserX "github.com/uole/sqlparser"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"reflect"
"strings"
)
type Sharding struct {
UnionAll bool
QuoteChar byte
}
func (plugin *Sharding) Name() string {
return "gorm:sharding"
}
func (plugin *Sharding) Initialize(db *gorm.DB) (err error) {
if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil {
return
}
if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil {
return
}
if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil {
return
}
if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil {
return err
}
return
}
func (plugin *Sharding) Create(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneCreate))
}
}
}
func (plugin *Sharding) Update(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneUpdate))
}
}
}
func (plugin *Sharding) Delete(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneDelete))
}
}
}
func (plugin *Sharding) Query(db *gorm.DB) {
var (
err error
ok bool
shardingModel Model
modelValue any
tables []string
rawVars []any
refValue reflect.Value
tableName *sqlparser.TableName
selectStmt *sqlparser.SelectStatement
stmt sqlparser.Statement
parser *sqlparser.Parser
numOfTable int
orderByExpr []*sqlparser.OrderingTerm
limitExpr sqlparser.Expr
offsetExpr sqlparser.Expr
groupingExpr []sqlparser.Expr
havingExpr sqlparser.Expr
isCountStatement bool
countField string
)
if db.Statement.Model != nil {
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
} else {
refValue = reflect.New(db.Statement.ReflectValue.Type())
}
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
refValue = reflect.Indirect(refValue)
}
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
elemType := refValue.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
modelValue = reflect.New(elemType).Interface()
} else {
modelValue = refValue.Interface()
}
if shardingModel, ok = modelValue.(Model); !ok {
return
}
if db.Statement.SQL.Len() == 0 {
callbacks.BuildQuerySQL(db)
}
parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String()))
if stmt, err = parser.ParseStatement(); err != nil {
return
}
if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok {
return
}
tables = shardingModel.ShardingTables(&Scope{
db: db,
stmt: selectStmt,
})
numOfTable = len(tables)
if numOfTable <= 1 {
return
}
rawVars = make([]any, 0, len(db.Statement.Vars))
for _, v := range db.Statement.Vars {
rawVars = append(rawVars, v)
}
//是否是查询count语句
//如果不是count的语句添加order和group的支持
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
if v == true {
isCountStatement = true
}
}
if !isCountStatement && len(*selectStmt.Columns) == 1 {
for _, column := range *selectStmt.Columns {
if expr, ok := column.Expr.(*sqlparser.Call); ok {
if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword {
isCountStatement = true
countField = expr.String()
break
}
}
}
}
if len(selectStmt.OrderBy) > 0 {
orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy))
for _, row := range selectStmt.OrderBy {
orderByExpr = append(orderByExpr, row)
}
selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0)
}
if len(selectStmt.GroupingElements) > 0 {
groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements))
for _, row := range selectStmt.GroupingElements {
groupingExpr = append(groupingExpr, row)
}
if selectStmt.HavingCondition != nil {
havingExpr = selectStmt.HavingCondition
}
}
if selectStmt.Limit != nil {
limitExpr = selectStmt.Limit
selectStmt.Limit = nil
}
if selectStmt.Offset != nil {
offsetExpr = selectStmt.Offset
selectStmt.Offset = nil
}
db.Statement.SQL.Reset()
if isCountStatement {
db.Statement.SQL.WriteString("SELECT SUM(")
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(strings.Trim(countField, "`"))
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(") FROM (")
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
for i, name := range tables {
db.Statement.SQL.WriteByte('(')
if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok {
tableName.Name.Name = name
}
db.Statement.SQL.WriteString(selectStmt.String())
db.Statement.SQL.WriteByte(')')
if i < numOfTable-1 {
if plugin.UnionAll {
db.Statement.SQL.WriteString(" UNION ALL ")
} else {
db.Statement.SQL.WriteString(" UNION ")
}
}
if i > 0 {
//copy vars
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
}
}
db.Statement.SQL.WriteString(") tbl ")
if !isCountStatement {
if len(groupingExpr) > 0 {
db.Statement.SQL.WriteString(" GROUP BY ")
for i, expr := range groupingExpr {
if i != 0 {
db.Statement.SQL.WriteString(", ")
}
db.Statement.SQL.WriteString(expr.String())
}
if havingExpr != nil {
db.Statement.SQL.WriteString(" HAVING ")
db.Statement.SQL.WriteString(havingExpr.String())
}
}
if orderByExpr != nil && len(orderByExpr) > 0 {
db.Statement.SQL.WriteString(" ORDER BY ")
for i, term := range orderByExpr {
if i != 0 {
db.Statement.SQL.WriteString(", ")
}
db.Statement.SQL.WriteString(term.String())
}
}
if limitExpr != nil {
db.Statement.SQL.WriteString(" LIMIT ")
db.Statement.SQL.WriteString(limitExpr.String())
if offsetExpr != nil {
db.Statement.SQL.WriteString(" OFFSET ")
db.Statement.SQL.WriteString(offsetExpr.String())
}
}
}
return
}
func (plugin *Sharding) QueryX(db *gorm.DB) {
var (
err error
ok bool
shardingModel Model
modelValue any
tables []string
rawVars []any
refValue reflect.Value
selectStmt *sqlparserX.Select
stmt sqlparserX.Statement
numOfTable int
isCountStatement bool
isPureCountStatement bool
trackedBuffer *sqlparserX.TrackedBuffer
funcExpr *sqlparserX.FuncExpr
aliasedExpr *sqlparserX.AliasedExpr
orderByExpr sqlparserX.OrderBy
groupByExpr sqlparserX.GroupBy
havingExpr *sqlparserX.Where
limitExpr *sqlparserX.Limit
)
if db.Statement.Model != nil {
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
} else {
if !db.Statement.ReflectValue.IsValid() {
return
}
refValue = reflect.New(db.Statement.ReflectValue.Type())
}
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
refValue = reflect.Indirect(refValue)
}
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
elemType := refValue.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
modelValue = reflect.New(elemType).Interface()
} else {
modelValue = refValue.Interface()
}
if shardingModel, ok = modelValue.(Model); !ok {
return
}
if db.Statement.SQL.Len() == 0 {
callbacks.BuildQuerySQL(db)
}
if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil {
return
}
if selectStmt, ok = stmt.(*sqlparserX.Select); !ok {
return
}
tables = shardingModel.ShardingTables(&Scope{
db: db,
stmtX: selectStmt,
})
numOfTable = len(tables)
if numOfTable <= 1 {
return
}
// 保存值
rawVars = make([]any, 0, len(db.Statement.Vars))
for _, v := range db.Statement.Vars {
rawVars = append(rawVars, v)
}
// 替换语句
if selectStmt.OrderBy != nil {
orderByExpr = selectStmt.OrderBy
selectStmt.OrderBy = nil
}
if selectStmt.GroupBy != nil {
groupByExpr = selectStmt.GroupBy
//selectStmt.GroupBy = nil
}
if selectStmt.Having != nil {
havingExpr = selectStmt.Having
//selectStmt.Having = nil
}
if selectStmt.Limit != nil {
limitExpr = selectStmt.Limit
selectStmt.Limit = nil
}
// 检查是否为COUNT语句
//如果不是count的语句添加order和group的支持
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
//这里处理的是报表的情况再报表里面需要重写COUNT语句才能获取到正确的数量
if v == true {
isCountStatement = true
isPureCountStatement = true
}
}
//常规的COUNT逻辑
if !isCountStatement && len(selectStmt.SelectExprs) == 1 {
for _, expr := range selectStmt.SelectExprs {
if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok {
if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok {
if funcExpr.Name.EqualString("count") {
isCountStatement = true
break
}
}
}
}
}
// 重写SQL
db.Statement.SQL.Reset()
trackedBuffer = sqlparserX.NewTrackedBuffer(nil)
if isCountStatement {
if isPureCountStatement {
db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (")
} else {
db.Statement.SQL.WriteString("SELECT SUM(")
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(strings.Trim("count(*)", "`"))
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(") FROM (")
}
} else {
if bs, ok := modelValue.(SelectBuilder); ok {
columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs)
if len(columns) > 0 {
db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (")
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
}
for i, name := range tables {
trackedBuffer.Reset()
db.Statement.SQL.WriteByte('(')
//赋值新的表名称
selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{
Expr: sqlparserX.TableName{
Name: sqlparserX.NewTableIdent(name),
},
}}
selectStmt.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
db.Statement.SQL.WriteByte(')')
if i < numOfTable-1 {
if plugin.UnionAll {
db.Statement.SQL.WriteString(" UNION ALL ")
} else {
db.Statement.SQL.WriteString(" UNION ")
}
}
if i > 0 {
//copy vars
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
}
}
db.Statement.SQL.WriteString(") tbl ")
if !isCountStatement {
//node.GroupBy, node.Having, node.OrderBy, node.Limit
if groupByExpr != nil {
trackedBuffer.Reset()
groupByExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if havingExpr != nil {
trackedBuffer.Reset()
havingExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if orderByExpr != nil {
trackedBuffer.Reset()
orderByExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if limitExpr != nil {
trackedBuffer.Reset()
limitExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
}
}
func New() *Sharding {
return &Sharding{
UnionAll: true,
QuoteChar: '`',
}
}