rest/query.go

406 lines
8.9 KiB
Go
Raw Permalink Normal View History

2024-12-11 17:29:01 +08:00
package rest
import (
"context"
"fmt"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"reflect"
"strconv"
"strings"
)
type (
Query struct {
db *gorm.DB
condition string
fields []string
params []interface{}
table string
joins []join
orderBy []string
groupBy []string
modelValue any
limit int
offset int
}
condition struct {
Field string `json:"field"`
Value interface{} `json:"value"`
Expr string `json:"expr"`
}
join struct {
Table string
Direction string
Conditions []*condition
}
)
func (cond *condition) WithExpr(v string) *condition {
cond.Expr = v
return cond
}
func (query *Query) Model() any {
return query.modelValue
}
func (query *Query) compile() (*gorm.DB, error) {
db := query.db
if query.condition != "" {
db = db.Where(query.condition, query.params...)
}
if query.fields != nil {
db = db.Select(strings.Join(query.fields, ","))
}
if query.table != "" {
db = db.Table(query.table)
}
if query.joins != nil && len(query.joins) > 0 {
for _, joinEntity := range query.joins {
cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...)
db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
}
}
if query.orderBy != nil && len(query.orderBy) > 0 {
db = db.Order(strings.Join(query.orderBy, ","))
}
if query.groupBy != nil && len(query.groupBy) > 0 {
db = db.Group(strings.Join(query.groupBy, ","))
}
if query.offset > 0 {
db = db.Offset(query.offset)
}
if query.limit > 0 {
db = db.Limit(query.limit)
}
return db, nil
}
func (query *Query) decodeValue(v any) string {
refVal := reflect.Indirect(reflect.ValueOf(v))
switch refVal.Kind() {
case reflect.Bool:
if refVal.Bool() {
return "1"
} else {
return "0"
}
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
return strconv.FormatInt(refVal.Int(), 10)
case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
case reflect.String:
return "'" + refVal.String() + "'"
default:
return fmt.Sprint(v)
}
}
func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) {
var (
sb strings.Builder
)
params = make([]interface{}, 0)
for _, cond := range conditions {
if filter {
if isEmpty(cond.Value) {
continue
}
}
if cond.Expr == "" {
cond.Expr = "="
}
switch strings.ToUpper(cond.Expr) {
case "=", "<>", ">", "<", ">=", "<=", "!=":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
if cond.Expr == "=" && cond.Value == nil {
sb.WriteString("`" + cond.Field + "` IS NULL")
} else {
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
params = append(params, cond.Value)
}
case "LIKE":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
sb.WriteString("`" + cond.Field + "` LIKE ?")
params = append(params, cond.Value)
case "IN":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
switch refVal.Kind() {
case reflect.Slice, reflect.Array:
ss := make([]string, refVal.Len())
for i := 0; i < refVal.Len(); i++ {
ss[i] = query.decodeValue(refVal.Index(i))
}
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
case reflect.String:
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
default:
}
case "BETWEEN":
refVal := reflect.ValueOf(cond.Value)
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
params = append(params, refVal.Index(0), refVal.Index(1))
}
}
}
str = sb.String()
return
}
func (query *Query) Select(fields ...string) *Query {
if query.fields == nil {
query.fields = fields
} else {
query.fields = append(query.fields, fields...)
}
return query
}
func (query *Query) From(table string) *Query {
query.table = table
return query
}
func (query *Query) LeftJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "LEFT",
Conditions: conditions,
})
return query
}
func (query *Query) RightJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "RIGHT",
Conditions: conditions,
})
return query
}
func (query *Query) InnerJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "INNER",
Conditions: conditions,
})
return query
}
func (query *Query) AddCondition(expr, column string, val any) {
if expr == "" {
expr = "="
}
query.AndWhere(newConditionWithOperator(expr, column, val))
}
func (query *Query) AndWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("AND", false, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) AndFilterWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("AND", true, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND " + cs
}
return query
}
func (query *Query) OrWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("OR", false, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) OrFilterWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("OR", true, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) GroupBy(cols ...string) *Query {
query.groupBy = append(query.groupBy, cols...)
return query
}
func (query *Query) OrderBy(col, direction string) *Query {
direction = strings.ToUpper(direction)
if direction == "" || !(direction == "ASC" || direction == "DESC") {
direction = "ASC"
}
query.orderBy = append(query.orderBy, col+" "+direction)
return query
}
func (query *Query) Offset(i int) *Query {
query.offset = i
return query
}
func (query *Query) Limit(i int) *Query {
query.limit = i
return query
}
func (query *Query) ResetSelect() *Query {
query.fields = make([]string, 0)
return query
}
func (query *Query) Count(v interface{}) (i int64) {
var (
db *gorm.DB
err error
)
if db, err = query.compile(); err != nil {
return
} else {
if v != nil {
refVal := reflect.ValueOf(v)
switch refVal.Kind() {
case reflect.String:
if query.table == "" {
err = db.Table(refVal.String()).Count(&i).Error
} else {
err = db.Table(query.table).Count(&i).Error
}
default:
//如果是报表的模型这的话手动构建一条SQL语句
if reporter, ok := v.(types.Reporter); ok {
sqlRes := &sqlCountResponse{}
childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true)
db.WithContext(childCtx).Model(reporter).First(sqlRes)
i = sqlRes.Count
} else {
err = db.Model(v).Count(&i).Error
}
}
}
}
return
}
func (query *Query) One(v interface{}) (err error) {
var (
db *gorm.DB
)
if db, err = query.compile(); err != nil {
return
} else {
err = db.First(v).Error
}
return
}
func (query *Query) All(v interface{}) (err error) {
var (
db *gorm.DB
)
if db, err = query.compile(); err != nil {
return
} else {
err = db.Find(v).Error
}
return
}
func NewCondition(column, opera string, value any) *condition {
if opera == "" {
opera = "="
}
return &condition{
Field: column,
Value: value,
Expr: opera,
}
}
func newCondition(field string, value interface{}) *condition {
return &condition{
Field: field,
Value: value,
Expr: "=",
}
}
func newConditionWithOperator(operator, field string, value interface{}) *condition {
cond := &condition{
Field: field,
Value: value,
Expr: operator,
}
return cond
}
func NewQuery(db *gorm.DB, model any) *Query {
return &Query{
db: db,
modelValue: model,
params: make([]interface{}, 0),
orderBy: make([]string, 0),
groupBy: make([]string, 0),
joins: make([]join, 0),
}
}