package rest import ( "context" "fmt" "git.nobla.cn/golang/rest/types" "gorm.io/gorm" "reflect" "strconv" "strings" ) type ( Query struct { db *gorm.DB condition string fields []string params []interface{} table string joins []join orderBy []string groupBy []string modelValue any limit int offset int } condition struct { Field string `json:"field"` Value interface{} `json:"value"` Expr string `json:"expr"` } join struct { Table string Direction string Conditions []*condition } ) func (cond *condition) WithExpr(v string) *condition { cond.Expr = v return cond } func (query *Query) Model() any { return query.modelValue } func (query *Query) compile() (*gorm.DB, error) { db := query.db if query.condition != "" { db = db.Where(query.condition, query.params...) } if query.fields != nil { db = db.Select(strings.Join(query.fields, ",")) } if query.table != "" { db = db.Table(query.table) } if query.joins != nil && len(query.joins) > 0 { for _, joinEntity := range query.joins { cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...) db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...) } } if query.orderBy != nil && len(query.orderBy) > 0 { db = db.Order(strings.Join(query.orderBy, ",")) } if query.groupBy != nil && len(query.groupBy) > 0 { db = db.Group(strings.Join(query.groupBy, ",")) } if query.offset > 0 { db = db.Offset(query.offset) } if query.limit > 0 { db = db.Limit(query.limit) } return db, nil } func (query *Query) decodeValue(v any) string { refVal := reflect.Indirect(reflect.ValueOf(v)) switch refVal.Kind() { case reflect.Bool: if refVal.Bool() { return "1" } else { return "0" } case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64: return strconv.FormatInt(refVal.Int(), 10) case reflect.Float32, reflect.Float64: return strconv.FormatFloat(refVal.Float(), 'f', -1, 64) case reflect.String: return "'" + refVal.String() + "'" default: return fmt.Sprint(v) } } func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) { var ( sb strings.Builder ) params = make([]interface{}, 0) for _, cond := range conditions { if filter { if isEmpty(cond.Value) { continue } } if cond.Expr == "" { cond.Expr = "=" } switch strings.ToUpper(cond.Expr) { case "=", "<>", ">", "<", ">=", "<=", "!=": if sb.Len() > 0 { sb.WriteString(" " + operator + " ") } if cond.Expr == "=" && cond.Value == nil { sb.WriteString("`" + cond.Field + "` IS NULL") } else { sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?") params = append(params, cond.Value) } case "LIKE": if sb.Len() > 0 { sb.WriteString(" " + operator + " ") } cond.Value = fmt.Sprintf("%%%s%%", cond.Value) sb.WriteString("`" + cond.Field + "` LIKE ?") params = append(params, cond.Value) case "IN": if sb.Len() > 0 { sb.WriteString(" " + operator + " ") } refVal := reflect.Indirect(reflect.ValueOf(cond.Value)) switch refVal.Kind() { case reflect.Slice, reflect.Array: ss := make([]string, refVal.Len()) for i := 0; i < refVal.Len(); i++ { ss[i] = query.decodeValue(refVal.Index(i)) } sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")") case reflect.String: sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")") default: } case "BETWEEN": refVal := reflect.ValueOf(cond.Value) if refVal.Kind() == reflect.Slice && refVal.Len() == 2 { sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?") params = append(params, refVal.Index(0), refVal.Index(1)) } } } str = sb.String() return } func (query *Query) Select(fields ...string) *Query { if query.fields == nil { query.fields = fields } else { query.fields = append(query.fields, fields...) } return query } func (query *Query) From(table string) *Query { query.table = table return query } func (query *Query) LeftJoin(table string, conditions ...*condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "LEFT", Conditions: conditions, }) return query } func (query *Query) RightJoin(table string, conditions ...*condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "RIGHT", Conditions: conditions, }) return query } func (query *Query) InnerJoin(table string, conditions ...*condition) *Query { query.joins = append(query.joins, join{ Table: table, Direction: "INNER", Conditions: conditions, }) return query } func (query *Query) AddCondition(expr, column string, val any) { if expr == "" { expr = "=" } query.AndWhere(newConditionWithOperator(expr, column, val)) } func (query *Query) AndWhere(conditions ...*condition) *Query { length := len(conditions) if length == 0 { return query } cs, ps := query.buildConditions("AND", false, conditions...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) AndFilterWhere(conditions ...*condition) *Query { length := len(conditions) if length == 0 { return query } cs, ps := query.buildConditions("AND", true, conditions...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND " + cs } return query } func (query *Query) OrWhere(conditions ...*condition) *Query { length := len(conditions) if length == 0 { return query } cs, ps := query.buildConditions("OR", false, conditions...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) OrFilterWhere(conditions ...*condition) *Query { length := len(conditions) if length == 0 { return query } cs, ps := query.buildConditions("OR", true, conditions...) if cs == "" { return query } query.params = append(query.params, ps...) if query.condition == "" { query.condition = cs } else { query.condition += " AND (" + cs + ")" } return query } func (query *Query) GroupBy(cols ...string) *Query { query.groupBy = append(query.groupBy, cols...) return query } func (query *Query) OrderBy(col, direction string) *Query { direction = strings.ToUpper(direction) if direction == "" || !(direction == "ASC" || direction == "DESC") { direction = "ASC" } query.orderBy = append(query.orderBy, col+" "+direction) return query } func (query *Query) Offset(i int) *Query { query.offset = i return query } func (query *Query) Limit(i int) *Query { query.limit = i return query } func (query *Query) ResetSelect() *Query { query.fields = make([]string, 0) return query } func (query *Query) Count(v interface{}) (i int64) { var ( db *gorm.DB err error ) if db, err = query.compile(); err != nil { return } else { if v != nil { refVal := reflect.ValueOf(v) switch refVal.Kind() { case reflect.String: if query.table == "" { err = db.Table(refVal.String()).Count(&i).Error } else { err = db.Table(query.table).Count(&i).Error } default: //如果是报表的模型,这的话手动构建一条SQL语句 if reporter, ok := v.(types.Reporter); ok { sqlRes := &sqlCountResponse{} childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true) db.WithContext(childCtx).Model(reporter).First(sqlRes) i = sqlRes.Count } else { err = db.Model(v).Count(&i).Error } } } } return } func (query *Query) One(v interface{}) (err error) { var ( db *gorm.DB ) if db, err = query.compile(); err != nil { return } else { err = db.First(v).Error } return } func (query *Query) All(v interface{}) (err error) { var ( db *gorm.DB ) if db, err = query.compile(); err != nil { return } else { err = db.Find(v).Error } return } func NewCondition(column, opera string, value any) *condition { if opera == "" { opera = "=" } return &condition{ Field: column, Value: value, Expr: opera, } } func newCondition(field string, value interface{}) *condition { return &condition{ Field: field, Value: value, Expr: "=", } } func newConditionWithOperator(operator, field string, value interface{}) *condition { cond := &condition{ Field: field, Value: value, Expr: operator, } return cond } func NewQuery(db *gorm.DB, model any) *Query { return &Query{ db: db, modelValue: model, params: make([]interface{}, 0), orderBy: make([]string, 0), groupBy: make([]string, 0), joins: make([]join, 0), } }