moto/common/db/db.go

128 lines
3.0 KiB
Go

package db
import (
"context"
"git.nobla.cn/golang/moto/config"
"git.nobla.cn/golang/rest/types"
"github.com/go-sql-driver/mysql"
lru "github.com/hashicorp/golang-lru/v2"
mysqlDriver "gorm.io/driver/mysql"
"gorm.io/gorm"
"reflect"
"time"
)
var (
db *gorm.DB
)
var (
cacheInstance, _ = lru.New[string, *cacheEntry](64)
)
func WithDepend(s string, args ...any) CacheOption {
return func(o *CacheOptions) {
o.dependSQL = s
o.dependArgs = args
}
}
func TryCache(ctx context.Context, key string, f CachingFunc, cbs ...CacheOption) (value any, err error) {
var (
ok bool
hasDependValue bool
dependValue any
)
opts := &CacheOptions{}
for _, cb := range cbs {
cb(opts)
}
//从缓存加载数据
if value, ok = cacheInstance.Get(key); ok {
entry := value.(*cacheEntry)
if opts.dependSQL == "" {
return entry.storeValue, nil
}
//如果频繁访问,不检查依赖
if time.Since(entry.lastChecked) < time.Millisecond*500 {
return entry.storeValue, nil
}
//对比依赖值
if err = WithContext(ctx).Raw(opts.dependSQL, opts.dependArgs...).Scan(&dependValue).Error; err == nil {
hasDependValue = true
if reflect.DeepEqual(entry.compareValue, dependValue) {
entry.lastChecked = time.Now()
return entry.storeValue, nil
} else {
cacheInstance.Remove(key)
}
}
}
//从数据库加载数据
if value, err = f(WithContext(ctx)); err == nil {
if !hasDependValue {
if err = WithContext(ctx).Raw(opts.dependSQL, opts.dependArgs...).Scan(&dependValue).Error; err == nil {
cacheInstance.Add(key, &cacheEntry{
compareValue: dependValue,
storeValue: value,
createdAt: time.Now(),
lastChecked: time.Now(),
})
}
} else {
cacheInstance.Add(key, &cacheEntry{
compareValue: dependValue,
storeValue: value,
createdAt: time.Now(),
lastChecked: time.Now(),
})
}
return value, nil
} else {
return nil, err
}
}
func Init(ctx context.Context, cfg config.Database, plugins ...gorm.Plugin) (err error) {
dbCfg := &mysql.Config{
Net: "tcp",
Addr: cfg.Address,
User: cfg.Username,
Passwd: cfg.Password,
DBName: cfg.Database,
AllowNativePasswords: true,
AllowOldPasswords: true,
Collation: "utf8mb4_unicode_ci",
Loc: time.Local,
CheckConnLiveness: true,
Params: make(map[string]string),
ParseTime: true,
MaxAllowedPacket: 4 << 20,
}
dbCfg.Params["charset"] = "utf8mb4"
if db, err = gorm.Open(mysqlDriver.Open(dbCfg.FormatDSN())); err != nil {
return
}
db = db.WithContext(ctx)
for _, plugin := range plugins {
if err = db.Use(plugin); err != nil {
return
}
}
if err = db.AutoMigrate(&types.Schema{}); err != nil {
return
}
return
}
func DB() *gorm.DB {
if db == nil {
panic("database component not initialized")
}
return db
}
func WithContext(ctx context.Context) *gorm.DB {
return DB().WithContext(ctx)
}