add client
parent
7861f56eb5
commit
4adb23c247
570
rest/client.go
570
rest/client.go
|
|
@ -1,275 +1,531 @@
|
||||||
|
/*
|
||||||
|
* TencentBlueKing is pleased to support the open source community by making
|
||||||
|
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
|
||||||
|
* Copyright (C) 2025 Tencent. All rights reserved.
|
||||||
|
* Licensed under the MIT License (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on
|
||||||
|
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||||
|
* either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations under the License.
|
||||||
|
* We undertake not to change the open source license (MIT license) applicable
|
||||||
|
* to the current version of the project delivered to anyone in the future.
|
||||||
|
*/
|
||||||
|
|
||||||
package rest
|
package rest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json/v2"
|
"encoding/json/v2"
|
||||||
"encoding/xml"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"maps"
|
"maps"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/html"
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||||
|
|
||||||
|
"git.ifooth.com/common/pkg/rest/transport"
|
||||||
|
"git.ifooth.com/common/pkg/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client[T any] struct {
|
var (
|
||||||
httpClient *http.Client
|
// ErrUnsupported is an error to check supported type or content-type
|
||||||
commonReq *http.Request
|
ErrUnsupported = errors.New("unsupported")
|
||||||
commonFormData url.Values
|
userAgent = fmt.Sprintf("pkg/%s", version.Version)
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultTransport default transport, same as http.DefaultTransport but change MaxIdleConns / MaxIdleConnsPerHost
|
||||||
|
var defaultTransport = &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 200, // 最大空闲连接数
|
||||||
|
MaxIdleConnsPerHost: 20, // 每个主机的最大空闲连接数
|
||||||
|
IdleConnTimeout: 30 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient a http.Client or rest.Client Do interface
|
||||||
|
type HTTPClient interface {
|
||||||
|
Do(*http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client is a HTTP client with common url, header, params and transports.
|
||||||
|
type Client struct {
|
||||||
|
baseTransport http.RoundTripper
|
||||||
|
transports []func(http.RoundTripper) http.RoundTripper
|
||||||
|
baseURL *url.URL
|
||||||
|
timeout time.Duration
|
||||||
|
commonHeader http.Header
|
||||||
commonQueryParams url.Values
|
commonQueryParams url.Values
|
||||||
commonCookies []*http.Cookie
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient[T any]() *Client[T] {
|
// NewClient creates a new client
|
||||||
r := &Client[T]{
|
func NewClient() *Client {
|
||||||
commonReq: &http.Request{Header: make(http.Header)},
|
commonHeader := make(http.Header)
|
||||||
httpClient: &http.Client{},
|
commonHeader.Set("User-Agent", userAgent)
|
||||||
commonQueryParams: url.Values{},
|
|
||||||
}
|
// 默认统一添加trace
|
||||||
return r
|
baseTransport := otelhttp.NewTransport(defaultTransport)
|
||||||
|
|
||||||
|
c := &Client{
|
||||||
|
baseTransport: baseTransport,
|
||||||
|
commonQueryParams: make(url.Values),
|
||||||
|
commonHeader: commonHeader,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client[T]) SetHeader(header, value string) *Client[T] {
|
|
||||||
c.commonReq.Header.Set(header, value)
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client[T]) SetQueryParam(param, value string) *Client[T] {
|
func (c *Client) parseBaseURL(baseURL string) (*url.URL, error) {
|
||||||
|
if baseURL == "" {
|
||||||
|
return nil, fmt.Errorf("baseURL is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("baseURL %s not valid, err: %w", baseURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Host == "" || u.Scheme == "" {
|
||||||
|
return nil, fmt.Errorf("baseURL %s not valid, host/schema is required", baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBaseURL sets the base URL for the client.
|
||||||
|
func (c *Client) SetBaseURL(baseURL string) *Client {
|
||||||
|
c.baseURL, c.err = c.parseBaseURL(baseURL)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout set a timeout for the client.
|
||||||
|
func (c *Client) SetTimeout(timeout time.Duration) *Client {
|
||||||
|
c.timeout = timeout
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHeader set the common header for the client.
|
||||||
|
func (c *Client) SetHeader(header, value string) *Client {
|
||||||
|
c.commonHeader.Set(header, value)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQueryParam set the common query param for the client.
|
||||||
|
func (c *Client) SetQueryParam(param, value string) *Client {
|
||||||
c.commonQueryParams.Set(param, value)
|
c.commonQueryParams.Set(param, value)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client[T]) SetTransport(transport http.RoundTripper) *Client[T] {
|
// SetTransport set the base transport for the client.
|
||||||
|
func (c *Client) SetTransport(transport http.RoundTripper) *Client {
|
||||||
if transport != nil {
|
if transport != nil {
|
||||||
c.httpClient.Transport = transport
|
c.baseTransport = transport
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client[T]) R() *Request[T] {
|
// GetTransport return the http.RoundTripper by chaining the transports for client.
|
||||||
req := &http.Request{
|
func (c *Client) GetTransport() http.RoundTripper {
|
||||||
Method: http.MethodGet,
|
// freeze transports
|
||||||
Proto: "HTTP/1.1",
|
next := c.chain()
|
||||||
ProtoMajor: 1,
|
|
||||||
ProtoMinor: 1,
|
|
||||||
Header: make(http.Header),
|
|
||||||
}
|
|
||||||
req = req.WithContext(context.Background())
|
|
||||||
|
|
||||||
r := &Request[T]{
|
rt := func(req *http.Request) (*http.Response, error) {
|
||||||
req: req,
|
newReq, err := c.finalReq(req)
|
||||||
client: c,
|
|
||||||
queryParams: url.Values{},
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client[T]) execute(req *http.Request) (*http.Response, error) {
|
|
||||||
// change header
|
|
||||||
maps.Copy(req.Header, c.commonReq.Header)
|
|
||||||
|
|
||||||
// change query
|
|
||||||
query := req.URL.Query()
|
|
||||||
maps.Copy(query, c.commonQueryParams)
|
|
||||||
req.URL.RawQuery = query.Encode()
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return resp, nil
|
|
||||||
|
return next.RoundTrip(newReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return transport.RoundTripFunc(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With adds inline http.RoundTripper for an base RoundTripper for the client.
|
||||||
|
func (c *Client) With(rt ...func(http.RoundTripper) http.RoundTripper) *Client {
|
||||||
|
c.transports = append(c.transports, rt...)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) chain() http.RoundTripper {
|
||||||
|
rt := c.baseTransport
|
||||||
|
|
||||||
|
for i := len(c.transports) - 1; i >= 0; i-- {
|
||||||
|
rt = c.transports[i](rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rt
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeURL return the merged url with common header and params
|
||||||
|
func (c *Client) mergeURL(u *url.URL) (*url.URL, error) {
|
||||||
|
query := make(url.Values)
|
||||||
|
|
||||||
|
if c.baseURL != nil {
|
||||||
|
if u.Host != "" || u.Scheme != "" {
|
||||||
|
return nil, fmt.Errorf("baseURL and request url host/schema can not be set at the same time")
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Scheme = c.baseURL.Scheme
|
||||||
|
u.Host = c.baseURL.Host
|
||||||
|
u.Path = path.Join(c.baseURL.Path, u.Path)
|
||||||
|
maps.Copy(query, c.baseURL.Query())
|
||||||
|
}
|
||||||
|
|
||||||
|
// change the url query
|
||||||
|
maps.Copy(query, c.commonQueryParams)
|
||||||
|
maps.Copy(query, u.Query())
|
||||||
|
u.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) finalReq(req *http.Request) (*http.Request, error) {
|
||||||
|
if c.err != nil {
|
||||||
|
return nil, c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
newReq := req.Clone(req.Context())
|
||||||
|
|
||||||
|
u, err := c.mergeURL(newReq.URL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newReq.URL = u
|
||||||
|
|
||||||
|
// change the req header
|
||||||
|
header := make(http.Header)
|
||||||
|
maps.Copy(header, c.commonHeader)
|
||||||
|
maps.Copy(header, newReq.Header)
|
||||||
|
newReq.Header = header
|
||||||
|
|
||||||
|
return newReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do implement HTTPClient interface, same as the http.Client
|
||||||
|
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||||
|
newReq, err := c.finalReq(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := &http.Client{Transport: c.chain(), Timeout: c.timeout}
|
||||||
|
return httpClient.Do(newReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request is a generic HTTP request with response type T
|
||||||
type Request[T any] struct {
|
type Request[T any] struct {
|
||||||
|
client HTTPClient
|
||||||
req *http.Request
|
req *http.Request
|
||||||
client *Client[T]
|
transports []func(http.RoundTripper) http.RoundTripper
|
||||||
queryParams url.Values
|
queryParams url.Values
|
||||||
|
body any
|
||||||
|
timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Request[T]) Clone() *Request[T] {
|
// R create a new request with the default client and response type T
|
||||||
newReq := r.client.R()
|
func R[T any]() *Request[T] {
|
||||||
newReq.req = r.req.Clone(r.req.Context())
|
req := &http.Request{Header: make(http.Header)}
|
||||||
newReq.client = r.client
|
req = req.WithContext(context.Background())
|
||||||
newReq.queryParams = r.queryParams
|
|
||||||
return newReq
|
r := &Request[T]{
|
||||||
|
client: NewClient(), // default client and transport
|
||||||
|
req: req,
|
||||||
|
queryParams: make(url.Values),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetClient set the rest.Client or http.Client for the request.
|
||||||
|
func (r *Request[T]) SetClient(client HTTPClient) *Request[T] {
|
||||||
|
r.client = client
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// With adds inline http.RoundTripper for the request.
|
||||||
|
func (r *Request[T]) With(rt ...func(http.RoundTripper) http.RoundTripper) *Request[T] {
|
||||||
|
r.transports = append(r.transports, rt...)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout set a timeout for the current request.
|
||||||
|
func (r *Request[T]) SetTimeout(timeout time.Duration) *Request[T] {
|
||||||
|
r.timeout = timeout
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHeader set the request header for the request.
|
||||||
func (r *Request[T]) SetHeader(header, value string) *Request[T] {
|
func (r *Request[T]) SetHeader(header, value string) *Request[T] {
|
||||||
r.req.Header.Set(header, value)
|
r.req.Header.Set(header, value)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetQueryParam set the request header for the request.
|
||||||
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
|
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
|
||||||
r.queryParams.Set(param, value)
|
r.queryParams.Set(param, value)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetContext set the request header for the request.
|
||||||
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
|
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
|
||||||
r.req = r.req.WithContext(ctx)
|
r.req = r.req.WithContext(ctx)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Request[T]) Get(url string) (*T, error) {
|
// SetBody set the request body for the request.
|
||||||
return r.Execute(http.MethodGet, url)
|
func (r *Request[T]) SetBody(body any) *Request[T] {
|
||||||
|
r.body = body
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Request[T]) Execute(method string, url string) (*T, error) {
|
// Get execute a get request
|
||||||
out := new(T)
|
func (r *Request[T]) Get(url string, args ...any) (*T, error) {
|
||||||
rt := reflect.TypeFor[T]()
|
return r.Execute(http.MethodGet, url, args...)
|
||||||
|
}
|
||||||
|
|
||||||
newReq, err := http.NewRequestWithContext(r.req.Context(), method, url, r.req.Body)
|
// Post execute a post request
|
||||||
|
func (r *Request[T]) Post(url string, args ...any) (*T, error) {
|
||||||
|
return r.Execute(http.MethodPost, url, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put execute a put request
|
||||||
|
func (r *Request[T]) Put(url string, args ...any) (*T, error) {
|
||||||
|
return r.Execute(http.MethodPut, url, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patch execute a patch request
|
||||||
|
func (r *Request[T]) Patch(url string, args ...any) (*T, error) {
|
||||||
|
return r.Execute(http.MethodPatch, url, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete execute a delete request
|
||||||
|
func (r *Request[T]) Delete(url string, args ...any) (*T, error) {
|
||||||
|
return r.Execute(http.MethodDelete, url, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBody get body reader or marshal reader
|
||||||
|
func (r *Request[T]) getBody() (io.Reader, error) {
|
||||||
|
if r.body == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch body := r.body.(type) {
|
||||||
|
case []byte:
|
||||||
|
return bytes.NewReader(body), nil
|
||||||
|
case string:
|
||||||
|
return strings.NewReader(body), nil
|
||||||
|
case io.Reader:
|
||||||
|
return body, nil
|
||||||
|
default:
|
||||||
|
// continue check kind below
|
||||||
|
}
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(r.body)
|
||||||
|
if rv.Kind() == reflect.Pointer {
|
||||||
|
rv = rv.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Map, reflect.Slice, reflect.Struct:
|
||||||
|
body, err := json.Marshal(r.body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r.SetHeader("Content-Type", "application/json")
|
||||||
|
|
||||||
|
return bytes.NewReader(body), nil
|
||||||
|
default:
|
||||||
|
// continue check kind below
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("%w body type: %T", ErrUnsupported, r.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getURL get the final url for the request.
|
||||||
|
func (r *Request[T]) getURL(u string) (string, error) {
|
||||||
|
newURL, err := url.Parse(u)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// change the url query
|
||||||
|
query := newURL.Query()
|
||||||
|
maps.Copy(query, r.queryParams)
|
||||||
|
|
||||||
|
newURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
return newURL.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// withStdHTTPClient chain with http Client and request transports
|
||||||
|
func (r *Request[T]) withStdHTTPClient(c *http.Client) HTTPClient {
|
||||||
|
rt := c.Transport
|
||||||
|
if rt == nil {
|
||||||
|
rt = http.DefaultTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := len(r.transports) - 1; i >= 0; i-- {
|
||||||
|
rt = r.transports[i](rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
newC := *c
|
||||||
|
newC.Transport = rt
|
||||||
|
return &newC
|
||||||
|
}
|
||||||
|
|
||||||
|
// withRestClient chain with rest client and request transports
|
||||||
|
func (r *Request[T]) withRestClient(c *Client) HTTPClient {
|
||||||
|
newTransports := make([]func(http.RoundTripper) http.RoundTripper, len(c.transports))
|
||||||
|
copy(newTransports, c.transports)
|
||||||
|
|
||||||
|
newC := *c
|
||||||
|
newC.transports = newTransports
|
||||||
|
|
||||||
|
// append request transports
|
||||||
|
newC.With(r.transports...)
|
||||||
|
|
||||||
|
return &newC
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request[T]) getHTTPClient() HTTPClient {
|
||||||
|
switch c := r.client.(type) {
|
||||||
|
case *Client:
|
||||||
|
return r.withRestClient(c)
|
||||||
|
case *http.Client:
|
||||||
|
return r.withStdHTTPClient(c)
|
||||||
|
default:
|
||||||
|
// will ignore request transport
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute execute a request with method and url
|
||||||
|
func (r *Request[T]) Execute(method string, url string, args ...any) (*T, error) {
|
||||||
|
body, err := r.getBody()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// change a req header
|
// fix non-constant format string in call
|
||||||
|
if len(args) > 0 {
|
||||||
|
url = fmt.Sprintf(url, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalURL, err := r.getURL(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newReq, err := http.NewRequestWithContext(r.req.Context(), method, finalURL, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// change the req header
|
||||||
maps.Copy(newReq.Header, r.req.Header)
|
maps.Copy(newReq.Header, r.req.Header)
|
||||||
|
|
||||||
// change a req query
|
if r.timeout > 0 {
|
||||||
query := newReq.URL.Query()
|
ctx, cancel := context.WithTimeout(newReq.Context(), r.timeout)
|
||||||
maps.Copy(query, r.queryParams)
|
defer cancel()
|
||||||
newReq.URL.RawQuery = query.Encode()
|
|
||||||
|
|
||||||
resp, err := r.client.execute(newReq)
|
newReq = newReq.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := r.getHTTPClient()
|
||||||
|
resp, err := client.Do(newReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return r.handleResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request[T]) handleResp(resp *http.Response) (*T, error) {
|
||||||
|
out := new(T)
|
||||||
|
|
||||||
// io.ReadCloser 需要自己调用Close
|
// io.ReadCloser 需要自己调用Close
|
||||||
switch any(out).(type) {
|
switch any(out).(type) {
|
||||||
case *http.Response:
|
case *http.Response:
|
||||||
return any(resp).(*T), nil
|
return any(resp).(*T), nil
|
||||||
case *io.ReadCloser:
|
case *io.ReadCloser:
|
||||||
return any(&resp.Body).(*T), nil
|
return any(&resp.Body).(*T), nil
|
||||||
case *Response: // for addition trace info, reuse body an so on
|
case *Response: // for additional trace info, reuse body and so on
|
||||||
r := Response{resp: resp, trace: &TraceInfo{}}
|
r := Response{Response: resp, receivedAt: time.Now()}
|
||||||
return any(&r).(*T), nil
|
return any(&r).(*T), nil
|
||||||
|
default:
|
||||||
|
// continue check kind below
|
||||||
}
|
}
|
||||||
|
|
||||||
// body 自动close
|
// body 自动close
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 需要读取整个body
|
||||||
switch any(out).(type) {
|
switch any(out).(type) {
|
||||||
case *[]byte:
|
case *[]byte:
|
||||||
// 需要读取整个body
|
body, err := io.ReadAll(resp.Body)
|
||||||
body, rErr := io.ReadAll(resp.Body)
|
if err != nil {
|
||||||
if rErr != nil {
|
return nil, err
|
||||||
return nil, rErr
|
|
||||||
}
|
}
|
||||||
return any(&body).(*T), nil
|
return any(&body).(*T), nil
|
||||||
case *string:
|
case *string:
|
||||||
// 需要读取整个body
|
body, err := io.ReadAll(resp.Body)
|
||||||
body, rErr := io.ReadAll(resp.Body)
|
if err != nil {
|
||||||
if rErr != nil {
|
return nil, err
|
||||||
return nil, rErr
|
|
||||||
}
|
}
|
||||||
v := string(body)
|
v := string(body)
|
||||||
return any(&v).(*T), nil
|
return any(&v).(*T), nil
|
||||||
case *io.Reader:
|
case *io.Reader:
|
||||||
// 需要读取整个body
|
body, err := io.ReadAll(resp.Body)
|
||||||
body, rErr := io.ReadAll(resp.Body)
|
if err != nil {
|
||||||
if rErr != nil {
|
return nil, err
|
||||||
return nil, rErr
|
|
||||||
}
|
}
|
||||||
v := io.Reader(bytes.NewBuffer(body))
|
v := io.Reader(bytes.NewBuffer(body))
|
||||||
return any(&v).(*T), nil
|
return any(&v).(*T), nil
|
||||||
case *html.Node:
|
default:
|
||||||
node, rErr := html.Parse(resp.Body)
|
// continue check kind below
|
||||||
if rErr != nil {
|
|
||||||
return nil, rErr
|
|
||||||
}
|
|
||||||
return any(node).(*T), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rt.Kind() != reflect.Struct {
|
rt := reflect.TypeFor[T]()
|
||||||
return nil, fmt.Errorf("unsupported type: %v", rt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 按类型反序列化
|
// 其他只处理json, 自定义序列化请使用[]byte或者io.Reader
|
||||||
contentType := resp.Header.Get("Content-Type")
|
switch rt.Kind() {
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
case reflect.Map, reflect.Slice, reflect.Struct:
|
||||||
if err = json.UnmarshalRead(resp.Body, out); err != nil {
|
if err := json.UnmarshalRead(resp.Body, out); err != nil {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(contentType, "application/xml") {
|
|
||||||
// 需要读取整个body
|
|
||||||
body, rErr := io.ReadAll(resp.Body)
|
|
||||||
if rErr != nil {
|
|
||||||
return nil, rErr
|
|
||||||
}
|
|
||||||
if err = xml.Unmarshal(body, out); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
|
default:
|
||||||
|
// continue check kind below
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported content-type: %s", contentType)
|
return nil, fmt.Errorf("%w type: %v", ErrUnsupported, rt)
|
||||||
}
|
|
||||||
|
|
||||||
// TraceInfo struct is used to provide request trace info such as DNS lookup
|
|
||||||
// duration, Connection obtain duration, Server processing duration, etc.
|
|
||||||
type TraceInfo struct {
|
|
||||||
// DNSLookup is the duration that transport took to perform
|
|
||||||
// DNS lookup.
|
|
||||||
DNSLookup time.Duration `json:"dns_lookup_time"`
|
|
||||||
|
|
||||||
// ConnTime is the duration it took to obtain a successful connection.
|
|
||||||
ConnTime time.Duration `json:"connection_time"`
|
|
||||||
|
|
||||||
// TCPConnTime is the duration it took to obtain the TCP connection.
|
|
||||||
TCPConnTime time.Duration `json:"tcp_connection_time"`
|
|
||||||
|
|
||||||
// TLSHandshake is the duration of the TLS handshake.
|
|
||||||
TLSHandshake time.Duration `json:"tls_handshake_time"`
|
|
||||||
|
|
||||||
// ServerTime is the server's duration for responding to the first byte.
|
|
||||||
ServerTime time.Duration `json:"server_time"`
|
|
||||||
|
|
||||||
// ResponseTime is the duration since the first response byte from the server to
|
|
||||||
// request completion.
|
|
||||||
ResponseTime time.Duration `json:"response_time"`
|
|
||||||
|
|
||||||
// TotalTime is the duration of the total time request taken end-to-end.
|
|
||||||
TotalTime time.Duration `json:"total_time"`
|
|
||||||
|
|
||||||
// IsConnReused is whether this connection has been previously
|
|
||||||
// used for another HTTP request.
|
|
||||||
IsConnReused bool `json:"is_connection_reused"`
|
|
||||||
|
|
||||||
// IsConnWasIdle is whether this connection was obtained from an
|
|
||||||
// idle pool.
|
|
||||||
IsConnWasIdle bool `json:"is_connection_was_idle"`
|
|
||||||
|
|
||||||
// ConnIdleTime is the duration how long the connection that was previously
|
|
||||||
// idle, if IsConnWasIdle is true.
|
|
||||||
ConnIdleTime time.Duration `json:"connection_idle_time"`
|
|
||||||
|
|
||||||
// RequestAttempt is to represent the request attempt made during a Resty
|
|
||||||
// request execution flow, including retry count.
|
|
||||||
RequestAttempt int `json:"request_attempt"`
|
|
||||||
|
|
||||||
// RemoteAddr returns the remote network address.
|
|
||||||
RemoteAddr string `json:"remote_address"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Response struct holds response values of executed requests.
|
||||||
type Response struct {
|
type Response struct {
|
||||||
trace *TraceInfo
|
*http.Response
|
||||||
resp *http.Response
|
|
||||||
receivedAt time.Time
|
receivedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Response) Duration() time.Duration {
|
|
||||||
return r.trace.TotalTime
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceivedAt method returns the time we received a response from the server for the request.
|
// ReceivedAt method returns the time we received a response from the server for the request.
|
||||||
func (r *Response) ReceivedAt() time.Time {
|
func (r *Response) ReceivedAt() time.Time {
|
||||||
return r.receivedAt
|
return r.receivedAt
|
||||||
|
|
|
||||||
|
|
@ -1,116 +1,406 @@
|
||||||
|
/*
|
||||||
|
* TencentBlueKing is pleased to support the open source community by making
|
||||||
|
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
|
||||||
|
* Copyright (C) 2025 Tencent. All rights reserved.
|
||||||
|
* Licensed under the MIT License (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on
|
||||||
|
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||||
|
* either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations under the License.
|
||||||
|
* We undertake not to change the open source license (MIT license) applicable
|
||||||
|
* to the current version of the project delivered to anyone in the future.
|
||||||
|
*/
|
||||||
|
|
||||||
package rest
|
package rest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json/v2"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"git.ifooth.com/common/pkg/rest/transport"
|
"git.ifooth.com/common/pkg/rest/transport"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"golang.org/x/net/html"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Resp1 struct {
|
func testServer() *httptest.Server {
|
||||||
|
r := http.NewServeMux()
|
||||||
|
// 正常数据返回
|
||||||
|
r.HandleFunc("GET /uuid", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
data := fmt.Sprintf(`{"uuid": "%s"}`, uuid.NewString())
|
||||||
|
|
||||||
|
// 原样返回query
|
||||||
|
for k, v := range r.URL.Query() {
|
||||||
|
w.Header().Set("query-"+k, strings.Join(v, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("content-type", "application/json")
|
||||||
|
w.Write([]byte(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
// 状态码测试
|
||||||
|
r.HandleFunc("/status/{code}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("content-type", "text/html; charset=utf-8")
|
||||||
|
|
||||||
|
code, err := strconv.ParseInt(r.PathValue("code"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(int(code))
|
||||||
|
w.Write([]byte(r.PathValue("code")))
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
r.HandleFunc("/bkResult/ok", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("content-type", "application/json")
|
||||||
|
|
||||||
|
w.WriteHeader(int(400))
|
||||||
|
w.Write([]byte(`{"code": 0, "data": {"name": "ko"}}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
r.HandleFunc("/bkResult/failed", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("content-type", "application/json")
|
||||||
|
|
||||||
|
w.WriteHeader(int(400))
|
||||||
|
w.Write([]byte(`{"code": 100400, "data": null}`))
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
// 原样返回任何body数据
|
||||||
|
r.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("content-type", "text/html; charset=utf-8")
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原样返回query
|
||||||
|
for k, v := range r.URL.Query() {
|
||||||
|
w.Header().Set("query-"+k, strings.Join(v, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write(body)
|
||||||
|
})
|
||||||
|
|
||||||
|
ts := httptest.NewServer(r)
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
type testResp struct {
|
||||||
UUID string `json:"uuid"`
|
UUID string `json:"uuid"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type BkResult[T any] struct {
|
type BkResp[T any] struct {
|
||||||
Code int
|
Code int `json:"code"`
|
||||||
Data *T
|
Data *T `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTitle(n *html.Node) string {
|
type bkResult struct {
|
||||||
if n.Type == html.ElementNode && n.Data == "title" {
|
Name string `json:"name"`
|
||||||
if n.FirstChild != nil {
|
|
||||||
return n.FirstChild.Data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
|
||||||
if title := getTitle(c); title != "" {
|
|
||||||
return title
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGet(t *testing.T) {
|
type BkError struct {
|
||||||
resp, err := http.Get("http://127.0.0.1:38033/uuid")
|
Code int `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func CustomC() *Client {
|
||||||
|
c := NewClient().
|
||||||
|
SetHeader("abc", "abc").
|
||||||
|
SetQueryParam("name", "ko").
|
||||||
|
With(transport.CurlLog())
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientUseAsDo(t *testing.T) {
|
||||||
|
svr := testServer()
|
||||||
|
defer svr.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// 用法1: Client使用Do处理请求
|
||||||
|
client := NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("testdata"), body)
|
||||||
|
assert.Equal(t, resp.Header.Get("query-ko"), "set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientUseAsTransport(t *testing.T) {
|
||||||
|
svr := testServer()
|
||||||
|
defer svr.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// 用法2: Client作为第三方http库的transport处理请求
|
||||||
|
client := &http.Client{}
|
||||||
|
client.Transport = NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set").GetTransport()
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("testdata"), body)
|
||||||
|
assert.Equal(t, resp.Header.Get("query-ko"), "set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUseHTTPClient(t *testing.T) {
|
||||||
|
svr := testServer()
|
||||||
|
defer svr.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(svr.URL + "/uuid")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
fmt.Println(string(body))
|
r := new(testResp)
|
||||||
|
|
||||||
r := new(Resp1)
|
|
||||||
err = json.Unmarshal(body, r)
|
err = json.Unmarshal(body, r)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, len(r.UUID) == 36)
|
assert.Equal(t, len(r.UUID), 36)
|
||||||
|
|
||||||
|
// 用法3: Request[T]使用原生HTTP Client
|
||||||
|
r1, err := R[string]().SetClient(http.DefaultClient).SetBody("testdata").Post(svr.URL + "/echo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "testdata", *r1)
|
||||||
|
|
||||||
|
r2, err := R[http.Response]().SetClient(http.DefaultClient).With(transport.CurlLog()).SetQueryParam("ko", "hi").SetBody("testdata").Post(svr.URL + "/echo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
body, err = io.ReadAll(r2.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "testdata", string(body))
|
||||||
|
assert.Equal(t, "hi", r2.Header.Get("query-ko"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func CustomC[T any]() *Client[T] {
|
func TestRequestUseRestClient(t *testing.T) {
|
||||||
c := NewClient[T]()
|
svr := testServer()
|
||||||
c.SetHeader("abc", "abc")
|
defer svr.Close()
|
||||||
c.SetQueryParam("name", "ko")
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetRest(t *testing.T) {
|
// 用法4: Request[T]使用rest.Client
|
||||||
r, err := NewClient[Resp1]().R().Get("http://127.0.0.1:38033/uuid")
|
r, err := R[testResp]().Get(svr.URL + "/uuid")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, len(r.UUID), 36)
|
||||||
|
|
||||||
|
rb, err := R[[]byte]().Get(svr.URL + "/uuid")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, len(*rb), 48)
|
||||||
|
|
||||||
|
r3, err := R[io.Reader]().Get(svr.URL + "/uuid")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
fmt.Println(r)
|
|
||||||
assert.True(t, len(r.UUID) == 36)
|
|
||||||
|
|
||||||
rb, err := NewClient[[]byte]().R().Get("http://127.0.0.1:38033/uuid")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
fmt.Println("abc", string(*rb))
|
|
||||||
|
|
||||||
r3, err := NewClient[io.Reader]().R().Get("http://127.0.0.1:38033/uuid")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
r31, err := io.ReadAll(*r3)
|
r31, err := io.ReadAll(*r3)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, len(r31), 48)
|
||||||
|
|
||||||
fmt.Println("r31", string(r31))
|
r4, err := R[io.ReadCloser]().Get(svr.URL + "/uuid")
|
||||||
|
|
||||||
r4, err := NewClient[io.ReadCloser]().R().Get("http://127.0.0.1:38033/uuid")
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
r41, err := io.ReadAll(*r4)
|
r41, err := io.ReadAll(*r4)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
(*r4).Close()
|
(*r4).Close()
|
||||||
|
assert.Equal(t, len(r41), 48)
|
||||||
|
|
||||||
fmt.Println("r41", string(r41))
|
r5, err := R[http.Response]().SetClient(CustomC()).
|
||||||
|
SetContext(t.Context()).
|
||||||
ts := transport.Chain(http.DefaultTransport,
|
SetHeader("ko", "abc").
|
||||||
transport.RequestID,
|
|
||||||
transport.CurlLogTransport(map[string]string{}),
|
|
||||||
)
|
|
||||||
|
|
||||||
r5, err := CustomC[http.Response]().SetTransport(ts).R().SetHeader("ko", "abc").
|
|
||||||
SetQueryParam("age", "12").
|
SetQueryParam("age", "12").
|
||||||
Get("http://127.0.0.1:38033/uuid?dd=abc")
|
Get(svr.URL + "/uuid?dd=abc")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
r51, err := io.ReadAll(r5.Body)
|
r51, err := io.ReadAll(r5.Body)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
r5.Body.Close()
|
r5.Body.Close()
|
||||||
|
assert.Equal(t, len(r51), 48)
|
||||||
fmt.Println("r51", string(r51))
|
assert.Equal(t, "12", r5.Header.Get("query-age"))
|
||||||
|
|
||||||
r6, err := NewClient[html.Node]().R().Get("http://127.0.0.1:38033")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
title := getTitle(r6)
|
|
||||||
assert.Equal(t, "httpbin.org", title)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRestError(t *testing.T) {
|
func TestClientReuseConn(t *testing.T) {
|
||||||
r, err := NewClient[int]().R().Get("http://127.0.0.1:38033/uuid")
|
svr := testServer()
|
||||||
|
defer svr.Close()
|
||||||
|
|
||||||
|
t.Logf("svrAddr: %s", svr.URL)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), time.Second*2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
wg sync.WaitGroup
|
||||||
|
reqCount atomic.Int64
|
||||||
|
newConn atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = httptrace.WithClientTrace(
|
||||||
|
ctx,
|
||||||
|
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
|
||||||
|
newConn.Add(1)
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
|
||||||
|
concurrent := int64(10)
|
||||||
|
for range concurrent {
|
||||||
|
wg.Go(func() {
|
||||||
|
for {
|
||||||
|
reqCount.Add(1)
|
||||||
|
|
||||||
|
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
|
||||||
|
if err != nil {
|
||||||
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, len(*r) > 10)
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
t.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
|
||||||
|
assert.Less(t, newConn.Load(), concurrent*2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkClientReuseConn(b *testing.B) {
|
||||||
|
svr := testServer()
|
||||||
|
defer svr.Close()
|
||||||
|
|
||||||
|
b.Logf("addr: %s", svr.URL)
|
||||||
|
|
||||||
|
var (
|
||||||
|
reqCount atomic.Int64
|
||||||
|
newConn atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := httptrace.WithClientTrace(
|
||||||
|
b.Context(),
|
||||||
|
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
|
||||||
|
newConn.Add(1)
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
reqCount.Add(1)
|
||||||
|
|
||||||
|
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
|
||||||
|
if err != nil {
|
||||||
|
assert.ErrorIs(b, err, context.Canceled)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(b, err)
|
||||||
|
assert.True(b, len(*r) > 10)
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
|
||||||
|
assert.Less(b, newConn.Load(), int64(20))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientBody(t *testing.T) {
|
||||||
|
svr := testServer()
|
||||||
|
defer svr.Close()
|
||||||
|
|
||||||
|
r, err := R[string]().SetBody("test").Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, "test")
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody([]byte("test")).Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, "test")
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody(strings.NewReader("test")).Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, "test")
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody(map[string]string{"name": "test"}).Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, `{"name":"test"}`)
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody([]int{1, 2, 3}).Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, `[1,2,3]`)
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody(testResp{UUID: "test"}).Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, `{"uuid":"test"}`)
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody(&testResp{UUID: "test"}).Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, `{"uuid":"test"}`)
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody(nil).Get(svr.URL + "/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, "")
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody(fmt.Errorf("test")).Get(svr.URL + "/echo")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
fmt.Println(err)
|
assert.Nil(t, r)
|
||||||
|
|
||||||
|
r, err = R[string]().SetBody(123).Get(svr.URL + "/echo")
|
||||||
|
assert.ErrorIs(t, err, ErrUnsupported)
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientURL(t *testing.T) {
|
||||||
|
svr := testServer()
|
||||||
|
defer svr.Close()
|
||||||
|
|
||||||
|
r, err := R[string]().SetClient(NewClient().SetBaseURL("")).
|
||||||
|
SetBody("test").Get("/echo")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "baseURL is required", err.Error())
|
||||||
|
assert.Nil(t, r)
|
||||||
|
|
||||||
|
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL)).
|
||||||
|
SetBody("test").Get("/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, "test")
|
||||||
|
|
||||||
|
// send as json
|
||||||
|
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL).With(transport.CurlLog())).
|
||||||
|
SetBody(map[string]string{"name": "ko"}).Get("/echo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, *r, "{\"name\":\"ko\"}")
|
||||||
|
|
||||||
|
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc")).
|
||||||
|
SetBody("test").Get(svr.URL + "/echo")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, r)
|
||||||
|
|
||||||
|
r2, err := R[http.Response]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc?name=ko").With(transport.CurlLog())).
|
||||||
|
SetBody("test").Get("/echo?abc=ddd")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, r2.StatusCode, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
/*
|
||||||
|
* TencentBlueKing is pleased to support the open source community by making
|
||||||
|
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
|
||||||
|
* Copyright (C) 2025 Tencent. All rights reserved.
|
||||||
|
* Licensed under the MIT License (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on
|
||||||
|
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||||
|
* either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations under the License.
|
||||||
|
* We undertake not to change the open source license (MIT license) applicable
|
||||||
|
* to the current version of the project delivered to anyone in the future.
|
||||||
|
*/
|
||||||
|
|
||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
@ -6,19 +22,34 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CurlLog make a http curl log transport
|
||||||
|
func CurlLog(maskKeys ...string) func(http.RoundTripper) http.RoundTripper {
|
||||||
|
keys := make(map[string]struct{}, len(maskKeys))
|
||||||
|
for _, key := range maskKeys {
|
||||||
|
keys[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(base http.RoundTripper) http.RoundTripper {
|
||||||
|
t := &curlLogTransport{baseTransport: base, maskKeys: keys}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// reqToCurl curl 格式的请求日志
|
// reqToCurl curl 格式的请求日志
|
||||||
func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
|
func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
|
||||||
|
headers := strings.Builder{}
|
||||||
|
|
||||||
// 过滤掉敏感信息, header 和 query
|
// 过滤掉敏感信息, header 和 query
|
||||||
headers := ""
|
|
||||||
for key, values := range r.Header {
|
for key, values := range r.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if _, ok := c.maskKeys[key]; ok {
|
if _, ok := c.maskKeys[key]; ok {
|
||||||
value = "***"
|
value = "***"
|
||||||
}
|
}
|
||||||
headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value))
|
headers.WriteString(fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -31,7 +62,7 @@ func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
|
||||||
}
|
}
|
||||||
rawURL.RawQuery = queryValue.Encode()
|
rawURL.RawQuery = queryValue.Encode()
|
||||||
|
|
||||||
reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers)
|
reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers.String())
|
||||||
if r.Body != nil {
|
if r.Body != nil {
|
||||||
bodyBytes, err := io.ReadAll(r.Body)
|
bodyBytes, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -52,7 +83,7 @@ func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// respToCurl 返回日志
|
// respToCurl 返回日志
|
||||||
func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) {
|
func (c *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) {
|
||||||
var (
|
var (
|
||||||
bodyBytes []byte
|
bodyBytes []byte
|
||||||
err error
|
err error
|
||||||
|
|
@ -68,59 +99,50 @@ func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(bodyBytes) > 1024 {
|
if len(bodyBytes) > 1024 {
|
||||||
respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)",
|
respMsg := fmt.Sprintf("[%d] %s %s...(Total %dB)", resp.StatusCode, time.Since(st), bodyBytes[:1024],
|
||||||
resp.Status, time.Since(st), bodyBytes[:1024], len(bodyBytes))
|
len(bodyBytes))
|
||||||
return respMsg, nil
|
return respMsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(bodyBytes) > 0 {
|
if len(bodyBytes) > 0 {
|
||||||
respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes)
|
respMsg := fmt.Sprintf("[%d] %s %s", resp.StatusCode, time.Since(st), bodyBytes)
|
||||||
return respMsg, nil
|
return respMsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st))
|
respMsg := fmt.Sprintf("[%d] %s", resp.StatusCode, time.Since(st))
|
||||||
return respMsg, nil
|
return respMsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// curlLogTransport print curl log transport
|
// curlLogTransport print curl log transport
|
||||||
type curlLogTransport struct {
|
type curlLogTransport struct {
|
||||||
maskKeys map[string]string
|
baseTransport http.RoundTripper
|
||||||
|
maskKeys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip curlLog Transport
|
// RoundTrip curlLog Transport
|
||||||
func (t *curlLogTransport) Transport(next http.RoundTripper) http.RoundTripper {
|
func (c *curlLogTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
f := func(req *http.Request) (*http.Response, error) {
|
ctx := req.Context()
|
||||||
logger := slog.With("request_id", req.Header.Get("X-Request-Id"))
|
|
||||||
st := time.Now()
|
st := time.Now()
|
||||||
|
|
||||||
// 记录请求
|
// 记录请求
|
||||||
rbody, err := t.reqToCurl(req)
|
rbody, err := c.reqToCurl(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logger.Info("curl", "REQ", rbody)
|
slog.InfoContext(ctx, "curl REQ: "+rbody)
|
||||||
|
|
||||||
resp, err := next.RoundTrip(req)
|
resp, err := c.baseTransport.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("curl", "RESP: [err]", err)
|
slog.ErrorContext(ctx, "curl RESP: [err] "+err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录返回
|
// 记录返回
|
||||||
respBody, err := t.respToCurl(resp, st)
|
respBody, err := c.respToCurl(resp, st)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logger.Info("curl", "RESP", respBody)
|
slog.InfoContext(ctx, "curl RESP: "+respBody)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
return RoundTripFunc(f)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// CurlLogTransport make a new curl log transport, default transport can be nil
|
|
||||||
func CurlLogTransport(maskkeys map[string]string) func(http.RoundTripper) http.RoundTripper {
|
|
||||||
t := curlLogTransport{map[string]string{}}
|
|
||||||
return t.Transport
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,62 +1,13 @@
|
||||||
|
// Package transport provides common HTTP client middleware(aka transport) components.
|
||||||
package transport
|
package transport
|
||||||
|
|
||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
||||||
// RoundTripFunc, similar to http.HandlerFunc, is an adapter
|
// RoundTripFunc similar to http.HandlerFunc, is an adapter
|
||||||
// to allow the use of ordinary functions as http.RoundTrippers.
|
// to allow the use of ordinary functions as http.RoundTrippers.
|
||||||
type RoundTripFunc func(r *http.Request) (*http.Response, error)
|
type RoundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
// RoundTrip ...
|
||||||
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return f(req)
|
return f(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chain wraps given base RoundTripper, which is used to make HTTP requests
|
|
||||||
// (e.g. http.DefaultTransport) with RoundTripper middlewares.
|
|
||||||
func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripper) *chain {
|
|
||||||
if base == nil {
|
|
||||||
base = http.DefaultTransport
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out nil transports.
|
|
||||||
mws := []func(http.RoundTripper) http.RoundTripper{}
|
|
||||||
for _, fn := range mw {
|
|
||||||
if fn != nil {
|
|
||||||
mws = append(mws, fn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c, ok := base.(*chain); ok {
|
|
||||||
c.middlewares = append(c.middlewares, mws...)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
return &chain{
|
|
||||||
baseTransport: base,
|
|
||||||
middlewares: mws,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type chain struct {
|
|
||||||
baseTransport http.RoundTripper
|
|
||||||
middlewares []func(http.RoundTripper) http.RoundTripper
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *chain) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
rt := c.baseTransport
|
|
||||||
|
|
||||||
// Apply middlewares in reversed order so the first middleware becomes
|
|
||||||
// the innermost onion layer and the last becomes the outermost. Example:
|
|
||||||
// Given
|
|
||||||
// [Auth, VCTraceID, Debug],
|
|
||||||
// the middlewares are applied in this order:
|
|
||||||
// rt = Debug(rt)
|
|
||||||
// rt = VCTraceID(rt)
|
|
||||||
// rt = Auth(rt)
|
|
||||||
// The Auth and VCTraceID are called before the Debug middleware,
|
|
||||||
// which can then see the final request headers, as seen by http.DefaultTransport.
|
|
||||||
for i := len(c.middlewares) - 1; i >= 0; i-- {
|
|
||||||
rt = c.middlewares[i](rt)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rt.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue