pkg/rest/client.go

277 lines
6.6 KiB
Go

package rest
import (
"bytes"
"context"
"encoding/json/v2"
"encoding/xml"
"fmt"
"io"
"maps"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"golang.org/x/net/html"
)
type Client[T any] struct {
httpClient *http.Client
commonReq *http.Request
commonFormData url.Values
commonQueryParams url.Values
commonCookies []*http.Cookie
}
func NewClient[T any]() *Client[T] {
r := &Client[T]{
commonReq: &http.Request{Header: make(http.Header)},
httpClient: &http.Client{},
commonQueryParams: url.Values{},
}
return r
}
func (c *Client[T]) SetHeader(header, value string) *Client[T] {
c.commonReq.Header.Set(header, value)
return c
}
func (c *Client[T]) SetQueryParam(param, value string) *Client[T] {
c.commonQueryParams.Set(param, value)
return c
}
func (c *Client[T]) SetTransport(transport http.RoundTripper) *Client[T] {
if transport != nil {
c.httpClient.Transport = transport
}
return c
}
func (c *Client[T]) R() *Request[T] {
req := &http.Request{
Method: http.MethodGet,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
}
req = req.WithContext(context.Background())
r := &Request[T]{
req: req,
client: c,
queryParams: url.Values{},
}
return r
}
func (c *Client[T]) execute(req *http.Request) (*http.Response, error) {
// change header
maps.Copy(req.Header, c.commonReq.Header)
// change query
query := req.URL.Query()
maps.Copy(query, c.commonQueryParams)
req.URL.RawQuery = query.Encode()
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
type Request[T any] struct {
req *http.Request
client *Client[T]
queryParams url.Values
}
func (r *Request[T]) Clone() *Request[T] {
newReq := r.client.R()
newReq.req = r.req.Clone(r.req.Context())
newReq.client = r.client
newReq.queryParams = r.queryParams
return newReq
}
func (r *Request[T]) SetHeader(header, value string) *Request[T] {
r.req.Header.Set(header, value)
return r
}
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
r.queryParams.Set(param, value)
return r
}
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
r.req = r.req.WithContext(ctx)
return r
}
func (r *Request[T]) Get(url string) (*T, error) {
return r.Execute(http.MethodGet, url)
}
func (r *Request[T]) Execute(method string, url string) (*T, error) {
out := new(T)
rt := reflect.TypeFor[T]()
newReq, err := http.NewRequestWithContext(r.req.Context(), method, url, r.req.Body)
if err != nil {
return nil, err
}
// change a req header
maps.Copy(newReq.Header, r.req.Header)
// change a req query
query := newReq.URL.Query()
maps.Copy(query, r.queryParams)
newReq.URL.RawQuery = query.Encode()
resp, err := r.client.execute(newReq)
if err != nil {
return nil, err
}
// io.ReadCloser 需要自己调用Close
switch any(out).(type) {
case *http.Response:
return any(resp).(*T), nil
case *io.ReadCloser:
return any(&resp.Body).(*T), nil
case *Response: // for addition trace info, reuse body an so on
r := Response{resp: resp, trace: &TraceInfo{}}
return any(&r).(*T), nil
}
// body 自动close
defer resp.Body.Close()
switch any(out).(type) {
case *[]byte:
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
return any(&body).(*T), nil
case *string:
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
v := string(body)
return any(&v).(*T), nil
case *io.Reader:
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
v := io.Reader(bytes.NewBuffer(body))
return any(&v).(*T), nil
case *html.Node:
node, rErr := html.Parse(resp.Body)
if rErr != nil {
return nil, rErr
}
return any(node).(*T), nil
}
if rt.Kind() != reflect.Struct {
return nil, fmt.Errorf("unsupported type: %v", rt)
}
// 按类型反序列化
contentType := resp.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
if err = json.UnmarshalRead(resp.Body, out); err != nil {
return nil, err
}
return out, nil
}
if strings.HasPrefix(contentType, "application/xml") {
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
if err = xml.Unmarshal(body, out); err != nil {
return nil, err
}
return out, nil
}
return nil, fmt.Errorf("unsupported content-type: %s", contentType)
}
// TraceInfo struct is used to provide request trace info such as DNS lookup
// duration, Connection obtain duration, Server processing duration, etc.
type TraceInfo struct {
// DNSLookup is the duration that transport took to perform
// DNS lookup.
DNSLookup time.Duration `json:"dns_lookup_time"`
// ConnTime is the duration it took to obtain a successful connection.
ConnTime time.Duration `json:"connection_time"`
// TCPConnTime is the duration it took to obtain the TCP connection.
TCPConnTime time.Duration `json:"tcp_connection_time"`
// TLSHandshake is the duration of the TLS handshake.
TLSHandshake time.Duration `json:"tls_handshake_time"`
// ServerTime is the server's duration for responding to the first byte.
ServerTime time.Duration `json:"server_time"`
// ResponseTime is the duration since the first response byte from the server to
// request completion.
ResponseTime time.Duration `json:"response_time"`
// TotalTime is the duration of the total time request taken end-to-end.
TotalTime time.Duration `json:"total_time"`
// IsConnReused is whether this connection has been previously
// used for another HTTP request.
IsConnReused bool `json:"is_connection_reused"`
// IsConnWasIdle is whether this connection was obtained from an
// idle pool.
IsConnWasIdle bool `json:"is_connection_was_idle"`
// ConnIdleTime is the duration how long the connection that was previously
// idle, if IsConnWasIdle is true.
ConnIdleTime time.Duration `json:"connection_idle_time"`
// RequestAttempt is to represent the request attempt made during a Resty
// request execution flow, including retry count.
RequestAttempt int `json:"request_attempt"`
// RemoteAddr returns the remote network address.
RemoteAddr string `json:"remote_address"`
}
type Response struct {
trace *TraceInfo
resp *http.Response
receivedAt time.Time
}
func (r *Response) Duration() time.Duration {
return r.trace.TotalTime
}
// ReceivedAt method returns the time we received a response from the server for the request.
func (r *Response) ReceivedAt() time.Time {
return r.receivedAt
}