From f3b532ec67e9f9a52df4e014ff3a835f6bd2e6c6 Mon Sep 17 00:00:00 2001 From: fancl Date: Tue, 6 Jun 2023 10:59:13 +0800 Subject: [PATCH] add request lib --- entry/http/server.go | 2 +- pkg/cache/memcache.go | 2 +- pkg/request/auth.go | 27 +++++ pkg/request/client.go | 152 +++++++++++++++++++++++++++ pkg/request/request.go | 230 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 pkg/request/auth.go create mode 100644 pkg/request/client.go create mode 100644 pkg/request/request.go diff --git a/entry/http/server.go b/entry/http/server.go index bc8752e..3df9e24 100644 --- a/entry/http/server.go +++ b/entry/http/server.go @@ -64,7 +64,7 @@ func (svr *Server) Use(middleware ...Middleware) { } func (svr *Server) Any(prefix string, handle http.Handler) { - if !strings.HasSuffix(prefix, "/") { + if !strings.HasPrefix(prefix, "/") { prefix = "/" + prefix } svr.anyRequests[prefix] = handle diff --git a/pkg/cache/memcache.go b/pkg/cache/memcache.go index a22631e..a7f0c8d 100644 --- a/pkg/cache/memcache.go +++ b/pkg/cache/memcache.go @@ -28,6 +28,6 @@ func (cache *MemCache) Del(ctx context.Context, key string) { func NewMemCache() *MemCache { return &MemCache{ - engine: cache.New(time.Hour, time.Minute*90), + engine: cache.New(time.Hour, time.Minute*10), } } diff --git a/pkg/request/auth.go b/pkg/request/auth.go new file mode 100644 index 0000000..20069e7 --- /dev/null +++ b/pkg/request/auth.go @@ -0,0 +1,27 @@ +package request + +import ( + "encoding/base64" + "fmt" +) + +type Authorization interface { + Token() string +} + +type BasicAuth struct { + Username string + Password string +} + +type BearerAuth struct { + AccessToken string +} + +func (auth *BasicAuth) Token() string { + return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(auth.Username+":"+auth.Password))) +} + +func (auth *BearerAuth) Token() string { + return fmt.Sprintf("Bearer %s", auth.AccessToken) +} diff --git a/pkg/request/client.go b/pkg/request/client.go new file mode 100644 index 0000000..5d3c312 --- /dev/null +++ b/pkg/request/client.go @@ -0,0 +1,152 @@ +package request + +import ( + "bytes" + "io" + "net/http" + "net/http/cookiejar" + "strings" +) + +type ( + BeforeRequest func(req *http.Request) (err error) + AfterRequest func(req *http.Request, res *http.Response) (err error) + + Client struct { + baseUrl string + Authorization Authorization + client *http.Client + cookieJar *cookiejar.Jar + interceptorRequest []BeforeRequest + interceptorResponse []AfterRequest + } +) + +func (client *Client) stashUri(urlPath string) string { + var ( + pos int + ) + if len(urlPath) == 0 { + return client.baseUrl + } + if pos = strings.Index(urlPath, "//"); pos == -1 { + if client.baseUrl != "" { + if urlPath[0] != '/' { + urlPath = "/" + urlPath + } + return client.baseUrl + urlPath + } + } + return urlPath +} + +func (client *Client) BeforeRequest(cb BeforeRequest) *Client { + client.interceptorRequest = append(client.interceptorRequest, cb) + return client +} + +func (client *Client) AfterRequest(cb AfterRequest) *Client { + client.interceptorResponse = append(client.interceptorResponse, cb) + return client +} + +func (client *Client) SetBaseUrl(s string) *Client { + client.baseUrl = strings.TrimSuffix(s, "/") + return client +} + +func (client *Client) SetCookieJar(cookieJar *cookiejar.Jar) *Client { + client.client.Jar = cookieJar + return client +} + +func (client *Client) SetClient(httpClient *http.Client) *Client { + client.client = httpClient + if client.cookieJar != nil { + client.client.Jar = client.cookieJar + } + return client +} + +func (client *Client) SetTransport(transport http.RoundTripper) *Client { + client.client.Transport = transport + return client +} + +func (client *Client) Get(urlPath string) *Request { + return newRequest(http.MethodGet, client.stashUri(urlPath), client) +} + +func (client *Client) Put(urlPath string) *Request { + return newRequest(http.MethodPut, client.stashUri(urlPath), client) +} + +func (client *Client) Post(urlPath string) *Request { + return newRequest(http.MethodPost, client.stashUri(urlPath), client) +} + +func (client *Client) Delete(urlPath string) *Request { + return newRequest(http.MethodDelete, client.stashUri(urlPath), client) +} + +func (client *Client) execute(r *Request) (res *http.Response, err error) { + var ( + n int + buf []byte + reader io.Reader + ) + if r.contentType == "" && r.body != nil { + r.contentType = r.detectContentType(r.body) + } + if r.body != nil { + if buf, err = r.readRequestBody(r.contentType, r.body); err != nil { + return + } + reader = bytes.NewReader(buf) + } + if r.rawRequest, err = http.NewRequest(r.method, r.uri, reader); err != nil { + return + } + for k, vs := range r.header { + for _, v := range vs { + r.rawRequest.Header.Add(k, v) + } + } + if r.contentType != "" { + r.rawRequest.Header.Set("Content-Type", r.contentType) + } + if client.Authorization != nil { + r.rawRequest.Header.Set("Authorization", client.Authorization.Token()) + } + if r.context != nil { + r.rawRequest = r.rawRequest.WithContext(r.context) + } + n = len(client.interceptorRequest) + for i := n - 1; i >= 0; i-- { + if err = client.interceptorRequest[i](r.rawRequest); err != nil { + return + } + } + if r.rawResponse, err = client.client.Do(r.rawRequest); err != nil { + return nil, err + } + n = len(client.interceptorResponse) + for i := n - 1; i >= 0; i-- { + if err = client.interceptorResponse[i](r.rawRequest, r.rawResponse); err != nil { + _ = r.rawResponse.Body.Close() + return + } + } + return r.rawResponse, err +} + +func New() *Client { + client := &Client{ + client: http.DefaultClient, + interceptorRequest: make([]BeforeRequest, 0, 10), + interceptorResponse: make([]AfterRequest, 0, 10), + } + client.cookieJar, _ = cookiejar.New(nil) + client.client.Jar = client.cookieJar + return client +} diff --git a/pkg/request/request.go b/pkg/request/request.go new file mode 100644 index 0000000..253f903 --- /dev/null +++ b/pkg/request/request.go @@ -0,0 +1,230 @@ +package request + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "reflect" + "regexp" + "strings" +) + +const ( + JSON = "application/json" + XML = "application/xml" + + plainTextType = "text/plain; charset=utf-8" + jsonContentType = "application/json" + formContentType = "application/x-www-form-urlencoded" +) + +var ( + jsonCheck = regexp.MustCompile(`(?i:(application|text)/(json|.*\+json|json\-.*)(;|$))`) + xmlCheck = regexp.MustCompile(`(?i:(application|text)/(xml|.*\+xml)(;|$))`) +) + +type Request struct { + context context.Context + method string + uri string + url *url.URL + body any + query url.Values + formData url.Values + header http.Header + contentType string + authorization Authorization + client *Client + rawRequest *http.Request + rawResponse *http.Response +} + +func (r *Request) detectContentType(body interface{}) string { + contentType := plainTextType + kind := reflect.Indirect(reflect.ValueOf(body)).Type().Kind() + switch kind { + case reflect.Struct, reflect.Map: + contentType = jsonContentType + case reflect.String: + contentType = plainTextType + default: + if b, ok := body.([]byte); ok { + contentType = http.DetectContentType(b) + } else if kind == reflect.Slice { + contentType = jsonContentType + } + } + return contentType +} + +func (r *Request) readRequestBody(contentType string, body any) (buf []byte, err error) { + var ( + ok bool + s string + reader io.Reader + ) + kind := reflect.Indirect(reflect.ValueOf(body)).Type().Kind() + if reader, ok = r.body.(io.Reader); ok { + buf, err = io.ReadAll(reader) + goto __end + } + if buf, ok = r.body.([]byte); ok { + goto __end + } + if s, ok = r.body.(string); ok { + buf = []byte(s) + goto __end + } + if jsonCheck.MatchString(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) { + buf, err = json.Marshal(r.body) + goto __end + } + if xmlCheck.MatchString(contentType) && (kind == reflect.Struct) { + buf, err = xml.Marshal(r.body) + goto __end + } + err = fmt.Errorf("unmarshal content type %s", contentType) +__end: + return +} + +func (r *Request) SetContext(ctx context.Context) *Request { + r.context = ctx + return r +} + +func (r *Request) AddQuery(k, v string) *Request { + r.query.Add(k, v) + return r +} + +func (r *Request) SetQuery(vs map[string]string) *Request { + for k, v := range vs { + r.query.Set(k, v) + } + return r +} + +func (r *Request) AddFormData(k, v string) *Request { + r.contentType = formContentType + r.formData.Add(k, v) + return r +} + +func (r *Request) SetFormData(vs map[string]string) *Request { + r.contentType = formContentType + for k, v := range vs { + r.formData.Set(k, v) + } + return r +} + +func (r *Request) SetBody(v any) *Request { + r.body = v + return r +} + +func (r *Request) SetContentType(v string) *Request { + r.contentType = v + return r +} + +func (r *Request) AddHeader(k, v string) *Request { + r.header.Add(k, v) + return r +} + +func (r *Request) SetHeader(h http.Header) *Request { + r.header = h + return r +} + +func (r *Request) Do() (res *http.Response, err error) { + var s string + s = r.formData.Encode() + if len(s) > 0 { + r.body = s + } + r.url.RawQuery = r.query.Encode() + r.uri = r.url.String() + return r.client.execute(r) +} + +func (r *Request) Response(v any) (err error) { + var ( + res *http.Response + buf []byte + contentType string + ) + if res, err = r.Do(); err != nil { + return + } + defer func() { + _ = res.Body.Close() + }() + if res.StatusCode/100 != 2 { + if buf, err = io.ReadAll(res.Body); err == nil && len(buf) > 0 { + err = fmt.Errorf("http response %s(%d): %s", res.Status, res.StatusCode, string(buf)) + } else { + err = fmt.Errorf("http response %d: %s", res.StatusCode, res.Status) + } + return + } + contentType = strings.ToLower(res.Header.Get("Content-Type")) + extName := path.Ext(r.rawRequest.URL.String()) + if strings.Contains(contentType, JSON) || extName == ".json" { + err = json.NewDecoder(res.Body).Decode(v) + } else if strings.Contains(contentType, XML) || extName == ".xml" { + err = xml.NewDecoder(res.Body).Decode(v) + } else { + err = fmt.Errorf("unsupported content type: %s", contentType) + } + return +} + +func (r *Request) Download(s string) (err error) { + var ( + fp *os.File + res *http.Response + ) + if res, err = r.Do(); err != nil { + return + } + defer func() { + _ = res.Body.Close() + }() + if fp, err = os.OpenFile(s, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil { + return + } + defer func() { + _ = fp.Close() + }() + _, err = io.Copy(fp, res.Body) + return +} + +func newRequest(method string, uri string, client *Client) *Request { + var ( + err error + ) + r := &Request{ + context: context.Background(), + method: method, + uri: uri, + header: make(http.Header), + formData: make(url.Values), + client: client, + } + if r.url, err = url.Parse(uri); err == nil { + r.query = r.url.Query() + } else { + r.query = make(url.Values) + } + return r +}