package cache import ( "encoding/json" "fmt" "git.nobla.cn/golang/kos/pkg/cache" xxhash "github.com/cespare/xxhash/v2" "gorm.io/gorm" "gorm.io/gorm/callbacks" "os" "strconv" "time" ) const ( DisableCache = "DISABLE_CACHE" DurationKey = "gorm:cache_duration" ) type Cacher struct { rawQuery func(db *gorm.DB) } func (c *Cacher) Name() string { return "gorm:cache" } func (c *Cacher) Initialize(db *gorm.DB) (err error) { c.rawQuery = db.Callback().Query().Get("gorm:query") err = db.Callback().Query().Replace("gorm:query", c.Query) return } // buildCacheKey 构建一个缓存的KEY func (c *Cacher) buildCacheKey(db *gorm.DB) string { s := strconv.FormatUint(xxhash.Sum64String(db.Statement.SQL.String()+fmt.Sprintf("%v", db.Statement.Vars)), 10) return s } // getDuration 获取缓存时长 func (c *Cacher) getDuration(db *gorm.DB) time.Duration { var ( ok bool v any duration time.Duration ) if v, ok = db.InstanceGet(DurationKey); !ok { return 0 } if duration, ok = v.(time.Duration); !ok { return 0 } return duration } // tryLoad 尝试从缓存读取数据 func (c *Cacher) tryLoad(key string, db *gorm.DB) (err error) { var ( ok bool buf []byte ) if buf, ok = cache.Get(db.Statement.Context, key); ok { err = json.Unmarshal(buf, db.Statement.Dest) } else { err = os.ErrNotExist } return } // storeCache 存储缓存数据 func (c *Cacher) storeCache(key string, db *gorm.DB, duration time.Duration) (err error) { var ( buf []byte ) if buf, err = json.Marshal(db.Statement.Dest); err == nil { cache.SetEx(db.Statement.Context, key, buf, duration) } return } func (c *Cacher) Query(db *gorm.DB) { var ( err error cacheKey string duration time.Duration ) duration = c.getDuration(db) if duration <= 0 { c.rawQuery(db) return } callbacks.BuildQuerySQL(db) cacheKey = c.buildCacheKey(db) if err = c.tryLoad(cacheKey, db); err == nil { return } c.rawQuery(db) if db.Error == nil { //store cache if err = c.storeCache(cacheKey, db, duration); err != nil { _ = db.AddError(err) } } } func New() *Cacher { return &Cacher{} }