From 4adb23c247c164e8bda7d8bddd5d0c0437415762 Mon Sep 17 00:00:00 2001 From: joelei Date: Sat, 6 Dec 2025 17:36:29 +0800 Subject: [PATCH] add client --- rest/client.go | 558 ++++++++++++++++++++++++++---------- rest/client_test.go | 422 ++++++++++++++++++++++----- rest/transport/curl.go | 102 ++++--- rest/transport/transport.go | 57 +--- 4 files changed, 829 insertions(+), 310 deletions(-) diff --git a/rest/client.go b/rest/client.go index 06febac..2361045 100644 --- a/rest/client.go +++ b/rest/client.go @@ -1,275 +1,531 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + package rest import ( "bytes" "context" "encoding/json/v2" - "encoding/xml" + "errors" "fmt" "io" "maps" + "net" "net/http" "net/url" + "path" "reflect" "strings" "time" - "golang.org/x/net/html" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + + "git.ifooth.com/common/pkg/rest/transport" + "git.ifooth.com/common/pkg/version" ) -type Client[T any] struct { - httpClient *http.Client - commonReq *http.Request - commonFormData url.Values +var ( + // ErrUnsupported is an error to check supported type or content-type + ErrUnsupported = errors.New("unsupported") + userAgent = fmt.Sprintf("pkg/%s", version.Version) +) + +// defaultTransport default transport, same as http.DefaultTransport but change MaxIdleConns / MaxIdleConnsPerHost +var defaultTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 200, // 最大空闲连接数 + MaxIdleConnsPerHost: 20, // 每个主机的最大空闲连接数 + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, +} + +// HTTPClient a http.Client or rest.Client Do interface +type HTTPClient interface { + Do(*http.Request) (*http.Response, error) +} + +// Client is a HTTP client with common url, header, params and transports. +type Client struct { + baseTransport http.RoundTripper + transports []func(http.RoundTripper) http.RoundTripper + baseURL *url.URL + timeout time.Duration + commonHeader http.Header commonQueryParams url.Values - commonCookies []*http.Cookie + err error } -func NewClient[T any]() *Client[T] { - r := &Client[T]{ - commonReq: &http.Request{Header: make(http.Header)}, - httpClient: &http.Client{}, - commonQueryParams: url.Values{}, +// NewClient creates a new client +func NewClient() *Client { + commonHeader := make(http.Header) + commonHeader.Set("User-Agent", userAgent) + + // 默认统一添加trace + baseTransport := otelhttp.NewTransport(defaultTransport) + + c := &Client{ + baseTransport: baseTransport, + commonQueryParams: make(url.Values), + commonHeader: commonHeader, } - return r -} -func (c *Client[T]) SetHeader(header, value string) *Client[T] { - c.commonReq.Header.Set(header, value) return c } -func (c *Client[T]) SetQueryParam(param, value string) *Client[T] { +func (c *Client) parseBaseURL(baseURL string) (*url.URL, error) { + if baseURL == "" { + return nil, fmt.Errorf("baseURL is required") + } + + u, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("baseURL %s not valid, err: %w", baseURL, err) + } + + if u.Host == "" || u.Scheme == "" { + return nil, fmt.Errorf("baseURL %s not valid, host/schema is required", baseURL) + } + + return u, nil +} + +// SetBaseURL sets the base URL for the client. +func (c *Client) SetBaseURL(baseURL string) *Client { + c.baseURL, c.err = c.parseBaseURL(baseURL) + return c +} + +// SetTimeout set a timeout for the client. +func (c *Client) SetTimeout(timeout time.Duration) *Client { + c.timeout = timeout + return c +} + +// SetHeader set the common header for the client. +func (c *Client) SetHeader(header, value string) *Client { + c.commonHeader.Set(header, value) + return c +} + +// SetQueryParam set the common query param for the client. +func (c *Client) SetQueryParam(param, value string) *Client { c.commonQueryParams.Set(param, value) return c } -func (c *Client[T]) SetTransport(transport http.RoundTripper) *Client[T] { +// SetTransport set the base transport for the client. +func (c *Client) SetTransport(transport http.RoundTripper) *Client { if transport != nil { - c.httpClient.Transport = transport + c.baseTransport = transport } return c } -func (c *Client[T]) R() *Request[T] { - req := &http.Request{ - Method: http.MethodGet, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - req = req.WithContext(context.Background()) +// GetTransport return the http.RoundTripper by chaining the transports for client. +func (c *Client) GetTransport() http.RoundTripper { + // freeze transports + next := c.chain() - r := &Request[T]{ - req: req, - client: c, - queryParams: url.Values{}, + rt := func(req *http.Request) (*http.Response, error) { + newReq, err := c.finalReq(req) + if err != nil { + return nil, err + } + + return next.RoundTrip(newReq) } - return r + + return transport.RoundTripFunc(rt) } -func (c *Client[T]) execute(req *http.Request) (*http.Response, error) { - // change header - maps.Copy(req.Header, c.commonReq.Header) +// With adds inline http.RoundTripper for an base RoundTripper for the client. +func (c *Client) With(rt ...func(http.RoundTripper) http.RoundTripper) *Client { + c.transports = append(c.transports, rt...) + return c +} - // change query - query := req.URL.Query() +func (c *Client) chain() http.RoundTripper { + rt := c.baseTransport + + for i := len(c.transports) - 1; i >= 0; i-- { + rt = c.transports[i](rt) + } + + return rt +} + +// mergeURL return the merged url with common header and params +func (c *Client) mergeURL(u *url.URL) (*url.URL, error) { + query := make(url.Values) + + if c.baseURL != nil { + if u.Host != "" || u.Scheme != "" { + return nil, fmt.Errorf("baseURL and request url host/schema can not be set at the same time") + } + + u.Scheme = c.baseURL.Scheme + u.Host = c.baseURL.Host + u.Path = path.Join(c.baseURL.Path, u.Path) + maps.Copy(query, c.baseURL.Query()) + } + + // change the url query maps.Copy(query, c.commonQueryParams) - req.URL.RawQuery = query.Encode() + maps.Copy(query, u.Query()) + u.RawQuery = query.Encode() - resp, err := c.httpClient.Do(req) + return u, nil +} + +func (c *Client) finalReq(req *http.Request) (*http.Request, error) { + if c.err != nil { + return nil, c.err + } + + newReq := req.Clone(req.Context()) + + u, err := c.mergeURL(newReq.URL) if err != nil { return nil, err } - return resp, nil + newReq.URL = u + + // change the req header + header := make(http.Header) + maps.Copy(header, c.commonHeader) + maps.Copy(header, newReq.Header) + newReq.Header = header + + return newReq, nil } +// Do implement HTTPClient interface, same as the http.Client +func (c *Client) Do(req *http.Request) (*http.Response, error) { + newReq, err := c.finalReq(req) + if err != nil { + return nil, err + } + + httpClient := &http.Client{Transport: c.chain(), Timeout: c.timeout} + return httpClient.Do(newReq) +} + +// Request is a generic HTTP request with response type T type Request[T any] struct { + client HTTPClient req *http.Request - client *Client[T] + transports []func(http.RoundTripper) http.RoundTripper queryParams url.Values + body any + timeout time.Duration } -func (r *Request[T]) Clone() *Request[T] { - newReq := r.client.R() - newReq.req = r.req.Clone(r.req.Context()) - newReq.client = r.client - newReq.queryParams = r.queryParams - return newReq +// R create a new request with the default client and response type T +func R[T any]() *Request[T] { + req := &http.Request{Header: make(http.Header)} + req = req.WithContext(context.Background()) + + r := &Request[T]{ + client: NewClient(), // default client and transport + req: req, + queryParams: make(url.Values), + } + + return r } +// SetClient set the rest.Client or http.Client for the request. +func (r *Request[T]) SetClient(client HTTPClient) *Request[T] { + r.client = client + return r +} + +// With adds inline http.RoundTripper for the request. +func (r *Request[T]) With(rt ...func(http.RoundTripper) http.RoundTripper) *Request[T] { + r.transports = append(r.transports, rt...) + return r +} + +// SetTimeout set a timeout for the current request. +func (r *Request[T]) SetTimeout(timeout time.Duration) *Request[T] { + r.timeout = timeout + return r +} + +// SetHeader set the request header for the request. func (r *Request[T]) SetHeader(header, value string) *Request[T] { r.req.Header.Set(header, value) return r } +// SetQueryParam set the request header for the request. func (r *Request[T]) SetQueryParam(param, value string) *Request[T] { r.queryParams.Set(param, value) return r } +// SetContext set the request header for the request. func (r *Request[T]) SetContext(ctx context.Context) *Request[T] { r.req = r.req.WithContext(ctx) return r } -func (r *Request[T]) Get(url string) (*T, error) { - return r.Execute(http.MethodGet, url) +// SetBody set the request body for the request. +func (r *Request[T]) SetBody(body any) *Request[T] { + r.body = body + return r } -func (r *Request[T]) Execute(method string, url string) (*T, error) { - out := new(T) - rt := reflect.TypeFor[T]() +// Get execute a get request +func (r *Request[T]) Get(url string, args ...any) (*T, error) { + return r.Execute(http.MethodGet, url, args...) +} - newReq, err := http.NewRequestWithContext(r.req.Context(), method, url, r.req.Body) +// Post execute a post request +func (r *Request[T]) Post(url string, args ...any) (*T, error) { + return r.Execute(http.MethodPost, url, args...) +} + +// Put execute a put request +func (r *Request[T]) Put(url string, args ...any) (*T, error) { + return r.Execute(http.MethodPut, url, args...) +} + +// Patch execute a patch request +func (r *Request[T]) Patch(url string, args ...any) (*T, error) { + return r.Execute(http.MethodPatch, url, args...) +} + +// Delete execute a delete request +func (r *Request[T]) Delete(url string, args ...any) (*T, error) { + return r.Execute(http.MethodDelete, url, args...) +} + +// getBody get body reader or marshal reader +func (r *Request[T]) getBody() (io.Reader, error) { + if r.body == nil { + return nil, nil + } + + switch body := r.body.(type) { + case []byte: + return bytes.NewReader(body), nil + case string: + return strings.NewReader(body), nil + case io.Reader: + return body, nil + default: + // continue check kind below + } + + rv := reflect.ValueOf(r.body) + if rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + + switch rv.Kind() { + case reflect.Map, reflect.Slice, reflect.Struct: + body, err := json.Marshal(r.body) + if err != nil { + return nil, err + } + r.SetHeader("Content-Type", "application/json") + + return bytes.NewReader(body), nil + default: + // continue check kind below + } + + return nil, fmt.Errorf("%w body type: %T", ErrUnsupported, r.body) +} + +// getURL get the final url for the request. +func (r *Request[T]) getURL(u string) (string, error) { + newURL, err := url.Parse(u) + if err != nil { + return "", err + } + + // change the url query + query := newURL.Query() + maps.Copy(query, r.queryParams) + + newURL.RawQuery = query.Encode() + + return newURL.String(), nil +} + +// withStdHTTPClient chain with http Client and request transports +func (r *Request[T]) withStdHTTPClient(c *http.Client) HTTPClient { + rt := c.Transport + if rt == nil { + rt = http.DefaultTransport + } + + for i := len(r.transports) - 1; i >= 0; i-- { + rt = r.transports[i](rt) + } + + newC := *c + newC.Transport = rt + return &newC +} + +// withRestClient chain with rest client and request transports +func (r *Request[T]) withRestClient(c *Client) HTTPClient { + newTransports := make([]func(http.RoundTripper) http.RoundTripper, len(c.transports)) + copy(newTransports, c.transports) + + newC := *c + newC.transports = newTransports + + // append request transports + newC.With(r.transports...) + + return &newC +} + +func (r *Request[T]) getHTTPClient() HTTPClient { + switch c := r.client.(type) { + case *Client: + return r.withRestClient(c) + case *http.Client: + return r.withStdHTTPClient(c) + default: + // will ignore request transport + return c + } +} + +// Execute execute a request with method and url +func (r *Request[T]) Execute(method string, url string, args ...any) (*T, error) { + body, err := r.getBody() if err != nil { return nil, err } - // change a req header + // fix non-constant format string in call + if len(args) > 0 { + url = fmt.Sprintf(url, args...) + } + + finalURL, err := r.getURL(url) + if err != nil { + return nil, err + } + + newReq, err := http.NewRequestWithContext(r.req.Context(), method, finalURL, body) + if err != nil { + return nil, err + } + + // change the req header maps.Copy(newReq.Header, r.req.Header) - // change a req query - query := newReq.URL.Query() - maps.Copy(query, r.queryParams) - newReq.URL.RawQuery = query.Encode() + if r.timeout > 0 { + ctx, cancel := context.WithTimeout(newReq.Context(), r.timeout) + defer cancel() - resp, err := r.client.execute(newReq) + newReq = newReq.WithContext(ctx) + } + + client := r.getHTTPClient() + resp, err := client.Do(newReq) if err != nil { return nil, err } + return r.handleResp(resp) +} + +func (r *Request[T]) handleResp(resp *http.Response) (*T, error) { + out := new(T) + // io.ReadCloser 需要自己调用Close switch any(out).(type) { case *http.Response: return any(resp).(*T), nil case *io.ReadCloser: return any(&resp.Body).(*T), nil - case *Response: // for addition trace info, reuse body an so on - r := Response{resp: resp, trace: &TraceInfo{}} + case *Response: // for additional trace info, reuse body and so on + r := Response{Response: resp, receivedAt: time.Now()} return any(&r).(*T), nil + default: + // continue check kind below } // body 自动close defer resp.Body.Close() + // 需要读取整个body switch any(out).(type) { case *[]byte: - // 需要读取整个body - body, rErr := io.ReadAll(resp.Body) - if rErr != nil { - return nil, rErr + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err } return any(&body).(*T), nil case *string: - // 需要读取整个body - body, rErr := io.ReadAll(resp.Body) - if rErr != nil { - return nil, rErr + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err } v := string(body) return any(&v).(*T), nil case *io.Reader: - // 需要读取整个body - body, rErr := io.ReadAll(resp.Body) - if rErr != nil { - return nil, rErr + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err } v := io.Reader(bytes.NewBuffer(body)) return any(&v).(*T), nil - case *html.Node: - node, rErr := html.Parse(resp.Body) - if rErr != nil { - return nil, rErr - } - return any(node).(*T), nil + default: + // continue check kind below } - if rt.Kind() != reflect.Struct { - return nil, fmt.Errorf("unsupported type: %v", rt) - } + rt := reflect.TypeFor[T]() - // 按类型反序列化 - contentType := resp.Header.Get("Content-Type") - if strings.HasPrefix(contentType, "application/json") { - if err = json.UnmarshalRead(resp.Body, out); err != nil { - return nil, err - } - return out, nil - } - if strings.HasPrefix(contentType, "application/xml") { - // 需要读取整个body - body, rErr := io.ReadAll(resp.Body) - if rErr != nil { - return nil, rErr - } - if err = xml.Unmarshal(body, out); err != nil { + // 其他只处理json, 自定义序列化请使用[]byte或者io.Reader + switch rt.Kind() { + case reflect.Map, reflect.Slice, reflect.Struct: + if err := json.UnmarshalRead(resp.Body, out); err != nil { return nil, err } return out, nil + default: + // continue check kind below } - return nil, fmt.Errorf("unsupported content-type: %s", contentType) -} - -// TraceInfo struct is used to provide request trace info such as DNS lookup -// duration, Connection obtain duration, Server processing duration, etc. -type TraceInfo struct { - // DNSLookup is the duration that transport took to perform - // DNS lookup. - DNSLookup time.Duration `json:"dns_lookup_time"` - - // ConnTime is the duration it took to obtain a successful connection. - ConnTime time.Duration `json:"connection_time"` - - // TCPConnTime is the duration it took to obtain the TCP connection. - TCPConnTime time.Duration `json:"tcp_connection_time"` - - // TLSHandshake is the duration of the TLS handshake. - TLSHandshake time.Duration `json:"tls_handshake_time"` - - // ServerTime is the server's duration for responding to the first byte. - ServerTime time.Duration `json:"server_time"` - - // ResponseTime is the duration since the first response byte from the server to - // request completion. - ResponseTime time.Duration `json:"response_time"` - - // TotalTime is the duration of the total time request taken end-to-end. - TotalTime time.Duration `json:"total_time"` - - // IsConnReused is whether this connection has been previously - // used for another HTTP request. - IsConnReused bool `json:"is_connection_reused"` - - // IsConnWasIdle is whether this connection was obtained from an - // idle pool. - IsConnWasIdle bool `json:"is_connection_was_idle"` - - // ConnIdleTime is the duration how long the connection that was previously - // idle, if IsConnWasIdle is true. - ConnIdleTime time.Duration `json:"connection_idle_time"` - - // RequestAttempt is to represent the request attempt made during a Resty - // request execution flow, including retry count. - RequestAttempt int `json:"request_attempt"` - - // RemoteAddr returns the remote network address. - RemoteAddr string `json:"remote_address"` + return nil, fmt.Errorf("%w type: %v", ErrUnsupported, rt) } +// Response struct holds response values of executed requests. type Response struct { - trace *TraceInfo - resp *http.Response + *http.Response receivedAt time.Time } -func (r *Response) Duration() time.Duration { - return r.trace.TotalTime -} - // ReceivedAt method returns the time we received a response from the server for the request. func (r *Response) ReceivedAt() time.Time { return r.receivedAt diff --git a/rest/client_test.go b/rest/client_test.go index a3b7fdf..d94cac7 100644 --- a/rest/client_test.go +++ b/rest/client_test.go @@ -1,116 +1,406 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + package rest import ( - "encoding/json" + "bytes" + "context" + "encoding/json/v2" "fmt" "io" "net/http" + "net/http/httptest" + "net/http/httptrace" + "strconv" + "strings" + "sync" + "sync/atomic" "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "git.ifooth.com/common/pkg/rest/transport" - "github.com/stretchr/testify/assert" - "golang.org/x/net/html" ) -type Resp1 struct { +func testServer() *httptest.Server { + r := http.NewServeMux() + // 正常数据返回 + r.HandleFunc("GET /uuid", func(w http.ResponseWriter, r *http.Request) { + data := fmt.Sprintf(`{"uuid": "%s"}`, uuid.NewString()) + + // 原样返回query + for k, v := range r.URL.Query() { + w.Header().Set("query-"+k, strings.Join(v, ",")) + } + + w.Header().Set("content-type", "application/json") + w.Write([]byte(data)) + }) + + // 状态码测试 + r.HandleFunc("/status/{code}", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "text/html; charset=utf-8") + + code, err := strconv.ParseInt(r.PathValue("code"), 10, 64) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + w.WriteHeader(int(code)) + w.Write([]byte(r.PathValue("code"))) + + }) + + r.HandleFunc("/bkResult/ok", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + + w.WriteHeader(int(400)) + w.Write([]byte(`{"code": 0, "data": {"name": "ko"}}`)) + }) + + r.HandleFunc("/bkResult/failed", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + + w.WriteHeader(int(400)) + w.Write([]byte(`{"code": 100400, "data": null}`)) + + }) + + // 原样返回任何body数据 + r.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "text/html; charset=utf-8") + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // 原样返回query + for k, v := range r.URL.Query() { + w.Header().Set("query-"+k, strings.Join(v, ",")) + } + + w.Write(body) + }) + + ts := httptest.NewServer(r) + return ts +} + +type testResp struct { UUID string `json:"uuid"` } -type BkResult[T any] struct { - Code int - Data *T +type BkResp[T any] struct { + Code int `json:"code"` + Data *T `json:"data"` } -func getTitle(n *html.Node) string { - if n.Type == html.ElementNode && n.Data == "title" { - if n.FirstChild != nil { - return n.FirstChild.Data - } - } - for c := n.FirstChild; c != nil; c = c.NextSibling { - if title := getTitle(c); title != "" { - return title - } - } - return "" +type bkResult struct { + Name string `json:"name"` } -func TestGet(t *testing.T) { - resp, err := http.Get("http://127.0.0.1:38033/uuid") +type BkError struct { + Code int `json:"code"` +} + +func CustomC() *Client { + c := NewClient(). + SetHeader("abc", "abc"). + SetQueryParam("name", "ko"). + With(transport.CurlLog()) + + return c +} + +func TestClientUseAsDo(t *testing.T) { + svr := testServer() + defer svr.Close() + + req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata"))) + assert.NoError(t, err) + + // 用法1: Client使用Do处理请求 + client := NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set") + + resp, err := client.Do(req) + assert.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + + assert.Equal(t, []byte("testdata"), body) + assert.Equal(t, resp.Header.Get("query-ko"), "set") +} + +func TestClientUseAsTransport(t *testing.T) { + svr := testServer() + defer svr.Close() + + req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata"))) + assert.NoError(t, err) + + // 用法2: Client作为第三方http库的transport处理请求 + client := &http.Client{} + client.Transport = NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set").GetTransport() + + resp, err := client.Do(req) + assert.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + + assert.Equal(t, []byte("testdata"), body) + assert.Equal(t, resp.Header.Get("query-ko"), "set") +} + +func TestRequestUseHTTPClient(t *testing.T) { + svr := testServer() + defer svr.Close() + + resp, err := http.Get(svr.URL + "/uuid") assert.NoError(t, err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err) - fmt.Println(string(body)) - - r := new(Resp1) + r := new(testResp) err = json.Unmarshal(body, r) assert.NoError(t, err) - assert.True(t, len(r.UUID) == 36) + assert.Equal(t, len(r.UUID), 36) + + // 用法3: Request[T]使用原生HTTP Client + r1, err := R[string]().SetClient(http.DefaultClient).SetBody("testdata").Post(svr.URL + "/echo") + assert.NoError(t, err) + + assert.Equal(t, "testdata", *r1) + + r2, err := R[http.Response]().SetClient(http.DefaultClient).With(transport.CurlLog()).SetQueryParam("ko", "hi").SetBody("testdata").Post(svr.URL + "/echo") + assert.NoError(t, err) + body, err = io.ReadAll(r2.Body) + assert.NoError(t, err) + + assert.Equal(t, "testdata", string(body)) + assert.Equal(t, "hi", r2.Header.Get("query-ko")) } -func CustomC[T any]() *Client[T] { - c := NewClient[T]() - c.SetHeader("abc", "abc") - c.SetQueryParam("name", "ko") - return c -} +func TestRequestUseRestClient(t *testing.T) { + svr := testServer() + defer svr.Close() -func TestGetRest(t *testing.T) { - r, err := NewClient[Resp1]().R().Get("http://127.0.0.1:38033/uuid") + // 用法4: Request[T]使用rest.Client + r, err := R[testResp]().Get(svr.URL + "/uuid") + assert.NoError(t, err) + assert.Equal(t, len(r.UUID), 36) + + rb, err := R[[]byte]().Get(svr.URL + "/uuid") + require.NoError(t, err) + assert.Equal(t, len(*rb), 48) + + r3, err := R[io.Reader]().Get(svr.URL + "/uuid") assert.NoError(t, err) - fmt.Println(r) - assert.True(t, len(r.UUID) == 36) - - rb, err := NewClient[[]byte]().R().Get("http://127.0.0.1:38033/uuid") - assert.NoError(t, err) - - fmt.Println("abc", string(*rb)) - - r3, err := NewClient[io.Reader]().R().Get("http://127.0.0.1:38033/uuid") - assert.NoError(t, err) r31, err := io.ReadAll(*r3) assert.NoError(t, err) + assert.Equal(t, len(r31), 48) - fmt.Println("r31", string(r31)) - - r4, err := NewClient[io.ReadCloser]().R().Get("http://127.0.0.1:38033/uuid") + r4, err := R[io.ReadCloser]().Get(svr.URL + "/uuid") assert.NoError(t, err) r41, err := io.ReadAll(*r4) assert.NoError(t, err) (*r4).Close() + assert.Equal(t, len(r41), 48) - fmt.Println("r41", string(r41)) - - ts := transport.Chain(http.DefaultTransport, - transport.RequestID, - transport.CurlLogTransport(map[string]string{}), - ) - - r5, err := CustomC[http.Response]().SetTransport(ts).R().SetHeader("ko", "abc"). + r5, err := R[http.Response]().SetClient(CustomC()). + SetContext(t.Context()). + SetHeader("ko", "abc"). SetQueryParam("age", "12"). - Get("http://127.0.0.1:38033/uuid?dd=abc") + Get(svr.URL + "/uuid?dd=abc") assert.NoError(t, err) r51, err := io.ReadAll(r5.Body) assert.NoError(t, err) r5.Body.Close() - - fmt.Println("r51", string(r51)) - - r6, err := NewClient[html.Node]().R().Get("http://127.0.0.1:38033") - assert.NoError(t, err) - title := getTitle(r6) - assert.Equal(t, "httpbin.org", title) - + assert.Equal(t, len(r51), 48) + assert.Equal(t, "12", r5.Header.Get("query-age")) } -func TestGetRestError(t *testing.T) { - r, err := NewClient[int]().R().Get("http://127.0.0.1:38033/uuid") +func TestClientReuseConn(t *testing.T) { + svr := testServer() + defer svr.Close() + + t.Logf("svrAddr: %s", svr.URL) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*2) + defer cancel() + + var ( + wg sync.WaitGroup + reqCount atomic.Int64 + newConn atomic.Int64 + ) + + ctx = httptrace.WithClientTrace( + ctx, + &httptrace.ClientTrace{ConnectStart: func(_ string, _ string) { + newConn.Add(1) + }}, + ) + + concurrent := int64(10) + for range concurrent { + wg.Go(func() { + for { + reqCount.Add(1) + + r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid") + if err != nil { + assert.ErrorIs(t, err, context.DeadlineExceeded) + return + } + + assert.True(t, len(*r) > 10) + time.Sleep(time.Millisecond * 10) + } + }) + } + + wg.Wait() + + t.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load()) + assert.Less(t, newConn.Load(), concurrent*2) +} + +func BenchmarkClientReuseConn(b *testing.B) { + svr := testServer() + defer svr.Close() + + b.Logf("addr: %s", svr.URL) + + var ( + reqCount atomic.Int64 + newConn atomic.Int64 + ) + + ctx := httptrace.WithClientTrace( + b.Context(), + &httptrace.ClientTrace{ConnectStart: func(_ string, _ string) { + newConn.Add(1) + }}, + ) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + reqCount.Add(1) + + r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid") + if err != nil { + assert.ErrorIs(b, err, context.Canceled) + return + } + + assert.NoError(b, err) + assert.True(b, len(*r) > 10) + time.Sleep(time.Millisecond * 10) + } + }) + + b.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load()) + assert.Less(b, newConn.Load(), int64(20)) +} + +func TestClientBody(t *testing.T) { + svr := testServer() + defer svr.Close() + + r, err := R[string]().SetBody("test").Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, "test") + + r, err = R[string]().SetBody([]byte("test")).Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, "test") + + r, err = R[string]().SetBody(strings.NewReader("test")).Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, "test") + + r, err = R[string]().SetBody(map[string]string{"name": "test"}).Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, `{"name":"test"}`) + + r, err = R[string]().SetBody([]int{1, 2, 3}).Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, `[1,2,3]`) + + r, err = R[string]().SetBody(testResp{UUID: "test"}).Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, `{"uuid":"test"}`) + + r, err = R[string]().SetBody(&testResp{UUID: "test"}).Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, `{"uuid":"test"}`) + + r, err = R[string]().SetBody(nil).Get(svr.URL + "/echo") + require.NoError(t, err) + assert.Equal(t, *r, "") + + r, err = R[string]().SetBody(fmt.Errorf("test")).Get(svr.URL + "/echo") assert.Error(t, err) - fmt.Println(err) + assert.Nil(t, r) + + r, err = R[string]().SetBody(123).Get(svr.URL + "/echo") + assert.ErrorIs(t, err, ErrUnsupported) assert.Nil(t, r) } + +func TestClientURL(t *testing.T) { + svr := testServer() + defer svr.Close() + + r, err := R[string]().SetClient(NewClient().SetBaseURL("")). + SetBody("test").Get("/echo") + assert.Error(t, err) + assert.Equal(t, "baseURL is required", err.Error()) + assert.Nil(t, r) + + r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL)). + SetBody("test").Get("/echo") + require.NoError(t, err) + assert.Equal(t, *r, "test") + + // send as json + r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL).With(transport.CurlLog())). + SetBody(map[string]string{"name": "ko"}).Get("/echo") + require.NoError(t, err) + assert.Equal(t, *r, "{\"name\":\"ko\"}") + + r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc")). + SetBody("test").Get(svr.URL + "/echo") + assert.Error(t, err) + assert.Nil(t, r) + + r2, err := R[http.Response]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc?name=ko").With(transport.CurlLog())). + SetBody("test").Get("/echo?abc=ddd") + require.NoError(t, err) + assert.Equal(t, r2.StatusCode, http.StatusNotFound) +} diff --git a/rest/transport/curl.go b/rest/transport/curl.go index 70d292c..07b387c 100644 --- a/rest/transport/curl.go +++ b/rest/transport/curl.go @@ -1,3 +1,19 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + package transport import ( @@ -6,19 +22,34 @@ import ( "io" "log/slog" "net/http" + "strings" "time" ) +// CurlLog make a http curl log transport +func CurlLog(maskKeys ...string) func(http.RoundTripper) http.RoundTripper { + keys := make(map[string]struct{}, len(maskKeys)) + for _, key := range maskKeys { + keys[key] = struct{}{} + } + + return func(base http.RoundTripper) http.RoundTripper { + t := &curlLogTransport{baseTransport: base, maskKeys: keys} + return t + } +} + // reqToCurl curl 格式的请求日志 func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) { + headers := strings.Builder{} + // 过滤掉敏感信息, header 和 query - headers := "" for key, values := range r.Header { for _, value := range values { if _, ok := c.maskKeys[key]; ok { value = "***" } - headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value)) + headers.WriteString(fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value))) } } @@ -31,7 +62,7 @@ func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) { } rawURL.RawQuery = queryValue.Encode() - reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers) + reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers.String()) if r.Body != nil { bodyBytes, err := io.ReadAll(r.Body) if err != nil { @@ -52,7 +83,7 @@ func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) { } // respToCurl 返回日志 -func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) { +func (c *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) { var ( bodyBytes []byte err error @@ -68,59 +99,50 @@ func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string } if len(bodyBytes) > 1024 { - respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)", - resp.Status, time.Since(st), bodyBytes[:1024], len(bodyBytes)) + respMsg := fmt.Sprintf("[%d] %s %s...(Total %dB)", resp.StatusCode, time.Since(st), bodyBytes[:1024], + len(bodyBytes)) return respMsg, nil } if len(bodyBytes) > 0 { - respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes) + respMsg := fmt.Sprintf("[%d] %s %s", resp.StatusCode, time.Since(st), bodyBytes) return respMsg, nil } - respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st)) + respMsg := fmt.Sprintf("[%d] %s", resp.StatusCode, time.Since(st)) return respMsg, nil } // curlLogTransport print curl log transport type curlLogTransport struct { - maskKeys map[string]string + baseTransport http.RoundTripper + maskKeys map[string]struct{} } // RoundTrip curlLog Transport -func (t *curlLogTransport) Transport(next http.RoundTripper) http.RoundTripper { - f := func(req *http.Request) (*http.Response, error) { - logger := slog.With("request_id", req.Header.Get("X-Request-Id")) - st := time.Now() +func (c *curlLogTransport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + st := time.Now() - // 记录请求 - rbody, err := t.reqToCurl(req) - if err != nil { - return nil, err - } - logger.Info("curl", "REQ", rbody) - - resp, err := next.RoundTrip(req) - if err != nil { - logger.Error("curl", "RESP: [err]", err) - return nil, err - } - - // 记录返回 - respBody, err := t.respToCurl(resp, st) - if err != nil { - return nil, err - } - logger.Info("curl", "RESP", respBody) - - return resp, nil + // 记录请求 + rbody, err := c.reqToCurl(req) + if err != nil { + return nil, err } - return RoundTripFunc(f) + slog.InfoContext(ctx, "curl REQ: "+rbody) -} + resp, err := c.baseTransport.RoundTrip(req) + if err != nil { + slog.ErrorContext(ctx, "curl RESP: [err] "+err.Error()) + return nil, err + } -// CurlLogTransport make a new curl log transport, default transport can be nil -func CurlLogTransport(maskkeys map[string]string) func(http.RoundTripper) http.RoundTripper { - t := curlLogTransport{map[string]string{}} - return t.Transport + // 记录返回 + respBody, err := c.respToCurl(resp, st) + if err != nil { + return nil, err + } + slog.InfoContext(ctx, "curl RESP: "+respBody) + + return resp, nil } diff --git a/rest/transport/transport.go b/rest/transport/transport.go index aace24c..523d220 100644 --- a/rest/transport/transport.go +++ b/rest/transport/transport.go @@ -1,62 +1,13 @@ +// Package transport provides common HTTP client middleware(aka transport) components. package transport import "net/http" -// RoundTripFunc, similar to http.HandlerFunc, is an adapter +// RoundTripFunc similar to http.HandlerFunc, is an adapter // to allow the use of ordinary functions as http.RoundTrippers. -type RoundTripFunc func(r *http.Request) (*http.Response, error) +type RoundTripFunc func(*http.Request) (*http.Response, error) +// RoundTrip ... func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } - -// Chain wraps given base RoundTripper, which is used to make HTTP requests -// (e.g. http.DefaultTransport) with RoundTripper middlewares. -func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripper) *chain { - if base == nil { - base = http.DefaultTransport - } - - // Filter out nil transports. - mws := []func(http.RoundTripper) http.RoundTripper{} - for _, fn := range mw { - if fn != nil { - mws = append(mws, fn) - } - } - - if c, ok := base.(*chain); ok { - c.middlewares = append(c.middlewares, mws...) - return c - } - - return &chain{ - baseTransport: base, - middlewares: mws, - } -} - -type chain struct { - baseTransport http.RoundTripper - middlewares []func(http.RoundTripper) http.RoundTripper -} - -func (c *chain) RoundTrip(req *http.Request) (*http.Response, error) { - rt := c.baseTransport - - // Apply middlewares in reversed order so the first middleware becomes - // the innermost onion layer and the last becomes the outermost. Example: - // Given - // [Auth, VCTraceID, Debug], - // the middlewares are applied in this order: - // rt = Debug(rt) - // rt = VCTraceID(rt) - // rt = Auth(rt) - // The Auth and VCTraceID are called before the Debug middleware, - // which can then see the final request headers, as seen by http.DefaultTransport. - for i := len(c.middlewares) - 1; i >= 0; i-- { - rt = c.middlewares[i](rt) - } - - return rt.RoundTrip(req) -}