406 lines
8.9 KiB
Go
406 lines
8.9 KiB
Go
|
package rest
|
|||
|
|
|||
|
import (
|
|||
|
"context"
|
|||
|
"fmt"
|
|||
|
"git.nobla.cn/golang/rest/types"
|
|||
|
"gorm.io/gorm"
|
|||
|
"reflect"
|
|||
|
"strconv"
|
|||
|
"strings"
|
|||
|
)
|
|||
|
|
|||
|
type (
|
|||
|
Query struct {
|
|||
|
db *gorm.DB
|
|||
|
condition string
|
|||
|
fields []string
|
|||
|
params []interface{}
|
|||
|
table string
|
|||
|
joins []join
|
|||
|
orderBy []string
|
|||
|
groupBy []string
|
|||
|
modelValue any
|
|||
|
limit int
|
|||
|
offset int
|
|||
|
}
|
|||
|
|
|||
|
condition struct {
|
|||
|
Field string `json:"field"`
|
|||
|
Value interface{} `json:"value"`
|
|||
|
Expr string `json:"expr"`
|
|||
|
}
|
|||
|
|
|||
|
join struct {
|
|||
|
Table string
|
|||
|
Direction string
|
|||
|
Conditions []*condition
|
|||
|
}
|
|||
|
)
|
|||
|
|
|||
|
func (cond *condition) WithExpr(v string) *condition {
|
|||
|
cond.Expr = v
|
|||
|
return cond
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) Model() any {
|
|||
|
return query.modelValue
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) compile() (*gorm.DB, error) {
|
|||
|
db := query.db
|
|||
|
if query.condition != "" {
|
|||
|
db = db.Where(query.condition, query.params...)
|
|||
|
}
|
|||
|
if query.fields != nil {
|
|||
|
db = db.Select(strings.Join(query.fields, ","))
|
|||
|
}
|
|||
|
if query.table != "" {
|
|||
|
db = db.Table(query.table)
|
|||
|
}
|
|||
|
if query.joins != nil && len(query.joins) > 0 {
|
|||
|
for _, joinEntity := range query.joins {
|
|||
|
cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...)
|
|||
|
db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
|
|||
|
}
|
|||
|
}
|
|||
|
if query.orderBy != nil && len(query.orderBy) > 0 {
|
|||
|
db = db.Order(strings.Join(query.orderBy, ","))
|
|||
|
}
|
|||
|
if query.groupBy != nil && len(query.groupBy) > 0 {
|
|||
|
db = db.Group(strings.Join(query.groupBy, ","))
|
|||
|
}
|
|||
|
if query.offset > 0 {
|
|||
|
db = db.Offset(query.offset)
|
|||
|
}
|
|||
|
if query.limit > 0 {
|
|||
|
db = db.Limit(query.limit)
|
|||
|
}
|
|||
|
return db, nil
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) decodeValue(v any) string {
|
|||
|
refVal := reflect.Indirect(reflect.ValueOf(v))
|
|||
|
switch refVal.Kind() {
|
|||
|
case reflect.Bool:
|
|||
|
if refVal.Bool() {
|
|||
|
return "1"
|
|||
|
} else {
|
|||
|
return "0"
|
|||
|
}
|
|||
|
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
|||
|
return strconv.FormatInt(refVal.Int(), 10)
|
|||
|
case reflect.Float32, reflect.Float64:
|
|||
|
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
|
|||
|
case reflect.String:
|
|||
|
return "'" + refVal.String() + "'"
|
|||
|
default:
|
|||
|
return fmt.Sprint(v)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) {
|
|||
|
var (
|
|||
|
sb strings.Builder
|
|||
|
)
|
|||
|
params = make([]interface{}, 0)
|
|||
|
for _, cond := range conditions {
|
|||
|
if filter {
|
|||
|
if isEmpty(cond.Value) {
|
|||
|
continue
|
|||
|
}
|
|||
|
}
|
|||
|
if cond.Expr == "" {
|
|||
|
cond.Expr = "="
|
|||
|
}
|
|||
|
switch strings.ToUpper(cond.Expr) {
|
|||
|
case "=", "<>", ">", "<", ">=", "<=", "!=":
|
|||
|
if sb.Len() > 0 {
|
|||
|
sb.WriteString(" " + operator + " ")
|
|||
|
}
|
|||
|
if cond.Expr == "=" && cond.Value == nil {
|
|||
|
sb.WriteString("`" + cond.Field + "` IS NULL")
|
|||
|
} else {
|
|||
|
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
|
|||
|
params = append(params, cond.Value)
|
|||
|
}
|
|||
|
case "LIKE":
|
|||
|
if sb.Len() > 0 {
|
|||
|
sb.WriteString(" " + operator + " ")
|
|||
|
}
|
|||
|
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
|
|||
|
sb.WriteString("`" + cond.Field + "` LIKE ?")
|
|||
|
params = append(params, cond.Value)
|
|||
|
case "IN":
|
|||
|
if sb.Len() > 0 {
|
|||
|
sb.WriteString(" " + operator + " ")
|
|||
|
}
|
|||
|
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
|
|||
|
switch refVal.Kind() {
|
|||
|
case reflect.Slice, reflect.Array:
|
|||
|
ss := make([]string, refVal.Len())
|
|||
|
for i := 0; i < refVal.Len(); i++ {
|
|||
|
ss[i] = query.decodeValue(refVal.Index(i))
|
|||
|
}
|
|||
|
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
|
|||
|
case reflect.String:
|
|||
|
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
|
|||
|
default:
|
|||
|
}
|
|||
|
case "BETWEEN":
|
|||
|
refVal := reflect.ValueOf(cond.Value)
|
|||
|
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
|
|||
|
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
|
|||
|
params = append(params, refVal.Index(0), refVal.Index(1))
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
str = sb.String()
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) Select(fields ...string) *Query {
|
|||
|
if query.fields == nil {
|
|||
|
query.fields = fields
|
|||
|
} else {
|
|||
|
query.fields = append(query.fields, fields...)
|
|||
|
}
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) From(table string) *Query {
|
|||
|
query.table = table
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) LeftJoin(table string, conditions ...*condition) *Query {
|
|||
|
query.joins = append(query.joins, join{
|
|||
|
Table: table,
|
|||
|
Direction: "LEFT",
|
|||
|
Conditions: conditions,
|
|||
|
})
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) RightJoin(table string, conditions ...*condition) *Query {
|
|||
|
query.joins = append(query.joins, join{
|
|||
|
Table: table,
|
|||
|
Direction: "RIGHT",
|
|||
|
Conditions: conditions,
|
|||
|
})
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) InnerJoin(table string, conditions ...*condition) *Query {
|
|||
|
query.joins = append(query.joins, join{
|
|||
|
Table: table,
|
|||
|
Direction: "INNER",
|
|||
|
Conditions: conditions,
|
|||
|
})
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) AddCondition(expr, column string, val any) {
|
|||
|
if expr == "" {
|
|||
|
expr = "="
|
|||
|
}
|
|||
|
query.AndWhere(newConditionWithOperator(expr, column, val))
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) AndWhere(conditions ...*condition) *Query {
|
|||
|
length := len(conditions)
|
|||
|
if length == 0 {
|
|||
|
return query
|
|||
|
}
|
|||
|
cs, ps := query.buildConditions("AND", false, conditions...)
|
|||
|
if cs == "" {
|
|||
|
return query
|
|||
|
}
|
|||
|
query.params = append(query.params, ps...)
|
|||
|
if query.condition == "" {
|
|||
|
query.condition = cs
|
|||
|
} else {
|
|||
|
query.condition += " AND (" + cs + ")"
|
|||
|
}
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) AndFilterWhere(conditions ...*condition) *Query {
|
|||
|
length := len(conditions)
|
|||
|
if length == 0 {
|
|||
|
return query
|
|||
|
}
|
|||
|
cs, ps := query.buildConditions("AND", true, conditions...)
|
|||
|
if cs == "" {
|
|||
|
return query
|
|||
|
}
|
|||
|
query.params = append(query.params, ps...)
|
|||
|
if query.condition == "" {
|
|||
|
query.condition = cs
|
|||
|
} else {
|
|||
|
query.condition += " AND " + cs
|
|||
|
}
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) OrWhere(conditions ...*condition) *Query {
|
|||
|
length := len(conditions)
|
|||
|
if length == 0 {
|
|||
|
return query
|
|||
|
}
|
|||
|
cs, ps := query.buildConditions("OR", false, conditions...)
|
|||
|
if cs == "" {
|
|||
|
return query
|
|||
|
}
|
|||
|
query.params = append(query.params, ps...)
|
|||
|
if query.condition == "" {
|
|||
|
query.condition = cs
|
|||
|
} else {
|
|||
|
query.condition += " AND (" + cs + ")"
|
|||
|
}
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) OrFilterWhere(conditions ...*condition) *Query {
|
|||
|
length := len(conditions)
|
|||
|
if length == 0 {
|
|||
|
return query
|
|||
|
}
|
|||
|
cs, ps := query.buildConditions("OR", true, conditions...)
|
|||
|
if cs == "" {
|
|||
|
return query
|
|||
|
}
|
|||
|
query.params = append(query.params, ps...)
|
|||
|
if query.condition == "" {
|
|||
|
query.condition = cs
|
|||
|
} else {
|
|||
|
query.condition += " AND (" + cs + ")"
|
|||
|
}
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) GroupBy(cols ...string) *Query {
|
|||
|
query.groupBy = append(query.groupBy, cols...)
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) OrderBy(col, direction string) *Query {
|
|||
|
direction = strings.ToUpper(direction)
|
|||
|
if direction == "" || !(direction == "ASC" || direction == "DESC") {
|
|||
|
direction = "ASC"
|
|||
|
}
|
|||
|
query.orderBy = append(query.orderBy, col+" "+direction)
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) Offset(i int) *Query {
|
|||
|
query.offset = i
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) Limit(i int) *Query {
|
|||
|
query.limit = i
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) ResetSelect() *Query {
|
|||
|
query.fields = make([]string, 0)
|
|||
|
return query
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) Count(v interface{}) (i int64) {
|
|||
|
var (
|
|||
|
db *gorm.DB
|
|||
|
err error
|
|||
|
)
|
|||
|
if db, err = query.compile(); err != nil {
|
|||
|
return
|
|||
|
} else {
|
|||
|
if v != nil {
|
|||
|
refVal := reflect.ValueOf(v)
|
|||
|
switch refVal.Kind() {
|
|||
|
case reflect.String:
|
|||
|
if query.table == "" {
|
|||
|
err = db.Table(refVal.String()).Count(&i).Error
|
|||
|
} else {
|
|||
|
err = db.Table(query.table).Count(&i).Error
|
|||
|
}
|
|||
|
default:
|
|||
|
//如果是报表的模型,这的话手动构建一条SQL语句
|
|||
|
if reporter, ok := v.(types.Reporter); ok {
|
|||
|
sqlRes := &sqlCountResponse{}
|
|||
|
childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true)
|
|||
|
db.WithContext(childCtx).Model(reporter).First(sqlRes)
|
|||
|
i = sqlRes.Count
|
|||
|
} else {
|
|||
|
err = db.Model(v).Count(&i).Error
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) One(v interface{}) (err error) {
|
|||
|
var (
|
|||
|
db *gorm.DB
|
|||
|
)
|
|||
|
if db, err = query.compile(); err != nil {
|
|||
|
return
|
|||
|
} else {
|
|||
|
err = db.First(v).Error
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func (query *Query) All(v interface{}) (err error) {
|
|||
|
var (
|
|||
|
db *gorm.DB
|
|||
|
)
|
|||
|
if db, err = query.compile(); err != nil {
|
|||
|
return
|
|||
|
} else {
|
|||
|
err = db.Find(v).Error
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func NewCondition(column, opera string, value any) *condition {
|
|||
|
if opera == "" {
|
|||
|
opera = "="
|
|||
|
}
|
|||
|
return &condition{
|
|||
|
Field: column,
|
|||
|
Value: value,
|
|||
|
Expr: opera,
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func newCondition(field string, value interface{}) *condition {
|
|||
|
return &condition{
|
|||
|
Field: field,
|
|||
|
Value: value,
|
|||
|
Expr: "=",
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func newConditionWithOperator(operator, field string, value interface{}) *condition {
|
|||
|
cond := &condition{
|
|||
|
Field: field,
|
|||
|
Value: value,
|
|||
|
Expr: operator,
|
|||
|
}
|
|||
|
return cond
|
|||
|
}
|
|||
|
|
|||
|
func NewQuery(db *gorm.DB, model any) *Query {
|
|||
|
return &Query{
|
|||
|
db: db,
|
|||
|
modelValue: model,
|
|||
|
params: make([]interface{}, 0),
|
|||
|
orderBy: make([]string, 0),
|
|||
|
groupBy: make([]string, 0),
|
|||
|
joins: make([]join, 0),
|
|||
|
}
|
|||
|
}
|