From 44e6e2b34f56c66e1c9c2f5b6e487d7a05bd01a7 Mon Sep 17 00:00:00 2001 From: Yavolte Date: Fri, 11 Apr 2025 14:12:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=90=9C=E7=B4=A2=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=A4=9A=E9=80=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- condition.go | 22 ++++++++--- go.mod | 1 + go.sum | 2 + query.go | 102 +++++++++++++++++++++++++++++++++++++++----------- query_test.go | 79 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 179 insertions(+), 27 deletions(-) create mode 100644 query_test.go diff --git a/condition.go b/condition.go index 597c367..dd68330 100644 --- a/condition.go +++ b/condition.go @@ -3,10 +3,11 @@ package rest import ( "context" "encoding/json" - "git.nobla.cn/golang/kos/util/arrays" - "git.nobla.cn/golang/rest/types" "net/http" "strings" + + "git.nobla.cn/golang/kos/util/arrays" + "git.nobla.cn/golang/rest/types" ) func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition { @@ -81,13 +82,24 @@ func BuildConditions(ctx context.Context, r *http.Request, query *Query, schemas if row.Native == 0 { continue } + if row.Format == "multiSelect" { + columnName := row.Column + "[]" + if qs.Has(columnName) && len(qs[columnName]) > 0 { + query.AndFilterWhere(newConditionWithOperator("IN", row.Column, qs[columnName])) + continue + } + } formValue = qs.Get(row.Column) switch row.Format { case types.FormatString, types.FormatText: - if row.Attribute.Match == types.MatchExactly { - query.AndFilterWhere(newCondition(row.Column, formValue)) + if len(qs[row.Column]) > 0 { + query.AndFilterWhere(newConditionWithOperator("IN", row.Column, qs[row.Column])) } else { - query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE")) + if row.Attribute.Match == types.MatchExactly { + query.AndFilterWhere(newCondition(row.Column, formValue)) + } else { + query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE")) + } } case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp: var sep string diff --git a/go.mod b/go.mod index f2bb69f..e50f9f3 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect golang.org/x/crypto v0.19.0 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/sys v0.17.0 // indirect diff --git a/go.sum b/go.sum index 4446986..a9683b0 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y= github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= diff --git a/query.go b/query.go index b82267d..6171f62 100644 --- a/query.go +++ b/query.go @@ -3,11 +3,13 @@ package rest import ( "context" "fmt" - "git.nobla.cn/golang/rest/types" - "gorm.io/gorm" "reflect" "strconv" "strings" + "unicode" + + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" ) type ( @@ -79,7 +81,35 @@ func (query *Query) compile() (*gorm.DB, error) { return db, nil } -func (query *Query) decodeValue(v any) string { +// quoteName 编码数据库字段 +func (quote *Query) quoteName(name string) string { + if len(name) == 0 { + return name + } + name = strings.Map(func(r rune) rune { + if unicode.IsControl(r) { + return -1 + } + if unicode.IsSpace(r) { + return -1 + } + return r + }, name) + isQuoted := len(name) >= 2 && + name[0] == '`' && + name[len(name)-1] == '`' + if !isQuoted { + var b strings.Builder + b.Grow(len(name) + 2) + b.WriteByte('`') + b.WriteString(name) + b.WriteByte('`') + return b.String() + } + return name +} + +func (query *Query) quoteValue(v any) string { refVal := reflect.Indirect(reflect.ValueOf(v)) switch refVal.Kind() { case reflect.Bool: @@ -88,13 +118,18 @@ func (query *Query) decodeValue(v any) string { } else { return "0" } - case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64: + case reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(refVal.Uint(), 10) + case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64: return strconv.FormatInt(refVal.Int(), 10) case reflect.Float32, reflect.Float64: return strconv.FormatFloat(refVal.Float(), 'f', -1, 64) case reflect.String: - return "'" + refVal.String() + "'" + return strconv.Quote(refVal.String()) default: + if v == nil { + return "IS NULL" + } return fmt.Sprint(v) } } @@ -103,7 +138,7 @@ func (query *Query) buildConditions(operator string, filter bool, conditions ... var ( sb strings.Builder ) - params = make([]interface{}, 0) + params = make([]any, 0) for _, cond := range conditions { if filter { if isEmpty(cond.Value) { @@ -116,42 +151,65 @@ func (query *Query) buildConditions(operator string, filter bool, conditions ... switch strings.ToUpper(cond.Expr) { case "=", "<>", ">", "<", ">=", "<=", "!=": if sb.Len() > 0 { - sb.WriteString(" " + operator + " ") + sb.WriteString(" ") + sb.WriteString(operator) + sb.WriteString(" ") } if cond.Expr == "=" && cond.Value == nil { - sb.WriteString("`" + cond.Field + "` IS NULL") + sb.WriteString(query.quoteName(cond.Field)) + sb.WriteString(" IS NULL") } else { - sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?") + sb.WriteString(query.quoteName(cond.Field)) + sb.WriteString(" ") + sb.WriteString(cond.Expr) + sb.WriteString(" ?") params = append(params, cond.Value) } case "LIKE": if sb.Len() > 0 { - sb.WriteString(" " + operator + " ") + sb.WriteString(" ") + sb.WriteString(operator) + sb.WriteString(" ") } cond.Value = fmt.Sprintf("%%%s%%", cond.Value) - sb.WriteString("`" + cond.Field + "` LIKE ?") + sb.WriteString(query.quoteName(cond.Field)) + sb.WriteString(" LIKE ?") params = append(params, cond.Value) case "IN": if sb.Len() > 0 { - sb.WriteString(" " + operator + " ") + sb.WriteString(" ") + sb.WriteString(operator) + sb.WriteString(" ") } refVal := reflect.Indirect(reflect.ValueOf(cond.Value)) switch refVal.Kind() { case reflect.Slice, reflect.Array: ss := make([]string, refVal.Len()) - for i := 0; i < refVal.Len(); i++ { - ss[i] = query.decodeValue(refVal.Index(i)) + for i := range refVal.Len() { + ss[i] = query.quoteValue(refVal.Index(i).Interface()) } - sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")") + sb.WriteString(query.quoteName(cond.Field)) + sb.WriteString(" IN (") + sb.WriteString(strings.Join(ss, ",")) + sb.WriteString(") ") case reflect.String: - sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")") + sb.WriteString(query.quoteName(cond.Field)) + sb.WriteString(" IN (") + sb.WriteString(refVal.String()) + sb.WriteString(")") default: } case "BETWEEN": + if sb.Len() > 0 { + sb.WriteString(" ") + sb.WriteString(operator) + sb.WriteString(" ") + } refVal := reflect.ValueOf(cond.Value) if refVal.Kind() == reflect.Slice && refVal.Len() == 2 { - sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?") - params = append(params, refVal.Index(0), refVal.Index(1)) + sb.WriteString(query.quoteName(cond.Field)) + sb.WriteString(" BETWEEN ? AND ?") + params = append(params, refVal.Index(0).Interface(), refVal.Index(1).Interface()) } } } @@ -341,7 +399,7 @@ func (query *Query) Count(v interface{}) (i int64) { return } -func (query *Query) One(v interface{}) (err error) { +func (query *Query) One(v any) (err error) { var ( db *gorm.DB ) @@ -353,7 +411,7 @@ func (query *Query) One(v interface{}) (err error) { return } -func (query *Query) All(v interface{}) (err error) { +func (query *Query) All(v any) (err error) { var ( db *gorm.DB ) @@ -376,7 +434,7 @@ func NewCondition(column, opera string, value any) *condition { } } -func newCondition(field string, value interface{}) *condition { +func newCondition(field string, value any) *condition { return &condition{ Field: field, Value: value, @@ -384,7 +442,7 @@ func newCondition(field string, value interface{}) *condition { } } -func newConditionWithOperator(operator, field string, value interface{}) *condition { +func newConditionWithOperator(operator, field string, value any) *condition { cond := &condition{ Field: field, Value: value, diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..86ec113 --- /dev/null +++ b/query_test.go @@ -0,0 +1,79 @@ +package rest + +import ( + "testing" +) + +func TestQuoteName(t *testing.T) { + q := &Query{} + + tests := []struct { + name string + input string + expected string + }{ + {"Empty string", "", ""}, + {"Control chars", "a\x00b\nc", "`abc`"}, // \x00 and \n should be filtered + {"Spaces", " test name ", "`testname`"}, // Spaces should be filtered + {"Properly quoted", "`valid`", "`valid`"}, // Already quoted + {"Left quote only", "`invalid", "``invalid`"}, // Add missing right quote + {"Normal unquoted", "normal", "`normal`"}, // Add quotes + {"All filtered", "\t\r\n", "``"}, // Filter all characters + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := q.quoteName(tt.input) + if got != tt.expected { + t.Errorf("quoteName(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestQuery_quoteValue(t *testing.T) { + type testCase struct { + name string + input any + expected string + } + + tests := []testCase{ + // Boolean values + {"bool true", true, "1"}, + {"bool false", false, "0"}, + {"bool pointer", new(bool), "0"}, // *bool with zero value + + // Integer family + {"int", 42, "42"}, + {"int8", int8(127), "127"}, + {"uint", uint(100), "100"}, + {"uint64", uint64(1<<64 - 1), "18446744073709551615"}, + + // Floating points + {"float64", 3.14, "3.14"}, + {"float64 scientific", 1e10, "10000000000"}, + {"float32", float32(1.5), "1.5"}, + + // Strings + {"simple string", "hello", `"hello"`}, + {"string with quotes", `"quoted"`, `"\"quoted\""`}, + {"string with newline", "line\nbreak", `"line\nbreak"`}, + + // Default cases + {"struct", struct{}{}, "{}"}, + {"slice", []int{1, 2}, "[1 2]"}, + {"nil", nil, "IS NULL"}, + } + + q := &Query{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := q.quoteValue(tt.input) + if got != tt.expected { + t.Errorf("quoteValue(%v) = %v, want %v", tt.input, got, tt.expected) + } + }) + } +}