/* * TencentBlueKing is pleased to support the open source community by making * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. * Copyright (C) 2025 Tencent. All rights reserved. * Licensed under the MIT License (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at http://opensource.org/licenses/MIT * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, * either express or implied. See the License for the * specific language governing permissions and limitations under the License. * We undertake not to change the open source license (MIT license) applicable * to the current version of the project delivered to anyone in the future. */ package rest import ( "bytes" "context" "encoding/json/v2" "fmt" "io" "net/http" "net/http/httptest" "net/http/httptrace" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "git.ifooth.com/common/pkg/rest/transport" ) func testServer() *httptest.Server { r := http.NewServeMux() // 正常数据返回 r.HandleFunc("GET /uuid", func(w http.ResponseWriter, r *http.Request) { data := fmt.Sprintf(`{"uuid": "%s"}`, uuid.NewString()) // 原样返回query for k, v := range r.URL.Query() { w.Header().Set("query-"+k, strings.Join(v, ",")) } w.Header().Set("content-type", "application/json") w.Write([]byte(data)) }) // 状态码测试 r.HandleFunc("/status/{code}", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("content-type", "text/html; charset=utf-8") code, err := strconv.ParseInt(r.PathValue("code"), 10, 64) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } w.WriteHeader(int(code)) w.Write([]byte(r.PathValue("code"))) }) r.HandleFunc("/bkResult/ok", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("content-type", "application/json") w.WriteHeader(int(400)) w.Write([]byte(`{"code": 0, "data": {"name": "ko"}}`)) }) r.HandleFunc("/bkResult/failed", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("content-type", "application/json") w.WriteHeader(int(400)) w.Write([]byte(`{"code": 100400, "data": null}`)) }) // 原样返回任何body数据 r.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("content-type", "text/html; charset=utf-8") body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // 原样返回query for k, v := range r.URL.Query() { w.Header().Set("query-"+k, strings.Join(v, ",")) } w.Write(body) }) ts := httptest.NewServer(r) return ts } type testResp struct { UUID string `json:"uuid"` } type BkResp[T any] struct { Code int `json:"code"` Data *T `json:"data"` } type bkResult struct { Name string `json:"name"` } type BkError struct { Code int `json:"code"` } func CustomC() *Client { c := NewClient(). SetHeader("abc", "abc"). SetQueryParam("name", "ko"). With(transport.CurlLog()) return c } func TestClientUseAsDo(t *testing.T) { svr := testServer() defer svr.Close() req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata"))) assert.NoError(t, err) // 用法1: Client使用Do处理请求 client := NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set") resp, err := client.Do(req) assert.NoError(t, err) body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, []byte("testdata"), body) assert.Equal(t, resp.Header.Get("query-ko"), "set") } func TestClientUseAsTransport(t *testing.T) { svr := testServer() defer svr.Close() req, err := http.NewRequest("POST", svr.URL+"/echo", bytes.NewBuffer([]byte("testdata"))) assert.NoError(t, err) // 用法2: Client作为第三方http库的transport处理请求 client := &http.Client{} client.Transport = NewClient().With(transport.CurlLog()).SetQueryParam("ko", "set").GetTransport() resp, err := client.Do(req) assert.NoError(t, err) body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, []byte("testdata"), body) assert.Equal(t, resp.Header.Get("query-ko"), "set") } func TestRequestUseHTTPClient(t *testing.T) { svr := testServer() defer svr.Close() resp, err := http.Get(svr.URL + "/uuid") assert.NoError(t, err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) assert.NoError(t, err) r := new(testResp) err = json.Unmarshal(body, r) assert.NoError(t, err) assert.Equal(t, len(r.UUID), 36) // 用法3: Request[T]使用原生HTTP Client r1, err := R[string]().SetClient(http.DefaultClient).SetBody("testdata").Post(svr.URL + "/echo") assert.NoError(t, err) assert.Equal(t, "testdata", *r1) r2, err := R[http.Response]().SetClient(http.DefaultClient).With(transport.CurlLog()).SetQueryParam("ko", "hi").SetBody("testdata").Post(svr.URL + "/echo") assert.NoError(t, err) body, err = io.ReadAll(r2.Body) assert.NoError(t, err) assert.Equal(t, "testdata", string(body)) assert.Equal(t, "hi", r2.Header.Get("query-ko")) } func TestRequestUseRestClient(t *testing.T) { svr := testServer() defer svr.Close() // 用法4: Request[T]使用rest.Client r, err := R[testResp]().Get(svr.URL + "/uuid") assert.NoError(t, err) assert.Equal(t, len(r.UUID), 36) rb, err := R[[]byte]().Get(svr.URL + "/uuid") require.NoError(t, err) assert.Equal(t, len(*rb), 48) r3, err := R[io.Reader]().Get(svr.URL + "/uuid") assert.NoError(t, err) r31, err := io.ReadAll(*r3) assert.NoError(t, err) assert.Equal(t, len(r31), 48) r4, err := R[io.ReadCloser]().Get(svr.URL + "/uuid") assert.NoError(t, err) r41, err := io.ReadAll(*r4) assert.NoError(t, err) (*r4).Close() assert.Equal(t, len(r41), 48) r5, err := R[http.Response]().SetClient(CustomC()). SetContext(t.Context()). SetHeader("ko", "abc"). SetQueryParam("age", "12"). Get(svr.URL + "/uuid?dd=abc") assert.NoError(t, err) r51, err := io.ReadAll(r5.Body) assert.NoError(t, err) r5.Body.Close() assert.Equal(t, len(r51), 48) assert.Equal(t, "12", r5.Header.Get("query-age")) } func TestClientReuseConn(t *testing.T) { svr := testServer() defer svr.Close() t.Logf("svrAddr: %s", svr.URL) ctx, cancel := context.WithTimeout(t.Context(), time.Second*2) defer cancel() var ( wg sync.WaitGroup reqCount atomic.Int64 newConn atomic.Int64 ) ctx = httptrace.WithClientTrace( ctx, &httptrace.ClientTrace{ConnectStart: func(_ string, _ string) { newConn.Add(1) }}, ) concurrent := int64(10) for range concurrent { wg.Go(func() { for { reqCount.Add(1) r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid") if err != nil { assert.ErrorIs(t, err, context.DeadlineExceeded) return } assert.True(t, len(*r) > 10) time.Sleep(time.Millisecond * 10) } }) } wg.Wait() t.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load()) assert.Less(t, newConn.Load(), concurrent*2) } func BenchmarkClientReuseConn(b *testing.B) { svr := testServer() defer svr.Close() b.Logf("addr: %s", svr.URL) var ( reqCount atomic.Int64 newConn atomic.Int64 ) ctx := httptrace.WithClientTrace( b.Context(), &httptrace.ClientTrace{ConnectStart: func(_ string, _ string) { newConn.Add(1) }}, ) b.RunParallel(func(pb *testing.PB) { for pb.Next() { reqCount.Add(1) r, err := R[string]().SetContext(ctx).Get(svr.URL + "/uuid") if err != nil { assert.ErrorIs(b, err, context.Canceled) return } assert.NoError(b, err) assert.True(b, len(*r) > 10) time.Sleep(time.Millisecond * 10) } }) b.Logf("newConn: %d, reqCount: %d", newConn.Load(), reqCount.Load()) assert.Less(b, newConn.Load(), int64(20)) } func TestClientBody(t *testing.T) { svr := testServer() defer svr.Close() r, err := R[string]().SetBody("test").Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, "test") r, err = R[string]().SetBody([]byte("test")).Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, "test") r, err = R[string]().SetBody(strings.NewReader("test")).Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, "test") r, err = R[string]().SetBody(map[string]string{"name": "test"}).Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, `{"name":"test"}`) r, err = R[string]().SetBody([]int{1, 2, 3}).Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, `[1,2,3]`) r, err = R[string]().SetBody(testResp{UUID: "test"}).Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, `{"uuid":"test"}`) r, err = R[string]().SetBody(&testResp{UUID: "test"}).Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, `{"uuid":"test"}`) r, err = R[string]().SetBody(nil).Get(svr.URL + "/echo") require.NoError(t, err) assert.Equal(t, *r, "") r, err = R[string]().SetBody(fmt.Errorf("test")).Get(svr.URL + "/echo") assert.Error(t, err) assert.Nil(t, r) r, err = R[string]().SetBody(123).Get(svr.URL + "/echo") assert.ErrorIs(t, err, ErrUnsupported) assert.Nil(t, r) } func TestClientURL(t *testing.T) { svr := testServer() defer svr.Close() r, err := R[string]().SetClient(NewClient().SetBaseURL("")). SetBody("test").Get("/echo") assert.Error(t, err) assert.Equal(t, "baseURL is required", err.Error()) assert.Nil(t, r) r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL)). SetBody("test").Get("/echo") require.NoError(t, err) assert.Equal(t, *r, "test") // send as json r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL).With(transport.CurlLog())). SetBody(map[string]string{"name": "ko"}).Get("/echo") require.NoError(t, err) assert.Equal(t, *r, "{\"name\":\"ko\"}") r, err = R[string]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc")). SetBody("test").Get(svr.URL + "/echo") assert.Error(t, err) assert.Nil(t, r) r2, err := R[http.Response]().SetClient(NewClient().SetBaseURL(svr.URL + "/abc?name=ko").With(transport.CurlLog())). SetBody("test").Get("/echo?abc=ddd") require.NoError(t, err) assert.Equal(t, r2.StatusCode, http.StatusNotFound) }