533 lines
13 KiB
Go
533 lines
13 KiB
Go
/*
|
|
* TencentBlueKing is pleased to support the open source community by making
|
|
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
|
|
* Copyright (C) 2025 Tencent. All rights reserved.
|
|
* Licensed under the MIT License (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
|
|
* Unless required by applicable law or agreed to in writing,
|
|
* software distributed under the License is distributed on
|
|
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
|
* either express or implied. See the License for the
|
|
* specific language governing permissions and limitations under the License.
|
|
* We undertake not to change the open source license (MIT license) applicable
|
|
* to the current version of the project delivered to anyone in the future.
|
|
*/
|
|
|
|
package rest
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json/v2"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
|
|
|
"git.ifooth.com/common/pkg/rest/transport"
|
|
"git.ifooth.com/common/pkg/version"
|
|
)
|
|
|
|
var (
|
|
// ErrUnsupported is an error to check supported type or content-type
|
|
ErrUnsupported = errors.New("unsupported")
|
|
userAgent = fmt.Sprintf("pkg/%s", version.Version)
|
|
)
|
|
|
|
// defaultTransport default transport, same as http.DefaultTransport but change MaxIdleConns / MaxIdleConnsPerHost
|
|
var defaultTransport = &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}).DialContext,
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 200, // 最大空闲连接数
|
|
MaxIdleConnsPerHost: 20, // 每个主机的最大空闲连接数
|
|
IdleConnTimeout: 30 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
|
|
// HTTPClient a http.Client or rest.Client Do interface
|
|
type HTTPClient interface {
|
|
Do(*http.Request) (*http.Response, error)
|
|
}
|
|
|
|
// Client is a HTTP client with common url, header, params and transports.
|
|
type Client struct {
|
|
baseTransport http.RoundTripper
|
|
transports []func(http.RoundTripper) http.RoundTripper
|
|
baseURL *url.URL
|
|
timeout time.Duration
|
|
commonHeader http.Header
|
|
commonQueryParams url.Values
|
|
err error
|
|
}
|
|
|
|
// NewClient creates a new client
|
|
func NewClient() *Client {
|
|
commonHeader := make(http.Header)
|
|
commonHeader.Set("User-Agent", userAgent)
|
|
|
|
// 默认统一添加trace
|
|
baseTransport := otelhttp.NewTransport(defaultTransport)
|
|
|
|
c := &Client{
|
|
baseTransport: baseTransport,
|
|
commonQueryParams: make(url.Values),
|
|
commonHeader: commonHeader,
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func (c *Client) parseBaseURL(baseURL string) (*url.URL, error) {
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("baseURL is required")
|
|
}
|
|
|
|
u, err := url.Parse(baseURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("baseURL %s not valid, err: %w", baseURL, err)
|
|
}
|
|
|
|
if u.Host == "" || u.Scheme == "" {
|
|
return nil, fmt.Errorf("baseURL %s not valid, host/schema is required", baseURL)
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
// SetBaseURL sets the base URL for the client.
|
|
func (c *Client) SetBaseURL(baseURL string) *Client {
|
|
c.baseURL, c.err = c.parseBaseURL(baseURL)
|
|
return c
|
|
}
|
|
|
|
// SetTimeout set a timeout for the client.
|
|
func (c *Client) SetTimeout(timeout time.Duration) *Client {
|
|
c.timeout = timeout
|
|
return c
|
|
}
|
|
|
|
// SetHeader set the common header for the client.
|
|
func (c *Client) SetHeader(header, value string) *Client {
|
|
c.commonHeader.Set(header, value)
|
|
return c
|
|
}
|
|
|
|
// SetQueryParam set the common query param for the client.
|
|
func (c *Client) SetQueryParam(param, value string) *Client {
|
|
c.commonQueryParams.Set(param, value)
|
|
return c
|
|
}
|
|
|
|
// SetTransport set the base transport for the client.
|
|
func (c *Client) SetTransport(transport http.RoundTripper) *Client {
|
|
if transport != nil {
|
|
c.baseTransport = transport
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// GetTransport return the http.RoundTripper by chaining the transports for client.
|
|
func (c *Client) GetTransport() http.RoundTripper {
|
|
// freeze transports
|
|
next := c.chain()
|
|
|
|
rt := func(req *http.Request) (*http.Response, error) {
|
|
newReq, err := c.finalReq(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return next.RoundTrip(newReq)
|
|
}
|
|
|
|
return transport.RoundTripFunc(rt)
|
|
}
|
|
|
|
// With adds inline http.RoundTripper for an base RoundTripper for the client.
|
|
func (c *Client) With(rt ...func(http.RoundTripper) http.RoundTripper) *Client {
|
|
c.transports = append(c.transports, rt...)
|
|
return c
|
|
}
|
|
|
|
func (c *Client) chain() http.RoundTripper {
|
|
rt := c.baseTransport
|
|
|
|
for i := len(c.transports) - 1; i >= 0; i-- {
|
|
rt = c.transports[i](rt)
|
|
}
|
|
|
|
return rt
|
|
}
|
|
|
|
// mergeURL return the merged url with common header and params
|
|
func (c *Client) mergeURL(u *url.URL) (*url.URL, error) {
|
|
query := make(url.Values)
|
|
|
|
if c.baseURL != nil {
|
|
if u.Host != "" || u.Scheme != "" {
|
|
return nil, fmt.Errorf("baseURL and request url host/schema can not be set at the same time")
|
|
}
|
|
|
|
u.Scheme = c.baseURL.Scheme
|
|
u.Host = c.baseURL.Host
|
|
u.Path = path.Join(c.baseURL.Path, u.Path)
|
|
maps.Copy(query, c.baseURL.Query())
|
|
}
|
|
|
|
// change the url query
|
|
maps.Copy(query, c.commonQueryParams)
|
|
maps.Copy(query, u.Query())
|
|
u.RawQuery = query.Encode()
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func (c *Client) finalReq(req *http.Request) (*http.Request, error) {
|
|
if c.err != nil {
|
|
return nil, c.err
|
|
}
|
|
|
|
newReq := req.Clone(req.Context())
|
|
|
|
u, err := c.mergeURL(newReq.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newReq.URL = u
|
|
|
|
// change the req header
|
|
header := make(http.Header)
|
|
maps.Copy(header, c.commonHeader)
|
|
maps.Copy(header, newReq.Header)
|
|
newReq.Header = header
|
|
|
|
return newReq, nil
|
|
}
|
|
|
|
// Do implement HTTPClient interface, same as the http.Client
|
|
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
|
newReq, err := c.finalReq(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpClient := &http.Client{Transport: c.chain(), Timeout: c.timeout}
|
|
return httpClient.Do(newReq)
|
|
}
|
|
|
|
// Request is a generic HTTP request with response type T
|
|
type Request[T any] struct {
|
|
client HTTPClient
|
|
req *http.Request
|
|
transports []func(http.RoundTripper) http.RoundTripper
|
|
queryParams url.Values
|
|
body any
|
|
timeout time.Duration
|
|
}
|
|
|
|
// R create a new request with the default client and response type T
|
|
func R[T any]() *Request[T] {
|
|
req := &http.Request{Header: make(http.Header)}
|
|
req = req.WithContext(context.Background())
|
|
|
|
r := &Request[T]{
|
|
client: NewClient(), // default client and transport
|
|
req: req,
|
|
queryParams: make(url.Values),
|
|
}
|
|
|
|
return r
|
|
}
|
|
|
|
// SetClient set the rest.Client or http.Client for the request.
|
|
func (r *Request[T]) SetClient(client HTTPClient) *Request[T] {
|
|
r.client = client
|
|
return r
|
|
}
|
|
|
|
// With adds inline http.RoundTripper for the request.
|
|
func (r *Request[T]) With(rt ...func(http.RoundTripper) http.RoundTripper) *Request[T] {
|
|
r.transports = append(r.transports, rt...)
|
|
return r
|
|
}
|
|
|
|
// SetTimeout set a timeout for the current request.
|
|
func (r *Request[T]) SetTimeout(timeout time.Duration) *Request[T] {
|
|
r.timeout = timeout
|
|
return r
|
|
}
|
|
|
|
// SetHeader set the request header for the request.
|
|
func (r *Request[T]) SetHeader(header, value string) *Request[T] {
|
|
r.req.Header.Set(header, value)
|
|
return r
|
|
}
|
|
|
|
// SetQueryParam set the request header for the request.
|
|
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
|
|
r.queryParams.Set(param, value)
|
|
return r
|
|
}
|
|
|
|
// SetContext set the request header for the request.
|
|
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
|
|
r.req = r.req.WithContext(ctx)
|
|
return r
|
|
}
|
|
|
|
// SetBody set the request body for the request.
|
|
func (r *Request[T]) SetBody(body any) *Request[T] {
|
|
r.body = body
|
|
return r
|
|
}
|
|
|
|
// Get execute a get request
|
|
func (r *Request[T]) Get(url string, args ...any) (*T, error) {
|
|
return r.Execute(http.MethodGet, url, args...)
|
|
}
|
|
|
|
// Post execute a post request
|
|
func (r *Request[T]) Post(url string, args ...any) (*T, error) {
|
|
return r.Execute(http.MethodPost, url, args...)
|
|
}
|
|
|
|
// Put execute a put request
|
|
func (r *Request[T]) Put(url string, args ...any) (*T, error) {
|
|
return r.Execute(http.MethodPut, url, args...)
|
|
}
|
|
|
|
// Patch execute a patch request
|
|
func (r *Request[T]) Patch(url string, args ...any) (*T, error) {
|
|
return r.Execute(http.MethodPatch, url, args...)
|
|
}
|
|
|
|
// Delete execute a delete request
|
|
func (r *Request[T]) Delete(url string, args ...any) (*T, error) {
|
|
return r.Execute(http.MethodDelete, url, args...)
|
|
}
|
|
|
|
// getBody get body reader or marshal reader
|
|
func (r *Request[T]) getBody() (io.Reader, error) {
|
|
if r.body == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
switch body := r.body.(type) {
|
|
case []byte:
|
|
return bytes.NewReader(body), nil
|
|
case string:
|
|
return strings.NewReader(body), nil
|
|
case io.Reader:
|
|
return body, nil
|
|
default:
|
|
// continue check kind below
|
|
}
|
|
|
|
rv := reflect.ValueOf(r.body)
|
|
if rv.Kind() == reflect.Pointer {
|
|
rv = rv.Elem()
|
|
}
|
|
|
|
switch rv.Kind() {
|
|
case reflect.Map, reflect.Slice, reflect.Struct:
|
|
body, err := json.Marshal(r.body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r.SetHeader("Content-Type", "application/json")
|
|
|
|
return bytes.NewReader(body), nil
|
|
default:
|
|
// continue check kind below
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w body type: %T", ErrUnsupported, r.body)
|
|
}
|
|
|
|
// getURL get the final url for the request.
|
|
func (r *Request[T]) getURL(u string) (string, error) {
|
|
newURL, err := url.Parse(u)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// change the url query
|
|
query := newURL.Query()
|
|
maps.Copy(query, r.queryParams)
|
|
|
|
newURL.RawQuery = query.Encode()
|
|
|
|
return newURL.String(), nil
|
|
}
|
|
|
|
// withStdHTTPClient chain with http Client and request transports
|
|
func (r *Request[T]) withStdHTTPClient(c *http.Client) HTTPClient {
|
|
rt := c.Transport
|
|
if rt == nil {
|
|
rt = http.DefaultTransport
|
|
}
|
|
|
|
for i := len(r.transports) - 1; i >= 0; i-- {
|
|
rt = r.transports[i](rt)
|
|
}
|
|
|
|
newC := *c
|
|
newC.Transport = rt
|
|
return &newC
|
|
}
|
|
|
|
// withRestClient chain with rest client and request transports
|
|
func (r *Request[T]) withRestClient(c *Client) HTTPClient {
|
|
newTransports := make([]func(http.RoundTripper) http.RoundTripper, len(c.transports))
|
|
copy(newTransports, c.transports)
|
|
|
|
newC := *c
|
|
newC.transports = newTransports
|
|
|
|
// append request transports
|
|
newC.With(r.transports...)
|
|
|
|
return &newC
|
|
}
|
|
|
|
func (r *Request[T]) getHTTPClient() HTTPClient {
|
|
switch c := r.client.(type) {
|
|
case *Client:
|
|
return r.withRestClient(c)
|
|
case *http.Client:
|
|
return r.withStdHTTPClient(c)
|
|
default:
|
|
// will ignore request transport
|
|
return c
|
|
}
|
|
}
|
|
|
|
// Execute execute a request with method and url
|
|
func (r *Request[T]) Execute(method string, url string, args ...any) (*T, error) {
|
|
body, err := r.getBody()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// fix non-constant format string in call
|
|
if len(args) > 0 {
|
|
url = fmt.Sprintf(url, args...)
|
|
}
|
|
|
|
finalURL, err := r.getURL(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
newReq, err := http.NewRequestWithContext(r.req.Context(), method, finalURL, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// change the req header
|
|
maps.Copy(newReq.Header, r.req.Header)
|
|
|
|
if r.timeout > 0 {
|
|
ctx, cancel := context.WithTimeout(newReq.Context(), r.timeout)
|
|
defer cancel()
|
|
|
|
newReq = newReq.WithContext(ctx)
|
|
}
|
|
|
|
client := r.getHTTPClient()
|
|
resp, err := client.Do(newReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.handleResp(resp)
|
|
}
|
|
|
|
func (r *Request[T]) handleResp(resp *http.Response) (*T, error) {
|
|
out := new(T)
|
|
|
|
// io.ReadCloser 需要自己调用Close
|
|
switch any(out).(type) {
|
|
case *http.Response:
|
|
return any(resp).(*T), nil
|
|
case *io.ReadCloser:
|
|
return any(&resp.Body).(*T), nil
|
|
case *Response: // for additional trace info, reuse body and so on
|
|
r := Response{Response: resp, receivedAt: time.Now()}
|
|
return any(&r).(*T), nil
|
|
default:
|
|
// continue check kind below
|
|
}
|
|
|
|
// body 自动close
|
|
defer resp.Body.Close()
|
|
|
|
// 需要读取整个body
|
|
switch any(out).(type) {
|
|
case *[]byte:
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return any(&body).(*T), nil
|
|
case *string:
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
v := string(body)
|
|
return any(&v).(*T), nil
|
|
case *io.Reader:
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
v := io.Reader(bytes.NewBuffer(body))
|
|
return any(&v).(*T), nil
|
|
default:
|
|
// continue check kind below
|
|
}
|
|
|
|
rt := reflect.TypeFor[T]()
|
|
|
|
// 其他只处理json, 自定义序列化请使用[]byte或者io.Reader
|
|
switch rt.Kind() {
|
|
case reflect.Map, reflect.Slice, reflect.Struct:
|
|
if err := json.UnmarshalRead(resp.Body, out); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
default:
|
|
// continue check kind below
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w type: %v", ErrUnsupported, rt)
|
|
}
|
|
|
|
// Response struct holds response values of executed requests.
|
|
type Response struct {
|
|
*http.Response
|
|
receivedAt time.Time
|
|
}
|
|
|
|
// ReceivedAt method returns the time we received a response from the server for the request.
|
|
func (r *Response) ReceivedAt() time.Time {
|
|
return r.receivedAt
|
|
}
|