package rest import ( "context" "encoding/csv" "encoding/json" "fmt" "git.nobla.cn/golang/kos/util/arrays" "git.nobla.cn/golang/kos/util/pool" "git.nobla.cn/golang/rest/inflector" "git.nobla.cn/golang/rest/types" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "io" "mime/multipart" "net/http" "net/url" "os" "path" "reflect" "strconv" "strings" "time" ) type Model struct { naming types.Naming //命名规则 value reflect.Value //模块值 db *gorm.DB //数据库 primaryKey string //主键 urlPrefix string //url前缀 disableDomain bool //禁用域 schemaLookup types.SchemaLookupFunc //获取schema的函数 valueLookup types.ValueLookupFunc //查看域 statement *gorm.Statement //字段声明 formatter *Formatter //格式化 response types.HttpWriter //HTTP响应 hookMgr *hookManager //钩子管理器 dirname string //存放文件目录 } var ( RuntimeScopeKey = &types.RuntimeScope{} ) // getDB 获取数据库连接对象 func (m *Model) getDB() *gorm.DB { return m.db } // getFormatter 获取格式化组件 func (m *Model) getFormatter() *Formatter { if m.formatter != nil { return m.formatter } return DefaultFormatter } // getHook 获取钩子 func (m *Model) getHook() *hookManager { return m.hookMgr } // hasScenario 判断是否有该场景 func (m *Model) hasScenario(s string) bool { return true } // setValue 设置字段的值 func (m *Model) setValue(refValue reflect.Value, column string, value any) { SetFieldValue(m.statement, refValue, column, value) } func (m *Model) safeSetValue(refValue reflect.Value, column string, value any) { SafeSetFileValue(m.statement, refValue, column, value) } // getValue 获取字段的值 func (m *Model) getValue(refValue reflect.Value, column string) interface{} { return GetFieldValue(m.statement, refValue, column) } // hasColumn 判断指定的列是否存在 func (m *Model) hasColumn(column string) bool { for _, field := range m.statement.Schema.Fields { if field.DBName == column || field.Name == column { return true } } return false } // getFilename 获取文件存放目录 func (m *Model) getFilename(domain string, spec string, name string) string { if m.dirname == "" { m.dirname = os.TempDir() } filename := path.Join(m.dirname, domain, spec, time.Now().Format("20060102"), name) if _, err := os.Stat(path.Dir(filename)); err != nil { _ = os.MkdirAll(path.Dir(filename), 0755) } return filename } // findPrimaryKey 查找主键的值 func (m *Model) findPrimaryKey(uri string, r *http.Request) string { var ( pos int ) urlPath := r.URL.Path pos = strings.IndexByte(uri, ':') if pos > 0 { return urlPath[pos:] } return "" } // parseReportColumn 解析报表的列 func (m *Model) parseReportColumn(name, props string) *types.SelectColumn { var ( key string value string ) column := &types.SelectColumn{ Name: inflector.Camel2id(name), Native: false, } tokens := strings.Split(props, ";") for _, token := range tokens { pair := strings.SplitN(token, ":", 2) if len(pair) == 0 { continue } if len(pair) == 1 { key = strings.TrimSpace(pair[0]) value = "" } else { key = strings.TrimSpace(pair[0]) value = strings.TrimSpace(pair[1]) } switch key { case "native": column.Native = true case "name": column.Name = value case "expr": column.Expr = value } } return column } func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Reporter, query *Query) { modelType := reflect.ValueOf(dest).Type() if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } columns := make([]string, 0) for i := 0; i < modelType.NumField(); i++ { field := modelType.Field(i) scenarios := field.Tag.Get("scenarios") if !hasToken(types.ScenarioList, scenarios) { continue } isPrimary := field.Tag.Get("is_primary") if isPrimary != "true" { continue } column := m.parseReportColumn(field.Name, field.Tag.Get("report")) if !column.Native { continue } if column.Expr == "" { columns = append(columns, dest.QuoteColumn(ctx, column.Name)) } else { columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name))) } } columns = append(columns, "COUNT(*) AS count") query.Select(columns...) } func (m *Model) buildReporterQueryColumns(ctx context.Context, dest types.Reporter, query *Query) { modelType := reflect.ValueOf(dest).Type() if modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } columns := make([]string, 0) for i := 0; i < modelType.NumField(); i++ { field := modelType.Field(i) scenarios := field.Tag.Get("scenarios") if !hasToken(types.ScenarioList, scenarios) { continue } column := m.parseReportColumn(field.Name, field.Tag.Get("report")) if !column.Native { continue } if column.Expr == "" { columns = append(columns, dest.QuoteColumn(ctx, column.Name)) } else { columns = append(columns, fmt.Sprintf("%s AS %s", column.Expr, dest.QuoteColumn(ctx, column.Name))) } } query.Select(columns...) } // buildCondition 构建sql条件 func (m *Model) buildCondition(ctx context.Context, r *http.Request, query *Query, schemas []*types.Schema) (err error) { return BuildConditions(ctx, r, query, schemas) } // ModuleName 模块名称 func (m *Model) ModuleName() string { return m.naming.ModuleName } // TableName 表的名称 func (m *Model) TableName() string { return m.naming.ModuleName } // Fields 返回搜索的模型的字段 func (m *Model) Fields() []*schema.Field { return m.statement.Schema.Fields } // Uri 获取请求的uri func (m *Model) Uri(scenario string) string { ss := make([]string, 4) if m.urlPrefix != "" { ss = append(ss, m.urlPrefix) } switch scenario { case types.ScenarioList: ss = append(ss, m.naming.ModuleName, m.naming.Pluralize) case types.ScenarioView: ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") case types.ScenarioCreate: ss = append(ss, m.naming.ModuleName, m.naming.Singular) case types.ScenarioUpdate: ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") case types.ScenarioDelete: ss = append(ss, m.naming.ModuleName, m.naming.Singular, ":id") case types.ScenarioExport: ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-export") case types.ScenarioImport: ss = append(ss, m.naming.ModuleName, m.naming.Singular+"-import") } uri := path.Join(ss...) if !strings.HasPrefix(uri, "/") { uri = "/" + uri } return uri } // Method 获取HTTP请求的方法 func (m *Model) Method(scenario string) string { var ( method = http.MethodGet ) switch scenario { case types.ScenarioCreate: method = http.MethodPost case types.ScenarioUpdate: method = http.MethodPut case types.ScenarioDelete: method = http.MethodDelete } return method } // Search 实现通过HTTP方法查找数据 func (m *Model) Search(w http.ResponseWriter, r *http.Request) { var ( ok bool err error qs url.Values page int pageSize int pageIndex int query *Query domainName string modelSlices reflect.Value modelValues reflect.Value searchSchemas []*types.Schema listSchemas []*types.Schema modelValue reflect.Value scenario string reporter types.Reporter namerTable tableNamer ) qs = r.URL.Query() page, _ = strconv.Atoi(qs.Get("page")) pageSize, _ = strconv.Atoi(qs.Get("pagesize")) if pageSize <= 0 { pageSize = defaultPageSize } pageIndex = page if pageIndex > 0 { pageIndex-- } modelValue = reflect.New(m.value.Type()) //这里创建指针类型,这样的话就能在format里面调用函数 if m.value.Kind() != reflect.Ptr { modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0) } else { modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0) } modelValues = reflect.New(modelSlices.Type()) modelValues.Elem().Set(modelSlices) query = NewQuery(m.getDB(), reflect.New(m.value.Type()).Interface()) domainName = m.valueLookup(types.FieldDomain, w, r) childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ Domain: domainName, Request: r, User: m.valueLookup("user", w, r), ModuleName: m.naming.ModuleName, TableName: m.naming.TableName, Scenario: types.ScenarioList, }) if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } scenario = types.ScenarioList if arrays.Exists(r.FormValue("scenario"), allowScenario) { scenario = r.FormValue("scenario") } if listSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } if !m.disableDomain { if m.hasColumn(types.FieldDomain) { query.AndWhere(newCondition(types.FieldDomain, domainName)) } } if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil { m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) return } // 处理表名逻辑 if namerTable, ok = query.Model().(tableNamer); ok { query.From(namerTable.HttpTableName(r)) } //处理报表逻辑 if reporter, ok = modelValue.Interface().(types.Reporter); ok { query.From(reporter.RealTable()) } res := &types.ListResponse{ Page: page, PageSize: pageSize, } if reporter == nil { res.TotalCount = query.Limit(0).Offset(0).Count(query.Model()) } else { //如果是报表的情况,需要手动指定COUNT的雨具逻辑才能生效 m.buildReporterCountColumns(childCtx, reporter, query) res.TotalCount = query.Limit(0).Offset(0).Count(query.Model()) //这里需要重置一下选项,不然会出问题 query.ResetSelect() query.GroupBy(reporter.GroupBy(childCtx)...) } query.Offset(pageIndex * pageSize).Limit(pageSize) if res.TotalCount > 0 { if reporter != nil { m.buildReporterQueryColumns(childCtx, reporter, query) } if err = query.All(modelValues.Interface()); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } // 不进行格式化输出 res.Data = m.getFormatter().formatModels(childCtx, modelValues.Interface(), listSchemas, m.statement, qs.Get("__format")) } else { res.Data = make([]string, 0) } m.response.Success(w, res) } // Create 实现通过HTTP方法创建模型 func (m *Model) Create(w http.ResponseWriter, r *http.Request) { var ( err error model any schemas []*types.Schema diffAttrs []*types.DiffAttr domainName string modelValue reflect.Value ) modelValue = reflect.New(m.value.Type()) model = modelValue.Interface() if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil { m.response.Failure(w, types.RequestPayloadInvalid, err.Error(), nil) return } domainName = m.valueLookup(types.FieldDomain, w, r) if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } if !m.disableDomain { if m.hasColumn(types.FieldDomain) { m.setValue(modelValue, types.FieldDomain, domainName) } } diffAttrs = make([]*types.DiffAttr, 0, 10) childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ Domain: domainName, User: m.valueLookup("user", w, r), Request: r, ModuleName: m.naming.ModuleName, TableName: m.naming.TableName, Scenario: types.ScenarioCreate, Schemas: schemas, }) dbSess := m.getDB().WithContext(childCtx) if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { if errTx = m.getHook().beforeCreate(childCtx, tx, model); errTx != nil { return } if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil { return } if tabler, ok := model.(types.Tabler); ok { errTx = tx.Table(tabler.TableName()).Save(model).Error } else { errTx = tx.Save(model).Error } if errTx != nil { return } for _, row := range schemas { diffAttrs = append(diffAttrs, &types.DiffAttr{ Column: row.Column, Label: row.Label, OldValue: nil, NewValue: m.getValue(modelValue, row.Column), }) } return }); err == nil { res := &types.CreateResponse{ ID: m.getValue(modelValue, m.primaryKey), Status: "created", } if creator, ok := model.(afterCreated); ok { creator.AfterCreated(childCtx, dbSess) } if preserver, ok := model.(afterSaved); ok { preserver.AfterSaved(childCtx, dbSess) } m.getHook().afterCreate(childCtx, dbSess, model, diffAttrs) m.getHook().afterSave(childCtx, dbSess, model, diffAttrs) m.response.Success(w, res) } else { m.response.Failure(w, types.RequestCreateFailure, err.Error(), err) } } // Update 实现通过HTTP方法更新模型 func (m *Model) Update(w http.ResponseWriter, r *http.Request) { var ( err error model any schemas []*types.Schema diffAttrs []*types.DiffAttr domainName string modelValue reflect.Value oldValues map[string]any ) idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) modelValue = reflect.New(m.value.Type()) model = modelValue.Interface() domainName = m.valueLookup(types.FieldDomain, w, r) if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioUpdate); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } conditions := map[string]any{ m.primaryKey: idStr, } if err = m.getDB().Where(conditions).First(model).Error; err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } oldValues = make(map[string]any) for _, row := range schemas { oldValues[row.Column] = m.getValue(modelValue, row.Column) } if err = json.NewDecoder(r.Body).Decode(model); err != nil { m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) return } diffAttrs = make([]*types.DiffAttr, 0, 10) updates := make(map[string]any) childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ Domain: domainName, Request: r, User: m.valueLookup("user", w, r), ModuleName: m.naming.ModuleName, TableName: m.naming.TableName, Scenario: types.ScenarioUpdate, Schemas: schemas, PrimaryKeyValue: idStr, }) dbSess := m.getDB().WithContext(childCtx) if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { if errTx = m.getHook().beforeUpdate(childCtx, tx, model); errTx != nil { return } if errTx = m.getHook().beforeSave(childCtx, tx, model); errTx != nil { return } for _, row := range schemas { v := m.getValue(modelValue, row.Column) if oldValues[row.Column] != v { updates[row.Column] = v diffAttrs = append(diffAttrs, &types.DiffAttr{ Column: row.Column, Label: row.Label, OldValue: oldValues[row.Column], NewValue: v, }) } } if len(updates) > 0 { if tabler, ok := model.(types.Tabler); ok { errTx = tx.Model(model).Table(tabler.TableName()).Updates(updates).Error } else { errTx = tx.Model(model).Updates(updates).Error } if errTx != nil { return } } return }); err == nil { if updater, ok := model.(afterUpdated); ok { updater.AfterUpdated(childCtx, dbSess) } if preserver, ok := model.(afterSaved); ok { preserver.AfterSaved(childCtx, dbSess) } m.getHook().afterUpdate(childCtx, dbSess, model, diffAttrs) m.getHook().afterSave(childCtx, dbSess, model, diffAttrs) m.response.Success(w, types.UpdateResponse{ ID: idStr, Status: "updated", }) } else { m.response.Failure(w, types.RequestUpdateFailure, err.Error(), nil) } } // Delete 实现通过HTTP方法删除模型 func (m *Model) Delete(w http.ResponseWriter, r *http.Request) { var ( err error model any modelValue reflect.Value ) idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r) modelValue = reflect.New(m.value.Type()) model = modelValue.Interface() conditions := map[string]any{ m.primaryKey: idStr, } if err = m.getDB().Where(conditions).First(model).Error; err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ Domain: m.valueLookup(types.FieldDomain, w, r), User: m.valueLookup("user", w, r), Request: r, ModuleName: m.naming.ModuleName, TableName: m.naming.TableName, Scenario: types.ScenarioDelete, PrimaryKeyValue: idStr, }) dbSess := m.getDB().WithContext(childCtx) if err = dbSess.Transaction(func(tx *gorm.DB) (errTx error) { if errTx = m.getHook().beforeDelete(childCtx, tx, model); errTx != nil { return } if tabler, ok := model.(types.Tabler); ok { errTx = tx.Table(tabler.TableName()).Delete(model).Error } else { errTx = tx.Delete(model).Error } if errTx != nil { return } m.getHook().afterDelete(childCtx, tx, model) return }); err == nil { m.response.Success(w, types.DeleteResponse{ ID: idStr, Status: "deleted", }) } else { m.response.Failure(w, types.RequestDeleteFailure, err.Error(), nil) } } // View 查看数据详情 func (m *Model) View(w http.ResponseWriter, r *http.Request) { var ( err error model any modelValue reflect.Value qs url.Values schemas []*types.Schema scenario string domainName string ) qs = r.URL.Query() idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) modelValue = reflect.New(m.value.Type()) model = modelValue.Interface() conditions := map[string]any{ m.primaryKey: idStr, } domainName = m.valueLookup(types.FieldDomain, w, r) if err = m.getDB().Where(conditions).First(model).Error; err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } scenario = qs.Get("scenario") if scenario == "" { schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioView) } else { schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, scenario) } if err == nil { m.response.Success(w, m.getFormatter().formatModel(r.Context(), modelValue, schemas, m.statement, qs.Get("__format"))) } else { m.response.Failure(w, types.RequestRecordNotFound, err.Error(), nil) } } // Export 实现通过HTTP方法导出模型 func (m *Model) Export(w http.ResponseWriter, r *http.Request) { var ( err error query *Query modelSlices reflect.Value modelValues reflect.Value searchSchemas []*types.Schema exportSchemas []*types.Schema domainName string fp *os.File modelValue reflect.Value ) if !m.hasScenario(types.ScenarioList) { m.response.Failure(w, types.RequestDenied, "request denied", nil) return } domainName = m.valueLookup(types.FieldDomain, w, r) filename := m.getFilename(domainName, "export", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix())) if fp, err = os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644); err != nil { m.response.Failure(w, types.RequestPayloadInvalid, "directory does not have permission", nil) return } defer func() { _ = fp.Close() }() modelValue = reflect.New(m.value.Type()) //这里创建指针类型,这样的话就能在format里面调用函数 if m.value.Kind() != reflect.Ptr { modelSlices = reflect.MakeSlice(reflect.SliceOf(modelValue.Type()), 0, 0) } else { modelSlices = reflect.MakeSlice(reflect.SliceOf(m.value.Type()), 0, 0) } modelValues = reflect.New(modelSlices.Type()) modelValues.Elem().Set(modelSlices) query = NewQuery(m.getDB(), modelValue.Interface()) childCtx := context.WithValue(r.Context(), RuntimeScopeKey, &types.RuntimeScope{ Domain: domainName, Request: r, User: m.valueLookup("user", w, r), ModuleName: m.naming.ModuleName, TableName: m.naming.TableName, Scenario: types.ScenarioExport, }) if searchSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioSearch); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } if exportSchemas, err = m.schemaLookup(childCtx, m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioExport); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } if err = m.buildCondition(childCtx, r, query, searchSchemas); err != nil { m.response.Failure(w, types.RequestPayloadInvalid, "payload invalid", nil) return } if !m.disableDomain { if m.hasColumn(types.FieldDomain) { query.AndWhere(newCondition(types.FieldDomain, domainName)) } } // 处理表名逻辑 if namerTable, ok := query.Model().(tableNamer); ok { query.From(namerTable.HttpTableName(r)) } //处理报表逻辑 if reporter, ok := modelValue.Interface().(types.Reporter); ok { query.From(reporter.RealTable()) query.GroupBy(reporter.GroupBy(childCtx)...) m.buildReporterQueryColumns(childCtx, reporter, query) } if err = query.All(modelValues.Interface()); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } w.Header().Set("Content-Type", "text/csv") w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition") w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular)) value := m.getFormatter().formatModels(childCtx, modelValues.Interface(), exportSchemas, m.statement, "") writer := csv.NewWriter(fp) rows := make([]string, len(exportSchemas)) for i, field := range exportSchemas { rows[i] = field.Label } _ = writer.Write(rows) if values, ok := value.([]any); ok { for _, val := range values { row, ok2 := val.(map[string]any) if !ok2 { continue } for i, field := range exportSchemas { if v, ok := row[field.Column]; ok { rows[i] = fmt.Sprint(v) } else { rows[i] = "" } } _ = writer.Write(rows) } } writer.Flush() m.getHook().afterExport(childCtx, filename) http.ServeContent(w, r, path.Base(filename), time.Now(), fp) } // findSchema 查找指定的schema func (m *Model) findSchema(label string, schemas []*types.Schema) *types.Schema { for _, row := range schemas { if row.Label == label { return row } } return nil } // importInternal 文件上传方法 func (m *Model) importInternal(ctx context.Context, domainName string, schemas []*types.Schema, filename string, fast bool, extraFields map[string]string) { var ( err error rows []string fp *os.File tm time.Time fields []string sess *gorm.DB csvReader *csv.Reader csvWriter *csv.Writer modelValue reflect.Value modelEntity any diffAttrs []*types.DiffAttr result *types.ImportResult failureFp *os.File failureFile string ) tm = time.Now() result = &types.ImportResult{} if fp, err = os.Open(filename); err != nil { result.Code = types.ErrImportFileNotExists goto __end } defer func() { _ = fp.Close() }() csvReader = csv.NewReader(fp) if rows, err = csvReader.Read(); err != nil { result.Code = types.ErrImportFileUnavailable goto __end } fields = make([]string, 0, len(rows)) for _, s := range rows { v := m.findSchema(s, schemas) if v == nil { result.Code = types.ErrImportColumnNotMatch goto __end } fields = append(fields, v.Column) } sess = m.getDB().WithContext(ctx) //失败文件指针 failureFile = m.getFilename(domainName, "import", fmt.Sprintf("%s-%d-fail.csv", m.naming.Singular, time.Now().Unix())) if failureFp, err = os.OpenFile(failureFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil { return } defer func() { _ = failureFp.Close() }() csvWriter = csv.NewWriter(failureFp) rows = append(rows, "Error") _ = csvWriter.Write(rows) diffAttrs = make([]*types.DiffAttr, len(schemas)) for { if rows, err = csvReader.Read(); err != nil { break } result.TotalCount++ if len(rows) != len(fields) { continue } modelValue = reflect.New(m.value.Type()) for idx, field := range fields { m.safeSetValue(modelValue, field, rows[idx]) } if len(extraFields) > 0 { for k, v := range extraFields { m.safeSetValue(modelValue, k, v) } } modelEntity = modelValue.Interface() //写入数据 if fast { //如果是快速模式,直接存储数据 if err = sess.Save(modelEntity).Error; err == nil { result.SuccessCount++ } else { rows = append(rows, err.Error()) _ = csvWriter.Write(rows) } } else { if err = sess.Transaction(func(tx *gorm.DB) (errTx error) { if errTx = m.getHook().beforeCreate(ctx, tx, modelEntity); errTx != nil { return } if errTx = m.getHook().beforeSave(ctx, tx, modelEntity); errTx != nil { return } if tabler, ok := modelEntity.(types.Tabler); ok { errTx = tx.Table(tabler.TableName()).Save(modelEntity).Error } else { errTx = tx.Save(modelEntity).Error } if errTx != nil { return } for idx, row := range schemas { diffAttrs[idx] = &types.DiffAttr{ Column: row.Column, Label: row.Label, NewValue: m.getValue(modelValue, row.Column), } } m.getHook().afterCreate(ctx, tx, modelEntity, diffAttrs) m.getHook().afterSave(ctx, tx, modelEntity, diffAttrs) return }); err == nil { result.SuccessCount++ } else { rows = append(rows, err.Error()) _ = csvWriter.Write(rows) } } } csvWriter.Flush() __end: result.UploadFile = filename if result.TotalCount > result.SuccessCount { result.FailureFile = failureFile } result.Duration = time.Now().Sub(tm) m.getHook().afterImport(ctx, result) } // Import 实现通过HTTP方法导入 func (m *Model) Import(w http.ResponseWriter, r *http.Request) { var ( err error fast bool schemas []*types.Schema rows []string domainName string dst *os.File fp multipart.File csvWriter *csv.Writer qs url.Values extraFields map[string]string ) domainName = m.valueLookup(types.FieldDomain, w, r) if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) return } //这里用background的context childCtx := context.WithValue(context.Background(), RuntimeScopeKey, &types.RuntimeScope{ Domain: domainName, User: m.valueLookup("user", w, r), ModuleName: m.naming.ModuleName, TableName: m.naming.TableName, Scenario: types.ScenarioImport, Schemas: schemas, }) if r.Method == http.MethodGet { //下载导入模板 csvWriter = csv.NewWriter(w) rows = make([]string, 0, len(schemas)) for _, row := range schemas { //主键不需要导入 if row.IsPrimaryKey == 1 { continue } rows = append(rows, row.Label) } w.Header().Set("Content-Type", "text/csv") w.Header().Set("Access-Control-Expose-Headers", "Content-Disposition") w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", m.naming.Singular)) err = csvWriter.Write(rows) csvWriter.Flush() return } filename := m.getFilename(domainName, "import", fmt.Sprintf("%s-%d.csv", m.naming.Singular, time.Now().Unix())) if fp, _, err = r.FormFile("file"); err != nil { m.response.Failure(w, types.RequestPayloadInvalid, "upload file not exists", nil) return } defer func() { _ = fp.Close() }() if dst, err = os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err == nil { buf := pool.GetBytes(32 * 1024) _, err = io.CopyBuffer(dst, fp, buf) pool.PutBytes(buf) _ = dst.Close() } else { m.response.Failure(w, types.RequestPayloadInvalid, "move upload file failed", nil) return } qs = r.URL.Query() if qs != nil { extraFields = make(map[string]string) for k, _ := range qs { if strings.HasPrefix(k, "_attr_") { extraFields[strings.TrimPrefix(k, "_attr_")] = qs.Get(k) } } } fast, _ = strconv.ParseBool(qs.Get("__fast")) go m.importInternal(childCtx, domainName, schemas, filename, fast, extraFields) m.response.Success(w, types.ImportResponse{ UID: m.valueLookup("user", w, r), Status: "committed", }) } // newModel 创建一个模型 func newModel(v any, db *gorm.DB, naming types.Naming) *Model { model := &Model{ db: db, naming: naming, response: &httpWriter{}, value: reflect.Indirect(reflect.ValueOf(v)), valueLookup: defaultValueLookup, } model.statement = &gorm.Statement{ DB: model.getDB(), ConnPool: model.getDB().ConnPool, Clauses: map[string]clause.Clause{}, } if err := model.statement.Parse(v); err == nil { if model.statement.Schema.PrimaryFieldDBNames != nil && len(model.statement.Schema.PrimaryFieldDBNames) > 0 { model.primaryKey = model.statement.Schema.PrimaryFieldDBNames[0] } } return model }