package dbdialer import ( "context" "errors" "fmt" "time" logpkg "git.nobla.cn/golang/aeus/pkg/logger" "gorm.io/gorm/logger" "gorm.io/gorm/utils" ) type Logger struct { LogLevel logger.LogLevel SlowThreshold time.Duration traceStr string traceErrStr string logger logpkg.Logger } func (lg *Logger) LogMode(level logger.LogLevel) logger.Interface { lg.LogLevel = level return lg } func (lg *Logger) Info(ctx context.Context, s string, i ...interface{}) { if lg.LogLevel >= logger.Info { lg.logger.Info(ctx, s, i...) } } func (lg *Logger) Warn(ctx context.Context, s string, i ...interface{}) { if lg.LogLevel >= logger.Warn { lg.logger.Warn(ctx, s, i...) } } func (lg *Logger) Error(ctx context.Context, s string, i ...interface{}) { if lg.LogLevel >= logger.Error { lg.logger.Error(ctx, s, i...) } } func (lg *Logger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { if lg.LogLevel <= logger.Silent { return } elapsed := time.Since(begin) switch { case err != nil && lg.LogLevel >= logger.Error && (!errors.Is(err, logger.ErrRecordNotFound)): sql, rows := fc() if rows == -1 { lg.Warn(ctx, lg.traceErrStr, sql, err, float64(elapsed.Nanoseconds())/1e6, "-", utils.FileWithLineNum()) } else { lg.Warn(ctx, lg.traceErrStr, sql, err, float64(elapsed.Nanoseconds())/1e6, rows, utils.FileWithLineNum()) } case elapsed > lg.SlowThreshold && lg.SlowThreshold != 0 && lg.LogLevel >= logger.Warn: sql, rows := fc() slowLog := fmt.Sprintf("SLOW SQL >= %v", lg.SlowThreshold) if rows == -1 { lg.Warn(ctx, lg.traceErrStr, sql, slowLog, float64(elapsed.Nanoseconds())/1e6, "-", utils.FileWithLineNum()) } else { lg.Warn(ctx, lg.traceErrStr, sql, slowLog, float64(elapsed.Nanoseconds())/1e6, rows, utils.FileWithLineNum()) } case lg.LogLevel == logger.Info: sql, rows := fc() if rows == -1 { lg.Info(ctx, lg.traceStr, sql, float64(elapsed.Nanoseconds())/1e6, "-", utils.FileWithLineNum()) } else { lg.Info(ctx, lg.traceStr, sql, float64(elapsed.Nanoseconds())/1e6, rows, utils.FileWithLineNum()) } } } func newLogger(log logpkg.Logger) *Logger { return &Logger{ logger: log, SlowThreshold: time.Second * 10, traceStr: "%s [%.3fms] [rows:%v] in %s", traceErrStr: "%s [%s] [%.3fms] [rows:%v] in %s", } }