upgrade metadata

This commit is contained in:
Yavolte 2025-06-10 11:28:21 +08:00
parent c20c14227d
commit a694e40b13
18 changed files with 805 additions and 87 deletions

View File

@ -4,3 +4,5 @@
# 快速开始

5
go.mod
View File

@ -5,7 +5,9 @@ go 1.23.0
toolchain go1.23.9
require (
github.com/envoyproxy/protoc-gen-validate v1.2.1
github.com/gin-gonic/gin v1.10.1
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
github.com/mattn/go-runewidth v0.0.16
github.com/peterh/liner v1.2.2
@ -15,6 +17,7 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb
google.golang.org/grpc v1.72.2
google.golang.org/protobuf v1.36.5
gorm.io/gorm v1.30.0
)
require (
@ -33,6 +36,8 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect

10
go.sum
View File

@ -13,6 +13,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8=
github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@ -36,6 +38,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@ -45,6 +49,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@ -177,5 +185,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@ -2,6 +2,7 @@ package metadata
import (
"context"
"iter"
"maps"
"strings"
)
@ -11,35 +12,78 @@ type metadataKey struct{}
// Metadata is our way of representing request headers internally.
// They're used at the RPC level and translate back and forth
// from Transport headers.
type Metadata map[string]string
type Metadata struct {
teeReader TeeReader
teeWriter TeeWriter
variables map[string]string
}
func canonicalMetadataKey(key string) string {
return strings.ToLower(key)
}
func (md Metadata) Has(key string) bool {
_, ok := md[canonicalMetadataKey(key)]
// TeeReader sets the tee reader.
func (m *Metadata) TeeReader(r TeeReader) {
m.teeReader = r
}
// TeeWriter sets the tee writer.
func (m *Metadata) TeeWriter(w TeeWriter) {
m.teeWriter = w
}
// Has returns true if the metadata contains the given key.
func (m *Metadata) Has(key string) bool {
_, ok := m.Get(key)
return ok
}
func (md Metadata) Get(key string) (string, bool) {
val, ok := md[canonicalMetadataKey(key)]
// Get returns the first value associated with the given key.
func (m *Metadata) Get(key string) (string, bool) {
key = canonicalMetadataKey(key)
val, ok := m.variables[key]
if !ok && m.teeReader != nil {
if val = m.teeReader.Get(key); val != "" {
ok = true
}
}
return val, ok
}
func (md Metadata) Set(key, val string) {
md[canonicalMetadataKey(key)] = val
// Set sets a metadata key/value pair.
func (m *Metadata) Set(key, val string) {
if m.variables == nil {
m.variables = make(map[string]string)
}
key = canonicalMetadataKey(key)
m.variables[key] = val
if m.teeWriter != nil {
m.teeWriter.Set(key, val)
}
}
func (md Metadata) Delete(key string) {
delete(md, canonicalMetadataKey(key))
// Delete removes a key from the metadata.
func (m *Metadata) Delete(key string) {
key = canonicalMetadataKey(key)
if m.variables != nil {
delete(m.variables, key)
}
if m.teeWriter != nil {
m.teeWriter.Set(key, "")
}
}
// Copy makes a copy of the metadata.
func Copy(md Metadata) Metadata {
cmd := make(Metadata, len(md))
maps.Copy(cmd, md)
return cmd
// Keys returns a sequence of the metadata keys.
func (m *Metadata) Keys() iter.Seq[string] {
return func(yield func(string) bool) {
for k := range m.variables {
if !yield(k) {
return
}
}
}
}
// Delete key from metadata.
@ -49,67 +93,60 @@ func Delete(ctx context.Context, k string) context.Context {
// Set add key with val to metadata.
func Set(ctx context.Context, k, v string) context.Context {
md, ok := FromContext(ctx)
if !ok {
md = make(Metadata)
}
md := FromContext(ctx)
k = canonicalMetadataKey(k)
if v == "" {
delete(md, k)
md.Delete(k)
} else {
md[k] = v
md.Set(k, v)
}
return context.WithValue(ctx, metadataKey{}, md)
}
// Get returns a single value from metadata in the context.
func Get(ctx context.Context, key string) (string, bool) {
md, ok := FromContext(ctx)
if !ok {
return "", ok
}
md := FromContext(ctx)
key = canonicalMetadataKey(key)
val, ok := md[canonicalMetadataKey(key)]
val, ok := md.Get(key)
return val, ok
}
// FromContext returns metadata from the given context.
func FromContext(ctx context.Context) (Metadata, bool) {
md, ok := ctx.Value(metadataKey{}).(Metadata)
func FromContext(ctx context.Context) *Metadata {
md, ok := ctx.Value(metadataKey{}).(*Metadata)
if !ok {
return nil, ok
return New()
}
// capitalise all values
newMD := make(Metadata, len(md))
for k, v := range md {
newMD[canonicalMetadataKey(k)] = v
}
return newMD, ok
return md
}
// NewContext creates a new context with the given metadata.
func NewContext(ctx context.Context, md Metadata) context.Context {
func NewContext(ctx context.Context, md *Metadata) context.Context {
return context.WithValue(ctx, metadataKey{}, md)
}
// MergeContext merges metadata to existing metadata, overwriting if specified.
func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context {
func MergeContext(ctx context.Context, patchMd *Metadata, overwrite bool) context.Context {
if ctx == nil {
ctx = context.Background()
}
md, _ := ctx.Value(metadataKey{}).(Metadata)
cmd := make(Metadata, len(md))
maps.Copy(cmd, md)
for k, v := range patchMd {
if _, ok := cmd[k]; ok && !overwrite {
cmd := New()
maps.Copy(cmd.variables, md.variables)
for k, v := range patchMd.variables {
if _, ok := cmd.variables[k]; ok && !overwrite {
// skip
} else if v != "" {
cmd[k] = v
cmd.variables[k] = v
} else {
delete(cmd, k)
delete(cmd.variables, k)
}
}
return context.WithValue(ctx, metadataKey{}, cmd)
}
func New() *Metadata {
return &Metadata{
variables: make(map[string]string, 16),
}
}

View File

@ -5,3 +5,13 @@ const (
RequestPathKey = "X-AEUS-Request-Path"
RequestProtocolKey = "X-AEUS-Request-Protocol"
)
type (
TeeReader interface {
Get(string) string
}
TeeWriter interface {
Set(string, string)
}
)

View File

@ -1 +1,136 @@
package auth
import (
"context"
"strings"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/errors"
jwt "github.com/golang-jwt/jwt/v5"
)
type authKey struct{}
const (
// bearerWord the bearer key word for authorization
bearerWord string = "Bearer"
// bearerFormat authorization token format
bearerFormat string = "Bearer %s"
// authorizationKey holds the key used to store the JWT Token in the request tokenHeader.
authorizationKey string = "Authorization"
// reason holds the error reason.
reason string = "UNAUTHORIZED"
)
type Option func(*options)
// Parser is a jwt parser
type options struct {
allows []string
claims func() jwt.Claims
}
// WithClaims with customer claim
// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems
// If you use it in Client, f only needs to return a single object to provide performance
func WithClaims(f func() jwt.Claims) Option {
return func(o *options) {
o.claims = f
}
}
// WithAllow with allow path
func WithAllow(path string) Option {
return func(o *options) {
if o.allows == nil {
o.allows = make([]string, 0, 16)
}
o.allows = append(o.allows, path)
}
}
// isAllowed check if the path is allowed
func isAllowed(uripath string, allows []string) bool {
for _, str := range allows {
n := len(str)
if n == 0 {
continue
}
if n > 1 && str[n-1] == '*' {
if strings.HasPrefix(uripath, str[:n-1]) {
return true
}
}
if str == uripath {
return true
}
}
return true
}
// JWT auth middleware
func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware {
opts := options{}
for _, cb := range cbs {
cb(&opts)
}
return func(next middleware.Handler) middleware.Handler {
return func(ctx context.Context) (err error) {
md := metadata.FromContext(ctx)
authorizationValue, ok := md.Get(authorizationKey)
if !ok {
return errors.ErrAccessDenied
}
if len(opts.allows) > 0 {
requestPath, ok := md.Get(metadata.RequestPathKey)
if ok {
if isAllowed(requestPath, opts.allows) {
return next(ctx)
}
}
}
if !strings.HasPrefix(authorizationValue, bearerWord) {
return errors.ErrAccessDenied
}
var (
ti *jwt.Token
)
authorizationToken := strings.TrimSpace(strings.TrimPrefix(authorizationValue, bearerWord))
if opts.claims != nil {
ti, err = jwt.ParseWithClaims(authorizationToken, opts.claims(), keyFunc)
} else {
ti, err = jwt.Parse(authorizationToken, keyFunc)
}
if err != nil {
if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) {
return errors.ErrAccessDenied
}
if errors.Is(err, jwt.ErrTokenNotValidYet) || errors.Is(err, jwt.ErrTokenExpired) {
return errors.ErrTokenExpired
}
return errors.ErrPermissionDenied
}
if !ti.Valid {
return errors.ErrPermissionDenied
}
ctx = NewContext(ctx, ti.Claims)
return next(ctx)
}
}
}
// NewContext put auth info into context
func NewContext(ctx context.Context, info jwt.Claims) context.Context {
return context.WithValue(ctx, authKey{}, info)
}
// FromContext extract auth info from context
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) {
token, ok = ctx.Value(authKey{}).(jwt.Claims)
return
}

12
pkg/bs/safe.go 100644
View File

@ -0,0 +1,12 @@
//go:build appengine
// +build appengine
package bs
func BytesToString(b []byte) string {
return string(b)
}
func StringToBytes(s string) []byte {
return []byte(s)
}

23
pkg/bs/unsafe.go 100644
View File

@ -0,0 +1,23 @@
//go:build !appengine
// +build !appengine
package bs
import (
"unsafe"
)
// BytesToString converts byte slice to string.
func BytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// StringToBytes converts string to byte slice.
func StringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

View File

@ -1,24 +1,32 @@
package errors
const (
OK = 0 //success
Exit = 1000 //normal exit
Invalid = 1001 //payload invalid
Timeout = 1002 //timeout
Expired = 1003 //expired
AccessDenied = 4005 //access denied
PermissionDenied = 4003 //permission denied
NotFound = 4004 //not found
Unavailable = 5000 //service unavailable
OK = 0 //success
Exit = 1000 //normal exit
Invalid = 1001 //payload invalid
Exists = 1002 //already exists
Unavailable = 1003 //service unavailable
Timeout = 2001 //timeout
Expired = 2002 //expired
TokenExpired = 4002 //token expired
NotFound = 4004 //not found
PermissionDenied = 4003 //permission denied
AccessDenied = 4005 //access denied
NetworkUnreachable = 5001 //network unreachable
ConnectionRefused = 5002 //connection refused
)
var (
ErrExit = New(Exit, "normal exit")
ErrTimeout = New(Timeout, "timeout")
ErrExpired = New(Expired, "expired")
ErrValidate = New(Invalid, "invalid payload")
ErrNotFound = New(NotFound, "not found")
ErrAccessDenied = New(AccessDenied, "access denied")
ErrPermissionDenied = New(PermissionDenied, "permission denied")
ErrUnavailable = New(Unavailable, "service unavailable")
ErrExit = New(Exit, "normal exit")
ErrTimeout = New(Timeout, "timeout")
ErrExists = New(Exists, "already exists")
ErrExpired = New(Expired, "expired")
ErrInvalid = New(Invalid, "invalid payload")
ErrNotFound = New(NotFound, "not found")
ErrAccessDenied = New(AccessDenied, "access denied")
ErrPermissionDenied = New(PermissionDenied, "permission denied")
ErrTokenExpired = New(TokenExpired, "token expired")
ErrUnavailable = New(Unavailable, "service unavailable")
ErrNetworkUnreachable = New(NetworkUnreachable, "network unreachable")
ErrConnectionRefused = New(ConnectionRefused, "connection refused")
)

View File

@ -14,8 +14,15 @@ func (e *Error) Error() string {
return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
}
func Format(code int, msg string, args ...any) Error {
return Error{
func Warp(code int, err error) error {
return &Error{
Code: code,
Message: err.Error(),
}
}
func Format(code int, msg string, args ...any) *Error {
return &Error{
Code: code,
Message: fmt.Sprintf(msg, args...),
}

View File

@ -0,0 +1 @@
package rest

View File

@ -0,0 +1,350 @@
package reflection
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
)
var (
allowTags = []string{"json", "yaml", "xml", "name"}
)
var (
ErrValueAssociated = errors.New("value cannot be associated")
)
func findField(v reflect.Value, field string) reflect.Value {
var (
pos int
tagValue string
refType reflect.Type
fieldType reflect.StructField
)
refType = v.Type()
for i := range refType.NumField() {
fieldType = refType.Field(i)
for _, tagName := range allowTags {
tagValue = fieldType.Tag.Get(tagName)
if tagValue == "" {
continue
}
if pos = strings.IndexByte(tagValue, ','); pos != -1 {
tagValue = tagValue[:pos]
}
if tagValue == field {
return v.Field(i)
}
}
}
return v.FieldByName(field)
}
func safeAssignment(variable reflect.Value, value any) (err error) {
var (
n int64
un uint64
fn float64
kind reflect.Kind
)
rv := reflect.ValueOf(value)
kind = variable.Kind()
if kind != reflect.Slice && kind != reflect.Array && kind != reflect.Map && kind == rv.Kind() {
variable.Set(rv)
return
}
switch kind {
case reflect.Bool:
switch rv.Kind() {
case reflect.Bool:
variable.SetBool(rv.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if rv.Int() != 0 {
variable.SetBool(true)
} else {
variable.SetBool(false)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if rv.Uint() != 0 {
variable.SetBool(true)
} else {
variable.SetBool(false)
}
case reflect.Float32, reflect.Float64:
if rv.Float() != 0 {
variable.SetBool(true)
} else {
variable.SetBool(false)
}
case reflect.String:
var tv bool
tv, err = strconv.ParseBool(rv.String())
if err == nil {
variable.SetBool(tv)
}
default:
err = fmt.Errorf("boolean value can not assign %s", rv.Kind())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch rv.Kind() {
case reflect.Bool:
if rv.Bool() {
variable.SetInt(1)
} else {
variable.SetInt(0)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
variable.SetInt(rv.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
variable.SetInt(int64(rv.Uint()))
case reflect.Float32, reflect.Float64:
variable.SetInt(int64(rv.Float()))
case reflect.String:
if n, err = strconv.ParseInt(rv.String(), 10, 64); err == nil {
variable.SetInt(n)
}
default:
err = fmt.Errorf("integer value can not assign %s", rv.Kind())
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
switch rv.Kind() {
case reflect.Bool:
if rv.Bool() {
variable.SetUint(1)
} else {
variable.SetUint(0)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
variable.SetUint(uint64(rv.Int()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
variable.SetUint(rv.Uint())
case reflect.Float32, reflect.Float64:
variable.SetUint(uint64(rv.Float()))
case reflect.String:
if un, err = strconv.ParseUint(rv.String(), 10, 64); err == nil {
variable.SetUint(un)
}
default:
err = fmt.Errorf("unsigned integer value can not assign %s", rv.Kind())
}
case reflect.Float32, reflect.Float64:
switch rv.Kind() {
case reflect.Bool:
if rv.Bool() {
variable.SetFloat(1)
} else {
variable.SetFloat(0)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
variable.SetFloat(float64(rv.Int()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
variable.SetFloat(float64(rv.Uint()))
case reflect.Float32, reflect.Float64:
variable.SetFloat(rv.Float())
case reflect.String:
if fn, err = strconv.ParseFloat(rv.String(), 64); err == nil {
variable.SetFloat(fn)
}
default:
err = fmt.Errorf("decimal value can not assign %s", rv.Kind())
}
case reflect.String:
switch rv.Kind() {
case reflect.Bool:
if rv.Bool() {
variable.SetString("true")
} else {
variable.SetString("false")
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
variable.SetString(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
variable.SetString(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32, reflect.Float64:
variable.SetString(strconv.FormatFloat(rv.Float(), 'f', -1, 64))
case reflect.String:
variable.SetString(rv.String())
default:
variable.SetString(fmt.Sprint(value))
}
case reflect.Interface:
variable.Set(rv)
default:
err = fmt.Errorf("unsupported kind %s", kind)
}
return
}
func Set(hacky any, field string, value any) (err error) {
var (
n int
refField reflect.Value
)
refVal := reflect.ValueOf(hacky)
if refVal.Kind() == reflect.Ptr {
refVal = reflect.Indirect(refVal)
}
if refVal.Kind() != reflect.Struct {
return fmt.Errorf("%s kind is %v", refVal.Type().String(), refField.Kind())
}
refField = findField(refVal, field)
if !refField.IsValid() {
return fmt.Errorf("%s field `%s` not found", refVal.Type(), field)
}
rv := reflect.ValueOf(value)
fieldKind := refField.Kind()
if fieldKind != reflect.Slice && fieldKind != reflect.Array && fieldKind != reflect.Map && fieldKind == rv.Kind() {
refField.Set(rv)
return
}
switch fieldKind {
case reflect.Struct:
if rv.Kind() != reflect.Map {
return ErrValueAssociated
}
keys := rv.MapKeys()
subVal := reflect.New(refField.Type())
for _, key := range keys {
pv := rv.MapIndex(key)
if key.Kind() == reflect.String {
if err = Set(subVal.Interface(), key.String(), pv.Interface()); err != nil {
return err
}
}
}
refField.Set(subVal.Elem())
case reflect.Ptr:
elemType := refField.Type()
if elemType.Elem().Kind() != reflect.Struct {
return ErrValueAssociated
} else {
if rv.Kind() != reflect.Map {
return ErrValueAssociated
}
keys := rv.MapKeys()
subVal := reflect.New(elemType.Elem())
for _, key := range keys {
pv := rv.MapIndex(key)
if key.Kind() == reflect.String {
if err = Set(subVal.Interface(), key.String(), pv.Interface()); err != nil {
return err
}
}
}
refField.Set(subVal)
}
case reflect.Map:
if rv.Kind() != reflect.Map {
return ErrValueAssociated
}
targetValue := reflect.MakeMap(refField.Type())
keys := rv.MapKeys()
for _, key := range keys {
pv := rv.MapIndex(key)
kVal := reflect.New(refField.Type().Key())
eVal := reflect.New(refField.Type().Elem())
if err = safeAssignment(kVal.Elem(), key.Interface()); err != nil {
return ErrValueAssociated
}
if refField.Type().Elem().Kind() == reflect.Struct {
if pv.Elem().Kind() != reflect.Map {
return ErrValueAssociated
}
subKeys := pv.Elem().MapKeys()
for _, subKey := range subKeys {
subVal := pv.Elem().MapIndex(subKey)
if subKey.Kind() == reflect.String {
if err = Set(eVal.Interface(), subKey.String(), subVal.Interface()); err != nil {
return err
}
}
}
targetValue.SetMapIndex(kVal.Elem(), eVal.Elem())
} else {
if err = safeAssignment(eVal.Elem(), pv.Interface()); err != nil {
return ErrValueAssociated
}
targetValue.SetMapIndex(kVal.Elem(), eVal.Elem())
}
}
refField.Set(targetValue)
case reflect.Array, reflect.Slice:
n = 0
innerType := refField.Type().Elem()
if rv.Kind() == reflect.Array || rv.Kind() == reflect.Slice {
if innerType.Kind() == reflect.Struct {
sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len())
for i := 0; i < rv.Len(); i++ {
srcVal := rv.Index(i)
if srcVal.Kind() != reflect.Map {
return ErrValueAssociated
}
dstVal := reflect.New(innerType)
keys := srcVal.MapKeys()
for _, key := range keys {
kv := srcVal.MapIndex(key)
if key.Kind() == reflect.String {
if err = Set(dstVal.Interface(), key.String(), kv.Interface()); err != nil {
return
}
}
}
sliceVar.Index(n).Set(dstVal.Elem())
n++
}
refField.Set(sliceVar.Slice(0, n))
} else if innerType.Kind() == reflect.Ptr {
sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len())
for i := 0; i < rv.Len(); i++ {
srcVal := rv.Index(i)
if srcVal.Kind() != reflect.Map {
return ErrValueAssociated
}
dstVal := reflect.New(innerType.Elem())
keys := srcVal.MapKeys()
for _, key := range keys {
kv := srcVal.MapIndex(key)
if key.Kind() == reflect.String {
if err = Set(dstVal.Interface(), key.String(), kv.Interface()); err != nil {
return
}
}
}
sliceVar.Index(n).Set(dstVal)
n++
}
refField.Set(sliceVar.Slice(0, n))
} else {
sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len())
for i := range rv.Len() {
srcVal := rv.Index(i)
dstVal := reflect.New(innerType).Elem()
if err = safeAssignment(dstVal, srcVal.Interface()); err != nil {
return
}
sliceVar.Index(n).Set(dstVal)
n++
}
refField.Set(sliceVar.Slice(0, n))
}
}
default:
err = safeAssignment(refField, value)
}
return
}
func Assign(variable reflect.Value, value any) (err error) {
return safeAssignment(variable, value)
}
func Setter[T string | int | int64 | float64 | any](hacky any, variables map[string]T) (err error) {
for k, v := range variables {
if err = Set(hacky, k, v); err != nil {
return err
}
}
return
}

View File

@ -12,6 +12,8 @@ import (
"sync/atomic"
"time"
"git.nobla.cn/golang/aeus/metadata"
"git.nobla.cn/golang/aeus/middleware"
"git.nobla.cn/golang/aeus/pkg/errors"
"git.nobla.cn/golang/aeus/pkg/logger"
netutil "git.nobla.cn/golang/aeus/pkg/net"
@ -27,6 +29,11 @@ type Server struct {
ctxMap sync.Map
uri *url.URL
exitFlag int32
middleware []middleware.Middleware
}
func (svr *Server) Use(middlewares ...middleware.Middleware) {
svr.middleware = append(svr.middleware, middlewares...)
}
func (svr *Server) Handle(pathname string, desc string, cb HandleFunc) {
@ -90,13 +97,13 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) {
}
if r, args, err = s.router.Lookup(tokens); err != nil {
if errors.Is(err, ErrNotFound) {
err = ctx.Error(errNotFound, fmt.Sprintf("Command %s not found", cmd))
err = ctx.Error(errors.NotFound, fmt.Sprintf("Command %s not found", cmd))
} else {
err = ctx.Error(errExecuteFailed, err.Error())
err = ctx.Error(errors.Unavailable, err.Error())
}
} else {
if len(r.params) > len(args) {
err = ctx.Error(errExecuteFailed, r.Usage())
err = ctx.Error(errors.Unavailable, r.Usage())
return
}
if len(r.params) > 0 {
@ -107,7 +114,17 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) {
}
ctx.setArgs(args)
ctx.setParam(params)
err = r.command.Handle(ctx)
h := func(c context.Context) error {
return r.command.Handle(ctx)
}
next := middleware.Chain(s.middleware...)(h)
md := metadata.FromContext(ctx.ctx)
md.Set(metadata.RequestPathKey, r.command.Path)
md.Set(metadata.RequestProtocolKey, Protocol)
md.TeeReader(&cliMetadataReader{ctx: ctx})
md.TeeWriter(&cliMetadataWriter{ctx: ctx})
ctx.ctx = metadata.NewContext(ctx.ctx, md)
err = next(ctx.ctx)
}
return
}

View File

@ -16,8 +16,7 @@ var (
)
const (
errNotFound = 4004
errExecuteFailed = 4005
Protocol = "cli"
)
var (
@ -84,6 +83,13 @@ type (
ServerTime time.Time `json:"server_time"`
RemoteAddr string `json:"remote_addr"`
}
cliMetadataReader struct {
ctx *Context
}
cliMetadataWriter struct {
ctx *Context
}
)
func WithAddress(addr string) Option {
@ -109,3 +115,11 @@ func WithContext(ctx context.Context) Option {
o.context = ctx
}
}
func (r *cliMetadataReader) Get(key string) string {
return r.ctx.Param(key)
}
func (r *cliMetadataWriter) Set(key string, value string) {
r.ctx.SetValue(key, value)
}

View File

@ -41,24 +41,59 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
return
}
h := middleware.Chain(s.middlewares...)(next)
md := make(metadata.Metadata)
md := metadata.FromContext(ctx)
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
if ok {
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
}
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
if !ok {
grpcOutgoingMetadata = make(grpcmd.MD)
}
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestPathKey, info.FullMethod)
md.Set(metadata.RequestProtocolKey, Protocol)
if gmd, ok := grpcmd.FromIncomingContext(ctx); ok {
for k, v := range gmd {
if len(v) > 0 {
md.Set(k, v[0])
}
}
}
ctx = metadata.MergeContext(ctx, md, true)
ctx = metadata.NewContext(ctx, md)
ctx = context.WithValue(ctx, requestValueContextKey{}, req)
err = h(ctx)
// grpcmd.AppendToOutgoingContext(ctx, grpcmd.New(metadata.FromContext(ctx)))
// grpc.SetHeader()
if grpcOutgoingMetadata.Len() > 0 {
grpc.SetHeader(ctx, grpcOutgoingMetadata)
}
return
}
}
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
ctx := ss.Context()
next := func(ctx context.Context) (err error) {
err = handler(srv, ss)
return
}
h := middleware.Chain(s.middlewares...)(next)
md := metadata.FromContext(ctx)
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
if ok {
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
}
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
if !ok {
grpcOutgoingMetadata = make(grpcmd.MD)
}
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestPathKey, info.FullMethod)
md.Set(metadata.RequestProtocolKey, Protocol)
ctx = metadata.NewContext(ctx, md)
err = h(ctx)
if grpcOutgoingMetadata.Len() > 0 {
grpc.SetHeader(ctx, grpcOutgoingMetadata)
}
return
}
}
@ -113,6 +148,7 @@ func New(cbs ...Option) *Server {
cb(svr.opts)
}
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor()))
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainStreamInterceptor(svr.streamServerInterceptor()))
svr.serve = grpc.NewServer(svr.opts.grpcOpts...)
return svr
}

View File

@ -7,6 +7,7 @@ import (
"git.nobla.cn/golang/aeus/pkg/logger"
"git.nobla.cn/golang/aeus/registry"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
@ -32,6 +33,14 @@ type (
ClientOption func(*clientOptions)
grpcMetadataReader struct {
md metadata.MD
}
grpcMetadataWriter struct {
md metadata.MD
}
requestValueContextKey struct{}
)
@ -83,3 +92,21 @@ func GetRequestValueFromContext(ctx context.Context) any {
}
return ctx.Value(requestValueContextKey{})
}
func (m *grpcMetadataReader) Get(key string) string {
if m.md == nil {
return ""
}
vs := m.md.Get(key)
if len(vs) > 0 {
return vs[0]
}
return ""
}
func (m *grpcMetadataWriter) Set(key string, value string) {
if m.md == nil {
return
}
m.md.Set(key, value)
}

View File

@ -98,18 +98,20 @@ func (s *Server) requestInterceptor() gin.HandlerFunc {
return nil
}
handler := middleware.Chain(s.middlewares...)(next)
md := make(metadata.Metadata)
for k, v := range ginCtx.Request.Header {
if len(v) > 0 {
md.Set(k, v[0])
}
}
md := metadata.FromContext(ctx)
md.TeeReader(&httpMetadataReader{
hd: ginCtx.Request.Header,
})
md.TeeWriter(&httpMetadataWriter{
w: ginCtx.Writer,
})
if !md.Has(metadata.RequestIDKey) {
md.Set(metadata.RequestIDKey, uuid.New().String())
}
md.Set(metadata.RequestProtocolKey, Protocol)
md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path)
ctx = metadata.MergeContext(ctx, md, true)
ctx = metadata.NewContext(ctx, md)
ginCtx.Request = ginCtx.Request.WithContext(ctx)
if err := handler(ctx); err != nil {
if se, ok := err.(*errors.Error); ok {
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil))

View File

@ -30,6 +30,14 @@ type (
HandleFunc func(ctx *Context) (err error)
Middleware func(http.Handler) http.Handler
httpMetadataReader struct {
hd http.Header
}
httpMetadataWriter struct {
w http.ResponseWriter
}
)
func WithNetwork(network string) Option {
@ -85,3 +93,17 @@ func WithGinOptions(opts ...gin.OptionFunc) Option {
o.ginOptions = opts
}
}
func (m *httpMetadataReader) Get(key string) string {
if m.hd == nil {
return ""
}
return m.hd.Get(key)
}
func (m *httpMetadataWriter) Set(key string, value string) {
if m.w == nil {
return
}
m.w.Header().Set(key, value)
}