pkg/rest/client.go

533 lines
13 KiB
Go

/*
* TencentBlueKing is pleased to support the open source community by making
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest
import (
"bytes"
"context"
"encoding/json/v2"
"errors"
"fmt"
"io"
"maps"
"net"
"net/http"
"net/url"
"path"
"reflect"
"strings"
"time"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"git.ifooth.com/common/pkg/rest/transport"
"git.ifooth.com/common/pkg/version"
)
var (
// ErrUnsupported is an error to check supported type or content-type
ErrUnsupported = errors.New("unsupported")
userAgent = fmt.Sprintf("pkg/%s", version.Version)
)
// defaultTransport default transport, same as http.DefaultTransport but change MaxIdleConns / MaxIdleConnsPerHost
var defaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 200, // 最大空闲连接数
MaxIdleConnsPerHost: 20, // 每个主机的最大空闲连接数
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
// HTTPClient a http.Client or rest.Client Do interface
type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}
// Client is a HTTP client with common url, header, params and transports.
type Client struct {
baseTransport http.RoundTripper
transports []func(http.RoundTripper) http.RoundTripper
baseURL *url.URL
timeout time.Duration
commonHeader http.Header
commonQueryParams url.Values
err error
}
// NewClient creates a new client
func NewClient() *Client {
commonHeader := make(http.Header)
commonHeader.Set("User-Agent", userAgent)
// 默认统一添加trace
baseTransport := otelhttp.NewTransport(defaultTransport)
c := &Client{
baseTransport: baseTransport,
commonQueryParams: make(url.Values),
commonHeader: commonHeader,
}
return c
}
func (c *Client) parseBaseURL(baseURL string) (*url.URL, error) {
if baseURL == "" {
return nil, fmt.Errorf("baseURL is required")
}
u, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("baseURL %s not valid, err: %w", baseURL, err)
}
if u.Host == "" || u.Scheme == "" {
return nil, fmt.Errorf("baseURL %s not valid, host/schema is required", baseURL)
}
return u, nil
}
// SetBaseURL sets the base URL for the client.
func (c *Client) SetBaseURL(baseURL string) *Client {
c.baseURL, c.err = c.parseBaseURL(baseURL)
return c
}
// SetTimeout set a timeout for the client.
func (c *Client) SetTimeout(timeout time.Duration) *Client {
c.timeout = timeout
return c
}
// SetHeader set the common header for the client.
func (c *Client) SetHeader(header, value string) *Client {
c.commonHeader.Set(header, value)
return c
}
// SetQueryParam set the common query param for the client.
func (c *Client) SetQueryParam(param, value string) *Client {
c.commonQueryParams.Set(param, value)
return c
}
// SetTransport set the base transport for the client.
func (c *Client) SetTransport(transport http.RoundTripper) *Client {
if transport != nil {
c.baseTransport = transport
}
return c
}
// GetTransport return the http.RoundTripper by chaining the transports for client.
func (c *Client) GetTransport() http.RoundTripper {
// freeze transports
next := c.chain()
rt := func(req *http.Request) (*http.Response, error) {
newReq, err := c.finalReq(req)
if err != nil {
return nil, err
}
return next.RoundTrip(newReq)
}
return transport.RoundTripFunc(rt)
}
// With adds inline http.RoundTripper for an base RoundTripper for the client.
func (c *Client) With(rt ...func(http.RoundTripper) http.RoundTripper) *Client {
c.transports = append(c.transports, rt...)
return c
}
func (c *Client) chain() http.RoundTripper {
rt := c.baseTransport
for i := len(c.transports) - 1; i >= 0; i-- {
rt = c.transports[i](rt)
}
return rt
}
// mergeURL return the merged url with common header and params
func (c *Client) mergeURL(u *url.URL) (*url.URL, error) {
query := make(url.Values)
if c.baseURL != nil {
if u.Host != "" || u.Scheme != "" {
return nil, fmt.Errorf("baseURL and request url host/schema can not be set at the same time")
}
u.Scheme = c.baseURL.Scheme
u.Host = c.baseURL.Host
u.Path = path.Join(c.baseURL.Path, u.Path)
maps.Copy(query, c.baseURL.Query())
}
// change the url query
maps.Copy(query, c.commonQueryParams)
maps.Copy(query, u.Query())
u.RawQuery = query.Encode()
return u, nil
}
func (c *Client) finalReq(req *http.Request) (*http.Request, error) {
if c.err != nil {
return nil, c.err
}
newReq := req.Clone(req.Context())
u, err := c.mergeURL(newReq.URL)
if err != nil {
return nil, err
}
newReq.URL = u
// change the req header
header := make(http.Header)
maps.Copy(header, c.commonHeader)
maps.Copy(header, newReq.Header)
newReq.Header = header
return newReq, nil
}
// Do implement HTTPClient interface, same as the http.Client
func (c *Client) Do(req *http.Request) (*http.Response, error) {
newReq, err := c.finalReq(req)
if err != nil {
return nil, err
}
httpClient := &http.Client{Transport: c.chain(), Timeout: c.timeout}
return httpClient.Do(newReq)
}
// Request is a generic HTTP request with response type T
type Request[T any] struct {
client HTTPClient
req *http.Request
transports []func(http.RoundTripper) http.RoundTripper
queryParams url.Values
body any
timeout time.Duration
}
// R create a new request with the default client and response type T
func R[T any]() *Request[T] {
req := &http.Request{Header: make(http.Header)}
req = req.WithContext(context.Background())
r := &Request[T]{
client: NewClient(), // default client and transport
req: req,
queryParams: make(url.Values),
}
return r
}
// SetClient set the rest.Client or http.Client for the request.
func (r *Request[T]) SetClient(client HTTPClient) *Request[T] {
r.client = client
return r
}
// With adds inline http.RoundTripper for the request.
func (r *Request[T]) With(rt ...func(http.RoundTripper) http.RoundTripper) *Request[T] {
r.transports = append(r.transports, rt...)
return r
}
// SetTimeout set a timeout for the current request.
func (r *Request[T]) SetTimeout(timeout time.Duration) *Request[T] {
r.timeout = timeout
return r
}
// SetHeader set the request header for the request.
func (r *Request[T]) SetHeader(header, value string) *Request[T] {
r.req.Header.Set(header, value)
return r
}
// SetQueryParam set the request header for the request.
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
r.queryParams.Set(param, value)
return r
}
// SetContext set the request header for the request.
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
r.req = r.req.WithContext(ctx)
return r
}
// SetBody set the request body for the request.
func (r *Request[T]) SetBody(body any) *Request[T] {
r.body = body
return r
}
// Get execute a get request
func (r *Request[T]) Get(url string, args ...any) (*T, error) {
return r.Execute(http.MethodGet, url, args...)
}
// Post execute a post request
func (r *Request[T]) Post(url string, args ...any) (*T, error) {
return r.Execute(http.MethodPost, url, args...)
}
// Put execute a put request
func (r *Request[T]) Put(url string, args ...any) (*T, error) {
return r.Execute(http.MethodPut, url, args...)
}
// Patch execute a patch request
func (r *Request[T]) Patch(url string, args ...any) (*T, error) {
return r.Execute(http.MethodPatch, url, args...)
}
// Delete execute a delete request
func (r *Request[T]) Delete(url string, args ...any) (*T, error) {
return r.Execute(http.MethodDelete, url, args...)
}
// getBody get body reader or marshal reader
func (r *Request[T]) getBody() (io.Reader, error) {
if r.body == nil {
return nil, nil
}
switch body := r.body.(type) {
case []byte:
return bytes.NewReader(body), nil
case string:
return strings.NewReader(body), nil
case io.Reader:
return body, nil
default:
// continue check kind below
}
rv := reflect.ValueOf(r.body)
if rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
switch rv.Kind() {
case reflect.Map, reflect.Slice, reflect.Struct:
body, err := json.Marshal(r.body)
if err != nil {
return nil, err
}
r.SetHeader("Content-Type", "application/json")
return bytes.NewReader(body), nil
default:
// continue check kind below
}
return nil, fmt.Errorf("%w body type: %T", ErrUnsupported, r.body)
}
// getURL get the final url for the request.
func (r *Request[T]) getURL(u string) (string, error) {
newURL, err := url.Parse(u)
if err != nil {
return "", err
}
// change the url query
query := newURL.Query()
maps.Copy(query, r.queryParams)
newURL.RawQuery = query.Encode()
return newURL.String(), nil
}
// withStdHTTPClient chain with http Client and request transports
func (r *Request[T]) withStdHTTPClient(c *http.Client) HTTPClient {
rt := c.Transport
if rt == nil {
rt = http.DefaultTransport
}
for i := len(r.transports) - 1; i >= 0; i-- {
rt = r.transports[i](rt)
}
newC := *c
newC.Transport = rt
return &newC
}
// withRestClient chain with rest client and request transports
func (r *Request[T]) withRestClient(c *Client) HTTPClient {
newTransports := make([]func(http.RoundTripper) http.RoundTripper, len(c.transports))
copy(newTransports, c.transports)
newC := *c
newC.transports = newTransports
// append request transports
newC.With(r.transports...)
return &newC
}
func (r *Request[T]) getHTTPClient() HTTPClient {
switch c := r.client.(type) {
case *Client:
return r.withRestClient(c)
case *http.Client:
return r.withStdHTTPClient(c)
default:
// will ignore request transport
return c
}
}
// Execute execute a request with method and url
func (r *Request[T]) Execute(method string, url string, args ...any) (*T, error) {
body, err := r.getBody()
if err != nil {
return nil, err
}
// fix non-constant format string in call
if len(args) > 0 {
url = fmt.Sprintf(url, args...)
}
finalURL, err := r.getURL(url)
if err != nil {
return nil, err
}
newReq, err := http.NewRequestWithContext(r.req.Context(), method, finalURL, body)
if err != nil {
return nil, err
}
// change the req header
maps.Copy(newReq.Header, r.req.Header)
if r.timeout > 0 {
ctx, cancel := context.WithTimeout(newReq.Context(), r.timeout)
defer cancel()
newReq = newReq.WithContext(ctx)
}
client := r.getHTTPClient()
resp, err := client.Do(newReq)
if err != nil {
return nil, err
}
return r.handleResp(resp)
}
func (r *Request[T]) handleResp(resp *http.Response) (*T, error) {
out := new(T)
// io.ReadCloser 需要自己调用Close
switch any(out).(type) {
case *http.Response:
return any(resp).(*T), nil
case *io.ReadCloser:
return any(&resp.Body).(*T), nil
case *Response: // for additional trace info, reuse body and so on
r := Response{Response: resp, receivedAt: time.Now()}
return any(&r).(*T), nil
default:
// continue check kind below
}
// body 自动close
defer resp.Body.Close()
// 需要读取整个body
switch any(out).(type) {
case *[]byte:
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return any(&body).(*T), nil
case *string:
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
v := string(body)
return any(&v).(*T), nil
case *io.Reader:
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
v := io.Reader(bytes.NewBuffer(body))
return any(&v).(*T), nil
default:
// continue check kind below
}
rt := reflect.TypeFor[T]()
// 其他只处理json, 自定义序列化请使用[]byte或者io.Reader
switch rt.Kind() {
case reflect.Map, reflect.Slice, reflect.Struct:
if err := json.UnmarshalRead(resp.Body, out); err != nil {
return nil, err
}
return out, nil
default:
// continue check kind below
}
return nil, fmt.Errorf("%w type: %v", ErrUnsupported, rt)
}
// Response struct holds response values of executed requests.
type Response struct {
*http.Response
receivedAt time.Time
}
// ReceivedAt method returns the time we received a response from the server for the request.
func (r *Response) ReceivedAt() time.Time {
return r.receivedAt
}