add permission checker

This commit is contained in:
Yavolte 2025-06-11 17:43:22 +08:00
parent 5264ee005e
commit 43b48e1469
5 changed files with 167 additions and 36 deletions

View File

@ -32,6 +32,7 @@ type Model struct {
primaryKey string //主键
urlPrefix string //url前缀
disableDomain bool //禁用域
permissionChecker types.PermissionChecker //权限检查
schemaLookup types.SchemaLookupFunc //获取schema的函数
valueLookup types.ValueLookupFunc //查看域
statement *gorm.Statement //字段声明
@ -39,6 +40,7 @@ type Model struct {
response types.HttpWriter //HTTP响应
hookMgr *hookManager //钩子管理器
dirname string //存放文件目录
scenarios []string //场景
}
var (
@ -65,8 +67,19 @@ func (m *Model) getHook() *hookManager {
// hasScenario 判断是否有该场景
func (m *Model) hasScenario(s string) bool {
if len(m.scenarios) == 0 {
return true
}
return slices.Contains(m.scenarios, s)
}
// hasPermission 判断是否有权限
func (m *Model) hasPermission(ctx context.Context, s string) (err error) {
if m.permissionChecker != nil {
return m.permissionChecker.CheckPermission(ctx, m.Permission(s))
}
return nil
}
// setValue 设置字段的值
func (m *Model) setValue(refValue reflect.Value, column string, value any) {
@ -158,7 +171,7 @@ func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Report
modelType = modelType.Elem()
}
columns := make([]string, 0)
for i := 0; i < modelType.NumField(); i++ {
for i := range modelType.NumField() {
field := modelType.Field(i)
scenarios := field.Tag.Get("scenarios")
if !hasToken(types.ScenarioList, scenarios) {
@ -227,6 +240,30 @@ func (m *Model) Fields() []*schema.Field {
return m.statement.Schema.Fields
}
// Permission 获取权限
// 权限示例: "module.model.scenario"
// 比如: organize:user:list, organize:user:create
func (m *Model) Permission(scenario string) string {
ss := make([]string, 4)
switch scenario {
case types.ScenarioList:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, "list")
case types.ScenarioView:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, "get")
case types.ScenarioCreate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, "create")
case types.ScenarioUpdate:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, "update")
case types.ScenarioDelete:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, "delete")
case types.ScenarioExport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, "export")
case types.ScenarioImport:
ss = append(ss, m.naming.ModuleName, m.naming.Singular, "import")
}
return strings.Join(ss, ":")
}
// Uri 获取请求的uri
func (m *Model) Uri(scenario string) string {
ss := make([]string, 4)
@ -274,6 +311,7 @@ func (m *Model) Method(scenario string) string {
// Search 实现通过HTTP方法查找数据
func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
var (
ok bool
err error
@ -292,6 +330,14 @@ func (m *Model) Search(w http.ResponseWriter, r *http.Request) {
reporter types.Reporter
namerTable tableNamer
)
if !m.hasScenario(types.ScenarioList) {
m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil)
return
}
if err = m.hasPermission(r.Context(), types.ScenarioList); err != nil {
m.response.Failure(w, types.RequestDenied, err.Error(), nil)
return
}
qs = r.URL.Query()
page, _ = strconv.Atoi(qs.Get("page"))
pageSize, _ = strconv.Atoi(qs.Get("pagesize"))
@ -392,6 +438,14 @@ func (m *Model) Create(w http.ResponseWriter, r *http.Request) {
domainName string
modelValue reflect.Value
)
if !m.hasScenario(types.ScenarioCreate) {
m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil)
return
}
if err = m.hasPermission(r.Context(), types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestDenied, err.Error(), nil)
return
}
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil {
@ -473,6 +527,14 @@ func (m *Model) Update(w http.ResponseWriter, r *http.Request) {
modelValue reflect.Value
oldValues map[string]any
)
if !m.hasScenario(types.ScenarioUpdate) {
m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil)
return
}
if err = m.hasPermission(r.Context(), types.ScenarioUpdate); err != nil {
m.response.Failure(w, types.RequestDenied, err.Error(), nil)
return
}
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
@ -564,6 +626,14 @@ func (m *Model) Delete(w http.ResponseWriter, r *http.Request) {
model any
modelValue reflect.Value
)
if !m.hasScenario(types.ScenarioDelete) {
m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil)
return
}
if err = m.hasPermission(r.Context(), types.ScenarioDelete); err != nil {
m.response.Failure(w, types.RequestDenied, err.Error(), nil)
return
}
idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r)
modelValue = reflect.New(m.value.Type())
model = modelValue.Interface()
@ -619,6 +689,14 @@ func (m *Model) View(w http.ResponseWriter, r *http.Request) {
scenario string
domainName string
)
if !m.hasScenario(types.ScenarioView) {
m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil)
return
}
if err = m.hasPermission(r.Context(), types.ScenarioView); err != nil {
m.response.Failure(w, types.RequestDenied, err.Error(), nil)
return
}
qs = r.URL.Query()
idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r)
modelValue = reflect.New(m.value.Type())
@ -657,8 +735,12 @@ func (m *Model) Export(w http.ResponseWriter, r *http.Request) {
fp *os.File
modelValue reflect.Value
)
if !m.hasScenario(types.ScenarioList) {
m.response.Failure(w, types.RequestDenied, "request denied", nil)
if !m.hasScenario(types.ScenarioExport) {
m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil)
return
}
if err = m.hasPermission(r.Context(), types.ScenarioExport); err != nil {
m.response.Failure(w, types.RequestDenied, err.Error(), nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
@ -899,6 +981,14 @@ func (m *Model) Import(w http.ResponseWriter, r *http.Request) {
qs url.Values
extraFields map[string]string
)
if !m.hasScenario(types.ScenarioImport) {
m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil)
return
}
if err = m.hasPermission(r.Context(), types.ScenarioImport); err != nil {
m.response.Failure(w, types.RequestDenied, err.Error(), nil)
return
}
domainName = m.valueLookup(types.FieldDomain, w, r)
if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil {
m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil)

View File

@ -12,11 +12,27 @@ type options struct {
db *gorm.DB
router types.HttpRouter
writer types.HttpWriter
permissionChecker types.PermissionChecker
formatter *Formatter
resourceDirectory string
}
type Option func(o *options)
func (o *options) Clone() *options {
return &options{
urlPrefix: o.urlPrefix,
moduleName: o.moduleName,
disableDomain: o.disableDomain,
db: o.db,
router: o.router,
writer: o.writer,
permissionChecker: o.permissionChecker,
formatter: o.formatter,
resourceDirectory: o.resourceDirectory,
}
}
// WithDB 设置DB
func WithDB(db *gorm.DB) Option {
return func(o *options) {
@ -65,3 +81,17 @@ func WithFormatter(s *Formatter) Option {
o.formatter = s
}
}
// WithPermissionChecker 配置PermissionChecker
func WithPermissionChecker(s types.PermissionChecker) Option {
return func(o *options) {
o.permissionChecker = s
}
}
// WithResourceDirectory 配置资源目录
func WithResourceDirectory(s string) Option {
return func(o *options) {
o.resourceDirectory = s
}
}

18
rest.go
View File

@ -492,21 +492,13 @@ func Init(cbs ...Option) (err error) {
}
// AutoMigrate 自动合并表的schema
func AutoMigrate(ctx context.Context, model any, cbs ...Option) (err error) {
func AutoMigrate(ctx context.Context, model any, cbs ...Option) (modelValue *Model, err error) {
var (
opts *options
table string
router types.HttpRouter
)
opts = &options{
db: globalOpts.db,
router: globalOpts.router,
writer: globalOpts.writer,
formatter: globalOpts.formatter,
moduleName: globalOpts.moduleName,
urlPrefix: globalOpts.urlPrefix,
disableDomain: globalOpts.disableDomain,
}
opts = globalOpts.Clone()
for _, cb := range cbs {
cb(opts)
}
@ -514,14 +506,18 @@ func AutoMigrate(ctx context.Context, model any, cbs ...Option) (err error) {
return
}
//路由模块处理
modelValue := newModel(model, opts.db, types.Naming{
modelValue = newModel(model, opts.db, types.Naming{
Pluralize: inflector.Pluralize(table),
Singular: inflector.Singularize(table),
ModuleName: opts.moduleName,
TableName: table,
})
if scenarioModel, ok := model.(types.ScenarioModel); ok {
modelValue.scenarios = scenarioModel.Scenario()
}
modelValue.hookMgr = hookMgr
modelValue.schemaLookup = VisibleSchemas
modelValue.permissionChecker = opts.permissionChecker
if opts.router != nil {
router = opts.router
}

View File

@ -21,6 +21,10 @@ var (
allowMethods = []string{http.MethodPut, http.MethodPost}
)
var (
scenarioNotAllow = "request not allowed"
)
type (
httpWriter struct {
}

View File

@ -2,9 +2,10 @@ package types
import (
"context"
"gorm.io/gorm"
"net/http"
"time"
"gorm.io/gorm"
)
const (
@ -107,6 +108,16 @@ type (
Handle(method string, uri string, handler http.HandlerFunc)
}
// 权限检测器
PermissionChecker interface {
CheckPermission(ctx context.Context, permission string) error
}
// 模型场景
ScenarioModel interface {
Scenario() []string
}
// TypeValue 键值对数据
TypeValue struct {
Label any `json:"label"`