rest/model.go

988 lines
29 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package rest
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"git.nobla.cn/golang/kos/util/arrays"
"git.nobla.cn/golang/kos/util/pool"
"git.nobla.cn/golang/rest/inflector"
"git.nobla.cn/golang/rest/types"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path"
"reflect"
"strconv"
"strings"
"time"
)
type Model struct {
naming types.Naming //命名规则
value reflect.Value //模块值
db *gorm.DB //数据库
primaryKey string //主键
urlPrefix string //url前缀
disableDomain bool //禁用域
schemaLookup types.SchemaLookupFunc //获取schema的函数
valueLookup types.ValueLookupFunc //查看域
statement *gorm.Statement //字段声明
formatter *Formatter //格式化
response types.HttpWriter //HTTP响应
hookMgr *hookManager //钩子管理器
dirname string //存放文件目录
}
var (
RuntimeScopeKey = &types.RuntimeScope{}
)
// getDB 获取数据库连接对象
func (m *Model) getDB() *gorm.DB {
return m.db
}
// getFormatter 获取格式化组件
func (m *Model) getFormatter() *Formatter {
if m.formatter != nil {
return m.formatter
}
return DefaultFormatter
}
// getHook 获取钩子
func (m *Model) getHook() *hookManager {
return m.hookMgr
}
// hasScenario 判断是否有该场景
func (m *Model) hasScenario(s string) bool {
return true
}
// setValue 设置字段的值
func (m *Model) setValue(refValue reflect.Value, column string, value any) {
SetFieldValue(m.statement, refValue, column, value)
}
func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) {
SafeSetFileValue(m.statement, refValue, column, value)
}
// getValue 获取字段的值
func (m *Model) getValue(refValue reflect.Value, column string) interface{} {
return GetFieldValue(m.statement, refValue, column)
}
// hasColumn 判断指定的列是否存在
func (m *Model) hasColumn(column string) bool {
for _, field := range m.statement.Schema.Fields {
if field.DBName == column || field.Name == column {
return true
}
}
return false
}
// getFilename 获取文件存放目录
func (m *Model) getFilename(domain string, spec string, name string) string {
if m.dirname == "" {
m.dirname = os.TempDir()
}
filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name)
if _, err := os.Stat(path.Dir(filename)); err != nil {
_ = os.MkdirAll(path.Dir(filename), 0755)
}
return filename
}
// findPrimaryKey 查找主键的值
func (m *Model) findPrimaryKey(uri string, r *http.Request) string {
var (
pos int
)
urlPath := r.URL.Path
pos = strings.IndexByte(uri, ':')
if pos > 0 {
return urlPath[pos:]
}
return ""
}
// parseReportColumn 解析报表的列
func (m *Model) parseReportColumn(name, props string) *types.SelectColumn {
var (
key string
value string
)
column := &types.SelectColumn{
Name: inflector.Camel2id(name),
Native: false,
}
tokens := strings.Split(props, ";")
for _, token := range tokens {
pair := strings.SplitN(token, ":", 2)
if len(pair) == 0 {
continue
}
if len(pair) == 1 {
key = strings.TrimSpace(pair[0])
value = ""
} else {
key = strings.TrimSpace(pair[0])
value = strings.TrimSpace(pair[1])
}
switch key {
case "native":
column.Native = true
case "name":
column.Name = value
case "expr":
column.Expr = value
}
}
return column
}
func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) {
modelType := reflect.ValueOf(dest).Type()
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
columns := make([]string, 0)
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
scenarios := field.Tag.Get("scenarios")
if !hasToken(types.ScenarioList, scenarios) {
continue
}
isPrimary := field.Tag.Get("is_primary")
if isPrimary != "true" {
continue
}
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
if !column.Native {
continue
}
if column.Expr == "" {
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
} else {
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
}
}
columns = append(columns, "COUNT(*) AS count")
query.Select(columns...)
}
func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) {
modelType := reflect.ValueOf(dest).Type()
if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
columns := make([]string, 0)
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
scenarios := field.Tag.Get("scenarios")
if !hasToken(types.ScenarioList, scenarios) {
continue
}
column := m.parseReportColumn(field.Name, field.Tag.Get("report"))
if !column.Native {
continue
}
if column.Expr == "" {
columns = append(columns, dest.QuoteColumn(ctx, column.Name))
} else {
columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name)))
}
}
query.Select(columns...)
}
// buildCondition 构建sql条件
func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) {
return BuildConditions(ctx, r, query, schemas)
}
// ModuleName 模块名称
func (m *Model) ModuleName() string {
return m.naming.ModuleName
}
// TableName 表的名称
func (m *Model) TableName() string {
return m.naming.ModuleName
}
// Fields 返回搜索的模型的字段
func (m *Model) Fields() []*schema.Field {
return m.statement.Schema.Fields
}
// Uri 获取请求的uri
func (m *Model) Uri(scenario string) string {
ss := make([]string, 4)
if m.urlPrefix != "" {
ss = append(ss, m.urlPrefix)
}
switch scenario {
case types.ScenarioList:
ss = append(ss, m.naming.ModuleName, m.naming.Pluralize)
case types.ScenarioView:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioCreate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular)
case types.ScenarioUpdate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioDelete:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id")
case types.ScenarioExport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export")
case types.ScenarioImport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import")
}
uri := path.Join(ss...)
if !strings.HasPrefix(uri, "/") {
uri = "/" + uri
}
return uri
}
// Method 获取HTTP请求的方法
func (m *Model) Method(scenario string) string {
var (
method = http.MethodGet
)
switch scenario {
case types.ScenarioCreate:
method = http.MethodPost
case types.ScenarioUpdate:
method = http.MethodPut
case types.ScenarioDelete:
method = http.MethodDelete
}
return method
}
// Search 实现通过HTTP方法查找数据
func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
var (
ok bool
err error
qs url.Values
page int
pageSize int
pageIndex int
query *Query
domainName string
modelSlices reflect.Value
modelValues reflect.Value
searchSchemas []*types.Schema
listSchemas []*types.Schema
modelValue reflect.Value
scenario string
reporter types.Reporter
namerTable tableNamer
)
qs = r.URL.Query()
page, _ = strconv.Atoi(qs.Get("page"))
pageSize, _ = strconv.Atoi(qs.Get("pagesize"))
if pageSize <= 0 {
pageSize = defaultPageSize
}
pageIndex = page
if pageIndex > 0 {
pageIndex--
}
modelValue = reflect.New(m.value.Type())
//这里创建指针类型这样的话就能在format里面调用函数
if m.value.Kind() != reflect.Ptr {
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
} else {
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
}
modelValues = reflect.New(modelSlices.Type())
modelValues.Elem().Set(modelSlices)
query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface())
domainName = m.valueLookup(types.FieldDomain, w, r)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioList,
})
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
scenario = types.ScenarioList
if arrays.Exists(r.FormValue("scenario"), allowScenario) {
scenario = r.FormValue("scenario")
}
if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
query.AndWhere(newCondition(types.FieldDomain, domainName))
}
}
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
// 处理表名逻辑
if namerTable, ok = query.Model().(tableNamer); ok {
query.From(namerTable.HttpTableName(r))
}
//处理报表逻辑
if reporter, ok = modelValue.Interface().(types.Reporter); ok {
query.From(reporter.RealTable())
}
res := &types.ListResponse{
Page: page,
PageSize: pageSize,
}
if reporter == nil {
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
} else {
//如果是报表的情况需要手动指定COUNT的雨具逻辑才能生效
m.buildReporterCountColumns(childCtx, reporter, query)
res.TotalCount = query.Limit(0).Offset(0).Count(query.Model())
//这里需要重置一下选项,不然会出问题
query.ResetSelect()
query.GroupBy(reporter.GroupBy(childCtx)...)
}
query.Offset(pageIndex * pageSize).Limit(pageSize)
if res.TotalCount > 0 {
if reporter != nil {
m.buildReporterQueryColumns(childCtx, reporter, query)
}
if err = query.All(modelValues.Interface()); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
// 不进行格式化输出
res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format"))
} else {
res.Data = make([]string, 0)
}
m.response.Success(w, res)
}
// Create 实现通过HTTP方法创建模型
func (m *Model) Create(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
schemas []*types.Schema
diffAttrs []*types.DiffAttr
domainName string
modelValue reflect.Value
)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
m.setValue(modelValue, types.FieldDomain, domainName)
}
}
diffAttrs = make([]*types.DiffAttr, 0, 10)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
User: m.valueLookup("user", w, r),
Request: r,
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioCreate,
Schemas: schemas,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
return
}
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Save(model).Error
} else {
errTx = tx.Save(model).Error
}
if errTx != nil {
return
}
for _, row := range schemas {
diffAttrs = append(diffAttrs, &types.DiffAttr{
Column: row.Column,
Label: row.Label,
OldValue: nil,
NewValue: m.getValue(modelValue, row.Column),
})
}
return
}); err == nil {
res := &types.CreateResponse{
ID: m.getValue(modelValue, m.primaryKey),
Status: "created",
}
if creator, ok := model.(afterCreated); ok {
creator.AfterCreated(childCtx, dbSess)
}
if preserver, ok := model.(afterSaved); ok {
preserver.AfterSaved(childCtx, dbSess)
}
m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs)
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
m.response.Success(w, res)
} else {
m.response.Failure(w, types.RequestCreateFailure, err.Error(), err)
}
}
// Update 实现通过HTTP方法更新模型
func (m *Model) Update(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
schemas []*types.Schema
diffAttrs []*types.DiffAttr
domainName string
modelValue reflect.Value
oldValues map[string]any
)
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
conditions := map[string]any{
m.primaryKey: idStr,
}
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
oldValues = make(map[string]any)
for _, row := range schemas {
oldValues[row.Column] = m.getValue(modelValue, row.Column)
}
if err = json.NewDecoder(r.Body).Decode(model); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
diffAttrs = make([]*types.DiffAttr, 0, 10)
updates := make(map[string]any)
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioUpdate,
Schemas: schemas,
PrimaryKeyValue: idStr,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil {
return
}
for _, row := range schemas {
v := m.getValue(modelValue, row.Column)
if oldValues[row.Column] != v {
updates[row.Column] = v
diffAttrs = append(diffAttrs, &types.DiffAttr{
Column: row.Column,
Label: row.Label,
OldValue: oldValues[row.Column],
NewValue: v,
})
}
}
if len(updates) > 0 {
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error
} else {
errTx = tx.Model(model).Updates(updates).Error
}
if errTx != nil {
return
}
}
return
}); err == nil {
if updater, ok := model.(afterUpdated); ok {
updater.AfterUpdated(childCtx, dbSess)
}
if preserver, ok := model.(afterSaved); ok {
preserver.AfterSaved(childCtx, dbSess)
}
m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs)
m.getHook().afterSave(childCtx, dbSess, model, diffAttrs)
m.response.Success(w, types.UpdateResponse{
ID: idStr,
Status: "updated",
})
} else {
m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil)
}
}
// Delete 实现通过HTTP方法删除模型
func (m *Model) Delete(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
modelValue reflect.Value
)
idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
conditions := map[string]any{
m.primaryKey: idStr,
}
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: m.valueLookup(types.FieldDomain, w, r),
User: m.valueLookup("user", w, r),
Request: r,
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioDelete,
PrimaryKeyValue: idStr,
})
dbSess := m.getDB().WithContext(childCtx)
if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil {
return
}
if tabler, ok := model.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Delete(model).Error
} else {
errTx = tx.Delete(model).Error
}
if errTx != nil {
return
}
m.getHook().afterDelete(childCtx, tx, model)
return
}); err == nil {
m.response.Success(w, types.DeleteResponse{
ID: idStr,
Status: "deleted",
})
} else {
m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil)
}
}
// View 查看数据详情
func (m *Model) View(w http.ResponseWriter, r *http.Request) {
var (
err error
model any
modelValue reflect.Value
qs url.Values
schemas []*types.Schema
scenario string
domainName string
)
qs = r.URL.Query()
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
conditions := map[string]any{
m.primaryKey: idStr,
}
domainName = m.valueLookup(types.FieldDomain, w, r)
if err = m.getDB().Where(conditions).First(model).Error; err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
scenario = qs.Get("scenario")
if scenario == "" {
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView)
} else {
schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario)
}
if err == nil {
m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format")))
} else {
m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil)
}
}
// Export 实现通过HTTP方法导出模型
func (m *Model) Export(w http.ResponseWriter, r *http.Request) {
var (
err error
query *Query
modelSlices reflect.Value
modelValues reflect.Value
searchSchemas []*types.Schema
exportSchemas []*types.Schema
domainName string
fp *os.File
modelValue reflect.Value
)
if !m.hasScenario(types.ScenarioList) {
m.response.Failure(w, types.RequestDenied, "request denied", nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil)
return
}
defer func() {
_ = fp.Close()
}()
modelValue = reflect.New(m.value.Type())
//这里创建指针类型这样的话就能在format里面调用函数
if m.value.Kind() != reflect.Ptr {
modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0)
} else {
modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0)
}
modelValues = reflect.New(modelSlices.Type())
modelValues.Elem().Set(modelSlices)
query = NewQuery(m.getDB(), modelValue.Interface())
childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
Request: r,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioExport,
})
if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil)
return
}
if !m.disableDomain {
if m.hasColumn(types.FieldDomain) {
query.AndWhere(newCondition(types.FieldDomain, domainName))
}
}
// 处理表名逻辑
if namerTable, ok := query.Model().(tableNamer); ok {
query.From(namerTable.HttpTableName(r))
}
//处理报表逻辑
if reporter, ok := modelValue.Interface().(types.Reporter); ok {
query.From(reporter.RealTable())
query.GroupBy(reporter.GroupBy(childCtx)...)
m.buildReporterQueryColumns(childCtx, reporter, query)
}
if err = query.All(modelValues.Interface()); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "")
writer := csv.NewWriter(fp)
rows := make([]string, len(exportSchemas))
for i, field := range exportSchemas {
rows[i] = field.Label
}
_ = writer.Write(rows)
if values, ok := value.([]any); ok {
for _, val := range values {
row, ok2 := val.(map[string]any)
if !ok2 {
continue
}
for i, field := range exportSchemas {
if v, ok := row[field.Column]; ok {
rows[i] = fmt.Sprint(v)
} else {
rows[i] = ""
}
}
_ = writer.Write(rows)
}
}
writer.Flush()
m.getHook().afterExport(childCtx, filename)
http.ServeContent(w, r, path.Base(filename), time.Now(), fp)
}
// findSchema 查找指定的schema
func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema {
for _, row := range schemas {
if row.Label == label {
return row
}
}
return nil
}
// importInternal 文件上传方法
func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) {
var (
err error
rows []string
fp *os.File
tm time.Time
fields []string
sess *gorm.DB
csvReader *csv.Reader
csvWriter *csv.Writer
modelValue reflect.Value
modelEntity any
diffAttrs []*types.DiffAttr
result *types.ImportResult
failureFp *os.File
failureFile string
)
tm = time.Now()
result = &types.ImportResult{}
if fp, err = os.Open(filename); err != nil {
result.Code = types.ErrImportFileNotExists
goto __end
}
defer func() {
_ = fp.Close()
}()
csvReader = csv.NewReader(fp)
if rows, err = csvReader.Read(); err != nil {
result.Code = types.ErrImportFileUnavailable
goto __end
}
fields = make([]string, 0, len(rows))
for _, s := range rows {
v := m.findSchema(s, schemas)
if v == nil {
result.Code = types.ErrImportColumnNotMatch
goto __end
}
fields = append(fields, v.Column)
}
sess = m.getDB().WithContext(ctx)
//失败文件指针
failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix()))
if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil {
return
}
defer func() {
_ = failureFp.Close()
}()
csvWriter = csv.NewWriter(failureFp)
rows = append(rows, "Error")
_ = csvWriter.Write(rows)
diffAttrs = make([]*types.DiffAttr, len(schemas))
for {
if rows, err = csvReader.Read(); err != nil {
break
}
result.TotalCount++
if len(rows) != len(fields) {
continue
}
modelValue = reflect.New(m.value.Type())
for idx, field := range fields {
m.safeSetValue(modelValue, field, rows[idx])
}
if len(extraFields) > 0 {
for k, v := range extraFields {
m.safeSetValue(modelValue, k, v)
}
}
modelEntity = modelValue.Interface()
//写入数据
if fast {
//如果是快速模式,直接存储数据
if err = sess.Save(modelEntity).Error; err == nil {
result.SuccessCount++
} else {
rows = append(rows, err.Error())
_ = csvWriter.Write(rows)
}
} else {
if err = sess.Transaction(func(tx *gorm.DB) (errTx error) {
if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil {
return
}
if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil {
return
}
if tabler, ok := modelEntity.(types.Tabler); ok {
errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error
} else {
errTx = tx.Save(modelEntity).Error
}
if errTx != nil {
return
}
for idx, row := range schemas {
diffAttrs[idx] = &types.DiffAttr{
Column: row.Column,
Label: row.Label,
NewValue: m.getValue(modelValue, row.Column),
}
}
m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs)
m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs)
return
}); err == nil {
result.SuccessCount++
} else {
rows = append(rows, err.Error())
_ = csvWriter.Write(rows)
}
}
}
csvWriter.Flush()
__end:
result.UploadFile = filename
if result.TotalCount > result.SuccessCount {
result.FailureFile = failureFile
}
result.Duration = time.Now().Sub(tm)
m.getHook().afterImport(ctx, result)
}
// Import 实现通过HTTP方法导入
func (m *Model) Import(w http.ResponseWriter, r *http.Request) {
var (
err error
fast bool
schemas []*types.Schema
rows []string
domainName string
dst *os.File
fp multipart.File
csvWriter *csv.Writer
qs url.Values
extraFields map[string]string
)
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)
return
}
//这里用background的context
childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{
Domain: domainName,
User: m.valueLookup("user", w, r),
ModuleName: m.naming.ModuleName,
TableName: m.naming.TableName,
Scenario: types.ScenarioImport,
Schemas: schemas,
})
if r.Method == http.MethodGet {
//下载导入模板
csvWriter = csv.NewWriter(w)
rows = make([]string, 0, len(schemas))
for _, row := range schemas {
//主键不需要导入
if row.IsPrimaryKey == 1 {
continue
}
rows = append(rows, row.Label)
}
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular))
err = csvWriter.Write(rows)
csvWriter.Flush()
return
}
filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix()))
if fp, _, err = r.FormFile("file"); err != nil {
m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil)
return
}
defer func() {
_ = fp.Close()
}()
if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil {
buf := pool.GetBytes(32 * 1024)
_, err = io.CopyBuffer(dst, fp, buf)
pool.PutBytes(buf)
_ = dst.Close()
} else {
m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil)
return
}
qs = r.URL.Query()
if qs != nil {
extraFields = make(map[string]string)
for k, _ := range qs {
if strings.HasPrefix(k, "_attr_") {
extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k)
}
}
}
fast, _ = strconv.ParseBool(qs.Get("__fast"))
go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields)
m.response.Success(w, types.ImportResponse{
UID: m.valueLookup("user", w, r),
Status: "committed",
})
}
// newModel 创建一个模型
func newModel(v any, db *gorm.DB, naming types.Naming) *Model {
model := &Model{
db: db,
naming: naming,
response: &httpWriter{},
value: reflect.Indirect(reflect.ValueOf(v)),
valueLookup: defaultValueLookup,
}
model.statement = &gorm.Statement{
DB: model.getDB(),
ConnPool: model.getDB().ConnPool,
Clauses: map[string]clause.Clause{},
}
if err := model.statement.Parse(v); err == nil {
if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 {
model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0]
}
}
return model
}