diff --git a/model.go b/model.go index 4abd4db..27e88c4 100644 --- a/model.go +++ b/model.go @@ -26,19 +26,21 @@ import ( ) type Model struct { - naming types.Naming //命名规则 - value reflect.Value //模块值 - db *gorm.DB //数据库 - primaryKey string //主键 - urlPrefix string //url前缀 - disableDomain bool //禁用域 - schemaLookup types.SchemaLookupFunc //获取schema的函数 - valueLookup types.ValueLookupFunc //查看域 - statement *gorm.Statement //字段声明 - formatter *Formatter //格式化 - response types.HttpWriter //HTTP响应 - hookMgr *hookManager //钩子管理器 - dirname string //存放文件目录 + naming types.Naming //命名规则 + value reflect.Value //模块值 + db *gorm.DB //数据库 + primaryKey string //主键 + urlPrefix string //url前缀 + disableDomain bool //禁用域 + permissionChecker types.PermissionChecker //权限检查 + schemaLookup types.SchemaLookupFunc //获取schema的函数 + valueLookup types.ValueLookupFunc //查看域 + statement *gorm.Statement //字段声明 + formatter *Formatter //格式化 + response types.HttpWriter //HTTP响应 + hookMgr *hookManager //钩子管理器 + dirname string //存放文件目录 + scenarios []string //场景 } var ( @@ -65,7 +67,18 @@ func (m *Model) getHook() *hookManager { // hasScenario 判断是否有该场景 func (m *Model) hasScenario(s string) bool { - return true + if len(m.scenarios) == 0 { + return true + } + return slices.Contains(m.scenarios, s) +} + +// hasPermission 判断是否有权限 +func (m *Model) hasPermission(ctx context.Context, s string) (err error) { + if m.permissionChecker != nil { + return m.permissionChecker.CheckPermission(ctx, m.Permission(s)) + } + return nil } // setValue 设置字段的值 @@ -158,7 +171,7 @@ func (m *Model) buildReporterCountColumns(ctx context.Context, dest types.Report modelType = modelType.Elem() } columns := make([]string, 0) - for i := 0; i < modelType.NumField(); i++ { + for i := range modelType.NumField() { field := modelType.Field(i) scenarios := field.Tag.Get("scenarios") if !hasToken(types.ScenarioList, scenarios) { @@ -227,6 +240,30 @@ func (m *Model) Fields() []*schema.Field { return m.statement.Schema.Fields } +// Permission 获取权限 +// 权限示例: "module.model.scenario" +// 比如: organize:user:list, organize:user:create +func (m *Model) Permission(scenario string) string { + ss := make([]string, 4) + switch scenario { + case types.ScenarioList: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, "list") + case types.ScenarioView: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, "get") + case types.ScenarioCreate: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, "create") + case types.ScenarioUpdate: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, "update") + case types.ScenarioDelete: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, "delete") + case types.ScenarioExport: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, "export") + case types.ScenarioImport: + ss = append(ss, m.naming.ModuleName, m.naming.Singular, "import") + } + return strings.Join(ss, ":") +} + // Uri 获取请求的uri func (m *Model) Uri(scenario string) string { ss := make([]string, 4) @@ -274,6 +311,7 @@ func (m *Model) Method(scenario string) string { // Search 实现通过HTTP方法查找数据 func (m *Model) Search(w http.ResponseWriter, r *http.Request) { + var ( ok bool err error @@ -292,6 +330,14 @@ func (m *Model) Search(w http.ResponseWriter, r *http.Request) { reporter types.Reporter namerTable tableNamer ) + if !m.hasScenario(types.ScenarioList) { + m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil) + return + } + if err = m.hasPermission(r.Context(), types.ScenarioList); err != nil { + m.response.Failure(w, types.RequestDenied, err.Error(), nil) + return + } qs = r.URL.Query() page, _ = strconv.Atoi(qs.Get("page")) pageSize, _ = strconv.Atoi(qs.Get("pagesize")) @@ -392,6 +438,14 @@ func (m *Model) Create(w http.ResponseWriter, r *http.Request) { domainName string modelValue reflect.Value ) + if !m.hasScenario(types.ScenarioCreate) { + m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil) + return + } + if err = m.hasPermission(r.Context(), types.ScenarioCreate); err != nil { + m.response.Failure(w, types.RequestDenied, err.Error(), nil) + return + } modelValue = reflect.New(m.value.Type()) model = modelValue.Interface() if err = json.NewDecoder(r.Body).Decode(modelValue.Interface()); err != nil { @@ -473,6 +527,14 @@ func (m *Model) Update(w http.ResponseWriter, r *http.Request) { modelValue reflect.Value oldValues map[string]any ) + if !m.hasScenario(types.ScenarioUpdate) { + m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil) + return + } + if err = m.hasPermission(r.Context(), types.ScenarioUpdate); err != nil { + m.response.Failure(w, types.RequestDenied, err.Error(), nil) + return + } idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) modelValue = reflect.New(m.value.Type()) model = modelValue.Interface() @@ -564,6 +626,14 @@ func (m *Model) Delete(w http.ResponseWriter, r *http.Request) { model any modelValue reflect.Value ) + if !m.hasScenario(types.ScenarioDelete) { + m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil) + return + } + if err = m.hasPermission(r.Context(), types.ScenarioDelete); err != nil { + m.response.Failure(w, types.RequestDenied, err.Error(), nil) + return + } idStr := m.findPrimaryKey(m.Uri(types.ScenarioDelete), r) modelValue = reflect.New(m.value.Type()) model = modelValue.Interface() @@ -619,6 +689,14 @@ func (m *Model) View(w http.ResponseWriter, r *http.Request) { scenario string domainName string ) + if !m.hasScenario(types.ScenarioView) { + m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil) + return + } + if err = m.hasPermission(r.Context(), types.ScenarioView); err != nil { + m.response.Failure(w, types.RequestDenied, err.Error(), nil) + return + } qs = r.URL.Query() idStr := m.findPrimaryKey(m.Uri(types.ScenarioUpdate), r) modelValue = reflect.New(m.value.Type()) @@ -657,8 +735,12 @@ func (m *Model) Export(w http.ResponseWriter, r *http.Request) { fp *os.File modelValue reflect.Value ) - if !m.hasScenario(types.ScenarioList) { - m.response.Failure(w, types.RequestDenied, "request denied", nil) + if !m.hasScenario(types.ScenarioExport) { + m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil) + return + } + if err = m.hasPermission(r.Context(), types.ScenarioExport); err != nil { + m.response.Failure(w, types.RequestDenied, err.Error(), nil) return } domainName = m.valueLookup(types.FieldDomain, w, r) @@ -899,6 +981,14 @@ func (m *Model) Import(w http.ResponseWriter, r *http.Request) { qs url.Values extraFields map[string]string ) + if !m.hasScenario(types.ScenarioImport) { + m.response.Failure(w, types.RequestDenied, scenarioNotAllow, nil) + return + } + if err = m.hasPermission(r.Context(), types.ScenarioImport); err != nil { + m.response.Failure(w, types.RequestDenied, err.Error(), nil) + return + } domainName = m.valueLookup(types.FieldDomain, w, r) if schemas, err = m.schemaLookup(r.Context(), m.getDB(), domainName, m.naming.ModuleName, m.naming.TableName, types.ScenarioCreate); err != nil { m.response.Failure(w, types.RequestRecordNotFound, "record not found", nil) diff --git a/options.go b/options.go index ca0cf2f..d5dc976 100644 --- a/options.go +++ b/options.go @@ -6,17 +6,33 @@ import ( ) type options struct { - urlPrefix string - moduleName string - disableDomain bool - db *gorm.DB - router types.HttpRouter - writer types.HttpWriter - formatter *Formatter + urlPrefix string + moduleName string + disableDomain bool + db *gorm.DB + router types.HttpRouter + writer types.HttpWriter + permissionChecker types.PermissionChecker + formatter *Formatter + resourceDirectory string } type Option func(o *options) +func (o *options) Clone() *options { + return &options{ + urlPrefix: o.urlPrefix, + moduleName: o.moduleName, + disableDomain: o.disableDomain, + db: o.db, + router: o.router, + writer: o.writer, + permissionChecker: o.permissionChecker, + formatter: o.formatter, + resourceDirectory: o.resourceDirectory, + } +} + // WithDB 设置DB func WithDB(db *gorm.DB) Option { return func(o *options) { @@ -65,3 +81,17 @@ func WithFormatter(s *Formatter) Option { o.formatter = s } } + +// WithPermissionChecker 配置PermissionChecker +func WithPermissionChecker(s types.PermissionChecker) Option { + return func(o *options) { + o.permissionChecker = s + } +} + +// WithResourceDirectory 配置资源目录 +func WithResourceDirectory(s string) Option { + return func(o *options) { + o.resourceDirectory = s + } +} diff --git a/rest.go b/rest.go index d160539..f3a3ed2 100644 --- a/rest.go +++ b/rest.go @@ -492,21 +492,13 @@ func Init(cbs ...Option) (err error) { } // AutoMigrate 自动合并表的schema -func AutoMigrate(ctx context.Context, model any, cbs ...Option) (err error) { +func AutoMigrate(ctx context.Context, model any, cbs ...Option) (modelValue *Model, err error) { var ( opts *options table string router types.HttpRouter ) - opts = &options{ - db: globalOpts.db, - router: globalOpts.router, - writer: globalOpts.writer, - formatter: globalOpts.formatter, - moduleName: globalOpts.moduleName, - urlPrefix: globalOpts.urlPrefix, - disableDomain: globalOpts.disableDomain, - } + opts = globalOpts.Clone() for _, cb := range cbs { cb(opts) } @@ -514,14 +506,18 @@ func AutoMigrate(ctx context.Context, model any, cbs ...Option) (err error) { return } //路由模块处理 - modelValue := newModel(model, opts.db, types.Naming{ + modelValue = newModel(model, opts.db, types.Naming{ Pluralize: inflector.Pluralize(table), Singular: inflector.Singularize(table), ModuleName: opts.moduleName, TableName: table, }) + if scenarioModel, ok := model.(types.ScenarioModel); ok { + modelValue.scenarios = scenarioModel.Scenario() + } modelValue.hookMgr = hookMgr modelValue.schemaLookup = VisibleSchemas + modelValue.permissionChecker = opts.permissionChecker if opts.router != nil { router = opts.router } diff --git a/types.go b/types.go index 4d50a79..52f0319 100644 --- a/types.go +++ b/types.go @@ -21,6 +21,10 @@ var ( allowMethods = []string{http.MethodPut, http.MethodPost} ) +var ( + scenarioNotAllow = "request not allowed" +) + type ( httpWriter struct { } diff --git a/types/types.go b/types/types.go index ff2d9d2..47b165f 100644 --- a/types/types.go +++ b/types/types.go @@ -2,9 +2,10 @@ package types import ( "context" - "gorm.io/gorm" "net/http" "time" + + "gorm.io/gorm" ) const ( @@ -107,6 +108,16 @@ type ( Handle(method string, uri string, handler http.HandlerFunc) } + // 权限检测器 + PermissionChecker interface { + CheckPermission(ctx context.Context, permission string) error + } + + // 模型场景 + ScenarioModel interface { + Scenario() []string + } + // TypeValue 键值对数据 TypeValue struct { Label any `json:"label"`