upgrade metadata
This commit is contained in:
parent
c20c14227d
commit
a694e40b13
5
go.mod
5
go.mod
|
@ -5,7 +5,9 @@ go 1.23.0
|
|||
toolchain go1.23.9
|
||||
|
||||
require (
|
||||
github.com/envoyproxy/protoc-gen-validate v1.2.1
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/peterh/liner v1.2.2
|
||||
|
@ -15,6 +17,7 @@ require (
|
|||
google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb
|
||||
google.golang.org/grpc v1.72.2
|
||||
google.golang.org/protobuf v1.36.5
|
||||
gorm.io/gorm v1.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -33,6 +36,8 @@ require (
|
|||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
|
|
10
go.sum
10
go.sum
|
@ -13,6 +13,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
|
@ -36,6 +38,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
|||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
|
@ -45,6 +49,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
|
@ -177,5 +185,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
|||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
|
|
@ -2,6 +2,7 @@ package metadata
|
|||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"maps"
|
||||
"strings"
|
||||
)
|
||||
|
@ -11,35 +12,78 @@ type metadataKey struct{}
|
|||
// Metadata is our way of representing request headers internally.
|
||||
// They're used at the RPC level and translate back and forth
|
||||
// from Transport headers.
|
||||
type Metadata map[string]string
|
||||
|
||||
type Metadata struct {
|
||||
teeReader TeeReader
|
||||
teeWriter TeeWriter
|
||||
variables map[string]string
|
||||
}
|
||||
|
||||
func canonicalMetadataKey(key string) string {
|
||||
return strings.ToLower(key)
|
||||
}
|
||||
|
||||
func (md Metadata) Has(key string) bool {
|
||||
_, ok := md[canonicalMetadataKey(key)]
|
||||
// TeeReader sets the tee reader.
|
||||
func (m *Metadata) TeeReader(r TeeReader) {
|
||||
m.teeReader = r
|
||||
}
|
||||
|
||||
// TeeWriter sets the tee writer.
|
||||
func (m *Metadata) TeeWriter(w TeeWriter) {
|
||||
m.teeWriter = w
|
||||
}
|
||||
|
||||
// Has returns true if the metadata contains the given key.
|
||||
func (m *Metadata) Has(key string) bool {
|
||||
_, ok := m.Get(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (md Metadata) Get(key string) (string, bool) {
|
||||
val, ok := md[canonicalMetadataKey(key)]
|
||||
// Get returns the first value associated with the given key.
|
||||
func (m *Metadata) Get(key string) (string, bool) {
|
||||
key = canonicalMetadataKey(key)
|
||||
val, ok := m.variables[key]
|
||||
if !ok && m.teeReader != nil {
|
||||
if val = m.teeReader.Get(key); val != "" {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (md Metadata) Set(key, val string) {
|
||||
md[canonicalMetadataKey(key)] = val
|
||||
// Set sets a metadata key/value pair.
|
||||
func (m *Metadata) Set(key, val string) {
|
||||
if m.variables == nil {
|
||||
m.variables = make(map[string]string)
|
||||
}
|
||||
key = canonicalMetadataKey(key)
|
||||
m.variables[key] = val
|
||||
if m.teeWriter != nil {
|
||||
m.teeWriter.Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
func (md Metadata) Delete(key string) {
|
||||
delete(md, canonicalMetadataKey(key))
|
||||
// Delete removes a key from the metadata.
|
||||
func (m *Metadata) Delete(key string) {
|
||||
|
||||
key = canonicalMetadataKey(key)
|
||||
if m.variables != nil {
|
||||
delete(m.variables, key)
|
||||
}
|
||||
if m.teeWriter != nil {
|
||||
m.teeWriter.Set(key, "")
|
||||
}
|
||||
}
|
||||
|
||||
// Copy makes a copy of the metadata.
|
||||
func Copy(md Metadata) Metadata {
|
||||
cmd := make(Metadata, len(md))
|
||||
maps.Copy(cmd, md)
|
||||
return cmd
|
||||
// Keys returns a sequence of the metadata keys.
|
||||
func (m *Metadata) Keys() iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for k := range m.variables {
|
||||
if !yield(k) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete key from metadata.
|
||||
|
@ -49,67 +93,60 @@ func Delete(ctx context.Context, k string) context.Context {
|
|||
|
||||
// Set add key with val to metadata.
|
||||
func Set(ctx context.Context, k, v string) context.Context {
|
||||
md, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
md = make(Metadata)
|
||||
}
|
||||
md := FromContext(ctx)
|
||||
k = canonicalMetadataKey(k)
|
||||
if v == "" {
|
||||
delete(md, k)
|
||||
md.Delete(k)
|
||||
} else {
|
||||
md[k] = v
|
||||
md.Set(k, v)
|
||||
}
|
||||
return context.WithValue(ctx, metadataKey{}, md)
|
||||
}
|
||||
|
||||
// Get returns a single value from metadata in the context.
|
||||
func Get(ctx context.Context, key string) (string, bool) {
|
||||
md, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
return "", ok
|
||||
}
|
||||
md := FromContext(ctx)
|
||||
key = canonicalMetadataKey(key)
|
||||
val, ok := md[canonicalMetadataKey(key)]
|
||||
val, ok := md.Get(key)
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// FromContext returns metadata from the given context.
|
||||
func FromContext(ctx context.Context) (Metadata, bool) {
|
||||
md, ok := ctx.Value(metadataKey{}).(Metadata)
|
||||
func FromContext(ctx context.Context) *Metadata {
|
||||
md, ok := ctx.Value(metadataKey{}).(*Metadata)
|
||||
if !ok {
|
||||
return nil, ok
|
||||
return New()
|
||||
}
|
||||
|
||||
// capitalise all values
|
||||
newMD := make(Metadata, len(md))
|
||||
for k, v := range md {
|
||||
newMD[canonicalMetadataKey(k)] = v
|
||||
}
|
||||
|
||||
return newMD, ok
|
||||
return md
|
||||
}
|
||||
|
||||
// NewContext creates a new context with the given metadata.
|
||||
func NewContext(ctx context.Context, md Metadata) context.Context {
|
||||
func NewContext(ctx context.Context, md *Metadata) context.Context {
|
||||
return context.WithValue(ctx, metadataKey{}, md)
|
||||
}
|
||||
|
||||
// MergeContext merges metadata to existing metadata, overwriting if specified.
|
||||
func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context.Context {
|
||||
func MergeContext(ctx context.Context, patchMd *Metadata, overwrite bool) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
md, _ := ctx.Value(metadataKey{}).(Metadata)
|
||||
cmd := make(Metadata, len(md))
|
||||
maps.Copy(cmd, md)
|
||||
for k, v := range patchMd {
|
||||
if _, ok := cmd[k]; ok && !overwrite {
|
||||
cmd := New()
|
||||
maps.Copy(cmd.variables, md.variables)
|
||||
for k, v := range patchMd.variables {
|
||||
if _, ok := cmd.variables[k]; ok && !overwrite {
|
||||
// skip
|
||||
} else if v != "" {
|
||||
cmd[k] = v
|
||||
cmd.variables[k] = v
|
||||
} else {
|
||||
delete(cmd, k)
|
||||
delete(cmd.variables, k)
|
||||
}
|
||||
}
|
||||
return context.WithValue(ctx, metadataKey{}, cmd)
|
||||
}
|
||||
|
||||
func New() *Metadata {
|
||||
return &Metadata{
|
||||
variables: make(map[string]string, 16),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,3 +5,13 @@ const (
|
|||
RequestPathKey = "X-AEUS-Request-Path"
|
||||
RequestProtocolKey = "X-AEUS-Request-Protocol"
|
||||
)
|
||||
|
||||
type (
|
||||
TeeReader interface {
|
||||
Get(string) string
|
||||
}
|
||||
|
||||
TeeWriter interface {
|
||||
Set(string, string)
|
||||
}
|
||||
)
|
||||
|
|
|
@ -1 +1,136 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"git.nobla.cn/golang/aeus/metadata"
|
||||
"git.nobla.cn/golang/aeus/middleware"
|
||||
"git.nobla.cn/golang/aeus/pkg/errors"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type authKey struct{}
|
||||
|
||||
const (
|
||||
|
||||
// bearerWord the bearer key word for authorization
|
||||
bearerWord string = "Bearer"
|
||||
|
||||
// bearerFormat authorization token format
|
||||
bearerFormat string = "Bearer %s"
|
||||
|
||||
// authorizationKey holds the key used to store the JWT Token in the request tokenHeader.
|
||||
authorizationKey string = "Authorization"
|
||||
|
||||
// reason holds the error reason.
|
||||
reason string = "UNAUTHORIZED"
|
||||
)
|
||||
|
||||
type Option func(*options)
|
||||
|
||||
// Parser is a jwt parser
|
||||
type options struct {
|
||||
allows []string
|
||||
claims func() jwt.Claims
|
||||
}
|
||||
|
||||
// WithClaims with customer claim
|
||||
// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems
|
||||
// If you use it in Client, f only needs to return a single object to provide performance
|
||||
func WithClaims(f func() jwt.Claims) Option {
|
||||
return func(o *options) {
|
||||
o.claims = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllow with allow path
|
||||
func WithAllow(path string) Option {
|
||||
return func(o *options) {
|
||||
if o.allows == nil {
|
||||
o.allows = make([]string, 0, 16)
|
||||
}
|
||||
o.allows = append(o.allows, path)
|
||||
}
|
||||
}
|
||||
|
||||
// isAllowed check if the path is allowed
|
||||
func isAllowed(uripath string, allows []string) bool {
|
||||
for _, str := range allows {
|
||||
n := len(str)
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
if n > 1 && str[n-1] == '*' {
|
||||
if strings.HasPrefix(uripath, str[:n-1]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if str == uripath {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// JWT auth middleware
|
||||
func JWT(keyFunc jwt.Keyfunc, cbs ...Option) middleware.Middleware {
|
||||
opts := options{}
|
||||
for _, cb := range cbs {
|
||||
cb(&opts)
|
||||
}
|
||||
return func(next middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context) (err error) {
|
||||
md := metadata.FromContext(ctx)
|
||||
authorizationValue, ok := md.Get(authorizationKey)
|
||||
if !ok {
|
||||
return errors.ErrAccessDenied
|
||||
}
|
||||
if len(opts.allows) > 0 {
|
||||
requestPath, ok := md.Get(metadata.RequestPathKey)
|
||||
if ok {
|
||||
if isAllowed(requestPath, opts.allows) {
|
||||
return next(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !strings.HasPrefix(authorizationValue, bearerWord) {
|
||||
return errors.ErrAccessDenied
|
||||
}
|
||||
var (
|
||||
ti *jwt.Token
|
||||
)
|
||||
authorizationToken := strings.TrimSpace(strings.TrimPrefix(authorizationValue, bearerWord))
|
||||
if opts.claims != nil {
|
||||
ti, err = jwt.ParseWithClaims(authorizationToken, opts.claims(), keyFunc)
|
||||
} else {
|
||||
ti, err = jwt.Parse(authorizationToken, keyFunc)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenMalformed) || errors.Is(err, jwt.ErrTokenUnverifiable) {
|
||||
return errors.ErrAccessDenied
|
||||
}
|
||||
if errors.Is(err, jwt.ErrTokenNotValidYet) || errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return errors.ErrTokenExpired
|
||||
}
|
||||
return errors.ErrPermissionDenied
|
||||
}
|
||||
if !ti.Valid {
|
||||
return errors.ErrPermissionDenied
|
||||
}
|
||||
ctx = NewContext(ctx, ti.Claims)
|
||||
return next(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewContext put auth info into context
|
||||
func NewContext(ctx context.Context, info jwt.Claims) context.Context {
|
||||
return context.WithValue(ctx, authKey{}, info)
|
||||
}
|
||||
|
||||
// FromContext extract auth info from context
|
||||
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) {
|
||||
token, ok = ctx.Value(authKey{}).(jwt.Claims)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
//go:build appengine
|
||||
// +build appengine
|
||||
|
||||
package bs
|
||||
|
||||
func BytesToString(b []byte) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func StringToBytes(s string) []byte {
|
||||
return []byte(s)
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
//go:build !appengine
|
||||
// +build !appengine
|
||||
|
||||
package bs
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// BytesToString converts byte slice to string.
|
||||
func BytesToString(b []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&b))
|
||||
}
|
||||
|
||||
// StringToBytes converts string to byte slice.
|
||||
func StringToBytes(s string) []byte {
|
||||
return *(*[]byte)(unsafe.Pointer(
|
||||
&struct {
|
||||
string
|
||||
Cap int
|
||||
}{s, len(s)},
|
||||
))
|
||||
}
|
|
@ -1,24 +1,32 @@
|
|||
package errors
|
||||
|
||||
const (
|
||||
OK = 0 //success
|
||||
Exit = 1000 //normal exit
|
||||
Invalid = 1001 //payload invalid
|
||||
Timeout = 1002 //timeout
|
||||
Expired = 1003 //expired
|
||||
AccessDenied = 4005 //access denied
|
||||
PermissionDenied = 4003 //permission denied
|
||||
NotFound = 4004 //not found
|
||||
Unavailable = 5000 //service unavailable
|
||||
OK = 0 //success
|
||||
Exit = 1000 //normal exit
|
||||
Invalid = 1001 //payload invalid
|
||||
Exists = 1002 //already exists
|
||||
Unavailable = 1003 //service unavailable
|
||||
Timeout = 2001 //timeout
|
||||
Expired = 2002 //expired
|
||||
TokenExpired = 4002 //token expired
|
||||
NotFound = 4004 //not found
|
||||
PermissionDenied = 4003 //permission denied
|
||||
AccessDenied = 4005 //access denied
|
||||
NetworkUnreachable = 5001 //network unreachable
|
||||
ConnectionRefused = 5002 //connection refused
|
||||
)
|
||||
|
||||
var (
|
||||
ErrExit = New(Exit, "normal exit")
|
||||
ErrTimeout = New(Timeout, "timeout")
|
||||
ErrExpired = New(Expired, "expired")
|
||||
ErrValidate = New(Invalid, "invalid payload")
|
||||
ErrNotFound = New(NotFound, "not found")
|
||||
ErrAccessDenied = New(AccessDenied, "access denied")
|
||||
ErrPermissionDenied = New(PermissionDenied, "permission denied")
|
||||
ErrUnavailable = New(Unavailable, "service unavailable")
|
||||
ErrExit = New(Exit, "normal exit")
|
||||
ErrTimeout = New(Timeout, "timeout")
|
||||
ErrExists = New(Exists, "already exists")
|
||||
ErrExpired = New(Expired, "expired")
|
||||
ErrInvalid = New(Invalid, "invalid payload")
|
||||
ErrNotFound = New(NotFound, "not found")
|
||||
ErrAccessDenied = New(AccessDenied, "access denied")
|
||||
ErrPermissionDenied = New(PermissionDenied, "permission denied")
|
||||
ErrTokenExpired = New(TokenExpired, "token expired")
|
||||
ErrUnavailable = New(Unavailable, "service unavailable")
|
||||
ErrNetworkUnreachable = New(NetworkUnreachable, "network unreachable")
|
||||
ErrConnectionRefused = New(ConnectionRefused, "connection refused")
|
||||
)
|
||||
|
|
|
@ -14,8 +14,15 @@ func (e *Error) Error() string {
|
|||
return fmt.Sprintf("code: %d, message: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
func Format(code int, msg string, args ...any) Error {
|
||||
return Error{
|
||||
func Warp(code int, err error) error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
func Format(code int, msg string, args ...any) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: fmt.Sprintf(msg, args...),
|
||||
}
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
package rest
|
|
@ -0,0 +1,350 @@
|
|||
package reflection
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
allowTags = []string{"json", "yaml", "xml", "name"}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrValueAssociated = errors.New("value cannot be associated")
|
||||
)
|
||||
|
||||
func findField(v reflect.Value, field string) reflect.Value {
|
||||
var (
|
||||
pos int
|
||||
tagValue string
|
||||
refType reflect.Type
|
||||
fieldType reflect.StructField
|
||||
)
|
||||
refType = v.Type()
|
||||
for i := range refType.NumField() {
|
||||
fieldType = refType.Field(i)
|
||||
for _, tagName := range allowTags {
|
||||
tagValue = fieldType.Tag.Get(tagName)
|
||||
if tagValue == "" {
|
||||
continue
|
||||
}
|
||||
if pos = strings.IndexByte(tagValue, ','); pos != -1 {
|
||||
tagValue = tagValue[:pos]
|
||||
}
|
||||
if tagValue == field {
|
||||
return v.Field(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
return v.FieldByName(field)
|
||||
}
|
||||
|
||||
func safeAssignment(variable reflect.Value, value any) (err error) {
|
||||
var (
|
||||
n int64
|
||||
un uint64
|
||||
fn float64
|
||||
kind reflect.Kind
|
||||
)
|
||||
rv := reflect.ValueOf(value)
|
||||
kind = variable.Kind()
|
||||
if kind != reflect.Slice && kind != reflect.Array && kind != reflect.Map && kind == rv.Kind() {
|
||||
variable.Set(rv)
|
||||
return
|
||||
}
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
variable.SetBool(rv.Bool())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if rv.Int() != 0 {
|
||||
variable.SetBool(true)
|
||||
} else {
|
||||
variable.SetBool(false)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if rv.Uint() != 0 {
|
||||
variable.SetBool(true)
|
||||
} else {
|
||||
variable.SetBool(false)
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if rv.Float() != 0 {
|
||||
variable.SetBool(true)
|
||||
} else {
|
||||
variable.SetBool(false)
|
||||
}
|
||||
case reflect.String:
|
||||
var tv bool
|
||||
tv, err = strconv.ParseBool(rv.String())
|
||||
if err == nil {
|
||||
variable.SetBool(tv)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("boolean value can not assign %s", rv.Kind())
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
if rv.Bool() {
|
||||
variable.SetInt(1)
|
||||
} else {
|
||||
variable.SetInt(0)
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
variable.SetInt(rv.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
variable.SetInt(int64(rv.Uint()))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
variable.SetInt(int64(rv.Float()))
|
||||
case reflect.String:
|
||||
if n, err = strconv.ParseInt(rv.String(), 10, 64); err == nil {
|
||||
variable.SetInt(n)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("integer value can not assign %s", rv.Kind())
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
if rv.Bool() {
|
||||
variable.SetUint(1)
|
||||
} else {
|
||||
variable.SetUint(0)
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
variable.SetUint(uint64(rv.Int()))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
variable.SetUint(rv.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
variable.SetUint(uint64(rv.Float()))
|
||||
case reflect.String:
|
||||
if un, err = strconv.ParseUint(rv.String(), 10, 64); err == nil {
|
||||
variable.SetUint(un)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unsigned integer value can not assign %s", rv.Kind())
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
if rv.Bool() {
|
||||
variable.SetFloat(1)
|
||||
} else {
|
||||
variable.SetFloat(0)
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
variable.SetFloat(float64(rv.Int()))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
variable.SetFloat(float64(rv.Uint()))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
variable.SetFloat(rv.Float())
|
||||
case reflect.String:
|
||||
if fn, err = strconv.ParseFloat(rv.String(), 64); err == nil {
|
||||
variable.SetFloat(fn)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("decimal value can not assign %s", rv.Kind())
|
||||
}
|
||||
case reflect.String:
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
if rv.Bool() {
|
||||
variable.SetString("true")
|
||||
} else {
|
||||
variable.SetString("false")
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
variable.SetString(strconv.FormatInt(rv.Int(), 10))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
variable.SetString(strconv.FormatUint(rv.Uint(), 10))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
variable.SetString(strconv.FormatFloat(rv.Float(), 'f', -1, 64))
|
||||
case reflect.String:
|
||||
variable.SetString(rv.String())
|
||||
default:
|
||||
variable.SetString(fmt.Sprint(value))
|
||||
}
|
||||
case reflect.Interface:
|
||||
variable.Set(rv)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported kind %s", kind)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func Set(hacky any, field string, value any) (err error) {
|
||||
var (
|
||||
n int
|
||||
refField reflect.Value
|
||||
)
|
||||
refVal := reflect.ValueOf(hacky)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
refVal = reflect.Indirect(refVal)
|
||||
}
|
||||
if refVal.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("%s kind is %v", refVal.Type().String(), refField.Kind())
|
||||
}
|
||||
refField = findField(refVal, field)
|
||||
if !refField.IsValid() {
|
||||
return fmt.Errorf("%s field `%s` not found", refVal.Type(), field)
|
||||
}
|
||||
rv := reflect.ValueOf(value)
|
||||
fieldKind := refField.Kind()
|
||||
if fieldKind != reflect.Slice && fieldKind != reflect.Array && fieldKind != reflect.Map && fieldKind == rv.Kind() {
|
||||
refField.Set(rv)
|
||||
return
|
||||
}
|
||||
switch fieldKind {
|
||||
case reflect.Struct:
|
||||
if rv.Kind() != reflect.Map {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
keys := rv.MapKeys()
|
||||
subVal := reflect.New(refField.Type())
|
||||
for _, key := range keys {
|
||||
pv := rv.MapIndex(key)
|
||||
if key.Kind() == reflect.String {
|
||||
if err = Set(subVal.Interface(), key.String(), pv.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
refField.Set(subVal.Elem())
|
||||
case reflect.Ptr:
|
||||
elemType := refField.Type()
|
||||
if elemType.Elem().Kind() != reflect.Struct {
|
||||
return ErrValueAssociated
|
||||
} else {
|
||||
if rv.Kind() != reflect.Map {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
keys := rv.MapKeys()
|
||||
subVal := reflect.New(elemType.Elem())
|
||||
for _, key := range keys {
|
||||
pv := rv.MapIndex(key)
|
||||
if key.Kind() == reflect.String {
|
||||
if err = Set(subVal.Interface(), key.String(), pv.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
refField.Set(subVal)
|
||||
}
|
||||
case reflect.Map:
|
||||
if rv.Kind() != reflect.Map {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
targetValue := reflect.MakeMap(refField.Type())
|
||||
keys := rv.MapKeys()
|
||||
for _, key := range keys {
|
||||
pv := rv.MapIndex(key)
|
||||
kVal := reflect.New(refField.Type().Key())
|
||||
eVal := reflect.New(refField.Type().Elem())
|
||||
if err = safeAssignment(kVal.Elem(), key.Interface()); err != nil {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
if refField.Type().Elem().Kind() == reflect.Struct {
|
||||
if pv.Elem().Kind() != reflect.Map {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
subKeys := pv.Elem().MapKeys()
|
||||
for _, subKey := range subKeys {
|
||||
subVal := pv.Elem().MapIndex(subKey)
|
||||
if subKey.Kind() == reflect.String {
|
||||
if err = Set(eVal.Interface(), subKey.String(), subVal.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
targetValue.SetMapIndex(kVal.Elem(), eVal.Elem())
|
||||
} else {
|
||||
if err = safeAssignment(eVal.Elem(), pv.Interface()); err != nil {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
targetValue.SetMapIndex(kVal.Elem(), eVal.Elem())
|
||||
}
|
||||
}
|
||||
refField.Set(targetValue)
|
||||
case reflect.Array, reflect.Slice:
|
||||
n = 0
|
||||
innerType := refField.Type().Elem()
|
||||
if rv.Kind() == reflect.Array || rv.Kind() == reflect.Slice {
|
||||
if innerType.Kind() == reflect.Struct {
|
||||
sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
srcVal := rv.Index(i)
|
||||
if srcVal.Kind() != reflect.Map {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
dstVal := reflect.New(innerType)
|
||||
keys := srcVal.MapKeys()
|
||||
for _, key := range keys {
|
||||
kv := srcVal.MapIndex(key)
|
||||
if key.Kind() == reflect.String {
|
||||
if err = Set(dstVal.Interface(), key.String(), kv.Interface()); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
sliceVar.Index(n).Set(dstVal.Elem())
|
||||
n++
|
||||
}
|
||||
refField.Set(sliceVar.Slice(0, n))
|
||||
} else if innerType.Kind() == reflect.Ptr {
|
||||
sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
srcVal := rv.Index(i)
|
||||
if srcVal.Kind() != reflect.Map {
|
||||
return ErrValueAssociated
|
||||
}
|
||||
dstVal := reflect.New(innerType.Elem())
|
||||
keys := srcVal.MapKeys()
|
||||
for _, key := range keys {
|
||||
kv := srcVal.MapIndex(key)
|
||||
if key.Kind() == reflect.String {
|
||||
if err = Set(dstVal.Interface(), key.String(), kv.Interface()); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
sliceVar.Index(n).Set(dstVal)
|
||||
n++
|
||||
}
|
||||
refField.Set(sliceVar.Slice(0, n))
|
||||
} else {
|
||||
sliceVar := reflect.MakeSlice(refField.Type(), rv.Len(), rv.Len())
|
||||
for i := range rv.Len() {
|
||||
srcVal := rv.Index(i)
|
||||
dstVal := reflect.New(innerType).Elem()
|
||||
if err = safeAssignment(dstVal, srcVal.Interface()); err != nil {
|
||||
return
|
||||
}
|
||||
sliceVar.Index(n).Set(dstVal)
|
||||
n++
|
||||
}
|
||||
refField.Set(sliceVar.Slice(0, n))
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = safeAssignment(refField, value)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func Assign(variable reflect.Value, value any) (err error) {
|
||||
return safeAssignment(variable, value)
|
||||
}
|
||||
|
||||
func Setter[T string | int | int64 | float64 | any](hacky any, variables map[string]T) (err error) {
|
||||
for k, v := range variables {
|
||||
if err = Set(hacky, k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -12,6 +12,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.nobla.cn/golang/aeus/metadata"
|
||||
"git.nobla.cn/golang/aeus/middleware"
|
||||
"git.nobla.cn/golang/aeus/pkg/errors"
|
||||
"git.nobla.cn/golang/aeus/pkg/logger"
|
||||
netutil "git.nobla.cn/golang/aeus/pkg/net"
|
||||
|
@ -27,6 +29,11 @@ type Server struct {
|
|||
ctxMap sync.Map
|
||||
uri *url.URL
|
||||
exitFlag int32
|
||||
middleware []middleware.Middleware
|
||||
}
|
||||
|
||||
func (svr *Server) Use(middlewares ...middleware.Middleware) {
|
||||
svr.middleware = append(svr.middleware, middlewares...)
|
||||
}
|
||||
|
||||
func (svr *Server) Handle(pathname string, desc string, cb HandleFunc) {
|
||||
|
@ -90,13 +97,13 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) {
|
|||
}
|
||||
if r, args, err = s.router.Lookup(tokens); err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
err = ctx.Error(errNotFound, fmt.Sprintf("Command %s not found", cmd))
|
||||
err = ctx.Error(errors.NotFound, fmt.Sprintf("Command %s not found", cmd))
|
||||
} else {
|
||||
err = ctx.Error(errExecuteFailed, err.Error())
|
||||
err = ctx.Error(errors.Unavailable, err.Error())
|
||||
}
|
||||
} else {
|
||||
if len(r.params) > len(args) {
|
||||
err = ctx.Error(errExecuteFailed, r.Usage())
|
||||
err = ctx.Error(errors.Unavailable, r.Usage())
|
||||
return
|
||||
}
|
||||
if len(r.params) > 0 {
|
||||
|
@ -107,7 +114,17 @@ func (s *Server) execute(ctx *Context, frame *Frame) (err error) {
|
|||
}
|
||||
ctx.setArgs(args)
|
||||
ctx.setParam(params)
|
||||
err = r.command.Handle(ctx)
|
||||
h := func(c context.Context) error {
|
||||
return r.command.Handle(ctx)
|
||||
}
|
||||
next := middleware.Chain(s.middleware...)(h)
|
||||
md := metadata.FromContext(ctx.ctx)
|
||||
md.Set(metadata.RequestPathKey, r.command.Path)
|
||||
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||
md.TeeReader(&cliMetadataReader{ctx: ctx})
|
||||
md.TeeWriter(&cliMetadataWriter{ctx: ctx})
|
||||
ctx.ctx = metadata.NewContext(ctx.ctx, md)
|
||||
err = next(ctx.ctx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -16,8 +16,7 @@ var (
|
|||
)
|
||||
|
||||
const (
|
||||
errNotFound = 4004
|
||||
errExecuteFailed = 4005
|
||||
Protocol = "cli"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -84,6 +83,13 @@ type (
|
|||
ServerTime time.Time `json:"server_time"`
|
||||
RemoteAddr string `json:"remote_addr"`
|
||||
}
|
||||
|
||||
cliMetadataReader struct {
|
||||
ctx *Context
|
||||
}
|
||||
cliMetadataWriter struct {
|
||||
ctx *Context
|
||||
}
|
||||
)
|
||||
|
||||
func WithAddress(addr string) Option {
|
||||
|
@ -109,3 +115,11 @@ func WithContext(ctx context.Context) Option {
|
|||
o.context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func (r *cliMetadataReader) Get(key string) string {
|
||||
return r.ctx.Param(key)
|
||||
}
|
||||
|
||||
func (r *cliMetadataWriter) Set(key string, value string) {
|
||||
r.ctx.SetValue(key, value)
|
||||
}
|
||||
|
|
|
@ -41,24 +41,59 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
|||
return
|
||||
}
|
||||
h := middleware.Chain(s.middlewares...)(next)
|
||||
md := make(metadata.Metadata)
|
||||
md := metadata.FromContext(ctx)
|
||||
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
|
||||
if ok {
|
||||
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
|
||||
}
|
||||
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
|
||||
if !ok {
|
||||
grpcOutgoingMetadata = make(grpcmd.MD)
|
||||
}
|
||||
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
|
||||
if !md.Has(metadata.RequestIDKey) {
|
||||
md.Set(metadata.RequestIDKey, uuid.New().String())
|
||||
}
|
||||
md.Set(metadata.RequestPathKey, info.FullMethod)
|
||||
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||
if gmd, ok := grpcmd.FromIncomingContext(ctx); ok {
|
||||
for k, v := range gmd {
|
||||
if len(v) > 0 {
|
||||
md.Set(k, v[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx = metadata.MergeContext(ctx, md, true)
|
||||
ctx = metadata.NewContext(ctx, md)
|
||||
ctx = context.WithValue(ctx, requestValueContextKey{}, req)
|
||||
err = h(ctx)
|
||||
// grpcmd.AppendToOutgoingContext(ctx, grpcmd.New(metadata.FromContext(ctx)))
|
||||
// grpc.SetHeader()
|
||||
if grpcOutgoingMetadata.Len() > 0 {
|
||||
grpc.SetHeader(ctx, grpcOutgoingMetadata)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
|
||||
ctx := ss.Context()
|
||||
next := func(ctx context.Context) (err error) {
|
||||
err = handler(srv, ss)
|
||||
return
|
||||
}
|
||||
h := middleware.Chain(s.middlewares...)(next)
|
||||
md := metadata.FromContext(ctx)
|
||||
grpcIncommingMetadata, ok := grpcmd.FromIncomingContext(ctx)
|
||||
if ok {
|
||||
md.TeeReader(&grpcMetadataReader{grpcIncommingMetadata})
|
||||
}
|
||||
grpcOutgoingMetadata, ok := grpcmd.FromOutgoingContext(ctx)
|
||||
if !ok {
|
||||
grpcOutgoingMetadata = make(grpcmd.MD)
|
||||
}
|
||||
md.TeeWriter(&grpcMetadataWriter{grpcOutgoingMetadata})
|
||||
if !md.Has(metadata.RequestIDKey) {
|
||||
md.Set(metadata.RequestIDKey, uuid.New().String())
|
||||
}
|
||||
md.Set(metadata.RequestPathKey, info.FullMethod)
|
||||
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||
ctx = metadata.NewContext(ctx, md)
|
||||
err = h(ctx)
|
||||
if grpcOutgoingMetadata.Len() > 0 {
|
||||
grpc.SetHeader(ctx, grpcOutgoingMetadata)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -113,6 +148,7 @@ func New(cbs ...Option) *Server {
|
|||
cb(svr.opts)
|
||||
}
|
||||
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainUnaryInterceptor(svr.unaryServerInterceptor()))
|
||||
svr.opts.grpcOpts = append(svr.opts.grpcOpts, grpc.ChainStreamInterceptor(svr.streamServerInterceptor()))
|
||||
svr.serve = grpc.NewServer(svr.opts.grpcOpts...)
|
||||
return svr
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"git.nobla.cn/golang/aeus/pkg/logger"
|
||||
"git.nobla.cn/golang/aeus/registry"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -32,6 +33,14 @@ type (
|
|||
|
||||
ClientOption func(*clientOptions)
|
||||
|
||||
grpcMetadataReader struct {
|
||||
md metadata.MD
|
||||
}
|
||||
|
||||
grpcMetadataWriter struct {
|
||||
md metadata.MD
|
||||
}
|
||||
|
||||
requestValueContextKey struct{}
|
||||
)
|
||||
|
||||
|
@ -83,3 +92,21 @@ func GetRequestValueFromContext(ctx context.Context) any {
|
|||
}
|
||||
return ctx.Value(requestValueContextKey{})
|
||||
}
|
||||
|
||||
func (m *grpcMetadataReader) Get(key string) string {
|
||||
if m.md == nil {
|
||||
return ""
|
||||
}
|
||||
vs := m.md.Get(key)
|
||||
if len(vs) > 0 {
|
||||
return vs[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *grpcMetadataWriter) Set(key string, value string) {
|
||||
if m.md == nil {
|
||||
return
|
||||
}
|
||||
m.md.Set(key, value)
|
||||
}
|
||||
|
|
|
@ -98,18 +98,20 @@ func (s *Server) requestInterceptor() gin.HandlerFunc {
|
|||
return nil
|
||||
}
|
||||
handler := middleware.Chain(s.middlewares...)(next)
|
||||
md := make(metadata.Metadata)
|
||||
for k, v := range ginCtx.Request.Header {
|
||||
if len(v) > 0 {
|
||||
md.Set(k, v[0])
|
||||
}
|
||||
}
|
||||
md := metadata.FromContext(ctx)
|
||||
md.TeeReader(&httpMetadataReader{
|
||||
hd: ginCtx.Request.Header,
|
||||
})
|
||||
md.TeeWriter(&httpMetadataWriter{
|
||||
w: ginCtx.Writer,
|
||||
})
|
||||
if !md.Has(metadata.RequestIDKey) {
|
||||
md.Set(metadata.RequestIDKey, uuid.New().String())
|
||||
}
|
||||
md.Set(metadata.RequestProtocolKey, Protocol)
|
||||
md.Set(metadata.RequestPathKey, ginCtx.Request.URL.Path)
|
||||
ctx = metadata.MergeContext(ctx, md, true)
|
||||
ctx = metadata.NewContext(ctx, md)
|
||||
ginCtx.Request = ginCtx.Request.WithContext(ctx)
|
||||
if err := handler(ctx); err != nil {
|
||||
if se, ok := err.(*errors.Error); ok {
|
||||
ginCtx.AbortWithStatusJSON(http.StatusInternalServerError, newResponse(se.Code, se.Message, nil))
|
||||
|
|
|
@ -30,6 +30,14 @@ type (
|
|||
HandleFunc func(ctx *Context) (err error)
|
||||
|
||||
Middleware func(http.Handler) http.Handler
|
||||
|
||||
httpMetadataReader struct {
|
||||
hd http.Header
|
||||
}
|
||||
|
||||
httpMetadataWriter struct {
|
||||
w http.ResponseWriter
|
||||
}
|
||||
)
|
||||
|
||||
func WithNetwork(network string) Option {
|
||||
|
@ -85,3 +93,17 @@ func WithGinOptions(opts ...gin.OptionFunc) Option {
|
|||
o.ginOptions = opts
|
||||
}
|
||||
}
|
||||
|
||||
func (m *httpMetadataReader) Get(key string) string {
|
||||
if m.hd == nil {
|
||||
return ""
|
||||
}
|
||||
return m.hd.Get(key)
|
||||
}
|
||||
|
||||
func (m *httpMetadataWriter) Set(key string, value string) {
|
||||
if m.w == nil {
|
||||
return
|
||||
}
|
||||
m.w.Header().Set(key, value)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue