package sharding import ( "github.com/longbridgeapp/sqlparser" sqlparserX "github.com/uole/sqlparser" "gorm.io/gorm" "gorm.io/gorm/callbacks" "reflect" "strings" ) type Sharding struct { UnionAll bool QuoteChar byte } func (plugin *Sharding) Name() string { return "gorm:sharding" } func (plugin *Sharding) Initialize(db *gorm.DB) (err error) { if err = db.Callback().Create().Before("gorm:create").Register("gorm_sharding_create", plugin.Create); err != nil { return } if err = db.Callback().Update().Before("gorm:update").Register("gorm_sharding_update", plugin.Update); err != nil { return } if err = db.Callback().Delete().Before("gorm:delete").Register("gorm_sharding_delete", plugin.Delete); err != nil { return } if err = db.Callback().Query().Before("gorm:query").Register("gorm_sharding_query", plugin.QueryX); err != nil { return err } return } func (plugin *Sharding) Create(db *gorm.DB) { var ( ok bool scopeModel Model refValue reflect.Value modelValue any ) if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { if db.Statement.ReflectValue.Len() > 0 { refValue = db.Statement.ReflectValue.Index(0) modelValue = refValue.Interface() } } else { if db.Statement.Model != nil { modelValue = db.Statement.Model } else { modelValue = db.Statement.ReflectValue.Interface() } } if modelValue != nil { if scopeModel, ok = modelValue.(Model); ok { db.Table(scopeModel.ShardingTable(sceneCreate)) } } } func (plugin *Sharding) Update(db *gorm.DB) { var ( ok bool scopeModel Model refValue reflect.Value modelValue any ) if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { if db.Statement.ReflectValue.Len() > 0 { refValue = db.Statement.ReflectValue.Index(0) modelValue = refValue.Interface() } } else { if db.Statement.Model != nil { modelValue = db.Statement.Model } else { modelValue = db.Statement.ReflectValue.Interface() } } if modelValue != nil { if scopeModel, ok = modelValue.(Model); ok { db.Table(scopeModel.ShardingTable(sceneUpdate)) } } } func (plugin *Sharding) Delete(db *gorm.DB) { var ( ok bool scopeModel Model refValue reflect.Value modelValue any ) if db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Array { if db.Statement.ReflectValue.Len() > 0 { refValue = db.Statement.ReflectValue.Index(0) modelValue = refValue.Interface() } } else { if db.Statement.Model != nil { modelValue = db.Statement.Model } else { modelValue = db.Statement.ReflectValue.Interface() } } if modelValue != nil { if scopeModel, ok = modelValue.(Model); ok { db.Table(scopeModel.ShardingTable(sceneDelete)) } } } func (plugin *Sharding) Query(db *gorm.DB) { var ( err error ok bool shardingModel Model modelValue any tables []string rawVars []any refValue reflect.Value tableName *sqlparser.TableName selectStmt *sqlparser.SelectStatement stmt sqlparser.Statement parser *sqlparser.Parser numOfTable int orderByExpr []*sqlparser.OrderingTerm limitExpr sqlparser.Expr offsetExpr sqlparser.Expr groupingExpr []sqlparser.Expr havingExpr sqlparser.Expr isCountStatement bool countField string ) if db.Statement.Model != nil { refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type()) } else { refValue = reflect.New(db.Statement.ReflectValue.Type()) } if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct { refValue = reflect.Indirect(refValue) } if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice { elemType := refValue.Type().Elem() if elemType.Kind() == reflect.Ptr { elemType = elemType.Elem() } modelValue = reflect.New(elemType).Interface() } else { modelValue = refValue.Interface() } if shardingModel, ok = modelValue.(Model); !ok { return } if db.Statement.SQL.Len() == 0 { callbacks.BuildQuerySQL(db) } parser = sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String())) if stmt, err = parser.ParseStatement(); err != nil { return } if selectStmt, ok = stmt.(*sqlparser.SelectStatement); !ok { return } tables = shardingModel.ShardingTables(&Scope{ db: db, stmt: selectStmt, }) numOfTable = len(tables) if numOfTable <= 1 { return } rawVars = make([]any, 0, len(db.Statement.Vars)) for _, v := range db.Statement.Vars { rawVars = append(rawVars, v) } //是否是查询count语句 //如果不是count的语句,添加order和group的支持 if v := db.Statement.Context.Value("@sql_count_statement"); v != nil { if v == true { isCountStatement = true } } if !isCountStatement && len(*selectStmt.Columns) == 1 { for _, column := range *selectStmt.Columns { if expr, ok := column.Expr.(*sqlparser.Call); ok { if expr.Star && strings.ToLower(expr.Name.Name) == stmtCountKeyword { isCountStatement = true countField = expr.String() break } } } } if len(selectStmt.OrderBy) > 0 { orderByExpr = make([]*sqlparser.OrderingTerm, 0, len(selectStmt.OrderBy)) for _, row := range selectStmt.OrderBy { orderByExpr = append(orderByExpr, row) } selectStmt.OrderBy = make([]*sqlparser.OrderingTerm, 0) } if len(selectStmt.GroupingElements) > 0 { groupingExpr = make([]sqlparser.Expr, 0, len(selectStmt.GroupingElements)) for _, row := range selectStmt.GroupingElements { groupingExpr = append(groupingExpr, row) } if selectStmt.HavingCondition != nil { havingExpr = selectStmt.HavingCondition } } if selectStmt.Limit != nil { limitExpr = selectStmt.Limit selectStmt.Limit = nil } if selectStmt.Offset != nil { offsetExpr = selectStmt.Offset selectStmt.Offset = nil } db.Statement.SQL.Reset() if isCountStatement { db.Statement.SQL.WriteString("SELECT SUM(") db.Statement.SQL.WriteByte(plugin.QuoteChar) db.Statement.SQL.WriteString(strings.Trim(countField, "`")) db.Statement.SQL.WriteByte(plugin.QuoteChar) db.Statement.SQL.WriteString(") FROM (") } else { db.Statement.SQL.WriteString("SELECT * FROM (") } for i, name := range tables { db.Statement.SQL.WriteByte('(') if tableName, ok = selectStmt.FromItems.(*sqlparser.TableName); ok { tableName.Name.Name = name } db.Statement.SQL.WriteString(selectStmt.String()) db.Statement.SQL.WriteByte(')') if i < numOfTable-1 { if plugin.UnionAll { db.Statement.SQL.WriteString(" UNION ALL ") } else { db.Statement.SQL.WriteString(" UNION ") } } if i > 0 { //copy vars db.Statement.Vars = append(db.Statement.Vars, rawVars...) } } db.Statement.SQL.WriteString(") tbl ") if !isCountStatement { if len(groupingExpr) > 0 { db.Statement.SQL.WriteString(" GROUP BY ") for i, expr := range groupingExpr { if i != 0 { db.Statement.SQL.WriteString(", ") } db.Statement.SQL.WriteString(expr.String()) } if havingExpr != nil { db.Statement.SQL.WriteString(" HAVING ") db.Statement.SQL.WriteString(havingExpr.String()) } } if orderByExpr != nil && len(orderByExpr) > 0 { db.Statement.SQL.WriteString(" ORDER BY ") for i, term := range orderByExpr { if i != 0 { db.Statement.SQL.WriteString(", ") } db.Statement.SQL.WriteString(term.String()) } } if limitExpr != nil { db.Statement.SQL.WriteString(" LIMIT ") db.Statement.SQL.WriteString(limitExpr.String()) if offsetExpr != nil { db.Statement.SQL.WriteString(" OFFSET ") db.Statement.SQL.WriteString(offsetExpr.String()) } } } return } func (plugin *Sharding) QueryX(db *gorm.DB) { var ( err error ok bool shardingModel Model modelValue any tables []string rawVars []any refValue reflect.Value selectStmt *sqlparserX.Select stmt sqlparserX.Statement numOfTable int isCountStatement bool isPureCountStatement bool trackedBuffer *sqlparserX.TrackedBuffer funcExpr *sqlparserX.FuncExpr aliasedExpr *sqlparserX.AliasedExpr orderByExpr sqlparserX.OrderBy groupByExpr sqlparserX.GroupBy havingExpr *sqlparserX.Where limitExpr *sqlparserX.Limit ) if db.Statement.Model != nil { refValue = reflect.New(reflect.Indirect(reflect.ValueOf(db.Statement.Model)).Type()) } else { if !db.Statement.ReflectValue.IsValid() { return } refValue = reflect.New(db.Statement.ReflectValue.Type()) } if refValue.Kind() == reflect.Ptr && refValue.Elem().Kind() != reflect.Struct { refValue = reflect.Indirect(refValue) } if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice { elemType := refValue.Type().Elem() if elemType.Kind() == reflect.Ptr { elemType = elemType.Elem() } modelValue = reflect.New(elemType).Interface() } else { modelValue = refValue.Interface() } if shardingModel, ok = modelValue.(Model); !ok { return } if db.Statement.SQL.Len() == 0 { callbacks.BuildQuerySQL(db) } if stmt, err = sqlparserX.Parse(db.Statement.SQL.String()); err != nil { return } if selectStmt, ok = stmt.(*sqlparserX.Select); !ok { return } tables = shardingModel.ShardingTables(&Scope{ db: db, stmtX: selectStmt, }) numOfTable = len(tables) if numOfTable <= 1 { return } // 保存值 rawVars = make([]any, 0, len(db.Statement.Vars)) for _, v := range db.Statement.Vars { rawVars = append(rawVars, v) } // 替换语句 if selectStmt.OrderBy != nil { orderByExpr = selectStmt.OrderBy selectStmt.OrderBy = nil } if selectStmt.GroupBy != nil { groupByExpr = selectStmt.GroupBy //selectStmt.GroupBy = nil } if selectStmt.Having != nil { havingExpr = selectStmt.Having //selectStmt.Having = nil } if selectStmt.Limit != nil { limitExpr = selectStmt.Limit selectStmt.Limit = nil } // 检查是否为COUNT语句 //如果不是count的语句,添加order和group的支持 if v := db.Statement.Context.Value("@sql_count_statement"); v != nil { //这里处理的是报表的情况,再报表里面需要重写COUNT语句才能获取到正确的数量 if v == true { isCountStatement = true isPureCountStatement = true } } //常规的COUNT逻辑 if !isCountStatement && len(selectStmt.SelectExprs) == 1 { for _, expr := range selectStmt.SelectExprs { if aliasedExpr, ok = expr.(*sqlparserX.AliasedExpr); ok { if funcExpr, ok = aliasedExpr.Expr.(*sqlparserX.FuncExpr); ok { if funcExpr.Name.EqualString("count") { isCountStatement = true break } } } } } // 重写SQL db.Statement.SQL.Reset() trackedBuffer = sqlparserX.NewTrackedBuffer(nil) if isCountStatement { if isPureCountStatement { db.Statement.SQL.WriteString("SELECT COUNT(*) AS count FROM (") } else { db.Statement.SQL.WriteString("SELECT SUM(") db.Statement.SQL.WriteByte(plugin.QuoteChar) db.Statement.SQL.WriteString(strings.Trim("count(*)", "`")) db.Statement.SQL.WriteByte(plugin.QuoteChar) db.Statement.SQL.WriteString(") FROM (") } } else { if bs, ok := modelValue.(SelectBuilder); ok { columns := bs.BuildSelect(db.Statement.Context, selectStmt.SelectExprs) if len(columns) > 0 { db.Statement.SQL.WriteString("SELECT " + strings.Join(columns, ",") + " FROM (") } else { db.Statement.SQL.WriteString("SELECT * FROM (") } } else { db.Statement.SQL.WriteString("SELECT * FROM (") } } for i, name := range tables { trackedBuffer.Reset() db.Statement.SQL.WriteByte('(') //赋值新的表名称 selectStmt.From = sqlparserX.TableExprs{&sqlparserX.AliasedTableExpr{ Expr: sqlparserX.TableName{ Name: sqlparserX.NewTableIdent(name), }, }} selectStmt.Format(trackedBuffer) db.Statement.SQL.WriteString(trackedBuffer.String()) db.Statement.SQL.WriteByte(')') if i < numOfTable-1 { if plugin.UnionAll { db.Statement.SQL.WriteString(" UNION ALL ") } else { db.Statement.SQL.WriteString(" UNION ") } } if i > 0 { //copy vars db.Statement.Vars = append(db.Statement.Vars, rawVars...) } } db.Statement.SQL.WriteString(") tbl ") if !isCountStatement { //node.GroupBy, node.Having, node.OrderBy, node.Limit if groupByExpr != nil { trackedBuffer.Reset() groupByExpr.Format(trackedBuffer) db.Statement.SQL.WriteString(trackedBuffer.String()) } if havingExpr != nil { trackedBuffer.Reset() havingExpr.Format(trackedBuffer) db.Statement.SQL.WriteString(trackedBuffer.String()) } if orderByExpr != nil { trackedBuffer.Reset() orderByExpr.Format(trackedBuffer) db.Statement.SQL.WriteString(trackedBuffer.String()) } if limitExpr != nil { trackedBuffer.Reset() limitExpr.Format(trackedBuffer) db.Statement.SQL.WriteString(trackedBuffer.String()) } } } func New() *Sharding { return &Sharding{ UnionAll: true, QuoteChar: '`', } }