/* * TencentBlueKing is pleased to support the open source community by making * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. * Copyright (C) 2025 Tencent. All rights reserved. * Licensed under the MIT License (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at http://opensource.org/licenses/MIT * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, * either express or implied. See the License for the * specific language governing permissions and limitations under the License. * We undertake not to change the open source license (MIT license) applicable * to the current version of the project delivered to anyone in the future. */ package rest import ( "bytes" "context" "encoding/json/v2" "errors" "fmt" "io" "maps" "net" "net/http" "net/url" "path" "reflect" "strings" "time" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "git.ifooth.com/common/pkg/rest/transport" "git.ifooth.com/common/pkg/version" ) var ( // ErrUnsupported is an error to check supported type or content-type ErrUnsupported = errors.New("unsupported") userAgent = fmt.Sprintf("pkg/%s", version.Version) ) // defaultTransport default transport, same as http.DefaultTransport but change MaxIdleConns / MaxIdleConnsPerHost var defaultTransport = &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 200, // 最大空闲连接数 MaxIdleConnsPerHost: 20, // 每个主机的最大空闲连接数 IdleConnTimeout: 30 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } // HTTPClient a http.Client or rest.Client Do interface type HTTPClient interface { Do(*http.Request) (*http.Response, error) } // Client is a HTTP client with common url, header, params and transports. type Client struct { baseTransport http.RoundTripper transports []func(http.RoundTripper) http.RoundTripper baseURL *url.URL timeout time.Duration commonHeader http.Header commonQueryParams url.Values err error } // NewClient creates a new client func NewClient() *Client { commonHeader := make(http.Header) commonHeader.Set("User-Agent", userAgent) // 默认统一添加trace baseTransport := otelhttp.NewTransport(defaultTransport) c := &Client{ baseTransport: baseTransport, commonQueryParams: make(url.Values), commonHeader: commonHeader, } return c } func (c *Client) parseBaseURL(baseURL string) (*url.URL, error) { if baseURL == "" { return nil, fmt.Errorf("baseURL is required") } u, err := url.Parse(baseURL) if err != nil { return nil, fmt.Errorf("baseURL %s not valid, err: %w", baseURL, err) } if u.Host == "" || u.Scheme == "" { return nil, fmt.Errorf("baseURL %s not valid, host/schema is required", baseURL) } return u, nil } // SetBaseURL sets the base URL for the client. func (c *Client) SetBaseURL(baseURL string) *Client { c.baseURL, c.err = c.parseBaseURL(baseURL) return c } // SetTimeout set a timeout for the client. func (c *Client) SetTimeout(timeout time.Duration) *Client { c.timeout = timeout return c } // SetHeader set the common header for the client. func (c *Client) SetHeader(header, value string) *Client { c.commonHeader.Set(header, value) return c } // SetQueryParam set the common query param for the client. func (c *Client) SetQueryParam(param, value string) *Client { c.commonQueryParams.Set(param, value) return c } // SetTransport set the base transport for the client. func (c *Client) SetTransport(transport http.RoundTripper) *Client { if transport != nil { c.baseTransport = transport } return c } // GetTransport return the http.RoundTripper by chaining the transports for client. func (c *Client) GetTransport() http.RoundTripper { // freeze transports next := c.chain() rt := func(req *http.Request) (*http.Response, error) { newReq, err := c.finalReq(req) if err != nil { return nil, err } return next.RoundTrip(newReq) } return transport.RoundTripFunc(rt) } // With adds inline http.RoundTripper for an base RoundTripper for the client. func (c *Client) With(rt ...func(http.RoundTripper) http.RoundTripper) *Client { c.transports = append(c.transports, rt...) return c } func (c *Client) chain() http.RoundTripper { rt := c.baseTransport for i := len(c.transports) - 1; i >= 0; i-- { rt = c.transports[i](rt) } return rt } // mergeURL return the merged url with common header and params func (c *Client) mergeURL(u *url.URL) (*url.URL, error) { query := make(url.Values) if c.baseURL != nil { if u.Host != "" || u.Scheme != "" { return nil, fmt.Errorf("baseURL and request url host/schema can not be set at the same time") } u.Scheme = c.baseURL.Scheme u.Host = c.baseURL.Host u.Path = path.Join(c.baseURL.Path, u.Path) maps.Copy(query, c.baseURL.Query()) } // change the url query maps.Copy(query, c.commonQueryParams) maps.Copy(query, u.Query()) u.RawQuery = query.Encode() return u, nil } func (c *Client) finalReq(req *http.Request) (*http.Request, error) { if c.err != nil { return nil, c.err } newReq := req.Clone(req.Context()) u, err := c.mergeURL(newReq.URL) if err != nil { return nil, err } newReq.URL = u // change the req header header := make(http.Header) maps.Copy(header, c.commonHeader) maps.Copy(header, newReq.Header) newReq.Header = header return newReq, nil } // Do implement HTTPClient interface, same as the http.Client func (c *Client) Do(req *http.Request) (*http.Response, error) { newReq, err := c.finalReq(req) if err != nil { return nil, err } httpClient := &http.Client{Transport: c.chain(), Timeout: c.timeout} return httpClient.Do(newReq) } // Request is a generic HTTP request with response type T type Request[T any] struct { client HTTPClient req *http.Request transports []func(http.RoundTripper) http.RoundTripper queryParams url.Values body any timeout time.Duration } // R create a new request with the default client and response type T func R[T any]() *Request[T] { req := &http.Request{Header: make(http.Header)} req = req.WithContext(context.Background()) r := &Request[T]{ client: NewClient(), // default client and transport req: req, queryParams: make(url.Values), } return r } // SetClient set the rest.Client or http.Client for the request. func (r *Request[T]) SetClient(client HTTPClient) *Request[T] { r.client = client return r } // With adds inline http.RoundTripper for the request. func (r *Request[T]) With(rt ...func(http.RoundTripper) http.RoundTripper) *Request[T] { r.transports = append(r.transports, rt...) return r } // SetTimeout set a timeout for the current request. func (r *Request[T]) SetTimeout(timeout time.Duration) *Request[T] { r.timeout = timeout return r } // SetHeader set the request header for the request. func (r *Request[T]) SetHeader(header, value string) *Request[T] { r.req.Header.Set(header, value) return r } // SetQueryParam set the request header for the request. func (r *Request[T]) SetQueryParam(param, value string) *Request[T] { r.queryParams.Set(param, value) return r } // SetContext set the request header for the request. func (r *Request[T]) SetContext(ctx context.Context) *Request[T] { r.req = r.req.WithContext(ctx) return r } // SetBody set the request body for the request. func (r *Request[T]) SetBody(body any) *Request[T] { r.body = body return r } // Get execute a get request func (r *Request[T]) Get(url string, args ...any) (*T, error) { return r.Execute(http.MethodGet, url, args...) } // Post execute a post request func (r *Request[T]) Post(url string, args ...any) (*T, error) { return r.Execute(http.MethodPost, url, args...) } // Put execute a put request func (r *Request[T]) Put(url string, args ...any) (*T, error) { return r.Execute(http.MethodPut, url, args...) } // Patch execute a patch request func (r *Request[T]) Patch(url string, args ...any) (*T, error) { return r.Execute(http.MethodPatch, url, args...) } // Delete execute a delete request func (r *Request[T]) Delete(url string, args ...any) (*T, error) { return r.Execute(http.MethodDelete, url, args...) } // getBody get body reader or marshal reader func (r *Request[T]) getBody() (io.Reader, error) { if r.body == nil { return nil, nil } switch body := r.body.(type) { case []byte: return bytes.NewReader(body), nil case string: return strings.NewReader(body), nil case io.Reader: return body, nil default: // continue check kind below } rv := reflect.ValueOf(r.body) if rv.Kind() == reflect.Pointer { rv = rv.Elem() } switch rv.Kind() { case reflect.Map, reflect.Slice, reflect.Struct: body, err := json.Marshal(r.body) if err != nil { return nil, err } r.SetHeader("Content-Type", "application/json") return bytes.NewReader(body), nil default: // continue check kind below } return nil, fmt.Errorf("%w body type: %T", ErrUnsupported, r.body) } // getURL get the final url for the request. func (r *Request[T]) getURL(u string) (string, error) { newURL, err := url.Parse(u) if err != nil { return "", err } // change the url query query := newURL.Query() maps.Copy(query, r.queryParams) newURL.RawQuery = query.Encode() return newURL.String(), nil } // withStdHTTPClient chain with http Client and request transports func (r *Request[T]) withStdHTTPClient(c *http.Client) HTTPClient { rt := c.Transport if rt == nil { rt = http.DefaultTransport } for i := len(r.transports) - 1; i >= 0; i-- { rt = r.transports[i](rt) } newC := *c newC.Transport = rt return &newC } // withRestClient chain with rest client and request transports func (r *Request[T]) withRestClient(c *Client) HTTPClient { newTransports := make([]func(http.RoundTripper) http.RoundTripper, len(c.transports)) copy(newTransports, c.transports) newC := *c newC.transports = newTransports // append request transports newC.With(r.transports...) return &newC } func (r *Request[T]) getHTTPClient() HTTPClient { switch c := r.client.(type) { case *Client: return r.withRestClient(c) case *http.Client: return r.withStdHTTPClient(c) default: // will ignore request transport return c } } // Execute execute a request with method and url func (r *Request[T]) Execute(method string, url string, args ...any) (*T, error) { body, err := r.getBody() if err != nil { return nil, err } // fix non-constant format string in call if len(args) > 0 { url = fmt.Sprintf(url, args...) } finalURL, err := r.getURL(url) if err != nil { return nil, err } newReq, err := http.NewRequestWithContext(r.req.Context(), method, finalURL, body) if err != nil { return nil, err } // change the req header maps.Copy(newReq.Header, r.req.Header) if r.timeout > 0 { ctx, cancel := context.WithTimeout(newReq.Context(), r.timeout) defer cancel() newReq = newReq.WithContext(ctx) } client := r.getHTTPClient() resp, err := client.Do(newReq) if err != nil { return nil, err } return r.handleResp(resp) } func (r *Request[T]) handleResp(resp *http.Response) (*T, error) { out := new(T) // io.ReadCloser 需要自己调用Close switch any(out).(type) { case *http.Response: return any(resp).(*T), nil case *io.ReadCloser: return any(&resp.Body).(*T), nil case *Response: // for additional trace info, reuse body and so on r := Response{Response: resp, receivedAt: time.Now()} return any(&r).(*T), nil default: // continue check kind below } // body 自动close defer resp.Body.Close() // 需要读取整个body switch any(out).(type) { case *[]byte: body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } return any(&body).(*T), nil case *string: body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } v := string(body) return any(&v).(*T), nil case *io.Reader: body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } v := io.Reader(bytes.NewBuffer(body)) return any(&v).(*T), nil default: // continue check kind below } rt := reflect.TypeFor[T]() // 其他只处理json, 自定义序列化请使用[]byte或者io.Reader switch rt.Kind() { case reflect.Map, reflect.Slice, reflect.Struct: if err := json.UnmarshalRead(resp.Body, out); err != nil { return nil, err } return out, nil default: // continue check kind below } return nil, fmt.Errorf("%w type: %v", ErrUnsupported, rt) } // Response struct holds response values of executed requests. type Response struct { *http.Response receivedAt time.Time } // ReceivedAt method returns the time we received a response from the server for the request. func (r *Response) ReceivedAt() time.Time { return r.receivedAt }