修改搜索支持多选
This commit is contained in:
parent
72b0de9c26
commit
44e6e2b34f
16
condition.go
16
condition.go
|
@ -3,10 +3,11 @@ package rest
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"git.nobla.cn/golang/kos/util/arrays"
|
|
||||||
"git.nobla.cn/golang/rest/types"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.nobla.cn/golang/kos/util/arrays"
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition {
|
func findCondition(schema *types.Schema, conditions []*types.Condition) *types.Condition {
|
||||||
|
@ -81,14 +82,25 @@ func BuildConditions(ctx context.Context, r *http.Request, query *Query, schemas
|
||||||
if row.Native == 0 {
|
if row.Native == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if row.Format == "multiSelect" {
|
||||||
|
columnName := row.Column + "[]"
|
||||||
|
if qs.Has(columnName) && len(qs[columnName]) > 0 {
|
||||||
|
query.AndFilterWhere(newConditionWithOperator("IN", row.Column, qs[columnName]))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
formValue = qs.Get(row.Column)
|
formValue = qs.Get(row.Column)
|
||||||
switch row.Format {
|
switch row.Format {
|
||||||
case types.FormatString, types.FormatText:
|
case types.FormatString, types.FormatText:
|
||||||
|
if len(qs[row.Column]) > 0 {
|
||||||
|
query.AndFilterWhere(newConditionWithOperator("IN", row.Column, qs[row.Column]))
|
||||||
|
} else {
|
||||||
if row.Attribute.Match == types.MatchExactly {
|
if row.Attribute.Match == types.MatchExactly {
|
||||||
query.AndFilterWhere(newCondition(row.Column, formValue))
|
query.AndFilterWhere(newCondition(row.Column, formValue))
|
||||||
} else {
|
} else {
|
||||||
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
|
query.AndFilterWhere(newCondition(row.Column, formValue).WithExpr("LIKE"))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp:
|
case types.FormatTime, types.FormatDate, types.FormatDatetime, types.FormatTimestamp:
|
||||||
var sep string
|
var sep string
|
||||||
seps := []byte{',', '/'}
|
seps := []byte{',', '/'}
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -19,6 +19,7 @@ require (
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||||
golang.org/x/crypto v0.19.0 // indirect
|
golang.org/x/crypto v0.19.0 // indirect
|
||||||
golang.org/x/net v0.21.0 // indirect
|
golang.org/x/net v0.21.0 // indirect
|
||||||
golang.org/x/sys v0.17.0 // indirect
|
golang.org/x/sys v0.17.0 // indirect
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -24,6 +24,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y=
|
github.com/longbridgeapp/sqlparser v0.3.2 h1:FV0dgMiv8VcksT3p10hJeqfPs8bodoehmUJ7MhBds+Y=
|
||||||
github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4=
|
github.com/longbridgeapp/sqlparser v0.3.2/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||||
|
|
102
query.go
102
query.go
|
@ -3,11 +3,13 @@ package rest
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"git.nobla.cn/golang/rest/types"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"git.nobla.cn/golang/rest/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -79,7 +81,35 @@ func (query *Query) compile() (*gorm.DB, error) {
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (query *Query) decodeValue(v any) string {
|
// quoteName 编码数据库字段
|
||||||
|
func (quote *Query) quoteName(name string) string {
|
||||||
|
if len(name) == 0 {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
name = strings.Map(func(r rune) rune {
|
||||||
|
if unicode.IsControl(r) {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if unicode.IsSpace(r) {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}, name)
|
||||||
|
isQuoted := len(name) >= 2 &&
|
||||||
|
name[0] == '`' &&
|
||||||
|
name[len(name)-1] == '`'
|
||||||
|
if !isQuoted {
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(name) + 2)
|
||||||
|
b.WriteByte('`')
|
||||||
|
b.WriteString(name)
|
||||||
|
b.WriteByte('`')
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query) quoteValue(v any) string {
|
||||||
refVal := reflect.Indirect(reflect.ValueOf(v))
|
refVal := reflect.Indirect(reflect.ValueOf(v))
|
||||||
switch refVal.Kind() {
|
switch refVal.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
@ -88,13 +118,18 @@ func (query *Query) decodeValue(v any) string {
|
||||||
} else {
|
} else {
|
||||||
return "0"
|
return "0"
|
||||||
}
|
}
|
||||||
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||||
|
return strconv.FormatUint(refVal.Uint(), 10)
|
||||||
|
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
return strconv.FormatInt(refVal.Int(), 10)
|
return strconv.FormatInt(refVal.Int(), 10)
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
|
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
return "'" + refVal.String() + "'"
|
return strconv.Quote(refVal.String())
|
||||||
default:
|
default:
|
||||||
|
if v == nil {
|
||||||
|
return "IS NULL"
|
||||||
|
}
|
||||||
return fmt.Sprint(v)
|
return fmt.Sprint(v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,7 +138,7 @@ func (query *Query) buildConditions(operator string, filter bool, conditions ...
|
||||||
var (
|
var (
|
||||||
sb strings.Builder
|
sb strings.Builder
|
||||||
)
|
)
|
||||||
params = make([]interface{}, 0)
|
params = make([]any, 0)
|
||||||
for _, cond := range conditions {
|
for _, cond := range conditions {
|
||||||
if filter {
|
if filter {
|
||||||
if isEmpty(cond.Value) {
|
if isEmpty(cond.Value) {
|
||||||
|
@ -116,42 +151,65 @@ func (query *Query) buildConditions(operator string, filter bool, conditions ...
|
||||||
switch strings.ToUpper(cond.Expr) {
|
switch strings.ToUpper(cond.Expr) {
|
||||||
case "=", "<>", ">", "<", ">=", "<=", "!=":
|
case "=", "<>", ">", "<", ">=", "<=", "!=":
|
||||||
if sb.Len() > 0 {
|
if sb.Len() > 0 {
|
||||||
sb.WriteString(" " + operator + " ")
|
sb.WriteString(" ")
|
||||||
|
sb.WriteString(operator)
|
||||||
|
sb.WriteString(" ")
|
||||||
}
|
}
|
||||||
if cond.Expr == "=" && cond.Value == nil {
|
if cond.Expr == "=" && cond.Value == nil {
|
||||||
sb.WriteString("`" + cond.Field + "` IS NULL")
|
sb.WriteString(query.quoteName(cond.Field))
|
||||||
|
sb.WriteString(" IS NULL")
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
|
sb.WriteString(query.quoteName(cond.Field))
|
||||||
|
sb.WriteString(" ")
|
||||||
|
sb.WriteString(cond.Expr)
|
||||||
|
sb.WriteString(" ?")
|
||||||
params = append(params, cond.Value)
|
params = append(params, cond.Value)
|
||||||
}
|
}
|
||||||
case "LIKE":
|
case "LIKE":
|
||||||
if sb.Len() > 0 {
|
if sb.Len() > 0 {
|
||||||
sb.WriteString(" " + operator + " ")
|
sb.WriteString(" ")
|
||||||
|
sb.WriteString(operator)
|
||||||
|
sb.WriteString(" ")
|
||||||
}
|
}
|
||||||
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
|
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
|
||||||
sb.WriteString("`" + cond.Field + "` LIKE ?")
|
sb.WriteString(query.quoteName(cond.Field))
|
||||||
|
sb.WriteString(" LIKE ?")
|
||||||
params = append(params, cond.Value)
|
params = append(params, cond.Value)
|
||||||
case "IN":
|
case "IN":
|
||||||
if sb.Len() > 0 {
|
if sb.Len() > 0 {
|
||||||
sb.WriteString(" " + operator + " ")
|
sb.WriteString(" ")
|
||||||
|
sb.WriteString(operator)
|
||||||
|
sb.WriteString(" ")
|
||||||
}
|
}
|
||||||
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
|
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
|
||||||
switch refVal.Kind() {
|
switch refVal.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
ss := make([]string, refVal.Len())
|
ss := make([]string, refVal.Len())
|
||||||
for i := 0; i < refVal.Len(); i++ {
|
for i := range refVal.Len() {
|
||||||
ss[i] = query.decodeValue(refVal.Index(i))
|
ss[i] = query.quoteValue(refVal.Index(i).Interface())
|
||||||
}
|
}
|
||||||
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
|
sb.WriteString(query.quoteName(cond.Field))
|
||||||
|
sb.WriteString(" IN (")
|
||||||
|
sb.WriteString(strings.Join(ss, ","))
|
||||||
|
sb.WriteString(") ")
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
|
sb.WriteString(query.quoteName(cond.Field))
|
||||||
|
sb.WriteString(" IN (")
|
||||||
|
sb.WriteString(refVal.String())
|
||||||
|
sb.WriteString(")")
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
case "BETWEEN":
|
case "BETWEEN":
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
sb.WriteString(" ")
|
||||||
|
sb.WriteString(operator)
|
||||||
|
sb.WriteString(" ")
|
||||||
|
}
|
||||||
refVal := reflect.ValueOf(cond.Value)
|
refVal := reflect.ValueOf(cond.Value)
|
||||||
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
|
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
|
||||||
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
|
sb.WriteString(query.quoteName(cond.Field))
|
||||||
params = append(params, refVal.Index(0), refVal.Index(1))
|
sb.WriteString(" BETWEEN ? AND ?")
|
||||||
|
params = append(params, refVal.Index(0).Interface(), refVal.Index(1).Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -341,7 +399,7 @@ func (query *Query) Count(v interface{}) (i int64) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (query *Query) One(v interface{}) (err error) {
|
func (query *Query) One(v any) (err error) {
|
||||||
var (
|
var (
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
)
|
)
|
||||||
|
@ -353,7 +411,7 @@ func (query *Query) One(v interface{}) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (query *Query) All(v interface{}) (err error) {
|
func (query *Query) All(v any) (err error) {
|
||||||
var (
|
var (
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
)
|
)
|
||||||
|
@ -376,7 +434,7 @@ func NewCondition(column, opera string, value any) *condition {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCondition(field string, value interface{}) *condition {
|
func newCondition(field string, value any) *condition {
|
||||||
return &condition{
|
return &condition{
|
||||||
Field: field,
|
Field: field,
|
||||||
Value: value,
|
Value: value,
|
||||||
|
@ -384,7 +442,7 @@ func newCondition(field string, value interface{}) *condition {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConditionWithOperator(operator, field string, value interface{}) *condition {
|
func newConditionWithOperator(operator, field string, value any) *condition {
|
||||||
cond := &condition{
|
cond := &condition{
|
||||||
Field: field,
|
Field: field,
|
||||||
Value: value,
|
Value: value,
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQuoteName(t *testing.T) {
|
||||||
|
q := &Query{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"Empty string", "", ""},
|
||||||
|
{"Control chars", "a\x00b\nc", "`abc`"}, // \x00 and \n should be filtered
|
||||||
|
{"Spaces", " test name ", "`testname`"}, // Spaces should be filtered
|
||||||
|
{"Properly quoted", "`valid`", "`valid`"}, // Already quoted
|
||||||
|
{"Left quote only", "`invalid", "``invalid`"}, // Add missing right quote
|
||||||
|
{"Normal unquoted", "normal", "`normal`"}, // Add quotes
|
||||||
|
{"All filtered", "\t\r\n", "``"}, // Filter all characters
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := q.quoteName(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("quoteName(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery_quoteValue(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
input any
|
||||||
|
expected string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
// Boolean values
|
||||||
|
{"bool true", true, "1"},
|
||||||
|
{"bool false", false, "0"},
|
||||||
|
{"bool pointer", new(bool), "0"}, // *bool with zero value
|
||||||
|
|
||||||
|
// Integer family
|
||||||
|
{"int", 42, "42"},
|
||||||
|
{"int8", int8(127), "127"},
|
||||||
|
{"uint", uint(100), "100"},
|
||||||
|
{"uint64", uint64(1<<64 - 1), "18446744073709551615"},
|
||||||
|
|
||||||
|
// Floating points
|
||||||
|
{"float64", 3.14, "3.14"},
|
||||||
|
{"float64 scientific", 1e10, "10000000000"},
|
||||||
|
{"float32", float32(1.5), "1.5"},
|
||||||
|
|
||||||
|
// Strings
|
||||||
|
{"simple string", "hello", `"hello"`},
|
||||||
|
{"string with quotes", `"quoted"`, `"\"quoted\""`},
|
||||||
|
{"string with newline", "line\nbreak", `"line\nbreak"`},
|
||||||
|
|
||||||
|
// Default cases
|
||||||
|
{"struct", struct{}{}, "{}"},
|
||||||
|
{"slice", []int{1, 2}, "[1 2]"},
|
||||||
|
{"nil", nil, "IS NULL"},
|
||||||
|
}
|
||||||
|
|
||||||
|
q := &Query{}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := q.quoteValue(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("quoteValue(%v) = %v, want %v", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue