288 lines
7.6 KiB
Go
288 lines
7.6 KiB
Go
|
package sharding
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"github.com/longbridgeapp/sqlparser"
|
||
|
sqlparserX "github.com/uole/sqlparser"
|
||
|
"gorm.io/gorm"
|
||
|
"strconv"
|
||
|
)
|
||
|
|
||
|
type Scope struct {
|
||
|
db *gorm.DB
|
||
|
stmt *sqlparser.SelectStatement
|
||
|
stmtX *sqlparserX.Select
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) findValue(express sqlparser.Expr) (int, any) {
|
||
|
var (
|
||
|
vType int
|
||
|
vData any
|
||
|
)
|
||
|
switch expr := express.(type) {
|
||
|
case *sqlparser.BindExpr:
|
||
|
vType = ValueTypeAny
|
||
|
if len(scope.db.Statement.Vars) > expr.Pos {
|
||
|
vData = scope.db.Statement.Vars[expr.Pos]
|
||
|
} else {
|
||
|
vType = ValueTypeNull
|
||
|
}
|
||
|
case *sqlparser.NumberLit:
|
||
|
vType = ValueTypeNumber
|
||
|
vData = expr.Value
|
||
|
case *sqlparser.StringLit:
|
||
|
vType = ValueTypeString
|
||
|
vData = expr.Value
|
||
|
case *sqlparser.BoolLit:
|
||
|
vType = ValueTypeBoolean
|
||
|
vData = expr.Value
|
||
|
case *sqlparser.BlobLit:
|
||
|
vType = ValueTypeString
|
||
|
vData = expr.Value
|
||
|
case *sqlparser.NullLit:
|
||
|
vType = ValueTypeNull
|
||
|
case *sqlparser.Range:
|
||
|
arr := make([]any, 2)
|
||
|
vType, arr[0] = scope.findValue(expr.X)
|
||
|
vType, arr[1] = scope.findValue(expr.Y)
|
||
|
vData = arr
|
||
|
}
|
||
|
return vType, vData
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) {
|
||
|
var (
|
||
|
vType int
|
||
|
vData any
|
||
|
)
|
||
|
switch express.Type {
|
||
|
case sqlparserX.IntVal:
|
||
|
vType = ValueTypeNumber
|
||
|
vData, _ = strconv.Atoi(string(express.Val))
|
||
|
case sqlparserX.FloatVal:
|
||
|
vType = ValueTypeNumber
|
||
|
vData, _ = strconv.ParseFloat(string(express.Val), 64)
|
||
|
case sqlparserX.ValArg:
|
||
|
vType = ValueTypeAny
|
||
|
pos, _ := strconv.Atoi(string(express.Val[2:]))
|
||
|
if pos > 0 {
|
||
|
pos = pos - 1
|
||
|
if len(scope.db.Statement.Vars) > pos {
|
||
|
vData = scope.db.Statement.Vars[pos]
|
||
|
} else {
|
||
|
vType = ValueTypeNull
|
||
|
}
|
||
|
} else {
|
||
|
vType = ValueTypeNull
|
||
|
}
|
||
|
default:
|
||
|
vType = ValueTypeString
|
||
|
vData = string(express.Val)
|
||
|
}
|
||
|
return vType, vData
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) {
|
||
|
var (
|
||
|
ok bool
|
||
|
andExpr *sqlparserX.AndExpr
|
||
|
orExpr *sqlparserX.OrExpr
|
||
|
parentExpr *sqlparserX.ParenExpr
|
||
|
comparisonExpr *sqlparserX.ComparisonExpr
|
||
|
rangeExpr *sqlparserX.RangeCond
|
||
|
coumnExpr *sqlparserX.ColName
|
||
|
valueExpr *sqlparserX.SQLVal
|
||
|
)
|
||
|
conditions = make([]*ColumnCondition, 0, 2)
|
||
|
if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok {
|
||
|
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
|
||
|
return
|
||
|
}
|
||
|
if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok {
|
||
|
return
|
||
|
}
|
||
|
if coumnExpr.Name.EqualString(column) {
|
||
|
cond := &ColumnCondition{
|
||
|
Name: coumnExpr.Name.String(),
|
||
|
}
|
||
|
vType, vData := scope.findValueX(valueExpr)
|
||
|
switch comparisonExpr.Operator {
|
||
|
case sqlparserX.LessThanStr, sqlparserX.LessEqualStr:
|
||
|
cond.Value = newCondValue(vType, ValueOperaLess, vData)
|
||
|
case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr:
|
||
|
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
|
||
|
case sqlparserX.EqualStr:
|
||
|
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
|
||
|
}
|
||
|
if cond.Value != nil {
|
||
|
conditions = append(conditions, cond)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok {
|
||
|
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
|
||
|
return
|
||
|
}
|
||
|
if coumnExpr.Name.EqualString(column) {
|
||
|
vType := 0
|
||
|
arr := make([]any, 2)
|
||
|
if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok {
|
||
|
vType, arr[0] = scope.findValueX(valueExpr)
|
||
|
}
|
||
|
if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok {
|
||
|
vType, arr[1] = scope.findValueX(valueExpr)
|
||
|
}
|
||
|
conditions = append(conditions, &ColumnCondition{
|
||
|
Name: coumnExpr.Name.String(),
|
||
|
Value: newCondValue(vType, ValueOperaRange, arr),
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if andExpr, ok = expr.(*sqlparserX.AndExpr); ok {
|
||
|
if andExpr.Left != nil {
|
||
|
conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...)
|
||
|
}
|
||
|
if andExpr.Right != nil {
|
||
|
conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...)
|
||
|
}
|
||
|
}
|
||
|
if orExpr, ok = expr.(*sqlparserX.OrExpr); ok {
|
||
|
if orExpr.Left != nil {
|
||
|
conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...)
|
||
|
}
|
||
|
if orExpr.Right != nil {
|
||
|
conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...)
|
||
|
}
|
||
|
}
|
||
|
if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok {
|
||
|
if parentExpr.Expr != nil {
|
||
|
conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...)
|
||
|
}
|
||
|
}
|
||
|
return conditions
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition {
|
||
|
var (
|
||
|
ok bool
|
||
|
identExpr *sqlparser.Ident
|
||
|
binaryExpr *sqlparser.BinaryExpr
|
||
|
parentExpr *sqlparser.ParenExpr
|
||
|
conditions []*ColumnCondition
|
||
|
)
|
||
|
conditions = make([]*ColumnCondition, 0, 2)
|
||
|
if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok {
|
||
|
if parentExpr.X != nil {
|
||
|
if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok {
|
||
|
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
|
||
|
}
|
||
|
if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok {
|
||
|
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok {
|
||
|
if binaryExpr.X != nil {
|
||
|
if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok {
|
||
|
if identExpr.Name == column {
|
||
|
cond := &ColumnCondition{
|
||
|
Name: identExpr.Name,
|
||
|
}
|
||
|
vType, vData := scope.findValue(binaryExpr.Y)
|
||
|
switch binaryExpr.Op {
|
||
|
case sqlparser.LT, sqlparser.LE:
|
||
|
cond.Value = newCondValue(vType, ValueOperaLess, vData)
|
||
|
case sqlparser.GT, sqlparser.GE:
|
||
|
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
|
||
|
case sqlparser.RANGE, sqlparser.BETWEEN:
|
||
|
cond.Value = newCondValue(vType, ValueOperaRange, vData)
|
||
|
case sqlparser.EQ:
|
||
|
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
|
||
|
}
|
||
|
if cond.Value != nil {
|
||
|
conditions = append(conditions, cond)
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok {
|
||
|
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
|
||
|
}
|
||
|
if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok {
|
||
|
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if binaryExpr.Y != nil {
|
||
|
if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok {
|
||
|
conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return conditions
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) DB() *gorm.DB {
|
||
|
return scope.db
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) Context() context.Context {
|
||
|
return scope.db.Statement.Context
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) FindCondition(column string) []*ColumnCondition {
|
||
|
if scope.stmtX != nil {
|
||
|
if scope.stmtX.Where == nil {
|
||
|
return []*ColumnCondition{}
|
||
|
}
|
||
|
return scope.recursiveFindX(scope.stmtX.Where.Expr, column)
|
||
|
}
|
||
|
return scope.recursiveFind(scope.stmt.Condition, column)
|
||
|
}
|
||
|
|
||
|
func (scope *Scope) FindColumnValues(column string) []any {
|
||
|
result := make([]any, 0)
|
||
|
conditions := scope.FindCondition(column)
|
||
|
if len(conditions) == 0 {
|
||
|
return result
|
||
|
}
|
||
|
for _, cond := range conditions {
|
||
|
if cond.Value.Opera() == ValueOperaGreater {
|
||
|
if len(result) == 0 {
|
||
|
result = make([]any, 2)
|
||
|
}
|
||
|
result[0] = cond.Value.Value()
|
||
|
}
|
||
|
|
||
|
if cond.Value.Opera() == ValueOperaLess {
|
||
|
if len(result) == 0 {
|
||
|
result = make([]any, 2)
|
||
|
}
|
||
|
result[1] = cond.Value.Value()
|
||
|
}
|
||
|
|
||
|
if cond.Value.Opera() == ValueOperaEqual {
|
||
|
result = append(result, cond.Value.Value())
|
||
|
break
|
||
|
}
|
||
|
|
||
|
if cond.Value.Opera() == ValueOperaRange {
|
||
|
if vs, ok := cond.Value.Value().([]any); ok {
|
||
|
if len(vs) == 2 {
|
||
|
if len(result) == 0 {
|
||
|
result = make([]any, 2)
|
||
|
}
|
||
|
result[0] = vs[0]
|
||
|
result[1] = vs[1]
|
||
|
}
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
return result
|
||
|
}
|