407 lines
10 KiB
Go
407 lines
10 KiB
Go
/*
|
|
* TencentBlueKing is pleased to support the open source community by making
|
|
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
|
|
* Copyright (C) 2025 Tencent. All rights reserved.
|
|
* Licensed under the MIT License (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
|
|
* Unless required by applicable law or agreed to in writing,
|
|
* software distributed under the License is distributed on
|
|
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
|
* either express or implied. See the License for the
|
|
* specific language governing permissions and limitations under the License.
|
|
* We undertake not to change the open source license (MIT license) applicable
|
|
* to the current version of the project delivered to anyone in the future.
|
|
*/
|
|
|
|
package rest
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json/v2"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httptrace"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"git.ifooth.com/common/pkg/rest/transport"
|
|
)
|
|
|
|
func testServer() *httptest.Server {
|
|
r := http.NewServeMux()
|
|
// 正常数据返回
|
|
r.HandleFunc("GET /uuid", func(w http.ResponseWriter, r *http.Request) {
|
|
data := fmt.Sprintf(`{"uuid": "%s"}`, uuid.NewString())
|
|
|
|
// 原样返回query
|
|
for k, v := range r.URL.Query() {
|
|
w.Header().Set("query-"+k, strings.Join(v, ","))
|
|
}
|
|
|
|
w.Header().Set("content-type", "application/json")
|
|
w.Write([]byte(data))
|
|
})
|
|
|
|
// 状态码测试
|
|
r.HandleFunc("/status/{code}", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("content-type", "text/html; charset=utf-8")
|
|
|
|
code, err := strconv.ParseInt(r.PathValue("code"), 10, 64)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(int(code))
|
|
w.Write([]byte(r.PathValue("code")))
|
|
|
|
})
|
|
|
|
r.HandleFunc("/bkResult/ok", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("content-type", "application/json")
|
|
|
|
w.WriteHeader(int(400))
|
|
w.Write([]byte(`{"code": 0, "data": {"name": "ko"}}`))
|
|
})
|
|
|
|
r.HandleFunc("/bkResult/failed", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("content-type", "application/json")
|
|
|
|
w.WriteHeader(int(400))
|
|
w.Write([]byte(`{"code": 100400, "data": null}`))
|
|
|
|
})
|
|
|
|
// 原样返回任何body数据
|
|
r.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("content-type", "text/html; charset=utf-8")
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// 原样返回query
|
|
for k, v := range r.URL.Query() {
|
|
w.Header().Set("query-"+k, strings.Join(v, ","))
|
|
}
|
|
|
|
w.Write(body)
|
|
})
|
|
|
|
ts := httptest.NewServer(r)
|
|
return ts
|
|
}
|
|
|
|
type testResp struct {
|
|
UUID string `json:"uuid"`
|
|
}
|
|
|
|
type BkResp[T any] struct {
|
|
Code int `json:"code"`
|
|
Data *T `json:"data"`
|
|
}
|
|
|
|
type bkResult struct {
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type BkError struct {
|
|
Code int `json:"code"`
|
|
}
|
|
|
|
func CustomC() *Client {
|
|
c := NewClient().
|
|
SetHeader("abc", "abc").
|
|
SetQueryParam("name", "ko").
|
|
With(transport.CurlLog())
|
|
|
|
return c
|
|
}
|
|
|
|
func TestClientUseAsDo(t *testing.T) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
|
|
assert.NoError(t, err)
|
|
|
|
// 用法1: Client使用Do处理请求
|
|
client := NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set")
|
|
|
|
resp, err := client.Do(req)
|
|
assert.NoError(t, err)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, []byte("testdata"), body)
|
|
assert.Equal(t, resp.Header.Get("query-ko"), "set")
|
|
}
|
|
|
|
func TestClientUseAsTransport(t *testing.T) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata")))
|
|
assert.NoError(t, err)
|
|
|
|
// 用法2: Client作为第三方http库的transport处理请求
|
|
client := &http.Client{}
|
|
client.Transport = NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set").GetTransport()
|
|
|
|
resp, err := client.Do(req)
|
|
assert.NoError(t, err)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, []byte("testdata"), body)
|
|
assert.Equal(t, resp.Header.Get("query-ko"), "set")
|
|
}
|
|
|
|
func TestRequestUseHTTPClient(t *testing.T) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
resp, err := http.Get(svr.URL + "/uuid")
|
|
assert.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
assert.NoError(t, err)
|
|
|
|
r := new(testResp)
|
|
err = json.Unmarshal(body, r)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, len(r.UUID), 36)
|
|
|
|
// 用法3: Request[T]使用原生HTTP Client
|
|
r1, err := R[string]().SetClient(http.DefaultClient).SetBody("testdata").Post(svr.URL + "/echo")
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, "testdata", *r1)
|
|
|
|
r2, err := R[http.Response]().SetClient(http.DefaultClient).With(transport.CurlLog()).SetQueryParam("ko", "hi").SetBody("testdata").Post(svr.URL + "/echo")
|
|
assert.NoError(t, err)
|
|
body, err = io.ReadAll(r2.Body)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, "testdata", string(body))
|
|
assert.Equal(t, "hi", r2.Header.Get("query-ko"))
|
|
}
|
|
|
|
func TestRequestUseRestClient(t *testing.T) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
// 用法4: Request[T]使用rest.Client
|
|
r, err := R[testResp]().Get(svr.URL + "/uuid")
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, len(r.UUID), 36)
|
|
|
|
rb, err := R[[]byte]().Get(svr.URL + "/uuid")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, len(*rb), 48)
|
|
|
|
r3, err := R[io.Reader]().Get(svr.URL + "/uuid")
|
|
assert.NoError(t, err)
|
|
|
|
r31, err := io.ReadAll(*r3)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, len(r31), 48)
|
|
|
|
r4, err := R[io.ReadCloser]().Get(svr.URL + "/uuid")
|
|
assert.NoError(t, err)
|
|
r41, err := io.ReadAll(*r4)
|
|
assert.NoError(t, err)
|
|
(*r4).Close()
|
|
assert.Equal(t, len(r41), 48)
|
|
|
|
r5, err := R[http.Response]().SetClient(CustomC()).
|
|
SetContext(t.Context()).
|
|
SetHeader("ko", "abc").
|
|
SetQueryParam("age", "12").
|
|
Get(svr.URL + "/uuid?dd=abc")
|
|
assert.NoError(t, err)
|
|
r51, err := io.ReadAll(r5.Body)
|
|
assert.NoError(t, err)
|
|
r5.Body.Close()
|
|
assert.Equal(t, len(r51), 48)
|
|
assert.Equal(t, "12", r5.Header.Get("query-age"))
|
|
}
|
|
|
|
func TestClientReuseConn(t *testing.T) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
t.Logf("svrAddr: %s", svr.URL)
|
|
|
|
ctx, cancel := context.WithTimeout(t.Context(), time.Second*2)
|
|
defer cancel()
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
reqCount atomic.Int64
|
|
newConn atomic.Int64
|
|
)
|
|
|
|
ctx = httptrace.WithClientTrace(
|
|
ctx,
|
|
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
|
|
newConn.Add(1)
|
|
}},
|
|
)
|
|
|
|
concurrent := int64(10)
|
|
for range concurrent {
|
|
wg.Go(func() {
|
|
for {
|
|
reqCount.Add(1)
|
|
|
|
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
|
|
if err != nil {
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
return
|
|
}
|
|
|
|
assert.True(t, len(*r) > 10)
|
|
time.Sleep(time.Millisecond * 10)
|
|
}
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
t.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
|
|
assert.Less(t, newConn.Load(), concurrent*2)
|
|
}
|
|
|
|
func BenchmarkClientReuseConn(b *testing.B) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
b.Logf("addr: %s", svr.URL)
|
|
|
|
var (
|
|
reqCount atomic.Int64
|
|
newConn atomic.Int64
|
|
)
|
|
|
|
ctx := httptrace.WithClientTrace(
|
|
b.Context(),
|
|
&httptrace.ClientTrace{ConnectStart: func(_ string, _ string) {
|
|
newConn.Add(1)
|
|
}},
|
|
)
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
for pb.Next() {
|
|
reqCount.Add(1)
|
|
|
|
r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid")
|
|
if err != nil {
|
|
assert.ErrorIs(b, err, context.Canceled)
|
|
return
|
|
}
|
|
|
|
assert.NoError(b, err)
|
|
assert.True(b, len(*r) > 10)
|
|
time.Sleep(time.Millisecond * 10)
|
|
}
|
|
})
|
|
|
|
b.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load())
|
|
assert.Less(b, newConn.Load(), int64(20))
|
|
}
|
|
|
|
func TestClientBody(t *testing.T) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
r, err := R[string]().SetBody("test").Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, "test")
|
|
|
|
r, err = R[string]().SetBody([]byte("test")).Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, "test")
|
|
|
|
r, err = R[string]().SetBody(strings.NewReader("test")).Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, "test")
|
|
|
|
r, err = R[string]().SetBody(map[string]string{"name": "test"}).Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, `{"name":"test"}`)
|
|
|
|
r, err = R[string]().SetBody([]int{1, 2, 3}).Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, `[1,2,3]`)
|
|
|
|
r, err = R[string]().SetBody(testResp{UUID: "test"}).Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, `{"uuid":"test"}`)
|
|
|
|
r, err = R[string]().SetBody(&testResp{UUID: "test"}).Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, `{"uuid":"test"}`)
|
|
|
|
r, err = R[string]().SetBody(nil).Get(svr.URL + "/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, "")
|
|
|
|
r, err = R[string]().SetBody(fmt.Errorf("test")).Get(svr.URL + "/echo")
|
|
assert.Error(t, err)
|
|
assert.Nil(t, r)
|
|
|
|
r, err = R[string]().SetBody(123).Get(svr.URL + "/echo")
|
|
assert.ErrorIs(t, err, ErrUnsupported)
|
|
assert.Nil(t, r)
|
|
}
|
|
|
|
func TestClientURL(t *testing.T) {
|
|
svr := testServer()
|
|
defer svr.Close()
|
|
|
|
r, err := R[string]().SetClient(NewClient().SetBaseURL("")).
|
|
SetBody("test").Get("/echo")
|
|
assert.Error(t, err)
|
|
assert.Equal(t, "baseURL is required", err.Error())
|
|
assert.Nil(t, r)
|
|
|
|
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL)).
|
|
SetBody("test").Get("/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, "test")
|
|
|
|
// send as json
|
|
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL).With(transport.CurlLog())).
|
|
SetBody(map[string]string{"name": "ko"}).Get("/echo")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, *r, "{\"name\":\"ko\"}")
|
|
|
|
r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc")).
|
|
SetBody("test").Get(svr.URL + "/echo")
|
|
assert.Error(t, err)
|
|
assert.Nil(t, r)
|
|
|
|
r2, err := R[http.Response]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc?name=ko").With(transport.CurlLog())).
|
|
SetBody("test").Get("/echo?abc=ddd")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, r2.StatusCode, http.StatusNotFound)
|
|
}
|