package sharding import ( "context" "github.com/longbridgeapp/sqlparser" sqlparserX "github.com/uole/sqlparser" "gorm.io/gorm" "strconv" ) type Scope struct { db *gorm.DB stmt *sqlparser.SelectStatement stmtX *sqlparserX.Select } func (scope *Scope) findValue(express sqlparser.Expr) (int, any) { var ( vType int vData any ) switch expr := express.(type) { case *sqlparser.BindExpr: vType = ValueTypeAny if len(scope.db.Statement.Vars) > expr.Pos { vData = scope.db.Statement.Vars[expr.Pos] } else { vType = ValueTypeNull } case *sqlparser.NumberLit: vType = ValueTypeNumber vData = expr.Value case *sqlparser.StringLit: vType = ValueTypeString vData = expr.Value case *sqlparser.BoolLit: vType = ValueTypeBoolean vData = expr.Value case *sqlparser.BlobLit: vType = ValueTypeString vData = expr.Value case *sqlparser.NullLit: vType = ValueTypeNull case *sqlparser.Range: arr := make([]any, 2) vType, arr[0] = scope.findValue(expr.X) vType, arr[1] = scope.findValue(expr.Y) vData = arr } return vType, vData } func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) { var ( vType int vData any ) switch express.Type { case sqlparserX.IntVal: vType = ValueTypeNumber vData, _ = strconv.Atoi(string(express.Val)) case sqlparserX.FloatVal: vType = ValueTypeNumber vData, _ = strconv.ParseFloat(string(express.Val), 64) case sqlparserX.ValArg: vType = ValueTypeAny pos, _ := strconv.Atoi(string(express.Val[2:])) if pos > 0 { pos = pos - 1 if len(scope.db.Statement.Vars) > pos { vData = scope.db.Statement.Vars[pos] } else { vType = ValueTypeNull } } else { vType = ValueTypeNull } default: vType = ValueTypeString vData = string(express.Val) } return vType, vData } func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) { var ( ok bool andExpr *sqlparserX.AndExpr orExpr *sqlparserX.OrExpr parentExpr *sqlparserX.ParenExpr comparisonExpr *sqlparserX.ComparisonExpr rangeExpr *sqlparserX.RangeCond coumnExpr *sqlparserX.ColName valueExpr *sqlparserX.SQLVal ) conditions = make([]*ColumnCondition, 0, 2) if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok { if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok { return } if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok { return } if coumnExpr.Name.EqualString(column) { cond := &ColumnCondition{ Name: coumnExpr.Name.String(), } vType, vData := scope.findValueX(valueExpr) switch comparisonExpr.Operator { case sqlparserX.LessThanStr, sqlparserX.LessEqualStr: cond.Value = newCondValue(vType, ValueOperaLess, vData) case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr: cond.Value = newCondValue(vType, ValueOperaGreater, vData) case sqlparserX.EqualStr: cond.Value = newCondValue(vType, ValueOperaEqual, vData) } if cond.Value != nil { conditions = append(conditions, cond) } } } if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok { if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok { return } if coumnExpr.Name.EqualString(column) { vType := 0 arr := make([]any, 2) if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok { vType, arr[0] = scope.findValueX(valueExpr) } if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok { vType, arr[1] = scope.findValueX(valueExpr) } conditions = append(conditions, &ColumnCondition{ Name: coumnExpr.Name.String(), Value: newCondValue(vType, ValueOperaRange, arr), }) } } if andExpr, ok = expr.(*sqlparserX.AndExpr); ok { if andExpr.Left != nil { conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...) } if andExpr.Right != nil { conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...) } } if orExpr, ok = expr.(*sqlparserX.OrExpr); ok { if orExpr.Left != nil { conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...) } if orExpr.Right != nil { conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...) } } if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok { if parentExpr.Expr != nil { conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...) } } return conditions } func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition { var ( ok bool identExpr *sqlparser.Ident binaryExpr *sqlparser.BinaryExpr parentExpr *sqlparser.ParenExpr conditions []*ColumnCondition ) conditions = make([]*ColumnCondition, 0, 2) if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok { if parentExpr.X != nil { if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok { conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...) } if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok { conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...) } } } if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok { if binaryExpr.X != nil { if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok { if identExpr.Name == column { cond := &ColumnCondition{ Name: identExpr.Name, } vType, vData := scope.findValue(binaryExpr.Y) switch binaryExpr.Op { case sqlparser.LT, sqlparser.LE: cond.Value = newCondValue(vType, ValueOperaLess, vData) case sqlparser.GT, sqlparser.GE: cond.Value = newCondValue(vType, ValueOperaGreater, vData) case sqlparser.RANGE, sqlparser.BETWEEN: cond.Value = newCondValue(vType, ValueOperaRange, vData) case sqlparser.EQ: cond.Value = newCondValue(vType, ValueOperaEqual, vData) } if cond.Value != nil { conditions = append(conditions, cond) } } } else { if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok { conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...) } if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok { conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...) } } } if binaryExpr.Y != nil { if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok { conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...) } } } return conditions } func (scope *Scope) DB() *gorm.DB { return scope.db } func (scope *Scope) Context() context.Context { return scope.db.Statement.Context } func (scope *Scope) FindCondition(column string) []*ColumnCondition { if scope.stmtX != nil { if scope.stmtX.Where == nil { return []*ColumnCondition{} } return scope.recursiveFindX(scope.stmtX.Where.Expr, column) } return scope.recursiveFind(scope.stmt.Condition, column) } func (scope *Scope) FindColumnValues(column string) []any { result := make([]any, 0) conditions := scope.FindCondition(column) if len(conditions) == 0 { return result } for _, cond := range conditions { if cond.Value.Opera() == ValueOperaGreater { if len(result) == 0 { result = make([]any, 2) } result[0] = cond.Value.Value() } if cond.Value.Opera() == ValueOperaLess { if len(result) == 0 { result = make([]any, 2) } result[1] = cond.Value.Value() } if cond.Value.Opera() == ValueOperaEqual { result = append(result, cond.Value.Value()) break } if cond.Value.Opera() == ValueOperaRange { if vs, ok := cond.Value.Value().([]any); ok { if len(vs) == 2 { if len(result) == 0 { result = make([]any, 2) } result[0] = vs[0] result[1] = vs[1] } } break } } return result }