add client

main
git 2025-12-06 17:36:29 +08:00
parent 7861f56eb5
commit 4adb23c247
Signed by: git
GPG Key ID: 3F65EFFA44207ADD
4 changed files with 829 additions and 310 deletions

View File

@ -1,275 +1,531 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest package rest
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json/v2" "encoding/json/v2"
"encoding/xml" "errors"
"fmt" "fmt"
"io" "io"
"maps" "maps"
"net"
"net/http" "net/http"
"net/url" "net/url"
"path"
"reflect" "reflect"
"strings" "strings"
"time" "time"
"golang.org/x/net/html" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"git.ifooth.com/common/pkg/rest/transport"
"git.ifooth.com/common/pkg/version"
) )
type Client[T any] struct { var (
httpClient *http.Client // ErrUnsupported is an error to check supported type or content-type
commonReq *http.Request ErrUnsupported = errors.New("unsupported")
commonFormData url.Values userAgent = fmt.Sprintf("pkg/%s", version.Version)
)
// defaultTransport default transport, same as http.DefaultTransport but change MaxIdleConns / MaxIdleConnsPerHost
var defaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 200, // 最大空闲连接数
MaxIdleConnsPerHost: 20, // 每个主机的最大空闲连接数
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
// HTTPClient a http.Client or rest.Client Do interface
type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}
// Client is a HTTP client with common url, header, params and transports.
type Client struct {
baseTransport http.RoundTripper
transports []func(http.RoundTripper) http.RoundTripper
baseURL *url.URL
timeout time.Duration
commonHeader http.Header
commonQueryParams url.Values commonQueryParams url.Values
commonCookies []*http.Cookie err error
} }
func NewClient[T any]() *Client[T] { // NewClient creates a new client
r := &Client[T]{ func NewClient() *Client {
commonReq: &http.Request{Header: make(http.Header)}, commonHeader := make(http.Header)
httpClient: &http.Client{}, commonHeader.Set("User-Agent", userAgent)
commonQueryParams: url.Values{},
// 默认统一添加trace
baseTransport := otelhttp.NewTransport(defaultTransport)
c := &Client{
baseTransport: baseTransport,
commonQueryParams: make(url.Values),
commonHeader: commonHeader,
} }
return r
}
func (c *Client[T]) SetHeader(header, value string) *Client[T] {
c.commonReq.Header.Set(header, value)
return c return c
} }
func (c *Client[T]) SetQueryParam(param, value string) *Client[T] { func (c *Client) parseBaseURL(baseURL string) (*url.URL, error) {
if baseURL == "" {
return nil, fmt.Errorf("baseURL is required")
}
u, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("baseURL %s not valid, err: %w", baseURL, err)
}
if u.Host == "" || u.Scheme == "" {
return nil, fmt.Errorf("baseURL %s not valid, host/schema is required", baseURL)
}
return u, nil
}
// SetBaseURL sets the base URL for the client.
func (c *Client) SetBaseURL(baseURL string) *Client {
c.baseURL, c.err = c.parseBaseURL(baseURL)
return c
}
// SetTimeout set a timeout for the client.
func (c *Client) SetTimeout(timeout time.Duration) *Client {
c.timeout = timeout
return c
}
// SetHeader set the common header for the client.
func (c *Client) SetHeader(header, value string) *Client {
c.commonHeader.Set(header, value)
return c
}
// SetQueryParam set the common query param for the client.
func (c *Client) SetQueryParam(param, value string) *Client {
c.commonQueryParams.Set(param, value) c.commonQueryParams.Set(param, value)
return c return c
} }
func (c *Client[T]) SetTransport(transport http.RoundTripper) *Client[T] { // SetTransport set the base transport for the client.
func (c *Client) SetTransport(transport http.RoundTripper) *Client {
if transport != nil { if transport != nil {
c.httpClient.Transport = transport c.baseTransport = transport
} }
return c return c
} }
func (c *Client[T]) R() *Request[T] { // GetTransport return the http.RoundTripper by chaining the transports for client.
req := &http.Request{ func (c *Client) GetTransport() http.RoundTripper {
Method: http.MethodGet, // freeze transports
Proto: "HTTP/1.1", next := c.chain()
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
}
req = req.WithContext(context.Background())
r := &Request[T]{ rt := func(req *http.Request) (*http.Response, error) {
req: req, newReq, err := c.finalReq(req)
client: c, if err != nil {
queryParams: url.Values{}, return nil, err
}
return next.RoundTrip(newReq)
} }
return r
return transport.RoundTripFunc(rt)
} }
func (c *Client[T]) execute(req *http.Request) (*http.Response, error) { // With adds inline http.RoundTripper for an base RoundTripper for the client.
// change header func (c *Client) With(rt ...func(http.RoundTripper) http.RoundTripper) *Client {
maps.Copy(req.Header, c.commonReq.Header) c.transports = append(c.transports, rt...)
return c
}
// change query func (c *Client) chain() http.RoundTripper {
query := req.URL.Query() rt := c.baseTransport
for i := len(c.transports) - 1; i >= 0; i-- {
rt = c.transports[i](rt)
}
return rt
}
// mergeURL return the merged url with common header and params
func (c *Client) mergeURL(u *url.URL) (*url.URL, error) {
query := make(url.Values)
if c.baseURL != nil {
if u.Host != "" || u.Scheme != "" {
return nil, fmt.Errorf("baseURL and request url host/schema can not be set at the same time")
}
u.Scheme = c.baseURL.Scheme
u.Host = c.baseURL.Host
u.Path = path.Join(c.baseURL.Path, u.Path)
maps.Copy(query, c.baseURL.Query())
}
// change the url query
maps.Copy(query, c.commonQueryParams) maps.Copy(query, c.commonQueryParams)
req.URL.RawQuery = query.Encode() maps.Copy(query, u.Query())
u.RawQuery = query.Encode()
resp, err := c.httpClient.Do(req) return u, nil
}
func (c *Client) finalReq(req *http.Request) (*http.Request, error) {
if c.err != nil {
return nil, c.err
}
newReq := req.Clone(req.Context())
u, err := c.mergeURL(newReq.URL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return resp, nil newReq.URL = u
// change the req header
header := make(http.Header)
maps.Copy(header, c.commonHeader)
maps.Copy(header, newReq.Header)
newReq.Header = header
return newReq, nil
} }
// Do implement HTTPClient interface, same as the http.Client
func (c *Client) Do(req *http.Request) (*http.Response, error) {
newReq, err := c.finalReq(req)
if err != nil {
return nil, err
}
httpClient := &http.Client{Transport: c.chain(), Timeout: c.timeout}
return httpClient.Do(newReq)
}
// Request is a generic HTTP request with response type T
type Request[T any] struct { type Request[T any] struct {
client HTTPClient
req *http.Request req *http.Request
client *Client[T] transports []func(http.RoundTripper) http.RoundTripper
queryParams url.Values queryParams url.Values
body any
timeout time.Duration
} }
func (r *Request[T]) Clone() *Request[T] { // R create a new request with the default client and response type T
newReq := r.client.R() func R[T any]() *Request[T] {
newReq.req = r.req.Clone(r.req.Context()) req := &http.Request{Header: make(http.Header)}
newReq.client = r.client req = req.WithContext(context.Background())
newReq.queryParams = r.queryParams
return newReq r := &Request[T]{
client: NewClient(), // default client and transport
req: req,
queryParams: make(url.Values),
}
return r
} }
// SetClient set the rest.Client or http.Client for the request.
func (r *Request[T]) SetClient(client HTTPClient) *Request[T] {
r.client = client
return r
}
// With adds inline http.RoundTripper for the request.
func (r *Request[T]) With(rt ...func(http.RoundTripper) http.RoundTripper) *Request[T] {
r.transports = append(r.transports, rt...)
return r
}
// SetTimeout set a timeout for the current request.
func (r *Request[T]) SetTimeout(timeout time.Duration) *Request[T] {
r.timeout = timeout
return r
}
// SetHeader set the request header for the request.
func (r *Request[T]) SetHeader(header, value string) *Request[T] { func (r *Request[T]) SetHeader(header, value string) *Request[T] {
r.req.Header.Set(header, value) r.req.Header.Set(header, value)
return r return r
} }
// SetQueryParam set the request header for the request.
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] { func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
r.queryParams.Set(param, value) r.queryParams.Set(param, value)
return r return r
} }
// SetContext set the request header for the request.
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] { func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
r.req = r.req.WithContext(ctx) r.req = r.req.WithContext(ctx)
return r return r
} }
func (r *Request[T]) Get(url string) (*T, error) { // SetBody set the request body for the request.
return r.Execute(http.MethodGet, url) func (r *Request[T]) SetBody(body any) *Request[T] {
r.body = body
return r
} }
func (r *Request[T]) Execute(method string, url string) (*T, error) { // Get execute a get request
out := new(T) func (r *Request[T]) Get(url string, args ...any) (*T, error) {
rt := reflect.TypeFor[T]() return r.Execute(http.MethodGet, url, args...)
}
newReq, err := http.NewRequestWithContext(r.req.Context(), method, url, r.req.Body) // Post execute a post request
func (r *Request[T]) Post(url string, args ...any) (*T, error) {
return r.Execute(http.MethodPost, url, args...)
}
// Put execute a put request
func (r *Request[T]) Put(url string, args ...any) (*T, error) {
return r.Execute(http.MethodPut, url, args...)
}
// Patch execute a patch request
func (r *Request[T]) Patch(url string, args ...any) (*T, error) {
return r.Execute(http.MethodPatch, url, args...)
}
// Delete execute a delete request
func (r *Request[T]) Delete(url string, args ...any) (*T, error) {
return r.Execute(http.MethodDelete, url, args...)
}
// getBody get body reader or marshal reader
func (r *Request[T]) getBody() (io.Reader, error) {
if r.body == nil {
return nil, nil
}
switch body := r.body.(type) {
case []byte:
return bytes.NewReader(body), nil
case string:
return strings.NewReader(body), nil
case io.Reader:
return body, nil
default:
// continue check kind below
}
rv := reflect.ValueOf(r.body)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
switch rv.Kind() {
case reflect.Map, reflect.Slice, reflect.Struct:
body, err := json.Marshal(r.body)
if err != nil {
return nil, err
}
r.SetHeader("Content-Type", "application/json")
return bytes.NewReader(body), nil
default:
// continue check kind below
}
return nil, fmt.Errorf("%w body type: %T", ErrUnsupported, r.body)
}
// getURL get the final url for the request.
func (r *Request[T]) getURL(u string) (string, error) {
newURL, err := url.Parse(u)
if err != nil {
return "", err
}
// change the url query
query := newURL.Query()
maps.Copy(query, r.queryParams)
newURL.RawQuery = query.Encode()
return newURL.String(), nil
}
// withStdHTTPClient chain with http Client and request transports
func (r *Request[T]) withStdHTTPClient(c *http.Client) HTTPClient {
rt := c.Transport
if rt == nil {
rt = http.DefaultTransport
}
for i := len(r.transports) - 1; i >= 0; i-- {
rt = r.transports[i](rt)
}
newC := *c
newC.Transport = rt
return &newC
}
// withRestClient chain with rest client and request transports
func (r *Request[T]) withRestClient(c *Client) HTTPClient {
newTransports := make([]func(http.RoundTripper) http.RoundTripper, len(c.transports))
copy(newTransports, c.transports)
newC := *c
newC.transports = newTransports
// append request transports
newC.With(r.transports...)
return &newC
}
func (r *Request[T]) getHTTPClient() HTTPClient {
switch c := r.client.(type) {
case *Client:
return r.withRestClient(c)
case *http.Client:
return r.withStdHTTPClient(c)
default:
// will ignore request transport
return c
}
}
// Execute execute a request with method and url
func (r *Request[T]) Execute(method string, url string, args ...any) (*T, error) {
body, err := r.getBody()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// change a req header // fix non-constant format string in call
if len(args) > 0 {
url = fmt.Sprintf(url, args...)
}
finalURL, err := r.getURL(url)
if err != nil {
return nil, err
}
newReq, err := http.NewRequestWithContext(r.req.Context(), method, finalURL, body)
if err != nil {
return nil, err
}
// change the req header
maps.Copy(newReq.Header, r.req.Header) maps.Copy(newReq.Header, r.req.Header)
// change a req query if r.timeout > 0 {
query := newReq.URL.Query() ctx, cancel := context.WithTimeout(newReq.Context(), r.timeout)
maps.Copy(query, r.queryParams) defer cancel()
newReq.URL.RawQuery = query.Encode()
resp, err := r.client.execute(newReq) newReq = newReq.WithContext(ctx)
}
client := r.getHTTPClient()
resp, err := client.Do(newReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.handleResp(resp)
}
func (r *Request[T]) handleResp(resp *http.Response) (*T, error) {
out := new(T)
// io.ReadCloser 需要自己调用Close // io.ReadCloser 需要自己调用Close
switch any(out).(type) { switch any(out).(type) {
case *http.Response: case *http.Response:
return any(resp).(*T), nil return any(resp).(*T), nil
case *io.ReadCloser: case *io.ReadCloser:
return any(&resp.Body).(*T), nil return any(&resp.Body).(*T), nil
case *Response: // for addition trace info, reuse body an so on case *Response: // for additional trace info, reuse body and so on
r := Response{resp: resp, trace: &TraceInfo{}} r := Response{Response: resp, receivedAt: time.Now()}
return any(&r).(*T), nil return any(&r).(*T), nil
default:
// continue check kind below
} }
// body 自动close // body 自动close
defer resp.Body.Close() defer resp.Body.Close()
// 需要读取整个body
switch any(out).(type) { switch any(out).(type) {
case *[]byte: case *[]byte:
// 需要读取整个body body, err := io.ReadAll(resp.Body)
body, rErr := io.ReadAll(resp.Body) if err != nil {
if rErr != nil { return nil, err
return nil, rErr
} }
return any(&body).(*T), nil return any(&body).(*T), nil
case *string: case *string:
// 需要读取整个body body, err := io.ReadAll(resp.Body)
body, rErr := io.ReadAll(resp.Body) if err != nil {
if rErr != nil { return nil, err
return nil, rErr
} }
v := string(body) v := string(body)
return any(&v).(*T), nil return any(&v).(*T), nil
case *io.Reader: case *io.Reader:
// 需要读取整个body body, err := io.ReadAll(resp.Body)
body, rErr := io.ReadAll(resp.Body) if err != nil {
if rErr != nil { return nil, err
return nil, rErr
} }
v := io.Reader(bytes.NewBuffer(body)) v := io.Reader(bytes.NewBuffer(body))
return any(&v).(*T), nil return any(&v).(*T), nil
case *html.Node: default:
node, rErr := html.Parse(resp.Body) // continue check kind below
if rErr != nil {
return nil, rErr
}
return any(node).(*T), nil
} }
if rt.Kind() != reflect.Struct { rt := reflect.TypeFor[T]()
return nil, fmt.Errorf("unsupported type: %v", rt)
}
// 按类型反序列化 // 其他只处理json, 自定义序列化请使用[]byte或者io.Reader
contentType := resp.Header.Get("Content-Type") switch rt.Kind() {
if strings.HasPrefix(contentType, "application/json") { case reflect.Map, reflect.Slice, reflect.Struct:
if err = json.UnmarshalRead(resp.Body, out); err != nil { if err := json.UnmarshalRead(resp.Body, out); err != nil {
return nil, err
}
return out, nil
}
if strings.HasPrefix(contentType, "application/xml") {
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
if err = xml.Unmarshal(body, out); err != nil {
return nil, err return nil, err
} }
return out, nil return out, nil
default:
// continue check kind below
} }
return nil, fmt.Errorf("unsupported content-type: %s", contentType) return nil, fmt.Errorf("%w type: %v", ErrUnsupported, rt)
}
// TraceInfo struct is used to provide request trace info such as DNS lookup
// duration, Connection obtain duration, Server processing duration, etc.
type TraceInfo struct {
// DNSLookup is the duration that transport took to perform
// DNS lookup.
DNSLookup time.Duration `json:"dns_lookup_time"`
// ConnTime is the duration it took to obtain a successful connection.
ConnTime time.Duration `json:"connection_time"`
// TCPConnTime is the duration it took to obtain the TCP connection.
TCPConnTime time.Duration `json:"tcp_connection_time"`
// TLSHandshake is the duration of the TLS handshake.
TLSHandshake time.Duration `json:"tls_handshake_time"`
// ServerTime is the server's duration for responding to the first byte.
ServerTime time.Duration `json:"server_time"`
// ResponseTime is the duration since the first response byte from the server to
// request completion.
ResponseTime time.Duration `json:"response_time"`
// TotalTime is the duration of the total time request taken end-to-end.
TotalTime time.Duration `json:"total_time"`
// IsConnReused is whether this connection has been previously
// used for another HTTP request.
IsConnReused bool `json:"is_connection_reused"`
// IsConnWasIdle is whether this connection was obtained from an
// idle pool.
IsConnWasIdle bool `json:"is_connection_was_idle"`
// ConnIdleTime is the duration how long the connection that was previously
// idle, if IsConnWasIdle is true.
ConnIdleTime time.Duration `json:"connection_idle_time"`
// RequestAttempt is to represent the request attempt made during a Resty
// request execution flow, including retry count.
RequestAttempt int `json:"request_attempt"`
// RemoteAddr returns the remote network address.
RemoteAddr string `json:"remote_address"`
} }
// Response struct holds response values of executed requests.
type Response struct { type Response struct {
trace *TraceInfo *http.Response
resp *http.Response
receivedAt time.Time receivedAt time.Time
} }
func (r *Response) Duration() time.Duration {
return r.trace.TotalTime
}
// ReceivedAt method returns the time we received a response from the server for the request. // ReceivedAt method returns the time we received a response from the server for the request.
func (r *Response) ReceivedAt() time.Time { func (r *Response) ReceivedAt() time.Time {
return r.receivedAt return r.receivedAt

View File

@ -1,116 +1,406 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest package rest
import ( import (
"encoding/json" "bytes"
"context"
"encoding/json/v2"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"net/http/httptrace"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing" "testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.ifooth.com/common/pkg/rest/transport" "git.ifooth.com/common/pkg/rest/transport"
"github.com/stretchr/testify/assert"
"golang.org/x/net/html"
) )
type Resp1 struct { func testServer() *httptest.Server {
r := http.NewServeMux()
// 正常数据返回
r.HandleFunc("GET /uuid", func(w http.ResponseWriter, r *http.Request) {
data := fmt.Sprintf(`{"uuid": "%s"}`, uuid.NewString())
// 原样返回query
for k, v := range r.URL.Query() {
w.Header().Set("query-"+k, strings.Join(v, ","))
}
w.Header().Set("content-type", "application/json")
w.Write([]byte(data))
})
// 状态码测试
r.HandleFunc("/status/{code}", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "text/html; charset=utf-8")
code, err := strconv.ParseInt(r.PathValue("code"), 10, 64)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(int(code))
w.Write([]byte(r.PathValue("code")))
})
r.HandleFunc("/bkResult/ok", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(int(400))
w.Write([]byte(`{"code": 0, "data": {"name": "ko"}}`))
})
r.HandleFunc("/bkResult/failed", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(int(400))
w.Write([]byte(`{"code": 100400, "data": null}`))
})
// 原样返回任何body数据
r.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "text/html; charset=utf-8")
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// 原样返回query
for k, v := range r.URL.Query() {
w.Header().Set("query-"+k, strings.Join(v, ","))
}
w.Write(body)
})
ts := httptest.NewServer(r)
return ts
}
type testResp struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
} }
type BkResult[T any] struct { type BkResp[T any] struct {
Code int Code int `json:"code"`
Data *T Data *T `json:"data"`
} }
func getTitle(n *html.Node) string { type bkResult struct {
if n.Type == html.ElementNode && n.Data == "title" { Name string `json:"name"`
if n.FirstChild != nil {
return n.FirstChild.Data
}
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
if title := getTitle(c); title != "" {
return title
}
}
return ""
} }
func TestGet(t *testing.T) { type BkError struct {
resp, err := http.Get("http://127.0.0.1:38033/uuid") Code int `json:"code"`
}
func CustomC() *Client {
c := NewClient().
SetHeader("abc", "abc").
SetQueryParam("name", "ko").
With(transport.CurlLog())
return c
}
func TestClientUseAsDo(t *testing.T) {
svr := testServer()
defer svr.Close()
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
assert.NoError(t, err)
// 用法1: Client使用Do处理请求
client := NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set")
resp, err := client.Do(req)
assert.NoError(t, err)
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, []byte("testdata"), body)
assert.Equal(t, resp.Header.Get("query-ko"), "set")
}
func TestClientUseAsTransport(t *testing.T) {
svr := testServer()
defer svr.Close()
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
assert.NoError(t, err)
// 用法2: Client作为第三方http库的transport处理请求
client := &http.Client{}
client.Transport = NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set").GetTransport()
resp, err := client.Do(req)
assert.NoError(t, err)
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, []byte("testdata"), body)
assert.Equal(t, resp.Header.Get("query-ko"), "set")
}
func TestRequestUseHTTPClient(t *testing.T) {
svr := testServer()
defer svr.Close()
resp, err := http.Get(svr.URL + "/uuid")
assert.NoError(t, err) assert.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
assert.NoError(t, err) assert.NoError(t, err)
fmt.Println(string(body)) r := new(testResp)
r := new(Resp1)
err = json.Unmarshal(body, r) err = json.Unmarshal(body, r)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, len(r.UUID) == 36) assert.Equal(t, len(r.UUID), 36)
// 用法3: Request[T]使用原生HTTP Client
r1, err := R[string]().SetClient(http.DefaultClient).SetBody("testdata").Post(svr.URL + "/echo")
assert.NoError(t, err)
assert.Equal(t, "testdata", *r1)
r2, err := R[http.Response]().SetClient(http.DefaultClient).With(transport.CurlLog()).SetQueryParam("ko", "hi").SetBody("testdata").Post(svr.URL + "/echo")
assert.NoError(t, err)
body, err = io.ReadAll(r2.Body)
assert.NoError(t, err)
assert.Equal(t, "testdata", string(body))
assert.Equal(t, "hi", r2.Header.Get("query-ko"))
} }
func CustomC[T any]() *Client[T] { func TestRequestUseRestClient(t *testing.T) {
c := NewClient[T]() svr := testServer()
c.SetHeader("abc", "abc") defer svr.Close()
c.SetQueryParam("name", "ko")
return c
}
func TestGetRest(t *testing.T) { // 用法4: Request[T]使用rest.Client
r, err := NewClient[Resp1]().R().Get("http://127.0.0.1:38033/uuid") r, err := R[testResp]().Get(svr.URL + "/uuid")
assert.NoError(t, err)
assert.Equal(t, len(r.UUID), 36)
rb, err := R[[]byte]().Get(svr.URL + "/uuid")
require.NoError(t, err)
assert.Equal(t, len(*rb), 48)
r3, err := R[io.Reader]().Get(svr.URL + "/uuid")
assert.NoError(t, err) assert.NoError(t, err)
fmt.Println(r)
assert.True(t, len(r.UUID) == 36)
rb, err := NewClient[[]byte]().R().Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err)
fmt.Println("abc", string(*rb))
r3, err := NewClient[io.Reader]().R().Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err)
r31, err := io.ReadAll(*r3) r31, err := io.ReadAll(*r3)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, len(r31), 48)
fmt.Println("r31", string(r31)) r4, err := R[io.ReadCloser]().Get(svr.URL + "/uuid")
r4, err := NewClient[io.ReadCloser]().R().Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err) assert.NoError(t, err)
r41, err := io.ReadAll(*r4) r41, err := io.ReadAll(*r4)
assert.NoError(t, err) assert.NoError(t, err)
(*r4).Close() (*r4).Close()
assert.Equal(t, len(r41), 48)
fmt.Println("r41", string(r41)) r5, err := R[http.Response]().SetClient(CustomC()).
SetContext(t.Context()).
ts := transport.Chain(http.DefaultTransport, SetHeader("ko", "abc").
transport.RequestID,
transport.CurlLogTransport(map[string]string{}),
)
r5, err := CustomC[http.Response]().SetTransport(ts).R().SetHeader("ko", "abc").
SetQueryParam("age", "12"). SetQueryParam("age", "12").
Get("http://127.0.0.1:38033/uuid?dd=abc") Get(svr.URL + "/uuid?dd=abc")
assert.NoError(t, err) assert.NoError(t, err)
r51, err := io.ReadAll(r5.Body) r51, err := io.ReadAll(r5.Body)
assert.NoError(t, err) assert.NoError(t, err)
r5.Body.Close() r5.Body.Close()
assert.Equal(t, len(r51), 48)
fmt.Println("r51", string(r51)) assert.Equal(t, "12", r5.Header.Get("query-age"))
r6, err := NewClient[html.Node]().R().Get("http://127.0.0.1:38033")
assert.NoError(t, err)
title := getTitle(r6)
assert.Equal(t, "httpbin.org", title)
} }
func TestGetRestError(t *testing.T) { func TestClientReuseConn(t *testing.T) {
r, err := NewClient[int]().R().Get("http://127.0.0.1:38033/uuid") svr := testServer()
defer svr.Close()
t.Logf("svrAddr: %s", svr.URL)
ctx, cancel := context.WithTimeout(t.Context(), time.Second*2)
defer cancel()
var (
wg sync.WaitGroup
reqCount atomic.Int64
newConn atomic.Int64
)
ctx = httptrace.WithClientTrace(
ctx,
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
newConn.Add(1)
}},
)
concurrent := int64(10)
for range concurrent {
wg.Go(func() {
for {
reqCount.Add(1)
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
if err != nil {
assert.ErrorIs(t, err, context.DeadlineExceeded)
return
}
assert.True(t, len(*r) > 10)
time.Sleep(time.Millisecond * 10)
}
})
}
wg.Wait()
t.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
assert.Less(t, newConn.Load(), concurrent*2)
}
func BenchmarkClientReuseConn(b *testing.B) {
svr := testServer()
defer svr.Close()
b.Logf("addr: %s", svr.URL)
var (
reqCount atomic.Int64
newConn atomic.Int64
)
ctx := httptrace.WithClientTrace(
b.Context(),
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
newConn.Add(1)
}},
)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
reqCount.Add(1)
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
if err != nil {
assert.ErrorIs(b, err, context.Canceled)
return
}
assert.NoError(b, err)
assert.True(b, len(*r) > 10)
time.Sleep(time.Millisecond * 10)
}
})
b.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
assert.Less(b, newConn.Load(), int64(20))
}
func TestClientBody(t *testing.T) {
svr := testServer()
defer svr.Close()
r, err := R[string]().SetBody("test").Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
r, err = R[string]().SetBody([]byte("test")).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
r, err = R[string]().SetBody(strings.NewReader("test")).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
r, err = R[string]().SetBody(map[string]string{"name": "test"}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `{"name":"test"}`)
r, err = R[string]().SetBody([]int{1, 2, 3}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `[1,2,3]`)
r, err = R[string]().SetBody(testResp{UUID: "test"}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `{"uuid":"test"}`)
r, err = R[string]().SetBody(&testResp{UUID: "test"}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `{"uuid":"test"}`)
r, err = R[string]().SetBody(nil).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "")
r, err = R[string]().SetBody(fmt.Errorf("test")).Get(svr.URL + "/echo")
assert.Error(t, err) assert.Error(t, err)
fmt.Println(err) assert.Nil(t, r)
r, err = R[string]().SetBody(123).Get(svr.URL + "/echo")
assert.ErrorIs(t, err, ErrUnsupported)
assert.Nil(t, r) assert.Nil(t, r)
} }
func TestClientURL(t *testing.T) {
svr := testServer()
defer svr.Close()
r, err := R[string]().SetClient(NewClient().SetBaseURL("")).
SetBody("test").Get("/echo")
assert.Error(t, err)
assert.Equal(t, "baseURL is required", err.Error())
assert.Nil(t, r)
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL)).
SetBody("test").Get("/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
// send as json
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL).With(transport.CurlLog())).
SetBody(map[string]string{"name": "ko"}).Get("/echo")
require.NoError(t, err)
assert.Equal(t, *r, "{\"name\":\"ko\"}")
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc")).
SetBody("test").Get(svr.URL + "/echo")
assert.Error(t, err)
assert.Nil(t, r)
r2, err := R[http.Response]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc?name=ko").With(transport.CurlLog())).
SetBody("test").Get("/echo?abc=ddd")
require.NoError(t, err)
assert.Equal(t, r2.StatusCode, http.StatusNotFound)
}

View File

@ -1,3 +1,19 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package transport package transport
import ( import (
@ -6,19 +22,34 @@ import (
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"time" "time"
) )
// CurlLog make a http curl log transport
func CurlLog(maskKeys ...string) func(http.RoundTripper) http.RoundTripper {
keys := make(map[string]struct{}, len(maskKeys))
for _, key := range maskKeys {
keys[key] = struct{}{}
}
return func(base http.RoundTripper) http.RoundTripper {
t := &curlLogTransport{baseTransport: base, maskKeys: keys}
return t
}
}
// reqToCurl curl 格式的请求日志 // reqToCurl curl 格式的请求日志
func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) { func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
headers := strings.Builder{}
// 过滤掉敏感信息, header 和 query // 过滤掉敏感信息, header 和 query
headers := ""
for key, values := range r.Header { for key, values := range r.Header {
for _, value := range values { for _, value := range values {
if _, ok := c.maskKeys[key]; ok { if _, ok := c.maskKeys[key]; ok {
value = "***" value = "***"
} }
headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value)) headers.WriteString(fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value)))
} }
} }
@ -31,7 +62,7 @@ func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
} }
rawURL.RawQuery = queryValue.Encode() rawURL.RawQuery = queryValue.Encode()
reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers) reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers.String())
if r.Body != nil { if r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body) bodyBytes, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@ -52,7 +83,7 @@ func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
} }
// respToCurl 返回日志 // respToCurl 返回日志
func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) { func (c *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) {
var ( var (
bodyBytes []byte bodyBytes []byte
err error err error
@ -68,59 +99,50 @@ func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string
} }
if len(bodyBytes) > 1024 { if len(bodyBytes) > 1024 {
respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)", respMsg := fmt.Sprintf("[%d] %s %s...(Total %dB)", resp.StatusCode, time.Since(st), bodyBytes[:1024],
resp.Status, time.Since(st), bodyBytes[:1024], len(bodyBytes)) len(bodyBytes))
return respMsg, nil return respMsg, nil
} }
if len(bodyBytes) > 0 { if len(bodyBytes) > 0 {
respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes) respMsg := fmt.Sprintf("[%d] %s %s", resp.StatusCode, time.Since(st), bodyBytes)
return respMsg, nil return respMsg, nil
} }
respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st)) respMsg := fmt.Sprintf("[%d] %s", resp.StatusCode, time.Since(st))
return respMsg, nil return respMsg, nil
} }
// curlLogTransport print curl log transport // curlLogTransport print curl log transport
type curlLogTransport struct { type curlLogTransport struct {
maskKeys map[string]string baseTransport http.RoundTripper
maskKeys map[string]struct{}
} }
// RoundTrip curlLog Transport // RoundTrip curlLog Transport
func (t *curlLogTransport) Transport(next http.RoundTripper) http.RoundTripper { func (c *curlLogTransport) RoundTrip(req *http.Request) (*http.Response, error) {
f := func(req *http.Request) (*http.Response, error) { ctx := req.Context()
logger := slog.With("request_id", req.Header.Get("X-Request-Id")) st := time.Now()
st := time.Now()
// 记录请求 // 记录请求
rbody, err := t.reqToCurl(req) rbody, err := c.reqToCurl(req)
if err != nil { if err != nil {
return nil, err return nil, err
}
logger.Info("curl", "REQ", rbody)
resp, err := next.RoundTrip(req)
if err != nil {
logger.Error("curl", "RESP: [err]", err)
return nil, err
}
// 记录返回
respBody, err := t.respToCurl(resp, st)
if err != nil {
return nil, err
}
logger.Info("curl", "RESP", respBody)
return resp, nil
} }
return RoundTripFunc(f) slog.InfoContext(ctx, "curl REQ: "+rbody)
} resp, err := c.baseTransport.RoundTrip(req)
if err != nil {
slog.ErrorContext(ctx, "curl RESP: [err] "+err.Error())
return nil, err
}
// CurlLogTransport make a new curl log transport, default transport can be nil // 记录返回
func CurlLogTransport(maskkeys map[string]string) func(http.RoundTripper) http.RoundTripper { respBody, err := c.respToCurl(resp, st)
t := curlLogTransport{map[string]string{}} if err != nil {
return t.Transport return nil, err
}
slog.InfoContext(ctx, "curl RESP: "+respBody)
return resp, nil
} }

View File

@ -1,62 +1,13 @@
// Package transport provides common HTTP client middleware(aka transport) components.
package transport package transport
import "net/http" import "net/http"
// RoundTripFunc, similar to http.HandlerFunc, is an adapter // RoundTripFunc similar to http.HandlerFunc, is an adapter
// to allow the use of ordinary functions as http.RoundTrippers. // to allow the use of ordinary functions as http.RoundTrippers.
type RoundTripFunc func(r *http.Request) (*http.Response, error) type RoundTripFunc func(*http.Request) (*http.Response, error)
// RoundTrip ...
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req) return f(req)
} }
// Chain wraps given base RoundTripper, which is used to make HTTP requests
// (e.g. http.DefaultTransport) with RoundTripper middlewares.
func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripper) *chain {
if base == nil {
base = http.DefaultTransport
}
// Filter out nil transports.
mws := []func(http.RoundTripper) http.RoundTripper{}
for _, fn := range mw {
if fn != nil {
mws = append(mws, fn)
}
}
if c, ok := base.(*chain); ok {
c.middlewares = append(c.middlewares, mws...)
return c
}
return &chain{
baseTransport: base,
middlewares: mws,
}
}
type chain struct {
baseTransport http.RoundTripper
middlewares []func(http.RoundTripper) http.RoundTripper
}
func (c *chain) RoundTrip(req *http.Request) (*http.Response, error) {
rt := c.baseTransport
// Apply middlewares in reversed order so the first middleware becomes
// the innermost onion layer and the last becomes the outermost. Example:
// Given
// [Auth, VCTraceID, Debug],
// the middlewares are applied in this order:
// rt = Debug(rt)
// rt = VCTraceID(rt)
// rt = Auth(rt)
// The Auth and VCTraceID are called before the Debug middleware,
// which can then see the final request headers, as seen by http.DefaultTransport.
for i := len(c.middlewares) - 1; i >= 0; i-- {
rt = c.middlewares[i](rt)
}
return rt.RoundTrip(req)
}