diff --git a/formatter.go b/formatter.go index 818caf0..82f3873 100644 --- a/formatter.go +++ b/formatter.go @@ -54,7 +54,7 @@ func (f *Formatter) FormatRole(ctx context.Context, value, model any, scm *types } func (f *Formatter) FormatMenu(ctx context.Context, value, model any, scm *types.Schema) any { - if values, err := f.menu.GetLabels(ctx, utils.GetDomainFromContext(ctx)); err == nil { + if values, err := f.menu.GetLabels(ctx); err == nil { for _, row := range values { if row.Value == value { return row.Label diff --git a/internal/logic/menu.go b/internal/logic/menu.go index e0fb5a3..1c045b8 100644 --- a/internal/logic/menu.go +++ b/internal/logic/menu.go @@ -2,7 +2,6 @@ package logic import ( "context" - "fmt" "git.nobla.cn/golang/aeus-admin/models" "git.nobla.cn/golang/aeus-admin/pkg/dbcache" @@ -18,14 +17,10 @@ type Menu struct { sqlDependency *dbcache.SqlDependency } -func (u *Menu) GetMenus(ctx context.Context, domainName string) (values []*models.Menu, err error) { - return dbcache.TryCache(ctx, fmt.Sprintf("menus:%s", domainName), func(tx *gorm.DB) ([]*models.Menu, error) { +func (u *Menu) GetMenus(ctx context.Context) (values []*models.Menu, err error) { + return dbcache.TryCache(ctx, "menus", func(tx *gorm.DB) ([]*models.Menu, error) { var items []*models.Menu - if domainName == "" { - err = tx.Order("`position`,`id` ASC").Find(&items).Error - } else { - err = tx.Where("`domain`=?", domainName).Order("`position`,`id` ASC").Find(&items).Error - } + err = tx.Order("`position`,`id` ASC").Find(&items).Error return items, err }, dbcache.WithDB(u.db), @@ -34,9 +29,9 @@ func (u *Menu) GetMenus(ctx context.Context, domainName string) (values []*model ) } -func (u *Menu) GetLabels(ctx context.Context, domainName string) (values []*types.TypeValue[string], err error) { - return dbcache.TryCache(ctx, fmt.Sprintf("menu:labels:%s", domainName), func(tx *gorm.DB) ([]*types.TypeValue[string], error) { - return rest.ModelTypes[string](ctx, tx, &models.Menu{}, domainName, "label", "name") +func (u *Menu) GetLabels(ctx context.Context) (values []*types.TypeValue[string], err error) { + return dbcache.TryCache(ctx, "menu:labels", func(tx *gorm.DB) ([]*types.TypeValue[string], error) { + return rest.ModelTypes[string](ctx, tx, &models.Menu{}, "", "label", "name") }, dbcache.WithDB(u.db), dbcache.WithCache(u.cache), diff --git a/migrate/migrate.go b/migrate/migrate.go index 422d8a3..d6246b7 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -11,7 +11,7 @@ import ( func Menu(db *gorm.DB, datas ...*models.Menu) (err error) { tx := db.Begin() for _, model := range datas { - if err = tx.Where("name = ?", model.Name).First(model).Error; err != nil { + if err = tx.Where("`name` = ?", model.Name).First(model).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { if err = tx.Create(model).Error; err != nil { tx.Rollback() @@ -39,7 +39,7 @@ func Permission(db *gorm.DB, menuName string, permission string, label string) ( } // Default 合并初始化数据集 -func Default(db *gorm.DB) (err error) { +func Default(db *gorm.DB, domain string) (err error) { var ( n int64 ) @@ -49,6 +49,9 @@ func Default(db *gorm.DB) (err error) { } } if db.Model(&models.Role{}).Count(&n); n == 0 { + for i := range defaultRoles { + defaultRoles[i].Domain = domain + } db.Create(defaultRoles) permissions := make([]*models.Permission, 0) db.Find(&permissions) @@ -65,9 +68,15 @@ func Default(db *gorm.DB) (err error) { } if db.Model(&models.Department{}).Count(&n); n == 0 { + for i := range defaultDepartments { + defaultDepartments[i].Domain = domain + } db.Create(defaultDepartments) } if db.Model(&models.User{}).Count(&n); n == 0 { + for i := range defaultUsers { + defaultUsers[i].Domain = domain + } db.Create(defaultUsers) } return diff --git a/server.go b/server.go index 78fd200..26dd9b8 100644 --- a/server.go +++ b/server.go @@ -226,7 +226,7 @@ func initREST(ctx context.Context, o *options) (err error) { if o.apiPrefix != "" { opts = append(opts, rest.WithUriPrefix(o.apiPrefix)) } - if o.disableDomain { + if !o.enableDomain { opts = append(opts, rest.WithoutDomain()) } if err = rest.Init(opts...); err != nil { @@ -540,7 +540,7 @@ func Init(ctx context.Context, cbs ...Option) (err error) { registerRESTRoute(opts.domain, opts.db, opts.httpServer) } if !opts.disableDefault { - if err = migrate.Default(opts.db); err != nil { + if err = migrate.Default(opts.db, opts.domain); err != nil { return } } diff --git a/service/auth.go b/service/auth.go index f6a3e97..d4ca7c1 100644 --- a/service/auth.go +++ b/service/auth.go @@ -28,7 +28,7 @@ type ( turnstileValidateUrl string turnstileSiteKey string enableRefreshToken bool - disableDomain bool + enableDomain bool } turnstileRequest struct { @@ -61,9 +61,9 @@ func WithAuthDB(db *gorm.DB) AuthOption { } } -func WithoutDomain() AuthOption { +func WithAuthDomain() AuthOption { return func(o *authOptions) { - o.disableDomain = true + o.enableDomain = true } } @@ -73,7 +73,7 @@ func WithAuthCache(cache cache.Cache) AuthOption { } } -func WithRefreshToken() AuthOption { +func WithAuthRefreshToken() AuthOption { return func(o *authOptions) { o.enableRefreshToken = true } @@ -166,7 +166,7 @@ func (s *AuthService) Login(ctx context.Context, req *pb.LoginRequest) (res *pb. IssuedAt: time.Now().Unix(), ExpirationAt: time.Now().Add(time.Second * time.Duration(s.opts.ttl)).Unix(), } - if !s.opts.disableDomain { + if s.opts.enableDomain { claims.Domain = model.Domain } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) @@ -185,7 +185,7 @@ func (s *AuthService) Login(ctx context.Context, req *pb.LoginRequest) (res *pb. IssuedAt: time.Now().Unix(), ExpirationAt: time.Now().Add(time.Hour * 48).Unix(), } - if !s.opts.disableDomain { + if s.opts.enableDomain { refreshClaims.Domain = model.Domain } refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims) diff --git a/service/menu.go b/service/menu.go index 235dba9..98ae305 100644 --- a/service/menu.go +++ b/service/menu.go @@ -6,7 +6,6 @@ import ( "git.nobla.cn/golang/aeus-admin/internal/logic" "git.nobla.cn/golang/aeus-admin/models" "git.nobla.cn/golang/aeus-admin/pb" - "git.nobla.cn/golang/aeus-admin/utils" "git.nobla.cn/golang/aeus/pkg/cache" "gorm.io/gorm" ) @@ -74,8 +73,7 @@ func (s *MenuService) GetMenus(ctx context.Context, req *pb.GetMenuRequest) (res var ( items []*models.Menu ) - domainName := utils.GetDomainFromContext(ctx) - if items, err = s.logic.GetMenus(ctx, domainName); err != nil { + if items, err = s.logic.GetMenus(ctx); err != nil { return } res = &pb.GetMenuResponse{ @@ -88,8 +86,7 @@ func (s *MenuService) GetMenuLevelLabels(ctx context.Context, req *pb.GetMenuLev var ( items []*models.Menu ) - domainName := utils.GetDomainFromContext(ctx) - if items, err = s.logic.GetMenus(ctx, domainName); err != nil { + if items, err = s.logic.GetMenus(ctx); err != nil { return } res = &pb.GetMenuLevelLabelsResponse{ diff --git a/service/user.go b/service/user.go index 5323497..adf4492 100644 --- a/service/user.go +++ b/service/user.go @@ -123,7 +123,7 @@ func (s *UserService) GetMenus(ctx context.Context, req *pb.GetUserMenuRequest) if err = tx.Where("`uid`=? AND `domain`=?", uid, domainName).First(userModel).Error; err != nil { return nil, err } - if menus, err = s.menu.GetMenus(ctx, domainName); err != nil { + if menus, err = s.menu.GetMenus(ctx); err != nil { return nil, err } roleName := userModel.Role @@ -280,9 +280,15 @@ func (s *UserService) DepartmentUserNested(ctx context.Context) []*types.NestedV func (s *UserService) GetUserTags(ctx context.Context, req *pb.GetUserTagRequest) (res *pb.GetUserTagResponse, err error) { res = &pb.GetUserTagResponse{} - res.Data, err = dbcache.TryCache(ctx, fmt.Sprintf("user:tags"), func(tx *gorm.DB) ([]*pb.LabelValue, error) { + domainName := utils.GetDomainFromContext(ctx) + res.Data, err = dbcache.TryCache(ctx, fmt.Sprintf("user:tags:%s", domainName), func(tx *gorm.DB) ([]*pb.LabelValue, error) { values := make([]*models.User, 0) - if err = tx.Select("DISTINCT(`tag`) AS `tag`").Find(&values).Error; err == nil { + if domainName == "" { + err = tx.Select("DISTINCT(`tag`) AS `tag`").Find(&values).Error + } else { + err = tx.Select("DISTINCT(`tag`) AS `tag`").Where("`domain`=?", domainName).Find(&values).Error + } + if err == nil { items := make([]*pb.LabelValue, 0, len(values)) for _, v := range values { if v.Tag == "" { diff --git a/types.go b/types.go index fc382ab..4ab3a60 100644 --- a/types.go +++ b/types.go @@ -34,7 +34,7 @@ type ( disableModels bool httpServer *http.Server restOpts []rest.Option - disableDomain bool + enableDomain bool } Option func(*options) @@ -74,9 +74,10 @@ func WithCache(cache cache.Cache) Option { } } -func WithoutDomain() Option { +func WithDomain(domain string) Option { return func(o *options) { - o.disableDomain = true + o.enableDomain = true + o.domain = domain } }