rest/plugins/sharding/sharding.go

477 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package sharding
import (
"github.com/longbridgeapp/sqlparser"
sqlparserX "github.com/uole/sqlparser"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"reflect"
"strings"
)
type Sharding struct {
UnionAll bool
QuoteChar byte
}
func (plugin *Sharding) Name() string {
return "gorm:sharding"
}
func (plugin *Sharding) Initialize(db *gorm.DB) (err error) {
if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil {
return
}
if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil {
return
}
if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil {
return
}
if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil {
return err
}
return
}
func (plugin *Sharding) Create(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneCreate))
}
}
}
func (plugin *Sharding) Update(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneUpdate))
}
}
}
func (plugin *Sharding) Delete(db *gorm.DB) {
var (
ok bool
scopeModel Model
refValue reflect.Value
modelValue any
)
if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array {
if db.Statement.ReflectValue.Len() > 0 {
refValue = db.Statement.ReflectValue.Index(0)
modelValue = refValue.Interface()
}
} else {
if db.Statement.Model != nil {
modelValue = db.Statement.Model
} else {
modelValue = db.Statement.ReflectValue.Interface()
}
}
if modelValue != nil {
if scopeModel, ok = modelValue.(Model); ok {
db.Table(scopeModel.ShardingTable(sceneDelete))
}
}
}
func (plugin *Sharding) Query(db *gorm.DB) {
var (
err error
ok bool
shardingModel Model
modelValue any
tables []string
rawVars []any
refValue reflect.Value
tableName *sqlparser.TableName
selectStmt *sqlparser.SelectStatement
stmt sqlparser.Statement
parser *sqlparser.Parser
numOfTable int
orderByExpr []*sqlparser.OrderingTerm
limitExpr sqlparser.Expr
offsetExpr sqlparser.Expr
groupingExpr []sqlparser.Expr
havingExpr sqlparser.Expr
isCountStatement bool
countField string
)
if db.Statement.Model != nil {
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
} else {
refValue = reflect.New(db.Statement.ReflectValue.Type())
}
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
refValue = reflect.Indirect(refValue)
}
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
elemType := refValue.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
modelValue = reflect.New(elemType).Interface()
} else {
modelValue = refValue.Interface()
}
if shardingModel, ok = modelValue.(Model); !ok {
return
}
if db.Statement.SQL.Len() == 0 {
callbacks.BuildQuerySQL(db)
}
parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String()))
if stmt, err = parser.ParseStatement(); err != nil {
return
}
if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok {
return
}
tables = shardingModel.ShardingTables(&Scope{
db: db,
stmt: selectStmt,
})
numOfTable = len(tables)
if numOfTable <= 1 {
return
}
rawVars = make([]any, 0, len(db.Statement.Vars))
for _, v := range db.Statement.Vars {
rawVars = append(rawVars, v)
}
//是否是查询count语句
//如果不是count的语句添加order和group的支持
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
if v == true {
isCountStatement = true
}
}
if !isCountStatement && len(*selectStmt.Columns) == 1 {
for _, column := range *selectStmt.Columns {
if expr, ok := column.Expr.(*sqlparser.Call); ok {
if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword {
isCountStatement = true
countField = expr.String()
break
}
}
}
}
if len(selectStmt.OrderBy) > 0 {
orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy))
for _, row := range selectStmt.OrderBy {
orderByExpr = append(orderByExpr, row)
}
selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0)
}
if len(selectStmt.GroupingElements) > 0 {
groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements))
for _, row := range selectStmt.GroupingElements {
groupingExpr = append(groupingExpr, row)
}
if selectStmt.HavingCondition != nil {
havingExpr = selectStmt.HavingCondition
}
}
if selectStmt.Limit != nil {
limitExpr = selectStmt.Limit
selectStmt.Limit = nil
}
if selectStmt.Offset != nil {
offsetExpr = selectStmt.Offset
selectStmt.Offset = nil
}
db.Statement.SQL.Reset()
if isCountStatement {
db.Statement.SQL.WriteString("SELECT SUM(")
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(strings.Trim(countField, "`"))
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(") FROM (")
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
for i, name := range tables {
db.Statement.SQL.WriteByte('(')
if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok {
tableName.Name.Name = name
}
db.Statement.SQL.WriteString(selectStmt.String())
db.Statement.SQL.WriteByte(')')
if i < numOfTable-1 {
if plugin.UnionAll {
db.Statement.SQL.WriteString(" UNION ALL ")
} else {
db.Statement.SQL.WriteString(" UNION ")
}
}
if i > 0 {
//copy vars
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
}
}
db.Statement.SQL.WriteString(") tbl ")
if !isCountStatement {
if len(groupingExpr) > 0 {
db.Statement.SQL.WriteString(" GROUP BY ")
for i, expr := range groupingExpr {
if i != 0 {
db.Statement.SQL.WriteString(", ")
}
db.Statement.SQL.WriteString(expr.String())
}
if havingExpr != nil {
db.Statement.SQL.WriteString(" HAVING ")
db.Statement.SQL.WriteString(havingExpr.String())
}
}
if orderByExpr != nil && len(orderByExpr) > 0 {
db.Statement.SQL.WriteString(" ORDER BY ")
for i, term := range orderByExpr {
if i != 0 {
db.Statement.SQL.WriteString(", ")
}
db.Statement.SQL.WriteString(term.String())
}
}
if limitExpr != nil {
db.Statement.SQL.WriteString(" LIMIT ")
db.Statement.SQL.WriteString(limitExpr.String())
if offsetExpr != nil {
db.Statement.SQL.WriteString(" OFFSET ")
db.Statement.SQL.WriteString(offsetExpr.String())
}
}
}
return
}
func (plugin *Sharding) QueryX(db *gorm.DB) {
var (
err error
ok bool
shardingModel Model
modelValue any
tables []string
rawVars []any
refValue reflect.Value
selectStmt *sqlparserX.Select
stmt sqlparserX.Statement
numOfTable int
isCountStatement bool
isPureCountStatement bool
trackedBuffer *sqlparserX.TrackedBuffer
funcExpr *sqlparserX.FuncExpr
aliasedExpr *sqlparserX.AliasedExpr
orderByExpr sqlparserX.OrderBy
groupByExpr sqlparserX.GroupBy
havingExpr *sqlparserX.Where
limitExpr *sqlparserX.Limit
)
if db.Statement.Model != nil {
refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type())
} else {
if !db.Statement.ReflectValue.IsValid() {
return
}
refValue = reflect.New(db.Statement.ReflectValue.Type())
}
if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct {
refValue = reflect.Indirect(refValue)
}
if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice {
elemType := refValue.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
modelValue = reflect.New(elemType).Interface()
} else {
modelValue = refValue.Interface()
}
if shardingModel, ok = modelValue.(Model); !ok {
return
}
if db.Statement.SQL.Len() == 0 {
callbacks.BuildQuerySQL(db)
}
if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil {
return
}
if selectStmt, ok = stmt.(*sqlparserX.Select); !ok {
return
}
tables = shardingModel.ShardingTables(&Scope{
db: db,
stmtX: selectStmt,
})
numOfTable = len(tables)
if numOfTable <= 1 {
return
}
// 保存值
rawVars = make([]any, 0, len(db.Statement.Vars))
for _, v := range db.Statement.Vars {
rawVars = append(rawVars, v)
}
// 替换语句
if selectStmt.OrderBy != nil {
orderByExpr = selectStmt.OrderBy
selectStmt.OrderBy = nil
}
if selectStmt.GroupBy != nil {
groupByExpr = selectStmt.GroupBy
//selectStmt.GroupBy = nil
}
if selectStmt.Having != nil {
havingExpr = selectStmt.Having
//selectStmt.Having = nil
}
if selectStmt.Limit != nil {
limitExpr = selectStmt.Limit
selectStmt.Limit = nil
}
// 检查是否为COUNT语句
//如果不是count的语句添加order和group的支持
if v := db.Statement.Context.Value("@sql_count_statement"); v != nil {
//这里处理的是报表的情况再报表里面需要重写COUNT语句才能获取到正确的数量
if v == true {
isCountStatement = true
isPureCountStatement = true
}
}
//常规的COUNT逻辑
if !isCountStatement && len(selectStmt.SelectExprs) == 1 {
for _, expr := range selectStmt.SelectExprs {
if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok {
if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok {
if funcExpr.Name.EqualString("count") {
isCountStatement = true
break
}
}
}
}
}
// 重写SQL
db.Statement.SQL.Reset()
trackedBuffer = sqlparserX.NewTrackedBuffer(nil)
if isCountStatement {
if isPureCountStatement {
db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (")
} else {
db.Statement.SQL.WriteString("SELECT SUM(")
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(strings.Trim("count(*)", "`"))
db.Statement.SQL.WriteByte(plugin.QuoteChar)
db.Statement.SQL.WriteString(") FROM (")
}
} else {
if bs, ok := modelValue.(SelectBuilder); ok {
columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs)
if len(columns) > 0 {
db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (")
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
} else {
db.Statement.SQL.WriteString("SELECT * FROM (")
}
}
for i, name := range tables {
trackedBuffer.Reset()
db.Statement.SQL.WriteByte('(')
//赋值新的表名称
selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{
Expr: sqlparserX.TableName{
Name: sqlparserX.NewTableIdent(name),
},
}}
selectStmt.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
db.Statement.SQL.WriteByte(')')
if i < numOfTable-1 {
if plugin.UnionAll {
db.Statement.SQL.WriteString(" UNION ALL ")
} else {
db.Statement.SQL.WriteString(" UNION ")
}
}
if i > 0 {
//copy vars
db.Statement.Vars = append(db.Statement.Vars, rawVars...)
}
}
db.Statement.SQL.WriteString(") tbl ")
if !isCountStatement {
//node.GroupBy, node.Having, node.OrderBy, node.Limit
if groupByExpr != nil {
trackedBuffer.Reset()
groupByExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if havingExpr != nil {
trackedBuffer.Reset()
havingExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if orderByExpr != nil {
trackedBuffer.Reset()
orderByExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
if limitExpr != nil {
trackedBuffer.Reset()
limitExpr.Format(trackedBuffer)
db.Statement.SQL.WriteString(trackedBuffer.String())
}
}
}
func New() *Sharding {
return &Sharding{
UnionAll: true,
QuoteChar: '`',
}
}