From ea7f75be62fed45dcd689e13c0ffbf49738a358a Mon Sep 17 00:00:00 2001 From: Yavolte Date: Thu, 19 Jun 2025 11:43:05 +0800 Subject: [PATCH] add custom gorm log --- pkg/dbdialer/dialer.go | 14 +++++-- pkg/dbdialer/logger.go | 83 ++++++++++++++++++++++++++++++++++++++++++ pkg/dbdialer/types.go | 38 +++++++++++++++++++ 3 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 pkg/dbdialer/logger.go create mode 100644 pkg/dbdialer/types.go diff --git a/pkg/dbdialer/dialer.go b/pkg/dbdialer/dialer.go index 8e72189..f4bc8e7 100644 --- a/pkg/dbdialer/dialer.go +++ b/pkg/dbdialer/dialer.go @@ -10,14 +10,20 @@ import ( ) // Dialer open database -func Dialer(ctx context.Context, driver string, dsn string) (db *gorm.DB, err error) { +func Dialer(ctx context.Context, driver string, dsn string, cbs ...Option) (db *gorm.DB, err error) { + opts := newOptions(cbs...) + + if opts.cfg == nil { + opts.cfg = &gorm.Config{} + } + opts.cfg.Logger = newLogger(opts.log) switch driver { case mysql.DefaultDriverName: - db, err = gorm.Open(mysql.Open(dsn)) + db, err = gorm.Open(mysql.Open(dsn), opts.cfg) case sqlite.DriverName: - db, err = gorm.Open(sqlite.Open(dsn)) + db, err = gorm.Open(sqlite.Open(dsn), opts.cfg) case "postgres": - db, err = gorm.Open(postgres.Open(dsn)) + db, err = gorm.Open(postgres.Open(dsn), opts.cfg) default: err = gorm.ErrNotImplemented } diff --git a/pkg/dbdialer/logger.go b/pkg/dbdialer/logger.go new file mode 100644 index 0000000..31d74d9 --- /dev/null +++ b/pkg/dbdialer/logger.go @@ -0,0 +1,83 @@ +package dbdialer + +import ( + "context" + "errors" + "fmt" + "time" + + logpkg "git.nobla.cn/golang/aeus/pkg/logger" + "gorm.io/gorm/logger" + "gorm.io/gorm/utils" +) + +type Logger struct { + LogLevel logger.LogLevel + SlowThreshold time.Duration + traceStr string + traceErrStr string + logger logpkg.Logger +} + +func (lg *Logger) LogMode(level logger.LogLevel) logger.Interface { + lg.LogLevel = level + return lg +} + +func (lg *Logger) Info(ctx context.Context, s string, i ...interface{}) { + if lg.LogLevel >= logger.Info { + lg.logger.Info(ctx, s, i...) + } +} + +func (lg *Logger) Warn(ctx context.Context, s string, i ...interface{}) { + if lg.LogLevel >= logger.Warn { + lg.logger.Warn(ctx, s, i...) + } +} + +func (lg *Logger) Error(ctx context.Context, s string, i ...interface{}) { + if lg.LogLevel >= logger.Error { + lg.logger.Error(ctx, s, i...) + } +} + +func (lg *Logger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if lg.LogLevel <= logger.Silent { + return + } + elapsed := time.Since(begin) + switch { + case err != nil && lg.LogLevel >= logger.Error && (!errors.Is(err, logger.ErrRecordNotFound)): + sql, rows := fc() + if rows == -1 { + lg.Warn(ctx, lg.traceErrStr, sql, err, float64(elapsed.Nanoseconds())/1e6, "-", utils.FileWithLineNum()) + } else { + lg.Warn(ctx, lg.traceErrStr, sql, err, float64(elapsed.Nanoseconds())/1e6, rows, utils.FileWithLineNum()) + } + case elapsed > lg.SlowThreshold && lg.SlowThreshold != 0 && lg.LogLevel >= logger.Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", lg.SlowThreshold) + if rows == -1 { + lg.Warn(ctx, lg.traceErrStr, sql, slowLog, float64(elapsed.Nanoseconds())/1e6, "-", utils.FileWithLineNum()) + } else { + lg.Warn(ctx, lg.traceErrStr, sql, slowLog, float64(elapsed.Nanoseconds())/1e6, rows, utils.FileWithLineNum()) + } + case lg.LogLevel == logger.Info: + sql, rows := fc() + if rows == -1 { + lg.Info(ctx, lg.traceStr, sql, float64(elapsed.Nanoseconds())/1e6, "-", utils.FileWithLineNum()) + } else { + lg.Info(ctx, lg.traceStr, sql, float64(elapsed.Nanoseconds())/1e6, rows, utils.FileWithLineNum()) + } + } +} + +func newLogger(log logpkg.Logger) *Logger { + return &Logger{ + logger: log, + SlowThreshold: time.Second * 10, + traceStr: "%s [%.3fms] [rows:%v] in %s", + traceErrStr: "%s [%s] [%.3fms] [rows:%v] in %s", + } +} diff --git a/pkg/dbdialer/types.go b/pkg/dbdialer/types.go new file mode 100644 index 0000000..8b951b7 --- /dev/null +++ b/pkg/dbdialer/types.go @@ -0,0 +1,38 @@ +package dbdialer + +import ( + "git.nobla.cn/golang/aeus/pkg/logger" + "gorm.io/gorm" +) + +type ( + options struct { + log logger.Logger + cfg *gorm.Config + } + + Option func(o *options) +) + +func WithConfig(cfg *gorm.Config) Option { + return func(o *options) { + o.cfg = cfg + } +} + +func WithLogger(log logger.Logger) Option { + return func(o *options) { + o.log = log + } +} + +func newOptions(opts ...Option) *options { + o := &options{ + cfg: &gorm.Config{}, + log: logger.Default(), + } + for _, opt := range opts { + opt(o) + } + return o +}