rest/model.go

988 lines
29 KiB
Go
Raw Normal View History

2024-12-11 17:29:01 +08:00
package rest
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"git.nobla.cn/golang/kos/util/arrays"
"git.nobla.cn/golang/kos/util/pool"
"git.nobla.cn/golang/rest/inflector"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path"
"reflect"
"strconv"
"strings"
"time"
)
type Model struct {
naming types.Naming //命名规则
value reflect.Value //模块值
db *gorm.DB //数据库
primaryKey string //主键
urlPrefix string //url前缀
disableDomain bool //禁用域
schemaLookup types.SchemaLookupFunc //获取schema的函数
valueLookup types.ValueLookupFunc //查看域
statement *gorm.Statement //字段声明
formatter *Formatter //格式化
response types.HttpWriter //HTTP响应
hookMgr *hookManager //钩子管理器
dirname string //存放文件目录
}
var (
RuntimeScopeKey = &types.RuntimeScope{}
)
// getDB 获取数据库连接对象
func (m *Model) getDB() *gorm.DB {
return m.db
}
// getFormatter 获取格式化组件
func (m *Model) getFormatter() *Formatter {
if m.formatter != nil {
return m.formatter
}
return DefaultFormatter
}
// getHook 获取钩子
func (m *Model) getHook() *hookManager {
return m.hookMgr
}
// hasScenario 判断是否有该场景
func (m *Model) hasScenario(s string) bool {
return true
}
// setValue 设置字段的值
func (m *Model) setValue(refValue reflect.Value, column string, value any) {
SetFieldValue(m.statement, refValue, column, value)
}
func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) {
SafeSetFileValue(m.statement, refValue, column, value)
}
// getValue 获取字段的值
func (m *Model) getValue(refValue reflect.Value, column string) interface{} {
return GetFieldValue(m.statement, refValue, column)
}
// hasColumn 判断指定的列是否存在
func (m *Model) hasColumn(column string) bool {
for _, field := range m.statement.Schema.Fields {
if field.DBName == column || field.Name == column {
return true
}
}
return false
}
// getFilename 获取文件存放目录
func (m *Model) getFilename(domain string, spec string, name string) string {
if m.dirname == "" {
m.dirname = os.TempDir()
}
filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name)
if _, err := os.Stat(path.Dir(filename)); err != nil {
_ = os.MkdirAll(path.Dir(filename), 0755)
}
return filename
}
// findPrimaryKey 查找主键的值
func (m *Model) findPrimaryKey(uri string, r *http.Request) string {
var (
pos int
)
urlPath := r.URL.Path
pos = strings.IndexByte(uri, ':')
if pos > 0 {
return urlPath[pos:]
}
return ""
}
// parseReportColumn 解析报表的列
func (m *Model) parseReportColumn(name, props string) *types.SelectColumn {
var (
key string
value string
)
column := &types.SelectColumn{
Name: inflector.Camel2id(name),
Native: false,
}
tokens := strings.Split(props, ";")
for _, token := range tokens {
pair := strings.SplitN(token, ":", 2)
if len(pair) == 0 {
continue
}
if len(pair) == 1 {
key = strings.TrimSpace(pair[0])
value = ""
} else {
key = strings.TrimSpace(pair[0])
value = strings.TrimSpace(pair[1])
}
switch key {
case "native":
column.Native = true
case "name":
column.Name = value
case "expr":
column.Expr = value
}
}
return column
}
func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) {
modelType := reflect.ValueOf(dest).Type()
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
columns := make([]string, 0)
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
scenarios := field.Tag.Get("scenarios")
if !hasToken(types.ScenarioList, scenarios) {
continue
}
isPrimary := field.Tag.Get("is_primary")
if isPrimary != "true" {
continue
}
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
if !column.Native {
continue
}
if column.Expr == "" {
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
} else {
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
}
}
columns = append(columns, "COUNT(*) AS count")
query.Select(columns...)
}
func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) {
modelType := reflect.ValueOf(dest).Type()
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
columns := make([]string, 0)
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
scenarios := field.Tag.Get("scenarios")
if !hasToken(types.ScenarioList, scenarios) {
continue
}
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
if !column.Native {
continue
}
if column.Expr == "" {
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
} else {
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
}
}
query.Select(columns...)
}
// buildCondition 构建sql条件
func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
return BuildConditions(ctx, r, query, schemas)
}
// ModuleName 模块名称
func (m *Model) ModuleName() string {
return m.naming.ModuleName
}
// TableName 表的名称
func (m *Model) TableName() string {
return m.naming.ModuleName
}
// Fields 返回搜索的模型的字段
func (m *Model) Fields() []*schema.Field {
return m.statement.Schema.Fields
}
// Uri 获取请求的uri
func (m *Model) Uri(scenario string) string {
ss := make([]string, 4)
if m.urlPrefix != "" {
ss = append(ss, m.urlPrefix)
}
switch scenario {
case types.ScenarioList:
ss = append(ss, m.naming.ModuleName, m.naming.Pluralize)
case types.ScenarioView:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioCreate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular)
case types.ScenarioUpdate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioDelete:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioExport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export")
case types.ScenarioImport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import")
}
uri := path.Join(ss...)
if !strings.HasPrefix(uri, "/") {
uri = "/" + uri
}
return uri
}
// Method 获取HTTP请求的方法
func (m *Model) Method(scenario string) string {
var (
method = http.MethodGet
)
switch scenario {
case types.ScenarioCreate:
method = http.MethodPost
case types.ScenarioUpdate:
method = http.MethodPut
case types.ScenarioDelete:
method = http.MethodDelete
}
return method
}
// Search 实现通过HTTP方法查找数据
func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
var (
ok bool
err error
qs url.Values
page int
pageSize int
pageIndex int
query *Query
domainName string
modelSlices reflect.Value
modelValues reflect.Value
searchSchemas []*types.Schema
listSchemas []*types.Schema
modelValue reflect.Value
scenario string
reporter types.Reporter
namerTable tableNamer
)
qs = r.URL.Query()
page, _ = strconv.Atoi(qs.Get("page"))
pageSize, _ = strconv.Atoi(qs.Get("pagesize"))
if pageSize <= 0 {
pageSize = defaultPageSize
}
pageIndex = page
if pageIndex > 0 {
pageIndex--
}
modelValue = reflect.New(m.value.Type())
//这里创建指针类型这样的话就能在format里面调用函数
if m.value.Kind() != reflect.Ptr {
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
} else {
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
}
modelValues = reflect.New(modelSlices.Type())
modelValues.Elem().Set(modelSlices)
query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface())
domainName = m.valueLookup(types.FieldDomain, w, r)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioList,
})
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
scenario = types.ScenarioList
if arrays.Exists(r.FormValue("scenario"), allowScenario) {
scenario = r.FormValue("scenario")
}
if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
query.AndWhere(newCondition(types.FieldDomain, domainName))
}
}
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
// 处理表名逻辑
if namerTable, ok = query.Model().(tableNamer); ok {
query.From(namerTable.HttpTableName(r))
}
//处理报表逻辑
if reporter, ok = modelValue.Interface().(types.Reporter); ok {
query.From(reporter.RealTable())
}
res := &types.ListResponse{
Page: page,
PageSize: pageSize,
}
if reporter == nil {
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
} else {
//如果是报表的情况需要手动指定COUNT的雨具逻辑才能生效
m.buildReporterCountColumns(childCtx, reporter, query)
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
//这里需要重置一下选项,不然会出问题
query.ResetSelect()
query.GroupBy(reporter.GroupBy(childCtx)...)
}
query.Offset(pageIndex * pageSize).Limit(pageSize)
if res.TotalCount > 0 {
if reporter != nil {
m.buildReporterQueryColumns(childCtx, reporter, query)
}
if err = query.All(modelValues.Interface()); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
// 不进行格式化输出
res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format"))
} else {
res.Data = make([]string, 0)
}
m.response.Success(w, res)
}
// Create 实现通过HTTP方法创建模型
func (m *Model) Create(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
schemas []*types.Schema
diffAttrs []*types.DiffAttr
domainName string
modelValue reflect.Value
)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
m.setValue(modelValue, types.FieldDomain, domainName)
}
}
diffAttrs = make([]*types.DiffAttr, 0, 10)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
User: m.valueLookup("user", w, r),
Request: r,
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioCreate,
Schemas: schemas,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
return
}
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Save(model).Error
} else {
errTx = tx.Save(model).Error
}
if errTx != nil {
return
}
for _, row := range schemas {
diffAttrs = append(diffAttrs, &types.DiffAttr{
Column: row.Column,
Label: row.Label,
OldValue: nil,
NewValue: m.getValue(modelValue, row.Column),
})
}
return
}); err == nil {
res := &types.CreateResponse{
ID: m.getValue(modelValue, m.primaryKey),
Status: "created",
}
if creator, ok := model.(afterCreated); ok {
creator.AfterCreated(childCtx, dbSess)
}
if preserver, ok := model.(afterSaved); ok {
preserver.AfterSaved(childCtx, dbSess)
}
m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs)
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
m.response.Success(w, res)
} else {
m.response.Failure(w, types.RequestCreateFailure, err.Error(), err)
}
}
// Update 实现通过HTTP方法更新模型
func (m *Model) Update(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
schemas []*types.Schema
diffAttrs []*types.DiffAttr
domainName string
modelValue reflect.Value
oldValues map[string]any
)
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
conditions := map[string]any{
m.primaryKey: idStr,
}
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
oldValues = make(map[string]any)
for _, row := range schemas {
oldValues[row.Column] = m.getValue(modelValue, row.Column)
}
if err = json.NewDecoder(r.Body).Decode(model); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
diffAttrs = make([]*types.DiffAttr, 0, 10)
updates := make(map[string]any)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioUpdate,
Schemas: schemas,
PrimaryKeyValue: idStr,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
return
}
for _, row := range schemas {
v := m.getValue(modelValue, row.Column)
if oldValues[row.Column] != v {
updates[row.Column] = v
diffAttrs = append(diffAttrs, &types.DiffAttr{
Column: row.Column,
Label: row.Label,
OldValue: oldValues[row.Column],
NewValue: v,
})
}
}
if len(updates) > 0 {
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error
} else {
errTx = tx.Model(model).Updates(updates).Error
}
if errTx != nil {
return
}
}
return
}); err == nil {
if updater, ok := model.(afterUpdated); ok {
updater.AfterUpdated(childCtx, dbSess)
}
if preserver, ok := model.(afterSaved); ok {
preserver.AfterSaved(childCtx, dbSess)
}
m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs)
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
m.response.Success(w, types.UpdateResponse{
ID: idStr,
Status: "updated",
})
} else {
m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil)
}
}
// Delete 实现通过HTTP方法删除模型
func (m *Model) Delete(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
modelValue reflect.Value
)
idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
conditions := map[string]any{
m.primaryKey: idStr,
}
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: m.valueLookup(types.FieldDomain, w, r),
User: m.valueLookup("user", w, r),
Request: r,
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioDelete,
PrimaryKeyValue: idStr,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil {
return
}
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Delete(model).Error
} else {
errTx = tx.Delete(model).Error
}
if errTx != nil {
return
}
m.getHook().afterDelete(childCtx, tx, model)
return
}); err == nil {
m.response.Success(w, types.DeleteResponse{
ID: idStr,
Status: "deleted",
})
} else {
m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil)
}
}
// View 查看数据详情
func (m *Model) View(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
modelValue reflect.Value
qs url.Values
schemas []*types.Schema
scenario string
domainName string
)
qs = r.URL.Query()
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
conditions := map[string]any{
m.primaryKey: idStr,
}
domainName = m.valueLookup(types.FieldDomain, w, r)
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
scenario = qs.Get("scenario")
if scenario == "" {
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView)
} else {
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario)
}
if err == nil {
m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format")))
} else {
m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil)
}
}
// Export 实现通过HTTP方法导出模型
func (m *Model) Export(w http.ResponseWriter, r *http.Request) {
var (
err error
query *Query
modelSlices reflect.Value
modelValues reflect.Value
searchSchemas []*types.Schema
exportSchemas []*types.Schema
domainName string
fp *os.File
modelValue reflect.Value
)
if !m.hasScenario(types.ScenarioList) {
m.response.Failure(w, types.RequestDenied, "request denied", nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil)
return
}
defer func() {
_ = fp.Close()
}()
modelValue = reflect.New(m.value.Type())
//这里创建指针类型这样的话就能在format里面调用函数
if m.value.Kind() != reflect.Ptr {
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
} else {
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
}
modelValues = reflect.New(modelSlices.Type())
modelValues.Elem().Set(modelSlices)
query = NewQuery(m.getDB(), modelValue.Interface())
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioExport,
})
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
query.AndWhere(newCondition(types.FieldDomain, domainName))
}
}
// 处理表名逻辑
if namerTable, ok := query.Model().(tableNamer); ok {
query.From(namerTable.HttpTableName(r))
}
//处理报表逻辑
if reporter, ok := modelValue.Interface().(types.Reporter); ok {
query.From(reporter.RealTable())
query.GroupBy(reporter.GroupBy(childCtx)...)
m.buildReporterQueryColumns(childCtx, reporter, query)
}
if err = query.All(modelValues.Interface()); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "")
writer := csv.NewWriter(fp)
rows := make([]string, len(exportSchemas))
for i, field := range exportSchemas {
rows[i] = field.Label
}
_ = writer.Write(rows)
if values, ok := value.([]any); ok {
for _, val := range values {
row, ok2 := val.(map[string]any)
if !ok2 {
continue
}
for i, field := range exportSchemas {
if v, ok := row[field.Column]; ok {
rows[i] = fmt.Sprint(v)
} else {
rows[i] = ""
}
}
_ = writer.Write(rows)
}
}
writer.Flush()
m.getHook().afterExport(childCtx, filename)
http.ServeContent(w, r, path.Base(filename), time.Now(), fp)
}
// findSchema 查找指定的schema
func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema {
for _, row := range schemas {
if row.Label == label {
return row
}
}
return nil
}
// importInternal 文件上传方法
func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) {
var (
err error
rows []string
fp *os.File
tm time.Time
fields []string
sess *gorm.DB
csvReader *csv.Reader
csvWriter *csv.Writer
modelValue reflect.Value
modelEntity any
diffAttrs []*types.DiffAttr
result *types.ImportResult
failureFp *os.File
failureFile string
)
tm = time.Now()
result = &types.ImportResult{}
if fp, err = os.Open(filename); err != nil {
result.Code = types.ErrImportFileNotExists
goto __end
}
defer func() {
_ = fp.Close()
}()
csvReader = csv.NewReader(fp)
if rows, err = csvReader.Read(); err != nil {
result.Code = types.ErrImportFileUnavailable
goto __end
}
fields = make([]string, 0, len(rows))
for _, s := range rows {
v := m.findSchema(s, schemas)
if v == nil {
result.Code = types.ErrImportColumnNotMatch
goto __end
}
fields = append(fields, v.Column)
}
sess = m.getDB().WithContext(ctx)
//失败文件指针
failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix()))
if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil {
return
}
defer func() {
_ = failureFp.Close()
}()
csvWriter = csv.NewWriter(failureFp)
rows = append(rows, "Error")
_ = csvWriter.Write(rows)
diffAttrs = make([]*types.DiffAttr, len(schemas))
for {
if rows, err = csvReader.Read(); err != nil {
break
}
result.TotalCount++
if len(rows) != len(fields) {
continue
}
modelValue = reflect.New(m.value.Type())
for idx, field := range fields {
m.safeSetValue(modelValue, field, rows[idx])
}
if len(extraFields) > 0 {
for k, v := range extraFields {
m.safeSetValue(modelValue, k, v)
}
}
modelEntity = modelValue.Interface()
//写入数据
if fast {
//如果是快速模式,直接存储数据
if err = sess.Save(modelEntity).Error; err == nil {
result.SuccessCount++
} else {
rows = append(rows, err.Error())
_ = csvWriter.Write(rows)
}
} else {
if err = sess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil {
return
}
if tabler, ok := modelEntity.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error
} else {
errTx = tx.Save(modelEntity).Error
}
if errTx != nil {
return
}
for idx, row := range schemas {
diffAttrs[idx] = &types.DiffAttr{
Column: row.Column,
Label: row.Label,
NewValue: m.getValue(modelValue, row.Column),
}
}
m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs)
m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs)
return
}); err == nil {
result.SuccessCount++
} else {
rows = append(rows, err.Error())
_ = csvWriter.Write(rows)
}
}
}
csvWriter.Flush()
__end:
result.UploadFile = filename
if result.TotalCount > result.SuccessCount {
result.FailureFile = failureFile
}
result.Duration = time.Now().Sub(tm)
m.getHook().afterImport(ctx, result)
}
// Import 实现通过HTTP方法导入
func (m *Model) Import(w http.ResponseWriter, r *http.Request) {
var (
err error
fast bool
schemas []*types.Schema
rows []string
domainName string
dst *os.File
fp multipart.File
csvWriter *csv.Writer
qs url.Values
extraFields map[string]string
)
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
//这里用background的context
childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioImport,
Schemas: schemas,
})
if r.Method == http.MethodGet {
//下载导入模板
csvWriter = csv.NewWriter(w)
rows = make([]string, 0, len(schemas))
for _, row := range schemas {
//主键不需要导入
if row.IsPrimaryKey == 1 {
continue
}
rows = append(rows, row.Label)
}
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
err = csvWriter.Write(rows)
csvWriter.Flush()
return
}
filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
if fp, _, err = r.FormFile("file"); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil)
return
}
defer func() {
_ = fp.Close()
}()
if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil {
buf := pool.GetBytes(32 * 1024)
_, err = io.CopyBuffer(dst, fp, buf)
pool.PutBytes(buf)
_ = dst.Close()
} else {
m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil)
return
}
qs = r.URL.Query()
if qs != nil {
extraFields = make(map[string]string)
for k, _ := range qs {
if strings.HasPrefix(k, "_attr_") {
extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k)
}
}
}
fast, _ = strconv.ParseBool(qs.Get("__fast"))
go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields)
m.response.Success(w, types.ImportResponse{
UID: m.valueLookup("user", w, r),
Status: "committed",
})
}
// newModel 创建一个模型
func newModel(v any, db *gorm.DB, naming types.Naming) *Model {
model := &Model{
db: db,
naming: naming,
response: &httpWriter{},
value: reflect.Indirect(reflect.ValueOf(v)),
valueLookup: defaultValueLookup,
}
model.statement = &gorm.Statement{
DB: model.getDB(),
ConnPool: model.getDB().ConnPool,
Clauses: map[string]clause.Clause{},
}
if err := model.statement.Parse(v); err == nil {
if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 {
model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0]
}
}
return model
}