988 lines
29 KiB
Go
988 lines
29 KiB
Go
package rest
|
||
|
||
import (
|
||
"context"
|
||
"encoding/csv"
|
||
"encoding/json"
|
||
"fmt"
|
||
"git.nobla.cn/golang/kos/util/arrays"
|
||
"git.nobla.cn/golang/kos/util/pool"
|
||
"git.nobla.cn/golang/rest/inflector"
|
||
"git.nobla.cn/golang/rest/types"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
"gorm.io/gorm/schema"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"path"
|
||
"reflect"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
type Model struct {
|
||
naming types.Naming //命名规则
|
||
value reflect.Value //模块值
|
||
db *gorm.DB //数据库
|
||
primaryKey string //主键
|
||
urlPrefix string //url前缀
|
||
disableDomain bool //禁用域
|
||
schemaLookup types.SchemaLookupFunc //获取schema的函数
|
||
valueLookup types.ValueLookupFunc //查看域
|
||
statement *gorm.Statement //字段声明
|
||
formatter *Formatter //格式化
|
||
response types.HttpWriter //HTTP响应
|
||
hookMgr *hookManager //钩子管理器
|
||
dirname string //存放文件目录
|
||
}
|
||
|
||
var (
|
||
RuntimeScopeKey = &types.RuntimeScope{}
|
||
)
|
||
|
||
// getDB 获取数据库连接对象
|
||
func (m *Model) getDB() *gorm.DB {
|
||
return m.db
|
||
}
|
||
|
||
// getFormatter 获取格式化组件
|
||
func (m *Model) getFormatter() *Formatter {
|
||
if m.formatter != nil {
|
||
return m.formatter
|
||
}
|
||
return DefaultFormatter
|
||
}
|
||
|
||
// getHook 获取钩子
|
||
func (m *Model) getHook() *hookManager {
|
||
return m.hookMgr
|
||
}
|
||
|
||
// hasScenario 判断是否有该场景
|
||
func (m *Model) hasScenario(s string) bool {
|
||
return true
|
||
}
|
||
|
||
// setValue 设置字段的值
|
||
func (m *Model) setValue(refValue reflect.Value, column string, value any) {
|
||
SetFieldValue(m.statement, refValue, column, value)
|
||
}
|
||
|
||
func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) {
|
||
SafeSetFileValue(m.statement, refValue, column, value)
|
||
}
|
||
|
||
// getValue 获取字段的值
|
||
func (m *Model) getValue(refValue reflect.Value, column string) interface{} {
|
||
return GetFieldValue(m.statement, refValue, column)
|
||
}
|
||
|
||
// hasColumn 判断指定的列是否存在
|
||
func (m *Model) hasColumn(column string) bool {
|
||
for _, field := range m.statement.Schema.Fields {
|
||
if field.DBName == column || field.Name == column {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// getFilename 获取文件存放目录
|
||
func (m *Model) getFilename(domain string, spec string, name string) string {
|
||
if m.dirname == "" {
|
||
m.dirname = os.TempDir()
|
||
}
|
||
filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name)
|
||
if _, err := os.Stat(path.Dir(filename)); err != nil {
|
||
_ = os.MkdirAll(path.Dir(filename), 0755)
|
||
}
|
||
return filename
|
||
}
|
||
|
||
// findPrimaryKey 查找主键的值
|
||
func (m *Model) findPrimaryKey(uri string, r *http.Request) string {
|
||
var (
|
||
pos int
|
||
)
|
||
urlPath := r.URL.Path
|
||
pos = strings.IndexByte(uri, ':')
|
||
if pos > 0 {
|
||
return urlPath[pos:]
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// parseReportColumn 解析报表的列
|
||
func (m *Model) parseReportColumn(name, props string) *types.SelectColumn {
|
||
var (
|
||
key string
|
||
value string
|
||
)
|
||
column := &types.SelectColumn{
|
||
Name: inflector.Camel2id(name),
|
||
Native: false,
|
||
}
|
||
tokens := strings.Split(props, ";")
|
||
for _, token := range tokens {
|
||
pair := strings.SplitN(token, ":", 2)
|
||
if len(pair) == 0 {
|
||
continue
|
||
}
|
||
if len(pair) == 1 {
|
||
key = strings.TrimSpace(pair[0])
|
||
value = ""
|
||
} else {
|
||
key = strings.TrimSpace(pair[0])
|
||
value = strings.TrimSpace(pair[1])
|
||
}
|
||
switch key {
|
||
case "native":
|
||
column.Native = true
|
||
case "name":
|
||
column.Name = value
|
||
case "expr":
|
||
column.Expr = value
|
||
}
|
||
}
|
||
return column
|
||
}
|
||
|
||
func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) {
|
||
modelType := reflect.ValueOf(dest).Type()
|
||
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||
modelType = modelType.Elem()
|
||
}
|
||
columns := make([]string, 0)
|
||
for i := 0; i < modelType.NumField(); i++ {
|
||
field := modelType.Field(i)
|
||
scenarios := field.Tag.Get("scenarios")
|
||
if !hasToken(types.ScenarioList, scenarios) {
|
||
continue
|
||
}
|
||
isPrimary := field.Tag.Get("is_primary")
|
||
if isPrimary != "true" {
|
||
continue
|
||
}
|
||
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
|
||
if !column.Native {
|
||
continue
|
||
}
|
||
if column.Expr == "" {
|
||
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
|
||
} else {
|
||
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
|
||
}
|
||
}
|
||
columns = append(columns, "COUNT(*) AS count")
|
||
query.Select(columns...)
|
||
}
|
||
|
||
func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) {
|
||
modelType := reflect.ValueOf(dest).Type()
|
||
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||
modelType = modelType.Elem()
|
||
}
|
||
columns := make([]string, 0)
|
||
for i := 0; i < modelType.NumField(); i++ {
|
||
field := modelType.Field(i)
|
||
scenarios := field.Tag.Get("scenarios")
|
||
if !hasToken(types.ScenarioList, scenarios) {
|
||
continue
|
||
}
|
||
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
|
||
if !column.Native {
|
||
continue
|
||
}
|
||
if column.Expr == "" {
|
||
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
|
||
} else {
|
||
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
|
||
}
|
||
}
|
||
query.Select(columns...)
|
||
}
|
||
|
||
// buildCondition 构建sql条件
|
||
func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
|
||
return BuildConditions(ctx, r, query, schemas)
|
||
}
|
||
|
||
// ModuleName 模块名称
|
||
func (m *Model) ModuleName() string {
|
||
return m.naming.ModuleName
|
||
}
|
||
|
||
// TableName 表的名称
|
||
func (m *Model) TableName() string {
|
||
return m.naming.ModuleName
|
||
}
|
||
|
||
// Fields 返回搜索的模型的字段
|
||
func (m *Model) Fields() []*schema.Field {
|
||
return m.statement.Schema.Fields
|
||
}
|
||
|
||
// Uri 获取请求的uri
|
||
func (m *Model) Uri(scenario string) string {
|
||
ss := make([]string, 4)
|
||
if m.urlPrefix != "" {
|
||
ss = append(ss, m.urlPrefix)
|
||
}
|
||
switch scenario {
|
||
case types.ScenarioList:
|
||
ss = append(ss, m.naming.ModuleName, m.naming.Pluralize)
|
||
case types.ScenarioView:
|
||
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||
case types.ScenarioCreate:
|
||
ss = append(ss, m.naming.ModuleName, m.naming.Singular)
|
||
case types.ScenarioUpdate:
|
||
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||
case types.ScenarioDelete:
|
||
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
|
||
case types.ScenarioExport:
|
||
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export")
|
||
case types.ScenarioImport:
|
||
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import")
|
||
}
|
||
uri := path.Join(ss...)
|
||
if !strings.HasPrefix(uri, "/") {
|
||
uri = "/" + uri
|
||
}
|
||
return uri
|
||
}
|
||
|
||
// Method 获取HTTP请求的方法
|
||
func (m *Model) Method(scenario string) string {
|
||
var (
|
||
method = http.MethodGet
|
||
)
|
||
switch scenario {
|
||
case types.ScenarioCreate:
|
||
method = http.MethodPost
|
||
case types.ScenarioUpdate:
|
||
method = http.MethodPut
|
||
case types.ScenarioDelete:
|
||
method = http.MethodDelete
|
||
}
|
||
return method
|
||
}
|
||
|
||
// Search 实现通过HTTP方法查找数据
|
||
func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
|
||
var (
|
||
ok bool
|
||
err error
|
||
qs url.Values
|
||
page int
|
||
pageSize int
|
||
pageIndex int
|
||
query *Query
|
||
domainName string
|
||
modelSlices reflect.Value
|
||
modelValues reflect.Value
|
||
searchSchemas []*types.Schema
|
||
listSchemas []*types.Schema
|
||
modelValue reflect.Value
|
||
scenario string
|
||
reporter types.Reporter
|
||
namerTable tableNamer
|
||
)
|
||
qs = r.URL.Query()
|
||
page, _ = strconv.Atoi(qs.Get("page"))
|
||
pageSize, _ = strconv.Atoi(qs.Get("pagesize"))
|
||
if pageSize <= 0 {
|
||
pageSize = defaultPageSize
|
||
}
|
||
pageIndex = page
|
||
if pageIndex > 0 {
|
||
pageIndex--
|
||
}
|
||
modelValue = reflect.New(m.value.Type())
|
||
//这里创建指针类型,这样的话就能在format里面调用函数
|
||
if m.value.Kind() != reflect.Ptr {
|
||
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
|
||
} else {
|
||
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
|
||
}
|
||
modelValues = reflect.New(modelSlices.Type())
|
||
modelValues.Elem().Set(modelSlices)
|
||
query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface())
|
||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||
Domain: domainName,
|
||
Request: r,
|
||
User: m.valueLookup("user", w, r),
|
||
ModuleName: m.naming.ModuleName,
|
||
TableName: m.naming.TableName,
|
||
Scenario: types.ScenarioList,
|
||
})
|
||
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
scenario = types.ScenarioList
|
||
if arrays.Exists(r.FormValue("scenario"), allowScenario) {
|
||
scenario = r.FormValue("scenario")
|
||
}
|
||
if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
if !m.disableDomain {
|
||
if m.hasColumn(types.FieldDomain) {
|
||
query.AndWhere(newCondition(types.FieldDomain, domainName))
|
||
}
|
||
}
|
||
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
|
||
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||
return
|
||
}
|
||
// 处理表名逻辑
|
||
if namerTable, ok = query.Model().(tableNamer); ok {
|
||
query.From(namerTable.HttpTableName(r))
|
||
}
|
||
//处理报表逻辑
|
||
if reporter, ok = modelValue.Interface().(types.Reporter); ok {
|
||
query.From(reporter.RealTable())
|
||
}
|
||
res := &types.ListResponse{
|
||
Page: page,
|
||
PageSize: pageSize,
|
||
}
|
||
if reporter == nil {
|
||
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
|
||
} else {
|
||
//如果是报表的情况,需要手动指定COUNT的雨具逻辑才能生效
|
||
m.buildReporterCountColumns(childCtx, reporter, query)
|
||
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
|
||
|
||
//这里需要重置一下选项,不然会出问题
|
||
query.ResetSelect()
|
||
query.GroupBy(reporter.GroupBy(childCtx)...)
|
||
}
|
||
query.Offset(pageIndex * pageSize).Limit(pageSize)
|
||
if res.TotalCount > 0 {
|
||
if reporter != nil {
|
||
m.buildReporterQueryColumns(childCtx, reporter, query)
|
||
}
|
||
if err = query.All(modelValues.Interface()); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
// 不进行格式化输出
|
||
res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format"))
|
||
} else {
|
||
res.Data = make([]string, 0)
|
||
}
|
||
m.response.Success(w, res)
|
||
}
|
||
|
||
// Create 实现通过HTTP方法创建模型
|
||
func (m *Model) Create(w http.ResponseWriter, r *http.Request) {
|
||
var (
|
||
err error
|
||
model any
|
||
schemas []*types.Schema
|
||
diffAttrs []*types.DiffAttr
|
||
domainName string
|
||
modelValue reflect.Value
|
||
)
|
||
modelValue = reflect.New(m.value.Type())
|
||
model = modelValue.Interface()
|
||
if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil {
|
||
m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil)
|
||
return
|
||
}
|
||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
if !m.disableDomain {
|
||
if m.hasColumn(types.FieldDomain) {
|
||
m.setValue(modelValue, types.FieldDomain, domainName)
|
||
}
|
||
}
|
||
diffAttrs = make([]*types.DiffAttr, 0, 10)
|
||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||
Domain: domainName,
|
||
User: m.valueLookup("user", w, r),
|
||
Request: r,
|
||
ModuleName: m.naming.ModuleName,
|
||
TableName: m.naming.TableName,
|
||
Scenario: types.ScenarioCreate,
|
||
Schemas: schemas,
|
||
})
|
||
dbSess := m.getDB().WithContext(childCtx)
|
||
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||
if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil {
|
||
return
|
||
}
|
||
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
|
||
return
|
||
}
|
||
if tabler, ok := model.(types.Tabler); ok {
|
||
errTx = tx.Table(tabler.TableName()).Save(model).Error
|
||
} else {
|
||
errTx = tx.Save(model).Error
|
||
}
|
||
if errTx != nil {
|
||
return
|
||
}
|
||
for _, row := range schemas {
|
||
diffAttrs = append(diffAttrs, &types.DiffAttr{
|
||
Column: row.Column,
|
||
Label: row.Label,
|
||
OldValue: nil,
|
||
NewValue: m.getValue(modelValue, row.Column),
|
||
})
|
||
}
|
||
return
|
||
}); err == nil {
|
||
res := &types.CreateResponse{
|
||
ID: m.getValue(modelValue, m.primaryKey),
|
||
Status: "created",
|
||
}
|
||
if creator, ok := model.(afterCreated); ok {
|
||
creator.AfterCreated(childCtx, dbSess)
|
||
}
|
||
if preserver, ok := model.(afterSaved); ok {
|
||
preserver.AfterSaved(childCtx, dbSess)
|
||
}
|
||
m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs)
|
||
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
|
||
m.response.Success(w, res)
|
||
} else {
|
||
m.response.Failure(w, types.RequestCreateFailure, err.Error(), err)
|
||
}
|
||
}
|
||
|
||
// Update 实现通过HTTP方法更新模型
|
||
func (m *Model) Update(w http.ResponseWriter, r *http.Request) {
|
||
var (
|
||
err error
|
||
model any
|
||
schemas []*types.Schema
|
||
diffAttrs []*types.DiffAttr
|
||
domainName string
|
||
modelValue reflect.Value
|
||
oldValues map[string]any
|
||
)
|
||
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
|
||
modelValue = reflect.New(m.value.Type())
|
||
model = modelValue.Interface()
|
||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
conditions := map[string]any{
|
||
m.primaryKey: idStr,
|
||
}
|
||
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
oldValues = make(map[string]any)
|
||
for _, row := range schemas {
|
||
oldValues[row.Column] = m.getValue(modelValue, row.Column)
|
||
}
|
||
if err = json.NewDecoder(r.Body).Decode(model); err != nil {
|
||
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||
return
|
||
}
|
||
diffAttrs = make([]*types.DiffAttr, 0, 10)
|
||
updates := make(map[string]any)
|
||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||
Domain: domainName,
|
||
Request: r,
|
||
User: m.valueLookup("user", w, r),
|
||
ModuleName: m.naming.ModuleName,
|
||
TableName: m.naming.TableName,
|
||
Scenario: types.ScenarioUpdate,
|
||
Schemas: schemas,
|
||
PrimaryKeyValue: idStr,
|
||
})
|
||
dbSess := m.getDB().WithContext(childCtx)
|
||
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||
if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil {
|
||
return
|
||
}
|
||
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
|
||
return
|
||
}
|
||
for _, row := range schemas {
|
||
v := m.getValue(modelValue, row.Column)
|
||
if oldValues[row.Column] != v {
|
||
updates[row.Column] = v
|
||
diffAttrs = append(diffAttrs, &types.DiffAttr{
|
||
Column: row.Column,
|
||
Label: row.Label,
|
||
OldValue: oldValues[row.Column],
|
||
NewValue: v,
|
||
})
|
||
}
|
||
}
|
||
if len(updates) > 0 {
|
||
if tabler, ok := model.(types.Tabler); ok {
|
||
errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error
|
||
} else {
|
||
errTx = tx.Model(model).Updates(updates).Error
|
||
}
|
||
if errTx != nil {
|
||
return
|
||
}
|
||
}
|
||
return
|
||
}); err == nil {
|
||
if updater, ok := model.(afterUpdated); ok {
|
||
updater.AfterUpdated(childCtx, dbSess)
|
||
}
|
||
if preserver, ok := model.(afterSaved); ok {
|
||
preserver.AfterSaved(childCtx, dbSess)
|
||
}
|
||
m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs)
|
||
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
|
||
m.response.Success(w, types.UpdateResponse{
|
||
ID: idStr,
|
||
Status: "updated",
|
||
})
|
||
} else {
|
||
m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil)
|
||
}
|
||
}
|
||
|
||
// Delete 实现通过HTTP方法删除模型
|
||
func (m *Model) Delete(w http.ResponseWriter, r *http.Request) {
|
||
var (
|
||
err error
|
||
model any
|
||
modelValue reflect.Value
|
||
)
|
||
idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r)
|
||
modelValue = reflect.New(m.value.Type())
|
||
model = modelValue.Interface()
|
||
conditions := map[string]any{
|
||
m.primaryKey: idStr,
|
||
}
|
||
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||
Domain: m.valueLookup(types.FieldDomain, w, r),
|
||
User: m.valueLookup("user", w, r),
|
||
Request: r,
|
||
ModuleName: m.naming.ModuleName,
|
||
TableName: m.naming.TableName,
|
||
Scenario: types.ScenarioDelete,
|
||
PrimaryKeyValue: idStr,
|
||
})
|
||
dbSess := m.getDB().WithContext(childCtx)
|
||
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||
if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil {
|
||
return
|
||
}
|
||
if tabler, ok := model.(types.Tabler); ok {
|
||
errTx = tx.Table(tabler.TableName()).Delete(model).Error
|
||
} else {
|
||
errTx = tx.Delete(model).Error
|
||
}
|
||
if errTx != nil {
|
||
return
|
||
}
|
||
m.getHook().afterDelete(childCtx, tx, model)
|
||
return
|
||
}); err == nil {
|
||
m.response.Success(w, types.DeleteResponse{
|
||
ID: idStr,
|
||
Status: "deleted",
|
||
})
|
||
} else {
|
||
m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil)
|
||
}
|
||
}
|
||
|
||
// View 查看数据详情
|
||
func (m *Model) View(w http.ResponseWriter, r *http.Request) {
|
||
var (
|
||
err error
|
||
model any
|
||
modelValue reflect.Value
|
||
qs url.Values
|
||
schemas []*types.Schema
|
||
scenario string
|
||
domainName string
|
||
)
|
||
qs = r.URL.Query()
|
||
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
|
||
modelValue = reflect.New(m.value.Type())
|
||
model = modelValue.Interface()
|
||
conditions := map[string]any{
|
||
m.primaryKey: idStr,
|
||
}
|
||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
scenario = qs.Get("scenario")
|
||
if scenario == "" {
|
||
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView)
|
||
} else {
|
||
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario)
|
||
}
|
||
if err == nil {
|
||
m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format")))
|
||
} else {
|
||
m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil)
|
||
}
|
||
}
|
||
|
||
// Export 实现通过HTTP方法导出模型
|
||
func (m *Model) Export(w http.ResponseWriter, r *http.Request) {
|
||
var (
|
||
err error
|
||
query *Query
|
||
modelSlices reflect.Value
|
||
modelValues reflect.Value
|
||
searchSchemas []*types.Schema
|
||
exportSchemas []*types.Schema
|
||
domainName string
|
||
fp *os.File
|
||
modelValue reflect.Value
|
||
)
|
||
if !m.hasScenario(types.ScenarioList) {
|
||
m.response.Failure(w, types.RequestDenied, "request denied", nil)
|
||
return
|
||
}
|
||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||
filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
|
||
if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil {
|
||
m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil)
|
||
return
|
||
}
|
||
defer func() {
|
||
_ = fp.Close()
|
||
}()
|
||
modelValue = reflect.New(m.value.Type())
|
||
//这里创建指针类型,这样的话就能在format里面调用函数
|
||
if m.value.Kind() != reflect.Ptr {
|
||
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
|
||
} else {
|
||
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
|
||
}
|
||
modelValues = reflect.New(modelSlices.Type())
|
||
modelValues.Elem().Set(modelSlices)
|
||
query = NewQuery(m.getDB(), modelValue.Interface())
|
||
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
|
||
Domain: domainName,
|
||
Request: r,
|
||
User: m.valueLookup("user", w, r),
|
||
ModuleName: m.naming.ModuleName,
|
||
TableName: m.naming.TableName,
|
||
Scenario: types.ScenarioExport,
|
||
})
|
||
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
|
||
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
|
||
return
|
||
}
|
||
if !m.disableDomain {
|
||
if m.hasColumn(types.FieldDomain) {
|
||
query.AndWhere(newCondition(types.FieldDomain, domainName))
|
||
}
|
||
}
|
||
// 处理表名逻辑
|
||
if namerTable, ok := query.Model().(tableNamer); ok {
|
||
query.From(namerTable.HttpTableName(r))
|
||
}
|
||
//处理报表逻辑
|
||
if reporter, ok := modelValue.Interface().(types.Reporter); ok {
|
||
query.From(reporter.RealTable())
|
||
query.GroupBy(reporter.GroupBy(childCtx)...)
|
||
m.buildReporterQueryColumns(childCtx, reporter, query)
|
||
}
|
||
if err = query.All(modelValues.Interface()); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
w.Header().Set("Content-Type", "text/csv")
|
||
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
|
||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
|
||
value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "")
|
||
writer := csv.NewWriter(fp)
|
||
rows := make([]string, len(exportSchemas))
|
||
for i, field := range exportSchemas {
|
||
rows[i] = field.Label
|
||
}
|
||
_ = writer.Write(rows)
|
||
if values, ok := value.([]any); ok {
|
||
for _, val := range values {
|
||
row, ok2 := val.(map[string]any)
|
||
if !ok2 {
|
||
continue
|
||
}
|
||
for i, field := range exportSchemas {
|
||
if v, ok := row[field.Column]; ok {
|
||
rows[i] = fmt.Sprint(v)
|
||
} else {
|
||
rows[i] = ""
|
||
}
|
||
}
|
||
_ = writer.Write(rows)
|
||
}
|
||
}
|
||
writer.Flush()
|
||
m.getHook().afterExport(childCtx, filename)
|
||
http.ServeContent(w, r, path.Base(filename), time.Now(), fp)
|
||
}
|
||
|
||
// findSchema 查找指定的schema
|
||
func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema {
|
||
for _, row := range schemas {
|
||
if row.Label == label {
|
||
return row
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// importInternal 文件上传方法
|
||
func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) {
|
||
var (
|
||
err error
|
||
rows []string
|
||
fp *os.File
|
||
tm time.Time
|
||
fields []string
|
||
sess *gorm.DB
|
||
csvReader *csv.Reader
|
||
csvWriter *csv.Writer
|
||
modelValue reflect.Value
|
||
modelEntity any
|
||
diffAttrs []*types.DiffAttr
|
||
result *types.ImportResult
|
||
failureFp *os.File
|
||
failureFile string
|
||
)
|
||
tm = time.Now()
|
||
result = &types.ImportResult{}
|
||
if fp, err = os.Open(filename); err != nil {
|
||
result.Code = types.ErrImportFileNotExists
|
||
goto __end
|
||
}
|
||
defer func() {
|
||
_ = fp.Close()
|
||
}()
|
||
csvReader = csv.NewReader(fp)
|
||
if rows, err = csvReader.Read(); err != nil {
|
||
result.Code = types.ErrImportFileUnavailable
|
||
goto __end
|
||
}
|
||
fields = make([]string, 0, len(rows))
|
||
for _, s := range rows {
|
||
v := m.findSchema(s, schemas)
|
||
if v == nil {
|
||
result.Code = types.ErrImportColumnNotMatch
|
||
goto __end
|
||
}
|
||
fields = append(fields, v.Column)
|
||
}
|
||
sess = m.getDB().WithContext(ctx)
|
||
//失败文件指针
|
||
failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix()))
|
||
if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil {
|
||
return
|
||
}
|
||
defer func() {
|
||
_ = failureFp.Close()
|
||
}()
|
||
csvWriter = csv.NewWriter(failureFp)
|
||
rows = append(rows, "Error")
|
||
_ = csvWriter.Write(rows)
|
||
diffAttrs = make([]*types.DiffAttr, len(schemas))
|
||
for {
|
||
if rows, err = csvReader.Read(); err != nil {
|
||
break
|
||
}
|
||
result.TotalCount++
|
||
if len(rows) != len(fields) {
|
||
continue
|
||
}
|
||
modelValue = reflect.New(m.value.Type())
|
||
for idx, field := range fields {
|
||
m.safeSetValue(modelValue, field, rows[idx])
|
||
}
|
||
if len(extraFields) > 0 {
|
||
for k, v := range extraFields {
|
||
m.safeSetValue(modelValue, k, v)
|
||
}
|
||
}
|
||
modelEntity = modelValue.Interface()
|
||
//写入数据
|
||
if fast {
|
||
//如果是快速模式,直接存储数据
|
||
if err = sess.Save(modelEntity).Error; err == nil {
|
||
result.SuccessCount++
|
||
} else {
|
||
rows = append(rows, err.Error())
|
||
_ = csvWriter.Write(rows)
|
||
}
|
||
} else {
|
||
if err = sess.Transaction(func(tx *gorm.DB) (errTx error) {
|
||
if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil {
|
||
return
|
||
}
|
||
if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil {
|
||
return
|
||
}
|
||
if tabler, ok := modelEntity.(types.Tabler); ok {
|
||
errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error
|
||
} else {
|
||
errTx = tx.Save(modelEntity).Error
|
||
}
|
||
if errTx != nil {
|
||
return
|
||
}
|
||
for idx, row := range schemas {
|
||
diffAttrs[idx] = &types.DiffAttr{
|
||
Column: row.Column,
|
||
Label: row.Label,
|
||
NewValue: m.getValue(modelValue, row.Column),
|
||
}
|
||
}
|
||
m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs)
|
||
m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs)
|
||
return
|
||
}); err == nil {
|
||
result.SuccessCount++
|
||
} else {
|
||
rows = append(rows, err.Error())
|
||
_ = csvWriter.Write(rows)
|
||
}
|
||
}
|
||
}
|
||
csvWriter.Flush()
|
||
__end:
|
||
result.UploadFile = filename
|
||
if result.TotalCount > result.SuccessCount {
|
||
result.FailureFile = failureFile
|
||
}
|
||
result.Duration = time.Now().Sub(tm)
|
||
m.getHook().afterImport(ctx, result)
|
||
}
|
||
|
||
// Import 实现通过HTTP方法导入
|
||
func (m *Model) Import(w http.ResponseWriter, r *http.Request) {
|
||
var (
|
||
err error
|
||
fast bool
|
||
schemas []*types.Schema
|
||
rows []string
|
||
domainName string
|
||
dst *os.File
|
||
fp multipart.File
|
||
csvWriter *csv.Writer
|
||
qs url.Values
|
||
extraFields map[string]string
|
||
)
|
||
domainName = m.valueLookup(types.FieldDomain, w, r)
|
||
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
|
||
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
|
||
return
|
||
}
|
||
//这里用background的context
|
||
childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{
|
||
Domain: domainName,
|
||
User: m.valueLookup("user", w, r),
|
||
ModuleName: m.naming.ModuleName,
|
||
TableName: m.naming.TableName,
|
||
Scenario: types.ScenarioImport,
|
||
Schemas: schemas,
|
||
})
|
||
if r.Method == http.MethodGet {
|
||
//下载导入模板
|
||
csvWriter = csv.NewWriter(w)
|
||
rows = make([]string, 0, len(schemas))
|
||
for _, row := range schemas {
|
||
//主键不需要导入
|
||
if row.IsPrimaryKey == 1 {
|
||
continue
|
||
}
|
||
rows = append(rows, row.Label)
|
||
}
|
||
w.Header().Set("Content-Type", "text/csv")
|
||
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
|
||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
|
||
err = csvWriter.Write(rows)
|
||
csvWriter.Flush()
|
||
return
|
||
}
|
||
filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
|
||
if fp, _, err = r.FormFile("file"); err != nil {
|
||
m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil)
|
||
return
|
||
}
|
||
defer func() {
|
||
_ = fp.Close()
|
||
}()
|
||
if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil {
|
||
buf := pool.GetBytes(32 * 1024)
|
||
_, err = io.CopyBuffer(dst, fp, buf)
|
||
pool.PutBytes(buf)
|
||
_ = dst.Close()
|
||
} else {
|
||
m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil)
|
||
return
|
||
}
|
||
qs = r.URL.Query()
|
||
if qs != nil {
|
||
extraFields = make(map[string]string)
|
||
for k, _ := range qs {
|
||
if strings.HasPrefix(k, "_attr_") {
|
||
extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k)
|
||
}
|
||
}
|
||
}
|
||
fast, _ = strconv.ParseBool(qs.Get("__fast"))
|
||
go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields)
|
||
m.response.Success(w, types.ImportResponse{
|
||
UID: m.valueLookup("user", w, r),
|
||
Status: "committed",
|
||
})
|
||
}
|
||
|
||
// newModel 创建一个模型
|
||
func newModel(v any, db *gorm.DB, naming types.Naming) *Model {
|
||
model := &Model{
|
||
db: db,
|
||
naming: naming,
|
||
response: &httpWriter{},
|
||
value: reflect.Indirect(reflect.ValueOf(v)),
|
||
valueLookup: defaultValueLookup,
|
||
}
|
||
model.statement = &gorm.Statement{
|
||
DB: model.getDB(),
|
||
ConnPool: model.getDB().ConnPool,
|
||
Clauses: map[string]clause.Clause{},
|
||
}
|
||
if err := model.statement.Parse(v); err == nil {
|
||
if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 {
|
||
model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0]
|
||
}
|
||
}
|
||
return model
|
||
}
|