rest/plugins/sharding/scope.go

288 lines
7.6 KiB
Go
Raw Permalink Normal View History

2024-12-11 17:29:01 +08:00
package sharding
import (
"context"
"github.com/longbridgeapp/sqlparser"
sqlparserX "github.com/uole/sqlparser"
"gorm.io/gorm"
"strconv"
)
type Scope struct {
db *gorm.DB
stmt *sqlparser.SelectStatement
stmtX *sqlparserX.Select
}
func (scope *Scope) findValue(express sqlparser.Expr) (int, any) {
var (
vType int
vData any
)
switch expr := express.(type) {
case *sqlparser.BindExpr:
vType = ValueTypeAny
if len(scope.db.Statement.Vars) > expr.Pos {
vData = scope.db.Statement.Vars[expr.Pos]
} else {
vType = ValueTypeNull
}
case *sqlparser.NumberLit:
vType = ValueTypeNumber
vData = expr.Value
case *sqlparser.StringLit:
vType = ValueTypeString
vData = expr.Value
case *sqlparser.BoolLit:
vType = ValueTypeBoolean
vData = expr.Value
case *sqlparser.BlobLit:
vType = ValueTypeString
vData = expr.Value
case *sqlparser.NullLit:
vType = ValueTypeNull
case *sqlparser.Range:
arr := make([]any, 2)
vType, arr[0] = scope.findValue(expr.X)
vType, arr[1] = scope.findValue(expr.Y)
vData = arr
}
return vType, vData
}
func (scope *Scope) findValueX(express *sqlparserX.SQLVal) (int, any) {
var (
vType int
vData any
)
switch express.Type {
case sqlparserX.IntVal:
vType = ValueTypeNumber
vData, _ = strconv.Atoi(string(express.Val))
case sqlparserX.FloatVal:
vType = ValueTypeNumber
vData, _ = strconv.ParseFloat(string(express.Val), 64)
case sqlparserX.ValArg:
vType = ValueTypeAny
pos, _ := strconv.Atoi(string(express.Val[2:]))
if pos > 0 {
pos = pos - 1
if len(scope.db.Statement.Vars) > pos {
vData = scope.db.Statement.Vars[pos]
} else {
vType = ValueTypeNull
}
} else {
vType = ValueTypeNull
}
default:
vType = ValueTypeString
vData = string(express.Val)
}
return vType, vData
}
func (scope *Scope) recursiveFindX(expr sqlparserX.Expr, column string) (conditions []*ColumnCondition) {
var (
ok bool
andExpr *sqlparserX.AndExpr
orExpr *sqlparserX.OrExpr
parentExpr *sqlparserX.ParenExpr
comparisonExpr *sqlparserX.ComparisonExpr
rangeExpr *sqlparserX.RangeCond
coumnExpr *sqlparserX.ColName
valueExpr *sqlparserX.SQLVal
)
conditions = make([]*ColumnCondition, 0, 2)
if comparisonExpr, ok = expr.(*sqlparserX.ComparisonExpr); ok {
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
return
}
if valueExpr, ok = comparisonExpr.Right.(*sqlparserX.SQLVal); !ok {
return
}
if coumnExpr.Name.EqualString(column) {
cond := &ColumnCondition{
Name: coumnExpr.Name.String(),
}
vType, vData := scope.findValueX(valueExpr)
switch comparisonExpr.Operator {
case sqlparserX.LessThanStr, sqlparserX.LessEqualStr:
cond.Value = newCondValue(vType, ValueOperaLess, vData)
case sqlparserX.GreaterThanStr, sqlparserX.GreaterEqualStr:
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
case sqlparserX.EqualStr:
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
}
if cond.Value != nil {
conditions = append(conditions, cond)
}
}
}
if rangeExpr, ok = expr.(*sqlparserX.RangeCond); ok {
if coumnExpr, ok = comparisonExpr.Left.(*sqlparserX.ColName); !ok {
return
}
if coumnExpr.Name.EqualString(column) {
vType := 0
arr := make([]any, 2)
if valueExpr, ok = rangeExpr.From.(*sqlparserX.SQLVal); ok {
vType, arr[0] = scope.findValueX(valueExpr)
}
if valueExpr, ok = rangeExpr.To.(*sqlparserX.SQLVal); ok {
vType, arr[1] = scope.findValueX(valueExpr)
}
conditions = append(conditions, &ColumnCondition{
Name: coumnExpr.Name.String(),
Value: newCondValue(vType, ValueOperaRange, arr),
})
}
}
if andExpr, ok = expr.(*sqlparserX.AndExpr); ok {
if andExpr.Left != nil {
conditions = append(conditions, scope.recursiveFindX(andExpr.Left, column)...)
}
if andExpr.Right != nil {
conditions = append(conditions, scope.recursiveFindX(andExpr.Right, column)...)
}
}
if orExpr, ok = expr.(*sqlparserX.OrExpr); ok {
if orExpr.Left != nil {
conditions = append(conditions, scope.recursiveFindX(orExpr.Left, column)...)
}
if orExpr.Right != nil {
conditions = append(conditions, scope.recursiveFindX(orExpr.Right, column)...)
}
}
if parentExpr, ok = expr.(*sqlparserX.ParenExpr); ok {
if parentExpr.Expr != nil {
conditions = append(conditions, scope.recursiveFindX(parentExpr.Expr, column)...)
}
}
return conditions
}
func (scope *Scope) recursiveFind(expr sqlparser.Expr, column string) []*ColumnCondition {
var (
ok bool
identExpr *sqlparser.Ident
binaryExpr *sqlparser.BinaryExpr
parentExpr *sqlparser.ParenExpr
conditions []*ColumnCondition
)
conditions = make([]*ColumnCondition, 0, 2)
if parentExpr, ok = expr.(*sqlparser.ParenExpr); ok {
if parentExpr.X != nil {
if _, ok = parentExpr.X.(*sqlparser.BinaryExpr); ok {
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
}
if _, ok = parentExpr.X.(*sqlparser.ParenExpr); ok {
conditions = append(conditions, scope.recursiveFind(parentExpr.X, column)...)
}
}
}
if binaryExpr, ok = expr.(*sqlparser.BinaryExpr); ok {
if binaryExpr.X != nil {
if identExpr, ok = binaryExpr.X.(*sqlparser.Ident); ok {
if identExpr.Name == column {
cond := &ColumnCondition{
Name: identExpr.Name,
}
vType, vData := scope.findValue(binaryExpr.Y)
switch binaryExpr.Op {
case sqlparser.LT, sqlparser.LE:
cond.Value = newCondValue(vType, ValueOperaLess, vData)
case sqlparser.GT, sqlparser.GE:
cond.Value = newCondValue(vType, ValueOperaGreater, vData)
case sqlparser.RANGE, sqlparser.BETWEEN:
cond.Value = newCondValue(vType, ValueOperaRange, vData)
case sqlparser.EQ:
cond.Value = newCondValue(vType, ValueOperaEqual, vData)
}
if cond.Value != nil {
conditions = append(conditions, cond)
}
}
} else {
if _, ok = binaryExpr.X.(*sqlparser.BinaryExpr); ok {
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
}
if _, ok = binaryExpr.X.(*sqlparser.ParenExpr); ok {
conditions = append(conditions, scope.recursiveFind(binaryExpr.X, column)...)
}
}
}
if binaryExpr.Y != nil {
if _, ok = binaryExpr.Y.(*sqlparser.BinaryExpr); ok {
conditions = append(conditions, scope.recursiveFind(binaryExpr.Y, column)...)
}
}
}
return conditions
}
func (scope *Scope) DB() *gorm.DB {
return scope.db
}
func (scope *Scope) Context() context.Context {
return scope.db.Statement.Context
}
func (scope *Scope) FindCondition(column string) []*ColumnCondition {
if scope.stmtX != nil {
if scope.stmtX.Where == nil {
return []*ColumnCondition{}
}
return scope.recursiveFindX(scope.stmtX.Where.Expr, column)
}
return scope.recursiveFind(scope.stmt.Condition, column)
}
func (scope *Scope) FindColumnValues(column string) []any {
result := make([]any, 0)
conditions := scope.FindCondition(column)
if len(conditions) == 0 {
return result
}
for _, cond := range conditions {
if cond.Value.Opera() == ValueOperaGreater {
if len(result) == 0 {
result = make([]any, 2)
}
result[0] = cond.Value.Value()
}
if cond.Value.Opera() == ValueOperaLess {
if len(result) == 0 {
result = make([]any, 2)
}
result[1] = cond.Value.Value()
}
if cond.Value.Opera() == ValueOperaEqual {
result = append(result, cond.Value.Value())
break
}
if cond.Value.Opera() == ValueOperaRange {
if vs, ok := cond.Value.Value().([]any); ok {
if len(vs) == 2 {
if len(result) == 0 {
result = make([]any, 2)
}
result[0] = vs[0]
result[1] = vs[1]
}
}
break
}
}
return result
}