277 lines
6.6 KiB
Go
277 lines
6.6 KiB
Go
package rest
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json/v2"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net/http"
|
|
"net/url"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/net/html"
|
|
)
|
|
|
|
type Client[T any] struct {
|
|
httpClient *http.Client
|
|
commonReq *http.Request
|
|
commonFormData url.Values
|
|
commonQueryParams url.Values
|
|
commonCookies []*http.Cookie
|
|
}
|
|
|
|
func NewClient[T any]() *Client[T] {
|
|
r := &Client[T]{
|
|
commonReq: &http.Request{Header: make(http.Header)},
|
|
httpClient: &http.Client{},
|
|
commonQueryParams: url.Values{},
|
|
}
|
|
return r
|
|
}
|
|
|
|
func (c *Client[T]) SetHeader(header, value string) *Client[T] {
|
|
c.commonReq.Header.Set(header, value)
|
|
return c
|
|
}
|
|
|
|
func (c *Client[T]) SetQueryParam(param, value string) *Client[T] {
|
|
c.commonQueryParams.Set(param, value)
|
|
return c
|
|
}
|
|
|
|
func (c *Client[T]) SetTransport(transport http.RoundTripper) *Client[T] {
|
|
if transport != nil {
|
|
c.httpClient.Transport = transport
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func (c *Client[T]) R() *Request[T] {
|
|
req := &http.Request{
|
|
Method: http.MethodGet,
|
|
Proto: "HTTP/1.1",
|
|
ProtoMajor: 1,
|
|
ProtoMinor: 1,
|
|
Header: make(http.Header),
|
|
}
|
|
req = req.WithContext(context.Background())
|
|
|
|
r := &Request[T]{
|
|
req: req,
|
|
client: c,
|
|
queryParams: url.Values{},
|
|
}
|
|
return r
|
|
}
|
|
|
|
func (c *Client[T]) execute(req *http.Request) (*http.Response, error) {
|
|
// change header
|
|
maps.Copy(req.Header, c.commonReq.Header)
|
|
|
|
// change query
|
|
query := req.URL.Query()
|
|
maps.Copy(query, c.commonQueryParams)
|
|
req.URL.RawQuery = query.Encode()
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
type Request[T any] struct {
|
|
req *http.Request
|
|
client *Client[T]
|
|
queryParams url.Values
|
|
}
|
|
|
|
func (r *Request[T]) Clone() *Request[T] {
|
|
newReq := r.client.R()
|
|
newReq.req = r.req.Clone(r.req.Context())
|
|
newReq.client = r.client
|
|
newReq.queryParams = r.queryParams
|
|
return newReq
|
|
}
|
|
|
|
func (r *Request[T]) SetHeader(header, value string) *Request[T] {
|
|
r.req.Header.Set(header, value)
|
|
return r
|
|
}
|
|
|
|
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
|
|
r.queryParams.Set(param, value)
|
|
return r
|
|
}
|
|
|
|
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
|
|
r.req = r.req.WithContext(ctx)
|
|
return r
|
|
}
|
|
|
|
func (r *Request[T]) Get(url string) (*T, error) {
|
|
return r.Execute(http.MethodGet, url)
|
|
}
|
|
|
|
func (r *Request[T]) Execute(method string, url string) (*T, error) {
|
|
out := new(T)
|
|
rt := reflect.TypeFor[T]()
|
|
|
|
newReq, err := http.NewRequestWithContext(r.req.Context(), method, url, r.req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// change a req header
|
|
maps.Copy(newReq.Header, r.req.Header)
|
|
|
|
// change a req query
|
|
query := newReq.URL.Query()
|
|
maps.Copy(query, r.queryParams)
|
|
newReq.URL.RawQuery = query.Encode()
|
|
|
|
resp, err := r.client.execute(newReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// io.ReadCloser 需要自己调用Close
|
|
switch any(out).(type) {
|
|
case *http.Response:
|
|
return any(resp).(*T), nil
|
|
case *io.ReadCloser:
|
|
return any(&resp.Body).(*T), nil
|
|
case *Response: // for addition trace info, reuse body an so on
|
|
r := Response{resp: resp, trace: &TraceInfo{}}
|
|
return any(&r).(*T), nil
|
|
}
|
|
|
|
// body 自动close
|
|
defer resp.Body.Close()
|
|
|
|
switch any(out).(type) {
|
|
case *[]byte:
|
|
// 需要读取整个body
|
|
body, rErr := io.ReadAll(resp.Body)
|
|
if rErr != nil {
|
|
return nil, rErr
|
|
}
|
|
return any(&body).(*T), nil
|
|
case *string:
|
|
// 需要读取整个body
|
|
body, rErr := io.ReadAll(resp.Body)
|
|
if rErr != nil {
|
|
return nil, rErr
|
|
}
|
|
v := string(body)
|
|
return any(&v).(*T), nil
|
|
case *io.Reader:
|
|
// 需要读取整个body
|
|
body, rErr := io.ReadAll(resp.Body)
|
|
if rErr != nil {
|
|
return nil, rErr
|
|
}
|
|
v := io.Reader(bytes.NewBuffer(body))
|
|
return any(&v).(*T), nil
|
|
case *html.Node:
|
|
node, rErr := html.Parse(resp.Body)
|
|
if rErr != nil {
|
|
return nil, rErr
|
|
}
|
|
return any(node).(*T), nil
|
|
}
|
|
|
|
if rt.Kind() != reflect.Struct {
|
|
return nil, fmt.Errorf("unsupported type: %v", rt)
|
|
}
|
|
|
|
// 按类型反序列化
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if strings.HasPrefix(contentType, "application/json") {
|
|
if err = json.UnmarshalRead(resp.Body, out); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
if strings.HasPrefix(contentType, "application/xml") {
|
|
// 需要读取整个body
|
|
body, rErr := io.ReadAll(resp.Body)
|
|
if rErr != nil {
|
|
return nil, rErr
|
|
}
|
|
if err = xml.Unmarshal(body, out); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("unsupported content-type: %s", contentType)
|
|
}
|
|
|
|
// TraceInfo struct is used to provide request trace info such as DNS lookup
|
|
// duration, Connection obtain duration, Server processing duration, etc.
|
|
type TraceInfo struct {
|
|
// DNSLookup is the duration that transport took to perform
|
|
// DNS lookup.
|
|
DNSLookup time.Duration `json:"dns_lookup_time"`
|
|
|
|
// ConnTime is the duration it took to obtain a successful connection.
|
|
ConnTime time.Duration `json:"connection_time"`
|
|
|
|
// TCPConnTime is the duration it took to obtain the TCP connection.
|
|
TCPConnTime time.Duration `json:"tcp_connection_time"`
|
|
|
|
// TLSHandshake is the duration of the TLS handshake.
|
|
TLSHandshake time.Duration `json:"tls_handshake_time"`
|
|
|
|
// ServerTime is the server's duration for responding to the first byte.
|
|
ServerTime time.Duration `json:"server_time"`
|
|
|
|
// ResponseTime is the duration since the first response byte from the server to
|
|
// request completion.
|
|
ResponseTime time.Duration `json:"response_time"`
|
|
|
|
// TotalTime is the duration of the total time request taken end-to-end.
|
|
TotalTime time.Duration `json:"total_time"`
|
|
|
|
// IsConnReused is whether this connection has been previously
|
|
// used for another HTTP request.
|
|
IsConnReused bool `json:"is_connection_reused"`
|
|
|
|
// IsConnWasIdle is whether this connection was obtained from an
|
|
// idle pool.
|
|
IsConnWasIdle bool `json:"is_connection_was_idle"`
|
|
|
|
// ConnIdleTime is the duration how long the connection that was previously
|
|
// idle, if IsConnWasIdle is true.
|
|
ConnIdleTime time.Duration `json:"connection_idle_time"`
|
|
|
|
// RequestAttempt is to represent the request attempt made during a Resty
|
|
// request execution flow, including retry count.
|
|
RequestAttempt int `json:"request_attempt"`
|
|
|
|
// RemoteAddr returns the remote network address.
|
|
RemoteAddr string `json:"remote_address"`
|
|
}
|
|
|
|
type Response struct {
|
|
trace *TraceInfo
|
|
resp *http.Response
|
|
receivedAt time.Time
|
|
}
|
|
|
|
func (r *Response) Duration() time.Duration {
|
|
return r.trace.TotalTime
|
|
}
|
|
|
|
// ReceivedAt method returns the time we received a response from the server for the request.
|
|
func (r *Response) ReceivedAt() time.Time {
|
|
return r.receivedAt
|
|
}
|