diff --git a/README.md b/README.md index 5c64670..1383f94 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,5 @@ +# 快速开始 + diff --git a/go.mod b/go.mod index 204c7c2..58cb5ab 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.23.0 toolchain go1.23.9 require ( + github.com/envoyproxy/protoc-gen-validate v1.2.1 github.com/gin-gonic/gin v1.10.1 + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 github.com/mattn/go-runewidth v0.0.16 github.com/peterh/liner v1.2.2 @@ -15,6 +17,7 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb google.golang.org/grpc v1.72.2 google.golang.org/protobuf v1.36.5 + gorm.io/gorm v1.30.0 ) require ( @@ -33,6 +36,8 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index 8c9ea05..471df25 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= +github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -36,6 +38,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -45,6 +49,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= @@ -177,5 +185,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= +gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/metadata/metadata.go b/metadata/metadata.go index 10d8d51..37faf87 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -2,6 +2,7 @@ package metadata import ( "context" + "iter" "maps" "strings" ) @@ -11,35 +12,78 @@ type metadataKey struct{} // Metadata is our way of representing request headers internally. // They're used at the RPC level and translate back and forth // from Transport headers. -type Metadata map[string]string + +type Metadata struct { + teeReader TeeReader + teeWriter TeeWriter + variables map[string]string +} func canonicalMetadataKey(key string) string { return strings.ToLower(key) } -func (md Metadata) Has(key string) bool { - _, ok := md[canonicalMetadataKey(key)] +// TeeReader sets the tee reader. +func (m *Metadata) TeeReader(r TeeReader) { + m.teeReader = r +} + +// TeeWriter sets the tee writer. +func (m *Metadata) TeeWriter(w TeeWriter) { + m.teeWriter = w +} + +// Has returns true if the metadata contains the given key. +func (m *Metadata) Has(key string) bool { + _, ok := m.Get(key) return ok } -func (md Metadata) Get(key string) (string, bool) { - val, ok := md[canonicalMetadataKey(key)] +// Get returns the first value associated with the given key. +func (m *Metadata) Get(key string) (string, bool) { + key = canonicalMetadataKey(key) + val, ok := m.variables[key] + if !ok && m.teeReader != nil { + if val = m.teeReader.Get(key); val != "" { + ok = true + } + } return val, ok } -func (md Metadata) Set(key, val string) { - md[canonicalMetadataKey(key)] = val +// Set sets a metadata key/value pair. +func (m *Metadata) Set(key, val string) { + if m.variables == nil { + m.variables = make(map[string]string) + } + key = canonicalMetadataKey(key) + m.variables[key] = val + if m.teeWriter != nil { + m.teeWriter.Set(key, val) + } } -func (md Metadata) Delete(key string) { - delete(md, canonicalMetadataKey(key)) +// Delete removes a key from the metadata. +func (m *Metadata) Delete(key string) { + + key = canonicalMetadataKey(key) + if m.variables != nil { + delete(m.variables, key) + } + if m.teeWriter != nil { + m.teeWriter.Set(key, "") + } } -// Copy makes a copy of the metadata. -func Copy(md Metadata) Metadata { - cmd := make(Metadata, len(md)) - maps.Copy(cmd, md) - return cmd +// Keys returns a sequence of the metadata keys. +func (m *Metadata) Keys() iter.Seq[string] { + return func(yield func(string) bool) { + for k := range m.variables { + if !yield(k) { + return + } + } + } } // Delete key from metadata. @@ -49,67 +93,60 @@ func Delete(ctx context.Context, k string) context.Context { // Set add key with val to metadata. func Set(ctx context.Context, k, v string) context.Context { - md, ok := FromContext(ctx) - if !ok { - md = make(Metadata) - } + md := FromContext(ctx) k = canonicalMetadataKey(k) if v == "" { - delete(md, k) + md.Delete(k) } else { - md[k] = v + md.Set(k, v) } return context.WithValue(ctx, metadataKey{}, md) } // Get returns a single value from metadata in the context. func Get(ctx context.Context, key string) (string, bool) { - md, ok := FromContext(ctx) - if !ok { - return "", ok - } + md := FromContext(ctx) key = canonicalMetadataKey(key) - val, ok := md[canonicalMetadataKey(key)] + val, ok := md.Get(key) return val, ok } // FromContext returns metadata from the given context. -func FromContext(ctx context.Context) (Metadata, bool) { - md, ok := ctx.Value(metadataKey{}).(Metadata) +func FromContext(ctx context.Context) *Metadata { + md, ok := ctx.Value(metadataKey{}).(*Metadata) if !ok { - return nil, ok + return New() } - - // capitalise all values - newMD := make(Metadata, len(md)) - for k, v := range md { - newMD[canonicalMetadataKey(k)] = v - } - - return newMD, ok + return md } // NewContext creates a new context with the given metadata. -func NewContext(ctx context.Context, md Metadata) context.Context { +func NewContext(ctx context.Context, md *Metadata) context.Context { return context.WithValue(ctx, metadataKey{}, md) } // MergeContext merges metadata to existing metadata, overwriting if specified. -func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context { +func MergeContext(ctx context.Context, patchMd *Metadata, overwrite bool) context.Context { if ctx == nil { ctx = context.Background() } md, _ := ctx.Value(metadataKey{}).(Metadata) - cmd := make(Metadata, len(md)) - maps.Copy(cmd, md) - for k, v := range patchMd { - if _, ok := cmd[k]; ok && !overwrite { + cmd := New() + maps.Copy(cmd.variables, md.variables) + for k, v := range patchMd.variables { + if _, ok := cmd.variables[k]; ok && !overwrite { // skip } else if v != "" { - cmd[k] = v + cmd.variables[k] = v } else { - delete(cmd, k) + delete(cmd.variables, k) } } return context.WithValue(ctx, metadataKey{}, cmd) } + +func New() *Metadata { + return &Metadata{ + variables: make(map[string]string, 16), + } +} diff --git a/metadata/types.go b/metadata/types.go index a0015c2..033f5a2 100644 --- a/metadata/types.go +++ b/metadata/types.go @@ -5,3 +5,13 @@ const ( RequestPathKey = "X-AEUS-Request-Path" RequestProtocolKey = "X-AEUS-Request-Protocol" ) + +type ( + TeeReader interface { + Get(string) string + } + + TeeWriter interface { + Set(string, string) + } +) diff --git a/middleware/auth/jwt.go b/middleware/auth/jwt.go index 8832b06..582af3d 100644 --- a/middleware/auth/jwt.go +++ b/middleware/auth/jwt.go @@ -1 +1,136 @@ package auth + +import ( + "context" + "strings" + + "git.nobla.cn/golang/aeus/metadata" + "git.nobla.cn/golang/aeus/middleware" + "git.nobla.cn/golang/aeus/pkg/errors" + jwt "github.com/golang-jwt/jwt/v5" +) + +type authKey struct{} + +const ( + + // bearerWord the bearer key word for authorization + bearerWord string = "Bearer" + + // bearerFormat authorization token format + bearerFormat string = "Bearer %s" + + // authorizationKey holds the key used to store the JWT Token in the request tokenHeader. + authorizationKey string = "Authorization" + + // reason holds the error reason. + reason string = "UNAUTHORIZED" +) + +type Option func(*options) + +// Parser is a jwt parser +type options struct { + allows []string + claims func() jwt.Claims +} + +// WithClaims with customer claim +// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems +// If you use it in Client, f only needs to return a single object to provide performance +func WithClaims(f func() jwt.Claims) Option { + return func(o *options) { + o.claims = f + } +} + +// WithAllow with allow path +func WithAllow(path string) Option { + return func(o *options) { + if o.allows == nil { + o.allows = make([]string, 0, 16) + } + o.allows = append(o.allows, path) + } +} + +// isAllowed check if the path is allowed +func isAllowed(uripath string, allows []string) bool { + for _, str := range allows { + n := len(str) + if n == 0 { + continue + } + if n > 1 && str[n-1] == '*' { + if strings.HasPrefix(uripath, str[:n-1]) { + return true + } + } + if str == uripath { + return true + } + } + return true +} + +// JWT auth middleware +func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware { + opts := options{} + for _, cb := range cbs { + cb(&opts) + } + return func(next middleware.Handler) middleware.Handler { + return func(ctx context.Context) (err error) { + md := metadata.FromContext(ctx) + authorizationValue, ok := md.Get(authorizationKey) + if !ok { + return errors.ErrAccessDenied + } + if len(opts.allows) > 0 { + requestPath, ok := md.Get(metadata.RequestPathKey) + if ok { + if isAllowed(requestPath, opts.allows) { + return next(ctx) + } + } + } + if !strings.HasPrefix(authorizationValue, bearerWord) { + return errors.ErrAccessDenied + } + var ( + ti *jwt.Token + ) + authorizationToken := strings.TrimSpace(strings.TrimPrefix(authorizationValue, bearerWord)) + if opts.claims != nil { + ti, err = jwt.ParseWithClaims(authorizationToken, opts.claims(), keyFunc) + } else { + ti, err = jwt.Parse(authorizationToken, keyFunc) + } + if err != nil { + if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) { + return errors.ErrAccessDenied + } + if errors.Is(err, jwt.ErrTokenNotValidYet) || errors.Is(err, jwt.ErrTokenExpired) { + return errors.ErrTokenExpired + } + return errors.ErrPermissionDenied + } + if !ti.Valid { + return errors.ErrPermissionDenied + } + ctx = NewContext(ctx, ti.Claims) + return next(ctx) + } + } +} + +// NewContext put auth info into context +func NewContext(ctx context.Context, info jwt.Claims) context.Context { + return context.WithValue(ctx, authKey{}, info) +} + +// FromContext extract auth info from context +func FromContext(ctx context.Context) (token jwt.Claims, ok bool) { + token, ok = ctx.Value(authKey{}).(jwt.Claims) + return +} diff --git a/pkg/bs/safe.go b/pkg/bs/safe.go new file mode 100644 index 0000000..1d8415a --- /dev/null +++ b/pkg/bs/safe.go @@ -0,0 +1,12 @@ +//go:build appengine +// +build appengine + +package bs + +func BytesToString(b []byte) string { + return string(b) +} + +func StringToBytes(s string) []byte { + return []byte(s) +} diff --git a/pkg/bs/unsafe.go b/pkg/bs/unsafe.go new file mode 100644 index 0000000..5a724aa --- /dev/null +++ b/pkg/bs/unsafe.go @@ -0,0 +1,23 @@ +//go:build !appengine +// +build !appengine + +package bs + +import ( + "unsafe" +) + +// BytesToString converts byte slice to string. +func BytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// StringToBytes converts string to byte slice. +func StringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/pkg/errors/const.go b/pkg/errors/const.go index 7e0ef19..18be489 100644 --- a/pkg/errors/const.go +++ b/pkg/errors/const.go @@ -1,24 +1,32 @@ package errors const ( - OK = 0 //success - Exit = 1000 //normal exit - Invalid = 1001 //payload invalid - Timeout = 1002 //timeout - Expired = 1003 //expired - AccessDenied = 4005 //access denied - PermissionDenied = 4003 //permission denied - NotFound = 4004 //not found - Unavailable = 5000 //service unavailable + OK = 0 //success + Exit = 1000 //normal exit + Invalid = 1001 //payload invalid + Exists = 1002 //already exists + Unavailable = 1003 //service unavailable + Timeout = 2001 //timeout + Expired = 2002 //expired + TokenExpired = 4002 //token expired + NotFound = 4004 //not found + PermissionDenied = 4003 //permission denied + AccessDenied = 4005 //access denied + NetworkUnreachable = 5001 //network unreachable + ConnectionRefused = 5002 //connection refused ) var ( - ErrExit = New(Exit, "normal exit") - ErrTimeout = New(Timeout, "timeout") - ErrExpired = New(Expired, "expired") - ErrValidate = New(Invalid, "invalid payload") - ErrNotFound = New(NotFound, "not found") - ErrAccessDenied = New(AccessDenied, "access denied") - ErrPermissionDenied = New(PermissionDenied, "permission denied") - ErrUnavailable = New(Unavailable, "service unavailable") + ErrExit = New(Exit, "normal exit") + ErrTimeout = New(Timeout, "timeout") + ErrExists = New(Exists, "already exists") + ErrExpired = New(Expired, "expired") + ErrInvalid = New(Invalid, "invalid payload") + ErrNotFound = New(NotFound, "not found") + ErrAccessDenied = New(AccessDenied, "access denied") + ErrPermissionDenied = New(PermissionDenied, "permission denied") + ErrTokenExpired = New(TokenExpired, "token expired") + ErrUnavailable = New(Unavailable, "service unavailable") + ErrNetworkUnreachable = New(NetworkUnreachable, "network unreachable") + ErrConnectionRefused = New(ConnectionRefused, "connection refused") ) diff --git a/pkg/errors/error.go b/pkg/errors/error.go index bfa01b3..37c5392 100644 --- a/pkg/errors/error.go +++ b/pkg/errors/error.go @@ -14,8 +14,15 @@ func (e *Error) Error() string { return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message) } -func Format(code int, msg string, args ...any) Error { - return Error{ +func Warp(code int, err error) error { + return &Error{ + Code: code, + Message: err.Error(), + } +} + +func Format(code int, msg string, args ...any) *Error { + return &Error{ Code: code, Message: fmt.Sprintf(msg, args...), } diff --git a/pkg/proto/rest/types.go b/pkg/proto/rest/types.go new file mode 100644 index 0000000..0062e0c --- /dev/null +++ b/pkg/proto/rest/types.go @@ -0,0 +1 @@ +package rest diff --git a/pkg/reflection/reflection.go b/pkg/reflection/reflection.go new file mode 100644 index 0000000..6c4a6ab --- /dev/null +++ b/pkg/reflection/reflection.go @@ -0,0 +1,350 @@ +package reflection + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +var ( + allowTags = []string{"json", "yaml", "xml", "name"} +) + +var ( + ErrValueAssociated = errors.New("value cannot be associated") +) + +func findField(v reflect.Value, field string) reflect.Value { + var ( + pos int + tagValue string + refType reflect.Type + fieldType reflect.StructField + ) + refType = v.Type() + for i := range refType.NumField() { + fieldType = refType.Field(i) + for _, tagName := range allowTags { + tagValue = fieldType.Tag.Get(tagName) + if tagValue == "" { + continue + } + if pos = strings.IndexByte(tagValue, ','); pos != -1 { + tagValue = tagValue[:pos] + } + if tagValue == field { + return v.Field(i) + } + } + } + return v.FieldByName(field) +} + +func safeAssignment(variable reflect.Value, value any) (err error) { + var ( + n int64 + un uint64 + fn float64 + kind reflect.Kind + ) + rv := reflect.ValueOf(value) + kind = variable.Kind() + if kind != reflect.Slice && kind != reflect.Array && kind != reflect.Map && kind == rv.Kind() { + variable.Set(rv) + return + } + switch kind { + case reflect.Bool: + switch rv.Kind() { + case reflect.Bool: + variable.SetBool(rv.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if rv.Int() != 0 { + variable.SetBool(true) + } else { + variable.SetBool(false) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if rv.Uint() != 0 { + variable.SetBool(true) + } else { + variable.SetBool(false) + } + case reflect.Float32, reflect.Float64: + if rv.Float() != 0 { + variable.SetBool(true) + } else { + variable.SetBool(false) + } + case reflect.String: + var tv bool + tv, err = strconv.ParseBool(rv.String()) + if err == nil { + variable.SetBool(tv) + } + default: + err = fmt.Errorf("boolean value can not assign %s", rv.Kind()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + variable.SetInt(1) + } else { + variable.SetInt(0) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + variable.SetInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + variable.SetInt(int64(rv.Uint())) + case reflect.Float32, reflect.Float64: + variable.SetInt(int64(rv.Float())) + case reflect.String: + if n, err = strconv.ParseInt(rv.String(), 10, 64); err == nil { + variable.SetInt(n) + } + default: + err = fmt.Errorf("integer value can not assign %s", rv.Kind()) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + variable.SetUint(1) + } else { + variable.SetUint(0) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + variable.SetUint(uint64(rv.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + variable.SetUint(rv.Uint()) + case reflect.Float32, reflect.Float64: + variable.SetUint(uint64(rv.Float())) + case reflect.String: + if un, err = strconv.ParseUint(rv.String(), 10, 64); err == nil { + variable.SetUint(un) + } + default: + err = fmt.Errorf("unsigned integer value can not assign %s", rv.Kind()) + } + case reflect.Float32, reflect.Float64: + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + variable.SetFloat(1) + } else { + variable.SetFloat(0) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + variable.SetFloat(float64(rv.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + variable.SetFloat(float64(rv.Uint())) + case reflect.Float32, reflect.Float64: + variable.SetFloat(rv.Float()) + case reflect.String: + if fn, err = strconv.ParseFloat(rv.String(), 64); err == nil { + variable.SetFloat(fn) + } + default: + err = fmt.Errorf("decimal value can not assign %s", rv.Kind()) + } + case reflect.String: + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + variable.SetString("true") + } else { + variable.SetString("false") + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + variable.SetString(strconv.FormatInt(rv.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + variable.SetString(strconv.FormatUint(rv.Uint(), 10)) + case reflect.Float32, reflect.Float64: + variable.SetString(strconv.FormatFloat(rv.Float(), 'f', -1, 64)) + case reflect.String: + variable.SetString(rv.String()) + default: + variable.SetString(fmt.Sprint(value)) + } + case reflect.Interface: + variable.Set(rv) + default: + err = fmt.Errorf("unsupported kind %s", kind) + } + return +} + +func Set(hacky any, field string, value any) (err error) { + var ( + n int + refField reflect.Value + ) + refVal := reflect.ValueOf(hacky) + if refVal.Kind() == reflect.Ptr { + refVal = reflect.Indirect(refVal) + } + if refVal.Kind() != reflect.Struct { + return fmt.Errorf("%s kind is %v", refVal.Type().String(), refField.Kind()) + } + refField = findField(refVal, field) + if !refField.IsValid() { + return fmt.Errorf("%s field `%s` not found", refVal.Type(), field) + } + rv := reflect.ValueOf(value) + fieldKind := refField.Kind() + if fieldKind != reflect.Slice && fieldKind != reflect.Array && fieldKind != reflect.Map && fieldKind == rv.Kind() { + refField.Set(rv) + return + } + switch fieldKind { + case reflect.Struct: + if rv.Kind() != reflect.Map { + return ErrValueAssociated + } + keys := rv.MapKeys() + subVal := reflect.New(refField.Type()) + for _, key := range keys { + pv := rv.MapIndex(key) + if key.Kind() == reflect.String { + if err = Set(subVal.Interface(), key.String(), pv.Interface()); err != nil { + return err + } + } + } + refField.Set(subVal.Elem()) + case reflect.Ptr: + elemType := refField.Type() + if elemType.Elem().Kind() != reflect.Struct { + return ErrValueAssociated + } else { + if rv.Kind() != reflect.Map { + return ErrValueAssociated + } + keys := rv.MapKeys() + subVal := reflect.New(elemType.Elem()) + for _, key := range keys { + pv := rv.MapIndex(key) + if key.Kind() == reflect.String { + if err = Set(subVal.Interface(), key.String(), pv.Interface()); err != nil { + return err + } + } + } + refField.Set(subVal) + } + case reflect.Map: + if rv.Kind() != reflect.Map { + return ErrValueAssociated + } + targetValue := reflect.MakeMap(refField.Type()) + keys := rv.MapKeys() + for _, key := range keys { + pv := rv.MapIndex(key) + kVal := reflect.New(refField.Type().Key()) + eVal := reflect.New(refField.Type().Elem()) + if err = safeAssignment(kVal.Elem(), key.Interface()); err != nil { + return ErrValueAssociated + } + if refField.Type().Elem().Kind() == reflect.Struct { + if pv.Elem().Kind() != reflect.Map { + return ErrValueAssociated + } + subKeys := pv.Elem().MapKeys() + for _, subKey := range subKeys { + subVal := pv.Elem().MapIndex(subKey) + if subKey.Kind() == reflect.String { + if err = Set(eVal.Interface(), subKey.String(), subVal.Interface()); err != nil { + return err + } + } + } + targetValue.SetMapIndex(kVal.Elem(), eVal.Elem()) + } else { + if err = safeAssignment(eVal.Elem(), pv.Interface()); err != nil { + return ErrValueAssociated + } + targetValue.SetMapIndex(kVal.Elem(), eVal.Elem()) + } + } + refField.Set(targetValue) + case reflect.Array, reflect.Slice: + n = 0 + innerType := refField.Type().Elem() + if rv.Kind() == reflect.Array || rv.Kind() == reflect.Slice { + if innerType.Kind() == reflect.Struct { + sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len()) + for i := 0; i < rv.Len(); i++ { + srcVal := rv.Index(i) + if srcVal.Kind() != reflect.Map { + return ErrValueAssociated + } + dstVal := reflect.New(innerType) + keys := srcVal.MapKeys() + for _, key := range keys { + kv := srcVal.MapIndex(key) + if key.Kind() == reflect.String { + if err = Set(dstVal.Interface(), key.String(), kv.Interface()); err != nil { + return + } + } + } + sliceVar.Index(n).Set(dstVal.Elem()) + n++ + } + refField.Set(sliceVar.Slice(0, n)) + } else if innerType.Kind() == reflect.Ptr { + sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len()) + for i := 0; i < rv.Len(); i++ { + srcVal := rv.Index(i) + if srcVal.Kind() != reflect.Map { + return ErrValueAssociated + } + dstVal := reflect.New(innerType.Elem()) + keys := srcVal.MapKeys() + for _, key := range keys { + kv := srcVal.MapIndex(key) + if key.Kind() == reflect.String { + if err = Set(dstVal.Interface(), key.String(), kv.Interface()); err != nil { + return + } + } + } + sliceVar.Index(n).Set(dstVal) + n++ + } + refField.Set(sliceVar.Slice(0, n)) + } else { + sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len()) + for i := range rv.Len() { + srcVal := rv.Index(i) + dstVal := reflect.New(innerType).Elem() + if err = safeAssignment(dstVal, srcVal.Interface()); err != nil { + return + } + sliceVar.Index(n).Set(dstVal) + n++ + } + refField.Set(sliceVar.Slice(0, n)) + } + } + default: + err = safeAssignment(refField, value) + } + return +} + +func Assign(variable reflect.Value, value any) (err error) { + return safeAssignment(variable, value) +} + +func Setter[T string | int | int64 | float64 | any](hacky any, variables map[string]T) (err error) { + for k, v := range variables { + if err = Set(hacky, k, v); err != nil { + return err + } + } + return +} diff --git a/transport/cli/server.go b/transport/cli/server.go index 90318a8..2a1926d 100644 --- a/transport/cli/server.go +++ b/transport/cli/server.go @@ -12,6 +12,8 @@ import ( "sync/atomic" "time" + "git.nobla.cn/golang/aeus/metadata" + "git.nobla.cn/golang/aeus/middleware" "git.nobla.cn/golang/aeus/pkg/errors" "git.nobla.cn/golang/aeus/pkg/logger" netutil "git.nobla.cn/golang/aeus/pkg/net" @@ -27,6 +29,11 @@ type Server struct { ctxMap sync.Map uri *url.URL exitFlag int32 + middleware []middleware.Middleware +} + +func (svr *Server) Use(middlewares ...middleware.Middleware) { + svr.middleware = append(svr.middleware, middlewares...) } func (svr *Server) Handle(pathname string, desc string, cb HandleFunc) { @@ -90,13 +97,13 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) { } if r, args, err = s.router.Lookup(tokens); err != nil { if errors.Is(err, ErrNotFound) { - err = ctx.Error(errNotFound, fmt.Sprintf("Command %s not found", cmd)) + err = ctx.Error(errors.NotFound, fmt.Sprintf("Command %s not found", cmd)) } else { - err = ctx.Error(errExecuteFailed, err.Error()) + err = ctx.Error(errors.Unavailable, err.Error()) } } else { if len(r.params) > len(args) { - err = ctx.Error(errExecuteFailed, r.Usage()) + err = ctx.Error(errors.Unavailable, r.Usage()) return } if len(r.params) > 0 { @@ -107,7 +114,17 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) { } ctx.setArgs(args) ctx.setParam(params) - err = r.command.Handle(ctx) + h := func(c context.Context) error { + return r.command.Handle(ctx) + } + next := middleware.Chain(s.middleware...)(h) + md := metadata.FromContext(ctx.ctx) + md.Set(metadata.RequestPathKey, r.command.Path) + md.Set(metadata.RequestProtocolKey, Protocol) + md.TeeReader(&cliMetadataReader{ctx: ctx}) + md.TeeWriter(&cliMetadataWriter{ctx: ctx}) + ctx.ctx = metadata.NewContext(ctx.ctx, md) + err = next(ctx.ctx) } return } diff --git a/transport/cli/types.go b/transport/cli/types.go index a295bf2..cf335fa 100644 --- a/transport/cli/types.go +++ b/transport/cli/types.go @@ -16,8 +16,7 @@ var ( ) const ( - errNotFound = 4004 - errExecuteFailed = 4005 + Protocol = "cli" ) var ( @@ -84,6 +83,13 @@ type ( ServerTime time.Time `json:"server_time"` RemoteAddr string `json:"remote_addr"` } + + cliMetadataReader struct { + ctx *Context + } + cliMetadataWriter struct { + ctx *Context + } ) func WithAddress(addr string) Option { @@ -109,3 +115,11 @@ func WithContext(ctx context.Context) Option { o.context = ctx } } + +func (r *cliMetadataReader) Get(key string) string { + return r.ctx.Param(key) +} + +func (r *cliMetadataWriter) Set(key string, value string) { + r.ctx.SetValue(key, value) +} diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 7ffbba9..7610f2a 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -41,24 +41,59 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { return } h := middleware.Chain(s.middlewares...)(next) - md := make(metadata.Metadata) + md := metadata.FromContext(ctx) + grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx) + if ok { + md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata}) + } + grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx) + if !ok { + grpcOutgoingMetadata = make(grpcmd.MD) + } + md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata}) if !md.Has(metadata.RequestIDKey) { md.Set(metadata.RequestIDKey, uuid.New().String()) } md.Set(metadata.RequestPathKey, info.FullMethod) md.Set(metadata.RequestProtocolKey, Protocol) - if gmd, ok := grpcmd.FromIncomingContext(ctx); ok { - for k, v := range gmd { - if len(v) > 0 { - md.Set(k, v[0]) - } - } - } - ctx = metadata.MergeContext(ctx, md, true) + ctx = metadata.NewContext(ctx, md) ctx = context.WithValue(ctx, requestValueContextKey{}, req) err = h(ctx) - // grpcmd.AppendToOutgoingContext(ctx, grpcmd.New(metadata.FromContext(ctx))) - // grpc.SetHeader() + if grpcOutgoingMetadata.Len() > 0 { + grpc.SetHeader(ctx, grpcOutgoingMetadata) + } + return + } +} + +func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + ctx := ss.Context() + next := func(ctx context.Context) (err error) { + err = handler(srv, ss) + return + } + h := middleware.Chain(s.middlewares...)(next) + md := metadata.FromContext(ctx) + grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx) + if ok { + md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata}) + } + grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx) + if !ok { + grpcOutgoingMetadata = make(grpcmd.MD) + } + md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata}) + if !md.Has(metadata.RequestIDKey) { + md.Set(metadata.RequestIDKey, uuid.New().String()) + } + md.Set(metadata.RequestPathKey, info.FullMethod) + md.Set(metadata.RequestProtocolKey, Protocol) + ctx = metadata.NewContext(ctx, md) + err = h(ctx) + if grpcOutgoingMetadata.Len() > 0 { + grpc.SetHeader(ctx, grpcOutgoingMetadata) + } return } } @@ -113,6 +148,7 @@ func New(cbs ...Option) *Server { cb(svr.opts) } svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor())) + svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainStreamInterceptor(svr.streamServerInterceptor())) svr.serve = grpc.NewServer(svr.opts.grpcOpts...) return svr } diff --git a/transport/grpc/types.go b/transport/grpc/types.go index c3ee296..3aa4f1e 100644 --- a/transport/grpc/types.go +++ b/transport/grpc/types.go @@ -7,6 +7,7 @@ import ( "git.nobla.cn/golang/aeus/pkg/logger" "git.nobla.cn/golang/aeus/registry" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) const ( @@ -32,6 +33,14 @@ type ( ClientOption func(*clientOptions) + grpcMetadataReader struct { + md metadata.MD + } + + grpcMetadataWriter struct { + md metadata.MD + } + requestValueContextKey struct{} ) @@ -83,3 +92,21 @@ func GetRequestValueFromContext(ctx context.Context) any { } return ctx.Value(requestValueContextKey{}) } + +func (m *grpcMetadataReader) Get(key string) string { + if m.md == nil { + return "" + } + vs := m.md.Get(key) + if len(vs) > 0 { + return vs[0] + } + return "" +} + +func (m *grpcMetadataWriter) Set(key string, value string) { + if m.md == nil { + return + } + m.md.Set(key, value) +} diff --git a/transport/http/server.go b/transport/http/server.go index b8b8be7..49d0ba1 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -98,18 +98,20 @@ func (s *Server) requestInterceptor() gin.HandlerFunc { return nil } handler := middleware.Chain(s.middlewares...)(next) - md := make(metadata.Metadata) - for k, v := range ginCtx.Request.Header { - if len(v) > 0 { - md.Set(k, v[0]) - } - } + md := metadata.FromContext(ctx) + md.TeeReader(&httpMetadataReader{ + hd: ginCtx.Request.Header, + }) + md.TeeWriter(&httpMetadataWriter{ + w: ginCtx.Writer, + }) if !md.Has(metadata.RequestIDKey) { md.Set(metadata.RequestIDKey, uuid.New().String()) } md.Set(metadata.RequestProtocolKey, Protocol) md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path) - ctx = metadata.MergeContext(ctx, md, true) + ctx = metadata.NewContext(ctx, md) + ginCtx.Request = ginCtx.Request.WithContext(ctx) if err := handler(ctx); err != nil { if se, ok := err.(*errors.Error); ok { ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil)) diff --git a/transport/http/types.go b/transport/http/types.go index 8e4fd8f..b9421f3 100644 --- a/transport/http/types.go +++ b/transport/http/types.go @@ -30,6 +30,14 @@ type ( HandleFunc func(ctx *Context) (err error) Middleware func(http.Handler) http.Handler + + httpMetadataReader struct { + hd http.Header + } + + httpMetadataWriter struct { + w http.ResponseWriter + } ) func WithNetwork(network string) Option { @@ -85,3 +93,17 @@ func WithGinOptions(opts ...gin.OptionFunc) Option { o.ginOptions = opts } } + +func (m *httpMetadataReader) Get(key string) string { + if m.hd == nil { + return "" + } + return m.hd.Get(key) +} + +func (m *httpMetadataWriter) Set(key string, value string) { + if m.w == nil { + return + } + m.w.Header().Set(key, value) +}