From 5264ee005e5613b01338db060bfa2a540fc636ae Mon Sep 17 00:00:00 2001 From: Yavolte Date: Wed, 11 Jun 2025 11:36:56 +0800 Subject: [PATCH] optimization module init --- options.go | 35 +++++++++++++++++++++++++---------- rest.go | 37 +++++++++++++++++++++++++------------ 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/options.go b/options.go index 7746e20..ca0cf2f 100644 --- a/options.go +++ b/options.go @@ -1,52 +1,67 @@ package rest -import "git.nobla.cn/golang/rest/types" +import ( + "git.nobla.cn/golang/rest/types" + "gorm.io/gorm" +) -type Options struct { +type options struct { urlPrefix string moduleName string disableDomain bool + db *gorm.DB router types.HttpRouter writer types.HttpWriter formatter *Formatter - dirname string //文件目录 } -type Option func(o *Options) +type Option func(o *options) +// WithDB 设置DB +func WithDB(db *gorm.DB) Option { + return func(o *options) { + o.db = db + } +} + +// WithUriPrefix 模块前缀 func WithUriPrefix(s string) Option { - return func(o *Options) { + return func(o *options) { o.urlPrefix = s } } +// WithModuleName 模块名称 func WithModuleName(s string) Option { - return func(o *Options) { + return func(o *options) { o.moduleName = s } } // WithoutDomain 禁用域 func WithoutDomain() Option { - return func(o *Options) { + return func(o *options) { o.disableDomain = true } } +// WithHttpRouter 设置HttpRouter func WithHttpRouter(s types.HttpRouter) Option { - return func(o *Options) { + return func(o *options) { o.router = s } } +// WithHttpWriter 配置HttpWriter func WithHttpWriter(s types.HttpWriter) Option { - return func(o *Options) { + return func(o *options) { o.writer = s } } +// WithFormatter 配置Formatter func WithFormatter(s *Formatter) Option { - return func(o *Options) { + return func(o *options) { o.formatter = s } } diff --git a/rest.go b/rest.go index ad52819..d160539 100644 --- a/rest.go +++ b/rest.go @@ -21,7 +21,7 @@ import ( var ( modelEntities []*Model - httpRouter types.HttpRouter + globalOpts *options hookMgr *hookManager timeKind = reflect.TypeOf(time.Time{}).Kind() timePtrKind = reflect.TypeOf(&time.Time{}).Kind() @@ -34,6 +34,7 @@ var ( ) func init() { + globalOpts = &options{} hookMgr = &hookManager{} modelEntities = make([]*Model, 0) } @@ -479,27 +480,41 @@ func autoMigrate(ctx context.Context, db *gorm.DB, module string, model any) (na return } -// SetHttpRouter 设置HTTP路由 -func SetHttpRouter(router types.HttpRouter) { - httpRouter = router +// Init 初始化 +func Init(cbs ...Option) (err error) { + for _, cb := range cbs { + cb(globalOpts) + } + if globalOpts.db != nil { + err = globalOpts.db.AutoMigrate(&types.Schema{}) + } + return } // AutoMigrate 自动合并表的schema -func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (err error) { +func AutoMigrate(ctx context.Context, model any, cbs ...Option) (err error) { var ( - opts *Options + opts *options table string router types.HttpRouter ) - opts = &Options{} + opts = &options{ + db: globalOpts.db, + router: globalOpts.router, + writer: globalOpts.writer, + formatter: globalOpts.formatter, + moduleName: globalOpts.moduleName, + urlPrefix: globalOpts.urlPrefix, + disableDomain: globalOpts.disableDomain, + } for _, cb := range cbs { cb(opts) } - if table, err = autoMigrate(ctx, db, opts.moduleName, model); err != nil { + if table, err = autoMigrate(ctx, opts.db, opts.moduleName, model); err != nil { return } //路由模块处理 - modelValue := newModel(model, db, types.Naming{ + modelValue := newModel(model, opts.db, types.Naming{ Pluralize: inflector.Pluralize(table), Singular: inflector.Singularize(table), ModuleName: opts.moduleName, @@ -510,9 +525,6 @@ func AutoMigrate(ctx context.Context, db *gorm.DB, model any, cbs ...Option) (er if opts.router != nil { router = opts.router } - if router == nil && httpRouter != nil { - router = httpRouter - } if opts.urlPrefix != "" { modelValue.urlPrefix = opts.urlPrefix } @@ -699,6 +711,7 @@ func SetFieldValue(stmt *gorm.Statement, refValue reflect.Value, column string, } } +// SafeSetFileValue 安全设置模型某个字段的值 func SafeSetFileValue(stmt *gorm.Statement, refValue reflect.Value, column string, value any) { var ( idx = -1