rest/query.go

406 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package rest
import (
"context"
"fmt"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"reflect"
"strconv"
"strings"
)
type (
Query struct {
db *gorm.DB
condition string
fields []string
params []interface{}
table string
joins []join
orderBy []string
groupBy []string
modelValue any
limit int
offset int
}
condition struct {
Field string `json:"field"`
Value interface{} `json:"value"`
Expr string `json:"expr"`
}
join struct {
Table string
Direction string
Conditions []*condition
}
)
func (cond *condition) WithExpr(v string) *condition {
cond.Expr = v
return cond
}
func (query *Query) Model() any {
return query.modelValue
}
func (query *Query) compile() (*gorm.DB, error) {
db := query.db
if query.condition != "" {
db = db.Where(query.condition, query.params...)
}
if query.fields != nil {
db = db.Select(strings.Join(query.fields, ","))
}
if query.table != "" {
db = db.Table(query.table)
}
if query.joins != nil && len(query.joins) > 0 {
for _, joinEntity := range query.joins {
cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...)
db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
}
}
if query.orderBy != nil && len(query.orderBy) > 0 {
db = db.Order(strings.Join(query.orderBy, ","))
}
if query.groupBy != nil && len(query.groupBy) > 0 {
db = db.Group(strings.Join(query.groupBy, ","))
}
if query.offset > 0 {
db = db.Offset(query.offset)
}
if query.limit > 0 {
db = db.Limit(query.limit)
}
return db, nil
}
func (query *Query) decodeValue(v any) string {
refVal := reflect.Indirect(reflect.ValueOf(v))
switch refVal.Kind() {
case reflect.Bool:
if refVal.Bool() {
return "1"
} else {
return "0"
}
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
return strconv.FormatInt(refVal.Int(), 10)
case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
case reflect.String:
return "'" + refVal.String() + "'"
default:
return fmt.Sprint(v)
}
}
func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) {
var (
sb strings.Builder
)
params = make([]interface{}, 0)
for _, cond := range conditions {
if filter {
if isEmpty(cond.Value) {
continue
}
}
if cond.Expr == "" {
cond.Expr = "="
}
switch strings.ToUpper(cond.Expr) {
case "=", "<>", ">", "<", ">=", "<=", "!=":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
if cond.Expr == "=" && cond.Value == nil {
sb.WriteString("`" + cond.Field + "` IS NULL")
} else {
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
params = append(params, cond.Value)
}
case "LIKE":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
sb.WriteString("`" + cond.Field + "` LIKE ?")
params = append(params, cond.Value)
case "IN":
if sb.Len() > 0 {
sb.WriteString(" " + operator + " ")
}
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
switch refVal.Kind() {
case reflect.Slice, reflect.Array:
ss := make([]string, refVal.Len())
for i := 0; i < refVal.Len(); i++ {
ss[i] = query.decodeValue(refVal.Index(i))
}
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
case reflect.String:
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
default:
}
case "BETWEEN":
refVal := reflect.ValueOf(cond.Value)
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
params = append(params, refVal.Index(0), refVal.Index(1))
}
}
}
str = sb.String()
return
}
func (query *Query) Select(fields ...string) *Query {
if query.fields == nil {
query.fields = fields
} else {
query.fields = append(query.fields, fields...)
}
return query
}
func (query *Query) From(table string) *Query {
query.table = table
return query
}
func (query *Query) LeftJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "LEFT",
Conditions: conditions,
})
return query
}
func (query *Query) RightJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "RIGHT",
Conditions: conditions,
})
return query
}
func (query *Query) InnerJoin(table string, conditions ...*condition) *Query {
query.joins = append(query.joins, join{
Table: table,
Direction: "INNER",
Conditions: conditions,
})
return query
}
func (query *Query) AddCondition(expr, column string, val any) {
if expr == "" {
expr = "="
}
query.AndWhere(newConditionWithOperator(expr, column, val))
}
func (query *Query) AndWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("AND", false, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) AndFilterWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("AND", true, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND " + cs
}
return query
}
func (query *Query) OrWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("OR", false, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) OrFilterWhere(conditions ...*condition) *Query {
length := len(conditions)
if length == 0 {
return query
}
cs, ps := query.buildConditions("OR", true, conditions...)
if cs == "" {
return query
}
query.params = append(query.params, ps...)
if query.condition == "" {
query.condition = cs
} else {
query.condition += " AND (" + cs + ")"
}
return query
}
func (query *Query) GroupBy(cols ...string) *Query {
query.groupBy = append(query.groupBy, cols...)
return query
}
func (query *Query) OrderBy(col, direction string) *Query {
direction = strings.ToUpper(direction)
if direction == "" || !(direction == "ASC" || direction == "DESC") {
direction = "ASC"
}
query.orderBy = append(query.orderBy, col+" "+direction)
return query
}
func (query *Query) Offset(i int) *Query {
query.offset = i
return query
}
func (query *Query) Limit(i int) *Query {
query.limit = i
return query
}
func (query *Query) ResetSelect() *Query {
query.fields = make([]string, 0)
return query
}
func (query *Query) Count(v interface{}) (i int64) {
var (
db *gorm.DB
err error
)
if db, err = query.compile(); err != nil {
return
} else {
if v != nil {
refVal := reflect.ValueOf(v)
switch refVal.Kind() {
case reflect.String:
if query.table == "" {
err = db.Table(refVal.String()).Count(&i).Error
} else {
err = db.Table(query.table).Count(&i).Error
}
default:
//如果是报表的模型这的话手动构建一条SQL语句
if reporter, ok := v.(types.Reporter); ok {
sqlRes := &sqlCountResponse{}
childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true)
db.WithContext(childCtx).Model(reporter).First(sqlRes)
i = sqlRes.Count
} else {
err = db.Model(v).Count(&i).Error
}
}
}
}
return
}
func (query *Query) One(v interface{}) (err error) {
var (
db *gorm.DB
)
if db, err = query.compile(); err != nil {
return
} else {
err = db.First(v).Error
}
return
}
func (query *Query) All(v interface{}) (err error) {
var (
db *gorm.DB
)
if db, err = query.compile(); err != nil {
return
} else {
err = db.Find(v).Error
}
return
}
func NewCondition(column, opera string, value any) *condition {
if opera == "" {
opera = "="
}
return &condition{
Field: column,
Value: value,
Expr: opera,
}
}
func newCondition(field string, value interface{}) *condition {
return &condition{
Field: field,
Value: value,
Expr: "=",
}
}
func newConditionWithOperator(operator, field string, value interface{}) *condition {
cond := &condition{
Field: field,
Value: value,
Expr: operator,
}
return cond
}
func NewQuery(db *gorm.DB, model any) *Query {
return &Query{
db: db,
modelValue: model,
params: make([]interface{}, 0),
orderBy: make([]string, 0),
groupBy: make([]string, 0),
joins: make([]join, 0),
}
}