diff --git a/middleware/auth/jwt.go b/middleware/auth/jwt.go index 50d9e86..5a4ff40 100644 --- a/middleware/auth/jwt.go +++ b/middleware/auth/jwt.go @@ -2,6 +2,7 @@ package auth import ( "context" + "reflect" "strings" "git.nobla.cn/golang/aeus/metadata" @@ -32,16 +33,7 @@ type Option func(*options) // Parser is a jwt parser type options struct { allows []string - claims func() jwt.Claims -} - -// WithClaims with customer claim -// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems -// If you use it in Client, f only needs to return a single object to provide performance -func WithClaims(f func() jwt.Claims) Option { - return func(o *options) { - o.claims = f - } + claims reflect.Type } // WithAllow with allow path @@ -54,6 +46,12 @@ func WithAllow(paths ...string) Option { } } +func WithClaims(claims reflect.Type) Option { + return func(o *options) { + o.claims = claims + } +} + // isAllowed check if the path is allowed func isAllowed(uripath string, allows []string) bool { for _, str := range allows { @@ -70,7 +68,7 @@ func isAllowed(uripath string, allows []string) bool { return true } } - return true + return false } // JWT auth middleware @@ -90,21 +88,24 @@ func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware { } } } - authorizationValue, ok := md.Get(authorizationKey) + token, ok := md.Get(authorizationKey) if !ok { return errors.ErrAccessDenied } - if !strings.HasPrefix(authorizationValue, bearerWord) { - return errors.ErrAccessDenied + if strings.HasPrefix(token, bearerWord) { + token = strings.TrimPrefix(token, bearerWord) } var ( ti *jwt.Token ) - authorizationToken := strings.TrimSpace(strings.TrimPrefix(authorizationValue, bearerWord)) + token = strings.TrimSpace(token) if opts.claims != nil { - ti, err = jwt.ParseWithClaims(authorizationToken, opts.claims(), keyFunc) - } else { - ti, err = jwt.Parse(authorizationToken, keyFunc) + if claims, ok := reflect.New(opts.claims).Interface().(jwt.Claims); ok { + ti, err = jwt.ParseWithClaims(token, claims, keyFunc) + } + } + if ti == nil { + ti, err = jwt.Parse(token, keyFunc) } if err != nil { if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) { diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 08b9edb..c90207b 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -13,7 +13,7 @@ var ( type Cache interface { // Get gets a cached value by key. - Get(ctx context.Context, key string) (any, time.Time, error) + Get(ctx context.Context, key string) (any, error) // Put stores a key-value pair into cache. Put(ctx context.Context, key string, val any, d time.Duration) error // Delete removes a key from cache. @@ -22,8 +22,12 @@ type Cache interface { String() string } +func Default() Cache { + return std +} + // Get gets a cached value by key. -func Get(ctx context.Context, key string) (any, time.Time, error) { +func Get(ctx context.Context, key string) (any, error) { return std.Get(ctx, key) } diff --git a/pkg/cache/memory/cache.go b/pkg/cache/memory/cache.go index fc57169..d189c0b 100644 --- a/pkg/cache/memory/cache.go +++ b/pkg/cache/memory/cache.go @@ -15,22 +15,22 @@ type memCache struct { sync.RWMutex } -func (c *memCache) Get(ctx context.Context, key string) (interface{}, time.Time, error) { +func (c *memCache) Get(ctx context.Context, key string) (any, error) { c.RWMutex.RLock() defer c.RWMutex.RUnlock() item, found := c.items[key] if !found { - return nil, time.Time{}, errors.ErrNotFound + return nil, errors.ErrNotFound } if item.Expired() { - return nil, time.Time{}, errors.ErrExpired + return nil, errors.ErrExpired } - return item.Value, time.Unix(0, item.Expiration), nil + return item.Value, nil } -func (c *memCache) Put(ctx context.Context, key string, val interface{}, d time.Duration) error { +func (c *memCache) Put(ctx context.Context, key string, val any, d time.Duration) error { var e int64 if d == DefaultExpiration { d = c.opts.Expiration diff --git a/pkg/cache/memory/item.go b/pkg/cache/memory/item.go index c35aebf..3e339f8 100644 --- a/pkg/cache/memory/item.go +++ b/pkg/cache/memory/item.go @@ -4,7 +4,7 @@ import "time" // Item represents an item stored in the cache. type Item struct { - Value interface{} + Value any Expiration int64 } diff --git a/pkg/errors/const.go b/pkg/errors/const.go index 18be489..f71eed9 100644 --- a/pkg/errors/const.go +++ b/pkg/errors/const.go @@ -6,6 +6,7 @@ const ( Invalid = 1001 //payload invalid Exists = 1002 //already exists Unavailable = 1003 //service unavailable + Incompatible = 1004 //type incompatible Timeout = 2001 //timeout Expired = 2002 //expired TokenExpired = 4002 //token expired @@ -29,4 +30,5 @@ var ( ErrUnavailable = New(Unavailable, "service unavailable") ErrNetworkUnreachable = New(NetworkUnreachable, "network unreachable") ErrConnectionRefused = New(ConnectionRefused, "connection refused") + ErrIncompatible = New(Incompatible, "incompatible") ) diff --git a/pkg/proto/rest/rest.pb.go b/pkg/proto/rest/rest.pb.go index 1c8dcf8..eb610dc 100644 --- a/pkg/proto/rest/rest.pb.go +++ b/pkg/proto/rest/rest.pb.go @@ -31,6 +31,16 @@ type RestFieldOptions struct { Format string `protobuf:"bytes,5,opt,name=format,proto3" json:"format,omitempty"` Props string `protobuf:"bytes,6,opt,name=props,proto3" json:"props,omitempty"` Rule string `protobuf:"bytes,7,opt,name=rule,proto3" json:"rule,omitempty"` + Live string `protobuf:"bytes,8,opt,name=live,proto3" json:"live,omitempty"` + Dropdown string `protobuf:"bytes,9,opt,name=dropdown,proto3" json:"dropdown,omitempty"` + Enum string `protobuf:"bytes,10,opt,name=enum,proto3" json:"enum,omitempty"` + Match string `protobuf:"bytes,11,opt,name=match,proto3" json:"match,omitempty"` + Invisible string `protobuf:"bytes,12,opt,name=invisible,proto3" json:"invisible,omitempty"` + Tooltip string `protobuf:"bytes,13,opt,name=tooltip,proto3" json:"tooltip,omitempty"` + Uploaduri string `protobuf:"bytes,14,opt,name=uploaduri,proto3" json:"uploaduri,omitempty"` + Description string `protobuf:"bytes,15,opt,name=description,proto3" json:"description,omitempty"` + Readonly string `protobuf:"bytes,16,opt,name=readonly,proto3" json:"readonly,omitempty"` + Endofnow string `protobuf:"bytes,17,opt,name=endofnow,proto3" json:"endofnow,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -114,6 +124,76 @@ func (x *RestFieldOptions) GetRule() string { return "" } +func (x *RestFieldOptions) GetLive() string { + if x != nil { + return x.Live + } + return "" +} + +func (x *RestFieldOptions) GetDropdown() string { + if x != nil { + return x.Dropdown + } + return "" +} + +func (x *RestFieldOptions) GetEnum() string { + if x != nil { + return x.Enum + } + return "" +} + +func (x *RestFieldOptions) GetMatch() string { + if x != nil { + return x.Match + } + return "" +} + +func (x *RestFieldOptions) GetInvisible() string { + if x != nil { + return x.Invisible + } + return "" +} + +func (x *RestFieldOptions) GetTooltip() string { + if x != nil { + return x.Tooltip + } + return "" +} + +func (x *RestFieldOptions) GetUploaduri() string { + if x != nil { + return x.Uploaduri + } + return "" +} + +func (x *RestFieldOptions) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + +func (x *RestFieldOptions) GetReadonly() string { + if x != nil { + return x.Readonly + } + return "" +} + +func (x *RestFieldOptions) GetEndofnow() string { + if x != nil { + return x.Endofnow + } + return "" +} + type RestMessageOptions struct { state protoimpl.MessageState `protogen:"open.v1"` Table string `protobuf:"bytes,1,opt,name=table,proto3" json:"table,omitempty"` @@ -198,7 +278,7 @@ var File_rest_proto protoreflect.FileDescriptor const file_rest_proto_rawDesc = "" + "\n" + "\n" + - "rest.proto\x12\x04aeus\x1a google/protobuf/descriptor.proto\"\xbc\x01\n" + + "rest.proto\x12\x04aeus\x1a google/protobuf/descriptor.proto\"\xc6\x03\n" + "\x10RestFieldOptions\x12\x12\n" + "\x04gorm\x18\x01 \x01(\tR\x04gorm\x12\x18\n" + "\acomment\x18\x02 \x01(\tR\acomment\x12\x1c\n" + @@ -206,7 +286,18 @@ const file_rest_proto_rawDesc = "" + "\bposition\x18\x04 \x01(\tR\bposition\x12\x16\n" + "\x06format\x18\x05 \x01(\tR\x06format\x12\x14\n" + "\x05props\x18\x06 \x01(\tR\x05props\x12\x12\n" + - "\x04rule\x18\a \x01(\tR\x04rule\"*\n" + + "\x04rule\x18\a \x01(\tR\x04rule\x12\x12\n" + + "\x04live\x18\b \x01(\tR\x04live\x12\x1a\n" + + "\bdropdown\x18\t \x01(\tR\bdropdown\x12\x12\n" + + "\x04enum\x18\n" + + " \x01(\tR\x04enum\x12\x14\n" + + "\x05match\x18\v \x01(\tR\x05match\x12\x1c\n" + + "\tinvisible\x18\f \x01(\tR\tinvisible\x12\x18\n" + + "\atooltip\x18\r \x01(\tR\atooltip\x12\x1c\n" + + "\tuploaduri\x18\x0e \x01(\tR\tuploaduri\x12 \n" + + "\vdescription\x18\x0f \x01(\tR\vdescription\x12\x1a\n" + + "\breadonly\x18\x10 \x01(\tR\breadonly\x12\x1a\n" + + "\bendofnow\x18\x11 \x01(\tR\bendofnow\"*\n" + "\x12RestMessageOptions\x12\x14\n" + "\x05table\x18\x01 \x01(\tR\x05table:M\n" + "\x05field\x12\x1d.google.protobuf.FieldOptions\x18\x96\x97\x03 \x01(\v2\x16.aeus.RestFieldOptionsR\x05field:O\n" + diff --git a/pkg/proto/rest/rest.proto b/pkg/proto/rest/rest.proto index 4b031bf..91a2987 100644 --- a/pkg/proto/rest/rest.proto +++ b/pkg/proto/rest/rest.proto @@ -21,6 +21,16 @@ message RestFieldOptions { string format = 5; string props = 6; string rule= 7; + string live = 8; + string dropdown = 9; + string enum = 10; + string match = 11; + string invisible = 12; + string tooltip = 13; + string uploaduri = 14; + string description = 15; + string readonly = 16; + string endofnow = 17; } extend google.protobuf.MessageOptions { diff --git a/tools/gen/internal/generator/generator.go b/tools/gen/internal/generator/generator.go index ed6d8dc..a53b26c 100644 --- a/tools/gen/internal/generator/generator.go +++ b/tools/gen/internal/generator/generator.go @@ -104,7 +104,7 @@ func Geerate(app *types.Applicetion) (err error) { } writer.Reset() } - if err = writeFile(shortName+".go", []byte("package "+shortName)); err != nil { + if err = writeFile(path.Join(shortName, shortName+".go"), []byte("package "+shortName)); err != nil { return } err = scanDir(protoDir, "third_party", func(filename string) error { diff --git a/transport/http/server.go b/transport/http/server.go index ebe520d..750a07b 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -116,10 +116,39 @@ func (s *Server) notFoundHandle(ctx *gin.Context) { ctx.JSON(http.StatusNotFound, newResponse(errors.NotFound, "Not Found", nil)) } +func (s *Server) CORSInterceptor() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request.Method == "OPTIONS" { + c.Writer.Header().Add("Vary", "Origin") + c.Writer.Header().Add("Vary", "Access-Control-Request-Method") + c.Writer.Header().Add("Vary", "Access-Control-Request-Headers") + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,HEAD,PUT,PATCH,POST,DELETE") + h := c.Request.Header.Get("Access-Control-Request-Headers") + if h != "" { + c.Writer.Header().Set("Access-Control-Allow-Headers", h) + } + c.AbortWithStatus(204) + return + } else { + c.Writer.Header().Add("Vary", "Origin") + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + h := c.Request.Header.Get("Access-Control-Request-Headers") + if h != "" { + c.Writer.Header().Set("Access-Control-Allow-Headers", h) + } + } + c.Next() + } +} + func (s *Server) requestInterceptor() gin.HandlerFunc { return func(ginCtx *gin.Context) { ctx := ginCtx.Request.Context() next := func(ctx context.Context) error { + ginCtx.Request = ginCtx.Request.WithContext(ctx) ginCtx.Next() if err := ginCtx.Errors.Last(); err != nil { return err.Err @@ -139,8 +168,8 @@ func (s *Server) requestInterceptor() gin.HandlerFunc { } md.Set(metadata.RequestProtocolKey, Protocol) md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path) + md.Set("method", ginCtx.Request.Method) ctx = metadata.NewContext(ctx, md) - ginCtx.Request = ginCtx.Request.WithContext(ctx) if err := handler(ctx); err != nil { if se, ok := err.(*errors.Error); ok { ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil)) @@ -223,6 +252,9 @@ func New(cbs ...Option) *Server { gin.SetMode(gin.ReleaseMode) } svr.engine = gin.New(svr.opts.ginOptions...) + if svr.opts.enableCORS { + svr.engine.Use(svr.CORSInterceptor()) + } svr.engine.Use(svr.requestInterceptor()) return svr } diff --git a/transport/http/types.go b/transport/http/types.go index b9421f3..c915494 100644 --- a/transport/http/types.go +++ b/transport/http/types.go @@ -25,6 +25,7 @@ type ( logger logger.Logger context context.Context ginOptions []gin.OptionFunc + enableCORS bool } HandleFunc func(ctx *Context) (err error) @@ -46,6 +47,12 @@ func WithNetwork(network string) Option { } } +func WithCORS() Option { + return func(o *options) { + o.enableCORS = true + } +} + func WithAddress(address string) Option { return func(o *options) { o.address = address