406 lines
8.9 KiB
Go
406 lines
8.9 KiB
Go
package rest
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"git.nobla.cn/golang/rest/types"
|
||
"gorm.io/gorm"
|
||
"reflect"
|
||
"strconv"
|
||
"strings"
|
||
)
|
||
|
||
type (
|
||
Query struct {
|
||
db *gorm.DB
|
||
condition string
|
||
fields []string
|
||
params []interface{}
|
||
table string
|
||
joins []join
|
||
orderBy []string
|
||
groupBy []string
|
||
modelValue any
|
||
limit int
|
||
offset int
|
||
}
|
||
|
||
condition struct {
|
||
Field string `json:"field"`
|
||
Value interface{} `json:"value"`
|
||
Expr string `json:"expr"`
|
||
}
|
||
|
||
join struct {
|
||
Table string
|
||
Direction string
|
||
Conditions []*condition
|
||
}
|
||
)
|
||
|
||
func (cond *condition) WithExpr(v string) *condition {
|
||
cond.Expr = v
|
||
return cond
|
||
}
|
||
|
||
func (query *Query) Model() any {
|
||
return query.modelValue
|
||
}
|
||
|
||
func (query *Query) compile() (*gorm.DB, error) {
|
||
db := query.db
|
||
if query.condition != "" {
|
||
db = db.Where(query.condition, query.params...)
|
||
}
|
||
if query.fields != nil {
|
||
db = db.Select(strings.Join(query.fields, ","))
|
||
}
|
||
if query.table != "" {
|
||
db = db.Table(query.table)
|
||
}
|
||
if query.joins != nil && len(query.joins) > 0 {
|
||
for _, joinEntity := range query.joins {
|
||
cs, ps := query.buildConditions("OR", false, joinEntity.Conditions...)
|
||
db = db.Joins(joinEntity.Direction+" JOIN "+joinEntity.Table+" ON "+cs, ps...)
|
||
}
|
||
}
|
||
if query.orderBy != nil && len(query.orderBy) > 0 {
|
||
db = db.Order(strings.Join(query.orderBy, ","))
|
||
}
|
||
if query.groupBy != nil && len(query.groupBy) > 0 {
|
||
db = db.Group(strings.Join(query.groupBy, ","))
|
||
}
|
||
if query.offset > 0 {
|
||
db = db.Offset(query.offset)
|
||
}
|
||
if query.limit > 0 {
|
||
db = db.Limit(query.limit)
|
||
}
|
||
return db, nil
|
||
}
|
||
|
||
func (query *Query) decodeValue(v any) string {
|
||
refVal := reflect.Indirect(reflect.ValueOf(v))
|
||
switch refVal.Kind() {
|
||
case reflect.Bool:
|
||
if refVal.Bool() {
|
||
return "1"
|
||
} else {
|
||
return "0"
|
||
}
|
||
case reflect.Int8, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||
return strconv.FormatInt(refVal.Int(), 10)
|
||
case reflect.Float32, reflect.Float64:
|
||
return strconv.FormatFloat(refVal.Float(), 'f', -1, 64)
|
||
case reflect.String:
|
||
return "'" + refVal.String() + "'"
|
||
default:
|
||
return fmt.Sprint(v)
|
||
}
|
||
}
|
||
|
||
func (query *Query) buildConditions(operator string, filter bool, conditions ...*condition) (str string, params []interface{}) {
|
||
var (
|
||
sb strings.Builder
|
||
)
|
||
params = make([]interface{}, 0)
|
||
for _, cond := range conditions {
|
||
if filter {
|
||
if isEmpty(cond.Value) {
|
||
continue
|
||
}
|
||
}
|
||
if cond.Expr == "" {
|
||
cond.Expr = "="
|
||
}
|
||
switch strings.ToUpper(cond.Expr) {
|
||
case "=", "<>", ">", "<", ">=", "<=", "!=":
|
||
if sb.Len() > 0 {
|
||
sb.WriteString(" " + operator + " ")
|
||
}
|
||
if cond.Expr == "=" && cond.Value == nil {
|
||
sb.WriteString("`" + cond.Field + "` IS NULL")
|
||
} else {
|
||
sb.WriteString("`" + cond.Field + "` " + cond.Expr + " ?")
|
||
params = append(params, cond.Value)
|
||
}
|
||
case "LIKE":
|
||
if sb.Len() > 0 {
|
||
sb.WriteString(" " + operator + " ")
|
||
}
|
||
cond.Value = fmt.Sprintf("%%%s%%", cond.Value)
|
||
sb.WriteString("`" + cond.Field + "` LIKE ?")
|
||
params = append(params, cond.Value)
|
||
case "IN":
|
||
if sb.Len() > 0 {
|
||
sb.WriteString(" " + operator + " ")
|
||
}
|
||
refVal := reflect.Indirect(reflect.ValueOf(cond.Value))
|
||
switch refVal.Kind() {
|
||
case reflect.Slice, reflect.Array:
|
||
ss := make([]string, refVal.Len())
|
||
for i := 0; i < refVal.Len(); i++ {
|
||
ss[i] = query.decodeValue(refVal.Index(i))
|
||
}
|
||
sb.WriteString("`" + cond.Field + "` IN (" + strings.Join(ss, ",") + ")")
|
||
case reflect.String:
|
||
sb.WriteString("`" + cond.Field + "` IN (" + refVal.String() + ")")
|
||
default:
|
||
}
|
||
case "BETWEEN":
|
||
refVal := reflect.ValueOf(cond.Value)
|
||
if refVal.Kind() == reflect.Slice && refVal.Len() == 2 {
|
||
sb.WriteString("`" + cond.Field + "` BETWEEN ? AND ?")
|
||
params = append(params, refVal.Index(0), refVal.Index(1))
|
||
}
|
||
}
|
||
}
|
||
str = sb.String()
|
||
return
|
||
}
|
||
|
||
func (query *Query) Select(fields ...string) *Query {
|
||
if query.fields == nil {
|
||
query.fields = fields
|
||
} else {
|
||
query.fields = append(query.fields, fields...)
|
||
}
|
||
return query
|
||
}
|
||
|
||
func (query *Query) From(table string) *Query {
|
||
query.table = table
|
||
return query
|
||
}
|
||
|
||
func (query *Query) LeftJoin(table string, conditions ...*condition) *Query {
|
||
query.joins = append(query.joins, join{
|
||
Table: table,
|
||
Direction: "LEFT",
|
||
Conditions: conditions,
|
||
})
|
||
return query
|
||
}
|
||
|
||
func (query *Query) RightJoin(table string, conditions ...*condition) *Query {
|
||
query.joins = append(query.joins, join{
|
||
Table: table,
|
||
Direction: "RIGHT",
|
||
Conditions: conditions,
|
||
})
|
||
return query
|
||
}
|
||
|
||
func (query *Query) InnerJoin(table string, conditions ...*condition) *Query {
|
||
query.joins = append(query.joins, join{
|
||
Table: table,
|
||
Direction: "INNER",
|
||
Conditions: conditions,
|
||
})
|
||
return query
|
||
}
|
||
|
||
func (query *Query) AddCondition(expr, column string, val any) {
|
||
if expr == "" {
|
||
expr = "="
|
||
}
|
||
query.AndWhere(newConditionWithOperator(expr, column, val))
|
||
}
|
||
|
||
func (query *Query) AndWhere(conditions ...*condition) *Query {
|
||
length := len(conditions)
|
||
if length == 0 {
|
||
return query
|
||
}
|
||
cs, ps := query.buildConditions("AND", false, conditions...)
|
||
if cs == "" {
|
||
return query
|
||
}
|
||
query.params = append(query.params, ps...)
|
||
if query.condition == "" {
|
||
query.condition = cs
|
||
} else {
|
||
query.condition += " AND (" + cs + ")"
|
||
}
|
||
return query
|
||
}
|
||
|
||
func (query *Query) AndFilterWhere(conditions ...*condition) *Query {
|
||
length := len(conditions)
|
||
if length == 0 {
|
||
return query
|
||
}
|
||
cs, ps := query.buildConditions("AND", true, conditions...)
|
||
if cs == "" {
|
||
return query
|
||
}
|
||
query.params = append(query.params, ps...)
|
||
if query.condition == "" {
|
||
query.condition = cs
|
||
} else {
|
||
query.condition += " AND " + cs
|
||
}
|
||
return query
|
||
}
|
||
|
||
func (query *Query) OrWhere(conditions ...*condition) *Query {
|
||
length := len(conditions)
|
||
if length == 0 {
|
||
return query
|
||
}
|
||
cs, ps := query.buildConditions("OR", false, conditions...)
|
||
if cs == "" {
|
||
return query
|
||
}
|
||
query.params = append(query.params, ps...)
|
||
if query.condition == "" {
|
||
query.condition = cs
|
||
} else {
|
||
query.condition += " AND (" + cs + ")"
|
||
}
|
||
return query
|
||
}
|
||
|
||
func (query *Query) OrFilterWhere(conditions ...*condition) *Query {
|
||
length := len(conditions)
|
||
if length == 0 {
|
||
return query
|
||
}
|
||
cs, ps := query.buildConditions("OR", true, conditions...)
|
||
if cs == "" {
|
||
return query
|
||
}
|
||
query.params = append(query.params, ps...)
|
||
if query.condition == "" {
|
||
query.condition = cs
|
||
} else {
|
||
query.condition += " AND (" + cs + ")"
|
||
}
|
||
return query
|
||
}
|
||
|
||
func (query *Query) GroupBy(cols ...string) *Query {
|
||
query.groupBy = append(query.groupBy, cols...)
|
||
return query
|
||
}
|
||
|
||
func (query *Query) OrderBy(col, direction string) *Query {
|
||
direction = strings.ToUpper(direction)
|
||
if direction == "" || !(direction == "ASC" || direction == "DESC") {
|
||
direction = "ASC"
|
||
}
|
||
query.orderBy = append(query.orderBy, col+" "+direction)
|
||
return query
|
||
}
|
||
|
||
func (query *Query) Offset(i int) *Query {
|
||
query.offset = i
|
||
return query
|
||
}
|
||
|
||
func (query *Query) Limit(i int) *Query {
|
||
query.limit = i
|
||
return query
|
||
}
|
||
|
||
func (query *Query) ResetSelect() *Query {
|
||
query.fields = make([]string, 0)
|
||
return query
|
||
}
|
||
|
||
func (query *Query) Count(v interface{}) (i int64) {
|
||
var (
|
||
db *gorm.DB
|
||
err error
|
||
)
|
||
if db, err = query.compile(); err != nil {
|
||
return
|
||
} else {
|
||
if v != nil {
|
||
refVal := reflect.ValueOf(v)
|
||
switch refVal.Kind() {
|
||
case reflect.String:
|
||
if query.table == "" {
|
||
err = db.Table(refVal.String()).Count(&i).Error
|
||
} else {
|
||
err = db.Table(query.table).Count(&i).Error
|
||
}
|
||
default:
|
||
//如果是报表的模型,这的话手动构建一条SQL语句
|
||
if reporter, ok := v.(types.Reporter); ok {
|
||
sqlRes := &sqlCountResponse{}
|
||
childCtx := context.WithValue(db.Statement.Context, "@sql_count_statement", true)
|
||
db.WithContext(childCtx).Model(reporter).First(sqlRes)
|
||
i = sqlRes.Count
|
||
} else {
|
||
err = db.Model(v).Count(&i).Error
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (query *Query) One(v interface{}) (err error) {
|
||
var (
|
||
db *gorm.DB
|
||
)
|
||
if db, err = query.compile(); err != nil {
|
||
return
|
||
} else {
|
||
err = db.First(v).Error
|
||
}
|
||
return
|
||
}
|
||
|
||
func (query *Query) All(v interface{}) (err error) {
|
||
var (
|
||
db *gorm.DB
|
||
)
|
||
if db, err = query.compile(); err != nil {
|
||
return
|
||
} else {
|
||
err = db.Find(v).Error
|
||
}
|
||
return
|
||
}
|
||
|
||
func NewCondition(column, opera string, value any) *condition {
|
||
if opera == "" {
|
||
opera = "="
|
||
}
|
||
return &condition{
|
||
Field: column,
|
||
Value: value,
|
||
Expr: opera,
|
||
}
|
||
}
|
||
|
||
func newCondition(field string, value interface{}) *condition {
|
||
return &condition{
|
||
Field: field,
|
||
Value: value,
|
||
Expr: "=",
|
||
}
|
||
}
|
||
|
||
func newConditionWithOperator(operator, field string, value interface{}) *condition {
|
||
cond := &condition{
|
||
Field: field,
|
||
Value: value,
|
||
Expr: operator,
|
||
}
|
||
return cond
|
||
}
|
||
|
||
func NewQuery(db *gorm.DB, model any) *Query {
|
||
return &Query{
|
||
db: db,
|
||
modelValue: model,
|
||
params: make([]interface{}, 0),
|
||
orderBy: make([]string, 0),
|
||
groupBy: make([]string, 0),
|
||
joins: make([]join, 0),
|
||
}
|
||
}
|