pkg/rest/client_test.go

407 lines
10 KiB
Go

/*
* TencentBlueKing is pleased to support the open source community by making
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest
import (
"bytes"
"context"
"encoding/json/v2"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.ifooth.com/common/pkg/rest/transport"
)
func testServer() *httptest.Server {
r := http.NewServeMux()
// 正常数据返回
r.HandleFunc("GET /uuid", func(w http.ResponseWriter, r *http.Request) {
data := fmt.Sprintf(`{"uuid": "%s"}`, uuid.NewString())
// 原样返回query
for k, v := range r.URL.Query() {
w.Header().Set("query-"+k, strings.Join(v, ","))
}
w.Header().Set("content-type", "application/json")
w.Write([]byte(data))
})
// 状态码测试
r.HandleFunc("/status/{code}", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "text/html; charset=utf-8")
code, err := strconv.ParseInt(r.PathValue("code"), 10, 64)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(int(code))
w.Write([]byte(r.PathValue("code")))
})
r.HandleFunc("/bkResult/ok", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(int(400))
w.Write([]byte(`{"code": 0, "data": {"name": "ko"}}`))
})
r.HandleFunc("/bkResult/failed", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(int(400))
w.Write([]byte(`{"code": 100400, "data": null}`))
})
// 原样返回任何body数据
r.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "text/html; charset=utf-8")
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// 原样返回query
for k, v := range r.URL.Query() {
w.Header().Set("query-"+k, strings.Join(v, ","))
}
w.Write(body)
})
ts := httptest.NewServer(r)
return ts
}
type testResp struct {
UUID string `json:"uuid"`
}
type BkResp[T any] struct {
Code int `json:"code"`
Data *T `json:"data"`
}
type bkResult struct {
Name string `json:"name"`
}
type BkError struct {
Code int `json:"code"`
}
func CustomC() *Client {
c := NewClient().
SetHeader("abc", "abc").
SetQueryParam("name", "ko").
With(transport.CurlLog())
return c
}
func TestClientUseAsDo(t *testing.T) {
svr := testServer()
defer svr.Close()
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
assert.NoError(t, err)
// 用法1: Client使用Do处理请求
client := NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set")
resp, err := client.Do(req)
assert.NoError(t, err)
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, []byte("testdata"), body)
assert.Equal(t, resp.Header.Get("query-ko"), "set")
}
func TestClientUseAsTransport(t *testing.T) {
svr := testServer()
defer svr.Close()
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
assert.NoError(t, err)
// 用法2: Client作为第三方http库的transport处理请求
client := &http.Client{}
client.Transport = NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set").GetTransport()
resp, err := client.Do(req)
assert.NoError(t, err)
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, []byte("testdata"), body)
assert.Equal(t, resp.Header.Get("query-ko"), "set")
}
func TestRequestUseHTTPClient(t *testing.T) {
svr := testServer()
defer svr.Close()
resp, err := http.Get(svr.URL + "/uuid")
assert.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
r := new(testResp)
err = json.Unmarshal(body, r)
assert.NoError(t, err)
assert.Equal(t, len(r.UUID), 36)
// 用法3: Request[T]使用原生HTTP Client
r1, err := R[string]().SetClient(http.DefaultClient).SetBody("testdata").Post(svr.URL + "/echo")
assert.NoError(t, err)
assert.Equal(t, "testdata", *r1)
r2, err := R[http.Response]().SetClient(http.DefaultClient).With(transport.CurlLog()).SetQueryParam("ko", "hi").SetBody("testdata").Post(svr.URL + "/echo")
assert.NoError(t, err)
body, err = io.ReadAll(r2.Body)
assert.NoError(t, err)
assert.Equal(t, "testdata", string(body))
assert.Equal(t, "hi", r2.Header.Get("query-ko"))
}
func TestRequestUseRestClient(t *testing.T) {
svr := testServer()
defer svr.Close()
// 用法4: Request[T]使用rest.Client
r, err := R[testResp]().Get(svr.URL + "/uuid")
assert.NoError(t, err)
assert.Equal(t, len(r.UUID), 36)
rb, err := R[[]byte]().Get(svr.URL + "/uuid")
require.NoError(t, err)
assert.Equal(t, len(*rb), 48)
r3, err := R[io.Reader]().Get(svr.URL + "/uuid")
assert.NoError(t, err)
r31, err := io.ReadAll(*r3)
assert.NoError(t, err)
assert.Equal(t, len(r31), 48)
r4, err := R[io.ReadCloser]().Get(svr.URL + "/uuid")
assert.NoError(t, err)
r41, err := io.ReadAll(*r4)
assert.NoError(t, err)
(*r4).Close()
assert.Equal(t, len(r41), 48)
r5, err := R[http.Response]().SetClient(CustomC()).
SetContext(t.Context()).
SetHeader("ko", "abc").
SetQueryParam("age", "12").
Get(svr.URL + "/uuid?dd=abc")
assert.NoError(t, err)
r51, err := io.ReadAll(r5.Body)
assert.NoError(t, err)
r5.Body.Close()
assert.Equal(t, len(r51), 48)
assert.Equal(t, "12", r5.Header.Get("query-age"))
}
func TestClientReuseConn(t *testing.T) {
svr := testServer()
defer svr.Close()
t.Logf("svrAddr: %s", svr.URL)
ctx, cancel := context.WithTimeout(t.Context(), time.Second*2)
defer cancel()
var (
wg sync.WaitGroup
reqCount atomic.Int64
newConn atomic.Int64
)
ctx = httptrace.WithClientTrace(
ctx,
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
newConn.Add(1)
}},
)
concurrent := int64(10)
for range concurrent {
wg.Go(func() {
for {
reqCount.Add(1)
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
if err != nil {
assert.ErrorIs(t, err, context.DeadlineExceeded)
return
}
assert.True(t, len(*r) > 10)
time.Sleep(time.Millisecond * 10)
}
})
}
wg.Wait()
t.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
assert.Less(t, newConn.Load(), concurrent*2)
}
func BenchmarkClientReuseConn(b *testing.B) {
svr := testServer()
defer svr.Close()
b.Logf("addr: %s", svr.URL)
var (
reqCount atomic.Int64
newConn atomic.Int64
)
ctx := httptrace.WithClientTrace(
b.Context(),
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
newConn.Add(1)
}},
)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
reqCount.Add(1)
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
if err != nil {
assert.ErrorIs(b, err, context.Canceled)
return
}
assert.NoError(b, err)
assert.True(b, len(*r) > 10)
time.Sleep(time.Millisecond * 10)
}
})
b.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
assert.Less(b, newConn.Load(), int64(20))
}
func TestClientBody(t *testing.T) {
svr := testServer()
defer svr.Close()
r, err := R[string]().SetBody("test").Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
r, err = R[string]().SetBody([]byte("test")).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
r, err = R[string]().SetBody(strings.NewReader("test")).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
r, err = R[string]().SetBody(map[string]string{"name": "test"}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `{"name":"test"}`)
r, err = R[string]().SetBody([]int{1, 2, 3}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `[1,2,3]`)
r, err = R[string]().SetBody(testResp{UUID: "test"}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `{"uuid":"test"}`)
r, err = R[string]().SetBody(&testResp{UUID: "test"}).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, `{"uuid":"test"}`)
r, err = R[string]().SetBody(nil).Get(svr.URL + "/echo")
require.NoError(t, err)
assert.Equal(t, *r, "")
r, err = R[string]().SetBody(fmt.Errorf("test")).Get(svr.URL + "/echo")
assert.Error(t, err)
assert.Nil(t, r)
r, err = R[string]().SetBody(123).Get(svr.URL + "/echo")
assert.ErrorIs(t, err, ErrUnsupported)
assert.Nil(t, r)
}
func TestClientURL(t *testing.T) {
svr := testServer()
defer svr.Close()
r, err := R[string]().SetClient(NewClient().SetBaseURL("")).
SetBody("test").Get("/echo")
assert.Error(t, err)
assert.Equal(t, "baseURL is required", err.Error())
assert.Nil(t, r)
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL)).
SetBody("test").Get("/echo")
require.NoError(t, err)
assert.Equal(t, *r, "test")
// send as json
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL).With(transport.CurlLog())).
SetBody(map[string]string{"name": "ko"}).Get("/echo")
require.NoError(t, err)
assert.Equal(t, *r, "{\"name\":\"ko\"}")
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc")).
SetBody("test").Get(svr.URL + "/echo")
assert.Error(t, err)
assert.Nil(t, r)
r2, err := R[http.Response]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc?name=ko").With(transport.CurlLog())).
SetBody("test").Get("/echo?abc=ddd")
require.NoError(t, err)
assert.Equal(t, r2.StatusCode, http.StatusNotFound)
}