main
git 2025-10-02 22:17:46 +08:00
parent 74435682fa
commit 7861f56eb5
Signed by: git
GPG Key ID: 3F65EFFA44207ADD
24 changed files with 1805 additions and 24 deletions

View File

@ -1,20 +0,0 @@
---
kind: pipeline
type: docker
name: test
steps:
- name: fetch
image: alpine/git
commands:
- git fetch --tags
- name: test
image: golang:1.20-alpine
commands:
- apk --update --no-cache add make tzdata git
- make test
- make benchmark
node:
host: devcloud

2
.envrc
View File

@ -0,0 +1,2 @@
export PATH=/opt/go/sdk/go1/bin:$PATH
export GOEXPERIMENT=jsonv2

View File

@ -1,4 +1,6 @@
{ {
"go.goroot": "/opt/go/sdk/go", "go.goroot": "/opt/go/sdk/go1",
"go.gopath": "/root/.go" "go.toolsEnvVars": {
"GOEXPERIMENT": "jsonv2"
}
} }

4
go.mod
View File

@ -1,6 +1,6 @@
module git.ifooth.com/common/pkg module git.ifooth.com/common/pkg
go 1.21 go 1.25
require ( require (
github.com/dustin/go-humanize v1.0.1 github.com/dustin/go-humanize v1.0.1
@ -18,6 +18,7 @@ require (
github.com/samber/lo v1.47.0 github.com/samber/lo v1.47.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0
golang.org/x/net v0.25.0
) )
require ( require (
@ -50,7 +51,6 @@ require (
go.opentelemetry.io/otel/metric v1.27.0 // indirect go.opentelemetry.io/otel/metric v1.27.0 // indirect
go.opentelemetry.io/otel/trace v1.27.0 // indirect go.opentelemetry.io/otel/trace v1.27.0 // indirect
golang.org/x/crypto v0.23.0 // indirect golang.org/x/crypto v0.23.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0 // indirect golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.16.0 // indirect golang.org/x/text v0.16.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect

276
rest/client.go Normal file
View File

@ -0,0 +1,276 @@
package rest
import (
"bytes"
"context"
"encoding/json/v2"
"encoding/xml"
"fmt"
"io"
"maps"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"golang.org/x/net/html"
)
type Client[T any] struct {
httpClient *http.Client
commonReq *http.Request
commonFormData url.Values
commonQueryParams url.Values
commonCookies []*http.Cookie
}
func NewClient[T any]() *Client[T] {
r := &Client[T]{
commonReq: &http.Request{Header: make(http.Header)},
httpClient: &http.Client{},
commonQueryParams: url.Values{},
}
return r
}
func (c *Client[T]) SetHeader(header, value string) *Client[T] {
c.commonReq.Header.Set(header, value)
return c
}
func (c *Client[T]) SetQueryParam(param, value string) *Client[T] {
c.commonQueryParams.Set(param, value)
return c
}
func (c *Client[T]) SetTransport(transport http.RoundTripper) *Client[T] {
if transport != nil {
c.httpClient.Transport = transport
}
return c
}
func (c *Client[T]) R() *Request[T] {
req := &http.Request{
Method: http.MethodGet,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
}
req = req.WithContext(context.Background())
r := &Request[T]{
req: req,
client: c,
queryParams: url.Values{},
}
return r
}
func (c *Client[T]) execute(req *http.Request) (*http.Response, error) {
// change header
maps.Copy(req.Header, c.commonReq.Header)
// change query
query := req.URL.Query()
maps.Copy(query, c.commonQueryParams)
req.URL.RawQuery = query.Encode()
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
type Request[T any] struct {
req *http.Request
client *Client[T]
queryParams url.Values
}
func (r *Request[T]) Clone() *Request[T] {
newReq := r.client.R()
newReq.req = r.req.Clone(r.req.Context())
newReq.client = r.client
newReq.queryParams = r.queryParams
return newReq
}
func (r *Request[T]) SetHeader(header, value string) *Request[T] {
r.req.Header.Set(header, value)
return r
}
func (r *Request[T]) SetQueryParam(param, value string) *Request[T] {
r.queryParams.Set(param, value)
return r
}
func (r *Request[T]) SetContext(ctx context.Context) *Request[T] {
r.req = r.req.WithContext(ctx)
return r
}
func (r *Request[T]) Get(url string) (*T, error) {
return r.Execute(http.MethodGet, url)
}
func (r *Request[T]) Execute(method string, url string) (*T, error) {
out := new(T)
rt := reflect.TypeFor[T]()
newReq, err := http.NewRequestWithContext(r.req.Context(), method, url, r.req.Body)
if err != nil {
return nil, err
}
// change a req header
maps.Copy(newReq.Header, r.req.Header)
// change a req query
query := newReq.URL.Query()
maps.Copy(query, r.queryParams)
newReq.URL.RawQuery = query.Encode()
resp, err := r.client.execute(newReq)
if err != nil {
return nil, err
}
// io.ReadCloser 需要自己调用Close
switch any(out).(type) {
case *http.Response:
return any(resp).(*T), nil
case *io.ReadCloser:
return any(&resp.Body).(*T), nil
case *Response: // for addition trace info, reuse body an so on
r := Response{resp: resp, trace: &TraceInfo{}}
return any(&r).(*T), nil
}
// body 自动close
defer resp.Body.Close()
switch any(out).(type) {
case *[]byte:
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
return any(&body).(*T), nil
case *string:
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
v := string(body)
return any(&v).(*T), nil
case *io.Reader:
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
v := io.Reader(bytes.NewBuffer(body))
return any(&v).(*T), nil
case *html.Node:
node, rErr := html.Parse(resp.Body)
if rErr != nil {
return nil, rErr
}
return any(node).(*T), nil
}
if rt.Kind() != reflect.Struct {
return nil, fmt.Errorf("unsupported type: %v", rt)
}
// 按类型反序列化
contentType := resp.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
if err = json.UnmarshalRead(resp.Body, out); err != nil {
return nil, err
}
return out, nil
}
if strings.HasPrefix(contentType, "application/xml") {
// 需要读取整个body
body, rErr := io.ReadAll(resp.Body)
if rErr != nil {
return nil, rErr
}
if err = xml.Unmarshal(body, out); err != nil {
return nil, err
}
return out, nil
}
return nil, fmt.Errorf("unsupported content-type: %s", contentType)
}
// TraceInfo struct is used to provide request trace info such as DNS lookup
// duration, Connection obtain duration, Server processing duration, etc.
type TraceInfo struct {
// DNSLookup is the duration that transport took to perform
// DNS lookup.
DNSLookup time.Duration `json:"dns_lookup_time"`
// ConnTime is the duration it took to obtain a successful connection.
ConnTime time.Duration `json:"connection_time"`
// TCPConnTime is the duration it took to obtain the TCP connection.
TCPConnTime time.Duration `json:"tcp_connection_time"`
// TLSHandshake is the duration of the TLS handshake.
TLSHandshake time.Duration `json:"tls_handshake_time"`
// ServerTime is the server's duration for responding to the first byte.
ServerTime time.Duration `json:"server_time"`
// ResponseTime is the duration since the first response byte from the server to
// request completion.
ResponseTime time.Duration `json:"response_time"`
// TotalTime is the duration of the total time request taken end-to-end.
TotalTime time.Duration `json:"total_time"`
// IsConnReused is whether this connection has been previously
// used for another HTTP request.
IsConnReused bool `json:"is_connection_reused"`
// IsConnWasIdle is whether this connection was obtained from an
// idle pool.
IsConnWasIdle bool `json:"is_connection_was_idle"`
// ConnIdleTime is the duration how long the connection that was previously
// idle, if IsConnWasIdle is true.
ConnIdleTime time.Duration `json:"connection_idle_time"`
// RequestAttempt is to represent the request attempt made during a Resty
// request execution flow, including retry count.
RequestAttempt int `json:"request_attempt"`
// RemoteAddr returns the remote network address.
RemoteAddr string `json:"remote_address"`
}
type Response struct {
trace *TraceInfo
resp *http.Response
receivedAt time.Time
}
func (r *Response) Duration() time.Duration {
return r.trace.TotalTime
}
// ReceivedAt method returns the time we received a response from the server for the request.
func (r *Response) ReceivedAt() time.Time {
return r.receivedAt
}

116
rest/client_test.go Normal file
View File

@ -0,0 +1,116 @@
package rest
import (
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"git.ifooth.com/common/pkg/rest/transport"
"github.com/stretchr/testify/assert"
"golang.org/x/net/html"
)
type Resp1 struct {
UUID string `json:"uuid"`
}
type BkResult[T any] struct {
Code int
Data *T
}
func getTitle(n *html.Node) string {
if n.Type == html.ElementNode && n.Data == "title" {
if n.FirstChild != nil {
return n.FirstChild.Data
}
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
if title := getTitle(c); title != "" {
return title
}
}
return ""
}
func TestGet(t *testing.T) {
resp, err := http.Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
fmt.Println(string(body))
r := new(Resp1)
err = json.Unmarshal(body, r)
assert.NoError(t, err)
assert.True(t, len(r.UUID) == 36)
}
func CustomC[T any]() *Client[T] {
c := NewClient[T]()
c.SetHeader("abc", "abc")
c.SetQueryParam("name", "ko")
return c
}
func TestGetRest(t *testing.T) {
r, err := NewClient[Resp1]().R().Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err)
fmt.Println(r)
assert.True(t, len(r.UUID) == 36)
rb, err := NewClient[[]byte]().R().Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err)
fmt.Println("abc", string(*rb))
r3, err := NewClient[io.Reader]().R().Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err)
r31, err := io.ReadAll(*r3)
assert.NoError(t, err)
fmt.Println("r31", string(r31))
r4, err := NewClient[io.ReadCloser]().R().Get("http://127.0.0.1:38033/uuid")
assert.NoError(t, err)
r41, err := io.ReadAll(*r4)
assert.NoError(t, err)
(*r4).Close()
fmt.Println("r41", string(r41))
ts := transport.Chain(http.DefaultTransport,
transport.RequestID,
transport.CurlLogTransport(map[string]string{}),
)
r5, err := CustomC[http.Response]().SetTransport(ts).R().SetHeader("ko", "abc").
SetQueryParam("age", "12").
Get("http://127.0.0.1:38033/uuid?dd=abc")
assert.NoError(t, err)
r51, err := io.ReadAll(r5.Body)
assert.NoError(t, err)
r5.Body.Close()
fmt.Println("r51", string(r51))
r6, err := NewClient[html.Node]().R().Get("http://127.0.0.1:38033")
assert.NoError(t, err)
title := getTitle(r6)
assert.Equal(t, "httpbin.org", title)
}
func TestGetRestError(t *testing.T) {
r, err := NewClient[int]().R().Get("http://127.0.0.1:38033/uuid")
assert.Error(t, err)
fmt.Println(err)
assert.Nil(t, r)
}

133
rest/codec/codec.go Normal file
View File

@ -0,0 +1,133 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
// Package codec provides encoding and decoding utilities across various formats
package codec
import (
"errors"
"fmt"
"net/http"
"reflect"
)
func decodeTo(r *http.Request, val any) error {
rt := reflect.TypeOf(val).Elem()
rv := reflect.ValueOf(val).Elem()
// json 整个解析
jsonCodec := NewJsonCodec(r)
if err := jsonCodec.Decode(val); err != nil {
return err
}
formCodec, err := NewFormCodec(r)
if err != nil {
return err
}
pathCodec := NewPathCodec(r)
queryCodec := NewQueryCodec(r)
headerCodec := NewHeaderCodec(r)
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)
// 非导出需要跳过, 无法设置值
if !field.IsExported() {
continue
}
tagStr := field.Tag.Get(tagName)
if tagStr == "" {
continue
}
tag, err := parseTag(tagStr)
if err != nil {
return err
}
fv := rv.Field(i)
if err := formCodec.Decode(field, fv, tag); err != nil {
return err
}
if err := queryCodec.Decode(field, fv, tag); err != nil {
return err
}
if err := headerCodec.Decode(field, fv, tag); err != nil {
return err
}
if err := pathCodec.Decode(field, fv, tag); err != nil {
return err
}
}
return nil
}
// Decode 按结构体反序列化Request
func Decode[T any](r *http.Request) (*T, error) {
rt := reflect.TypeFor[T]()
if rt.Kind() != reflect.Struct {
return nil, errors.New("generic type T must be a struct type")
}
t := new(T)
err := decodeTo(r, t)
if err != nil {
return nil, fmt.Errorf("codec decode: %w", err)
}
return t, nil
}
// getFieldValue 获取字段值
func getFieldValue(field reflect.Type, tag *Tag, values []string) (reflect.Value, error) {
// 指针类型
if field.Kind() == reflect.Pointer {
typ := field.Elem()
rv, err := getFieldValue(typ, tag, values)
if err != nil {
return reflect.Value{}, err
}
newPtr := reflect.New(typ)
newPtr.Elem().Set(rv)
return newPtr, nil
}
// slice类型
if field.Kind() == reflect.Slice {
typ := field.Elem()
// []byte 特殊处理
if typ == byteType {
return ParseValue(field, values[0], tag.Option)
}
val := reflect.MakeSlice(field, 0, len(values))
for _, v := range values {
rv, err := getFieldValue(typ, tag, []string{v})
if err != nil {
return reflect.Value{}, err
}
val = reflect.Append(val, rv)
}
return val, nil
}
return ParseValue(field, values[0], tag.Option)
}

77
rest/codec/form.go Normal file
View File

@ -0,0 +1,77 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package codec
import (
"fmt"
"net/http"
"net/url"
"reflect"
"strings"
)
type formCodec struct {
values url.Values
isForm bool
}
// NewFormCodec ...
func NewFormCodec(r *http.Request) (*formCodec, error) {
isForm := false
contentType := r.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
if err := r.ParseForm(); err != nil {
return nil, err
}
isForm = true
}
c := &formCodec{
values: r.PostForm,
isForm: isForm,
}
return c, nil
}
// Decode ...
func (c *formCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error {
if !c.isForm {
return nil
}
formTag, ok := tag.Option["form"]
if !ok {
return nil
}
v := c.values[formTag]
if len(v) == 0 {
return nil
}
rv, err := getFieldValue(field.Type, tag, v)
if err != nil {
return err
}
if !rv.IsValid() {
return fmt.Errorf("%s not valid", rv)
}
fv.Set(rv)
return nil
}

59
rest/codec/header.go Normal file
View File

@ -0,0 +1,59 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package codec
import (
"fmt"
"net/http"
"reflect"
)
type headerCodec struct {
values http.Header
}
// NewHeaderCodec ...
func NewHeaderCodec(r *http.Request) *headerCodec {
c := &headerCodec{values: r.Header}
return c
}
// Decode ...
func (c *headerCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error {
headTag, ok := tag.Option["header"]
if !ok {
return nil
}
// header统一格式
key := http.CanonicalHeaderKey(headTag)
v := c.values[key]
if len(v) == 0 {
return nil
}
rv, err := getFieldValue(field.Type, tag, v)
if err != nil {
return err
}
if !rv.IsValid() {
return fmt.Errorf("%s not valid", rv)
}
fv.Set(rv)
return nil
}

61
rest/codec/json.go Normal file
View File

@ -0,0 +1,61 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package codec
import (
"encoding/json/v2"
"fmt"
"io"
"net/http"
"strings"
)
type jsonCodec struct {
isJson bool
req *http.Request
}
// NewJsonCodec ...
func NewJsonCodec(r *http.Request) *jsonCodec {
isJson := false
contentType := r.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
isJson = true
}
return &jsonCodec{req: r, isJson: isJson}
}
// Decode ...
func (j *jsonCodec) Decode(val any) error {
if !j.isJson {
return nil
}
body, err := io.ReadAll(j.req.Body)
if err != nil {
return err
}
if len(body) == 0 {
return fmt.Errorf("json body is empty")
}
if err := json.Unmarshal(body, val); err != nil {
return fmt.Errorf("unmarshal json body: %s", err)
}
return nil
}

57
rest/codec/path.go Normal file
View File

@ -0,0 +1,57 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package codec
import (
"fmt"
"net/http"
"reflect"
)
type pathCodec struct {
req *http.Request
}
// NewPathCodec ...
func NewPathCodec(r *http.Request) *pathCodec {
c := &pathCodec{req: r}
return c
}
// Decode ...
func (c *pathCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error {
pathTag, ok := tag.Option["path"]
if !ok {
return nil
}
pv := c.req.PathValue(pathTag)
if pv == "" {
return nil
}
rv, err := getFieldValue(field.Type, tag, []string{pv})
if err != nil {
return err
}
if !rv.IsValid() {
return fmt.Errorf("%s not valid", rv)
}
fv.Set(rv)
return nil
}

58
rest/codec/query.go Normal file
View File

@ -0,0 +1,58 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package codec
import (
"fmt"
"net/http"
"net/url"
"reflect"
)
type queryCodec struct {
values url.Values
}
// NewQueryCodec ...
func NewQueryCodec(r *http.Request) *queryCodec {
c := &queryCodec{values: r.URL.Query()}
return c
}
// Decode ...
func (c *queryCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error {
queryTag, ok := tag.Option["query"]
if !ok {
return nil
}
v := c.values[queryTag]
if len(v) == 0 {
return nil
}
rv, err := getFieldValue(field.Type, tag, v)
if err != nil {
return err
}
if !rv.IsValid() {
return fmt.Errorf("%s not valid", rv)
}
fv.Set(rv)
return nil
}

62
rest/codec/tag.go Normal file
View File

@ -0,0 +1,62 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package codec
import (
"fmt"
"strings"
)
const (
// tagName 结构体tag名称
// 格式参考 https://pkg.go.dev/encoding/json/v2#example-package-FormatFlags
tagName = "req"
)
// Tag is a struct tag
type Tag struct {
Option map[string]string
}
func parseTag(tagStr string) (*Tag, error) {
tagStr = strings.TrimSpace(tagStr)
if tagStr == "" {
return nil, fmt.Errorf("tag is empty")
}
parts := strings.Split(tagStr, ",")
t := &Tag{
Option: map[string]string{},
}
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
return nil, fmt.Errorf("tag option not valid")
}
opt := strings.SplitN(part, ":", 2)
key := opt[0]
val := ""
if len(opt) == 2 {
val = opt[1]
}
t.Option[key] = val
}
return t, nil
}

195
rest/codec/value_parser.go Normal file
View File

@ -0,0 +1,195 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package codec
import (
"errors"
"fmt"
"reflect"
"strconv"
"time"
)
var (
// ErrUnsupportedType unsupported parse type
ErrUnsupportedType = errors.New("unsupported type")
parserRegistry = make(map[reflect.Type]Parser)
byteType = reflect.TypeFor[byte]()
)
// Parser defines the interface for converting a string to a reflect.Value
// Implementations should parse a string representation into a Go value
type Parser interface {
Parse(s string) (reflect.Value, error)
}
// NewParser defines a factory interface for creating Parser instances
// Implementations should return a new Parser configured with the provided options
type NewParser interface {
// New creates a new Parser instance with the given configuration options
// The options map contains key-value pairs for parser configuration
New(opt map[string]string) Parser
}
// ParserFunc is a function type that implements the Parser interface
// Allows regular functions to be used as Parser implementations
type ParserFunc func(s string) (reflect.Value, error)
// Parse implements the Parser interface for ParserFunc
// This adapter allows any function with the signature func(string) (reflect.Value, error)
// to be used as a Parser without defining a new type
func (p ParserFunc) Parse(s string) (reflect.Value, error) {
return p(s)
}
// RegisterParser registers a parser implementation for a specific type T
func RegisterParser[T any](p Parser) {
parserRegistry[reflect.TypeFor[T]()] = p
}
// ParseValue converts a string to a value of the specified type using registered parsers
// Returns an error if parsing fails or the value type is unsupported
func ParseValue(rt reflect.Type, s string, opt map[string]string) (reflect.Value, error) {
parser, ok := parserRegistry[rt]
if !ok {
return reflect.Value{}, fmt.Errorf("%w: %v", ErrUnsupportedType, rt)
}
// 实现自定义初始化
if v, ok := parser.(NewParser); ok {
parser = v.New(opt)
}
return parser.Parse(s)
}
// StringParser ...
func StringParser(s string) (reflect.Value, error) {
return reflect.ValueOf(s), nil
}
// BoolParser ...
func BoolParser(s string) (reflect.Value, error) {
v, err := strconv.ParseBool(s)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(v), nil
}
// Int generic int parser
type Int[T int | int8 | int16 | int32 | int64] struct {
bitSize int
}
// Parse int parser
func (i Int[T]) Parse(s string) (reflect.Value, error) {
v, err := strconv.ParseInt(s, 10, i.bitSize)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(T(v)), nil
}
// Uint generic uint parser
type Uint[T uint | uint8 | uint16 | uint32 | uint64] struct {
bitSize int
}
// Parse uint parser
func (i Uint[T]) Parse(s string) (reflect.Value, error) {
v, err := strconv.ParseUint(s, 10, i.bitSize)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(T(v)), nil
}
// Float generic float parser
type Float[T float32 | float64] struct {
bitSize int
}
// Parse float parser
func (f Float[T]) Parse(s string) (reflect.Value, error) {
v, err := strconv.ParseFloat(s, f.bitSize)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(T(v)), nil
}
// ByteSlice is a wrapper of []byte to implement Parser
type ByteSlice []byte
// Parse ...
func (bs ByteSlice) Parse(s string) (reflect.Value, error) {
v := []byte(s)
return reflect.ValueOf(v), nil
}
// Time is a parser with format option
type Time struct {
option map[string]string
}
// Parse ...
func (t Time) Parse(s string) (reflect.Value, error) {
format := t.option["format"]
if format == "" {
format = time.DateTime
}
v, err := time.Parse(format, s)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(v), nil
}
// New ...
func (t *Time) New(opt map[string]string) Parser {
newParser := &Time{option: opt}
return newParser
}
func init() {
// buildin parser
RegisterParser[string](ParserFunc(StringParser))
RegisterParser[bool](ParserFunc(BoolParser))
RegisterParser[int](Int[int]{0})
RegisterParser[int8](Int[int8]{8})
RegisterParser[int16](Int[int16]{16})
RegisterParser[int32](Int[int32]{32})
RegisterParser[int64](Int[int64]{64})
RegisterParser[uint](Uint[uint]{0})
RegisterParser[uint8](Uint[uint8]{8})
RegisterParser[uint16](Uint[uint16]{16})
RegisterParser[uint32](Uint[uint32]{32})
RegisterParser[uint64](Uint[uint64]{64})
RegisterParser[float32](Float[float32]{32})
RegisterParser[float64](Float[float64]{64})
RegisterParser[[]byte](ByteSlice{})
RegisterParser[time.Time](&Time{})
}

142
rest/generic.go Normal file
View File

@ -0,0 +1,142 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
// Package rest framework
package rest
import (
"context"
"log/slog"
"net/http"
"time"
)
// UnaryFunc Unary or ClientStreaming handle function
type UnaryFunc[Req, Resp any] func(context.Context, *Req) (*Resp, error)
// StreamingServer server or bidi streaming server
type StreamingServer interface {
http.ResponseWriter
Context() context.Context
}
// StreamFunc ServerStreaming or BidiStreaming handle function
type StreamFunc[Req any] func(*Req, StreamingServer) error
// Handle Composable HTTP Handlers using generics
func Handle[Req, Resp any](fn UnaryFunc[Req, Resp]) func(w http.ResponseWriter, r *http.Request) {
handleName := getHandleName(fn)
f := func(w http.ResponseWriter, r *http.Request) {
st := time.Now()
var err error
defer func() {
collectHandleMetrics(handleName, r.Method, st, err)
}()
// 反序列化
in, err := decodeReq[Req](r)
if err != nil {
slog.Error("handle decode request failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
// 参数校验
if err = validateReq(r.Context(), in); err != nil {
slog.Error("validate req failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
out, err := fn(r.Context(), in)
if err != nil {
_ = APIError(err).Render(w, r)
return
}
_ = APIOK(out).Render(w, r)
}
return f
}
type streamingServer struct {
http.ResponseWriter
*http.ResponseController
ctx context.Context
}
// Context return svr's context
func (s *streamingServer) Context() context.Context {
return s.ctx
}
// Stream Composable HTTP Handlers using generics
func Stream[Req any](fn StreamFunc[Req]) func(w http.ResponseWriter, r *http.Request) {
handleName := getHandleName(fn)
f := func(w http.ResponseWriter, r *http.Request) {
st := time.Now()
var err error
defer func() {
collectHandleMetrics(handleName, r.Method, st, err)
}()
// 反序列化
in, err := decodeReq[Req](r)
if err != nil {
slog.Error("handle decode stream request failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
// 参数校验
if err = validateReq(r.Context(), in); err != nil {
slog.Error("validate stream req failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
svr := &streamingServer{
ResponseWriter: w,
ResponseController: http.NewResponseController(w),
ctx: r.Context(),
}
err = fn(in, svr)
if err != nil {
_ = APIError(err).Render(w, r)
}
}
return f
}
// EmptyReq 空的请求
type EmptyReq struct{}
// EmptyResp 空的返回
type EmptyResp struct{}
// PaginationReq 分页接口通用请求
type PaginationReq struct {
Offset int `json:"offset" in:"query=offset" validate:"gte=0"`
Limit int `json:"limit" in:"query=limit" validate:"gte=0"`
}
// PaginationResp 分页接口通用返回
type PaginationResp[T any] struct {
Count int64 `json:"count"`
Items []T `json:"items"`
}

7
rest/init.go Normal file
View File

@ -0,0 +1,7 @@
package rest
import "git.ifooth.com/common/pkg/logger"
func init() {
logger.Init()
}

76
rest/metrics.go Normal file
View File

@ -0,0 +1,76 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest
import (
"reflect"
"runtime"
"strconv"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
)
var (
requestCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Number of get requests.",
},
[]string{"handler", "method", "code"},
)
responseTimeDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "Histogram of response time for HTTP requests.",
Buckets: []float64{0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60},
},
[]string{"handler", "method", "code"},
)
)
// getHandleName 获取FuncHandle/StreamHandle函数名
func getHandleName(fn any) string {
fullName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
if fullName == "" {
panic("get func name is empty")
}
parts := strings.Split(fullName, ".")
lastPart := parts[len(parts)-1]
name := strings.TrimSuffix(lastPart, "-fm")
return name
}
// collectHandleMetrics api指标数据
func collectHandleMetrics(funcName, method string, st time.Time, err error) {
code := 200
if err != nil {
code = APIError(err).(*APIResponse).HTTPCode
}
codeStr := strconv.Itoa(code)
requestCounter.WithLabelValues(funcName, method, codeStr).Inc()
duration := time.Since(st).Seconds()
responseTimeDuration.WithLabelValues(funcName, method, codeStr).Observe(duration)
}
func init() {
prometheus.MustRegister(requestCounter)
prometheus.MustRegister(responseTimeDuration)
}

62
rest/request.go Normal file
View File

@ -0,0 +1,62 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest
import (
"context"
"net/http"
"git.ifooth.com/common/pkg/rest/codec"
"git.ifooth.com/common/pkg/validator"
)
// decodeReq ...
func decodeReq[T any](r *http.Request) (*T, error) {
in := new(T)
// http.Request 直接返回
if _, ok := any(in).(*http.Request); ok {
return any(r).(*T), nil
}
// 空值不需要反序列化
if _, ok := any(in).(*EmptyReq); ok {
return in, nil
}
in, err := codec.Decode[T](r)
if err != nil {
return nil, err
}
return in, nil
}
// validate 参数校验
func validateReq(ctx context.Context, req any) error {
// http.Request 直接返回
if _, ok := req.(*http.Request); ok {
return nil
}
// 空值不需要校验
if _, ok := req.(*EmptyReq); ok {
return nil
}
return validator.Struct(ctx, req)
}

134
rest/request_test.go Normal file
View File

@ -0,0 +1,134 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest
import (
"bytes"
"io"
"net/http"
"net/url"
"testing"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.ifooth.com/common/pkg/rest/codec"
)
// reqStruct for rest
type reqStruct struct {
Org string `json:"org" req:"path:org" in:"path=org"`
Name string `json:"name" req:"query:name" in:"query=name"`
Age int32 `json:"age" req:"form:age" in:"form=age"`
Bool bool `json:"bool" req:"query:bool" in:"query=bool"`
AgePtr *int `json:"agePtr" req:"header:age_ptr" in:"header=age_ptr"`
SliceStr []string `json:"sliceStr" req:"query:slice_str" in:"query=slice_str"`
Page int64 `json:"page"`
}
// newMockRequest creates a new mock request.
func newMockRequest(t testing.TB, method string, header map[string]string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, "/{org}/vm/xxx?name=alices&age=20&age_ptr=21&slice_str=1&slice_str=2", body)
req.SetPathValue("org", "myOrg")
for k, v := range header {
req.Header.Set(k, v)
}
req.Body = io.NopCloser(body)
require.NoError(t, err)
return req
}
func TestDecode(t *testing.T) {
header := map[string]string{
"age_ptr": "21",
}
r := newMockRequest(t, http.MethodGet, header, nil)
req, err := decodeReq[reqStruct](r)
assert.NoError(t, err)
assert.Equal(t, "myOrg", req.Org)
assert.Equal(t, "alices", req.Name)
assert.Equal(t, lo.ToPtr(21), req.AgePtr)
assert.Equal(t, []string{"1", "2"}, req.SliceStr)
}
func TestFormDecode(t *testing.T) {
header := map[string]string{
"age_ptr": "21",
"Content-Type": "application/x-www-form-urlencoded",
}
formData := url.Values{}
formData.Set("age", "20")
r := newMockRequest(t, http.MethodPost, header, bytes.NewBufferString(formData.Encode()))
req, err := decodeReq[reqStruct](r)
assert.NoError(t, err)
assert.Equal(t, "myOrg", req.Org)
assert.Equal(t, "alices", req.Name)
assert.Equal(t, int32(20), req.Age)
assert.Equal(t, lo.ToPtr(21), req.AgePtr)
assert.Equal(t, []string{"1", "2"}, req.SliceStr)
}
func TestJsonDecode(t *testing.T) {
header := map[string]string{
"age_ptr": "21",
"Content-Type": "application/json",
}
jsonData := `{"page": 64}`
r := newMockRequest(t, http.MethodPost, header, bytes.NewBufferString(jsonData))
req, err := decodeReq[reqStruct](r)
assert.NoError(t, err)
assert.Equal(t, "myOrg", req.Org)
assert.Equal(t, "alices", req.Name)
assert.Equal(t, lo.ToPtr(21), req.AgePtr)
assert.Equal(t, []string{"1", "2"}, req.SliceStr)
assert.Equal(t, int64(64), req.Page)
}
func TestDecodeErr(t *testing.T) {
header := map[string]string{
"age_ptr": "21",
}
// array not support
type Req2 struct {
SliceStr [1]string `json:"sliceStr" req:"query:slice_str"`
}
r := newMockRequest(t, http.MethodGet, header, nil)
_, err := decodeReq[Req2](r)
assert.ErrorIs(t, err, codec.ErrUnsupportedType)
}
func BenchmarkDecodeReq(b *testing.B) {
for b.Loop() {
r := newMockRequest(b, http.MethodGet, nil, nil)
req, err := decodeReq[reqStruct](r)
if err != nil {
b.Fatal(err)
}
if req.Name != "alices" {
b.Fatal("name not equal")
}
}
}

63
rest/response.go Normal file
View File

@ -0,0 +1,63 @@
/*
* Tencent is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest
import (
"encoding/json/v2"
"net/http"
)
// Renderer interface for managing response payloads.
type Renderer interface {
Render(w http.ResponseWriter, r *http.Request) error
}
// APIResponse response for api request
type APIResponse struct {
HTTPCode int `json:"-"` // http response status code
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
}
// Render chi render interface implementation
func (e *APIResponse) Render(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(e.HTTPCode)
return json.MarshalWrite(w, e)
}
// APIOK 正常返回
func APIOK(data any) Renderer {
return &APIResponse{
Message: "request OK",
HTTPCode: http.StatusOK,
Code: 0,
Data: data,
}
}
// APIError 错误返回
func APIError(err error) Renderer {
return &APIResponse{
Message: err.Error(),
HTTPCode: http.StatusBadRequest,
Code: 40000,
}
}

126
rest/transport/curl.go Normal file
View File

@ -0,0 +1,126 @@
package transport
import (
"bytes"
"fmt"
"io"
"log/slog"
"net/http"
"time"
)
// reqToCurl curl 格式的请求日志
func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
// 过滤掉敏感信息, header 和 query
headers := ""
for key, values := range r.Header {
for _, value := range values {
if _, ok := c.maskKeys[key]; ok {
value = "***"
}
headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value))
}
}
rawURL := *r.URL
queryValue := rawURL.Query()
for key := range queryValue {
if _, ok := c.maskKeys[key]; ok {
queryValue.Set(key, "<masked>")
}
}
rawURL.RawQuery = queryValue.Encode()
reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers)
if r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}
r.Body.Close()
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if len(bodyBytes) > 1024 {
reqMsg += fmt.Sprintf(" -d '%s...(total %dB)'", bodyBytes[:1024], len(bodyBytes))
} else {
reqMsg += fmt.Sprintf(" -d '%s'", bodyBytes)
}
}
return reqMsg, nil
}
// respToCurl 返回日志
func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) {
var (
bodyBytes []byte
err error
)
if resp.Body != nil {
bodyBytes, err = io.ReadAll(resp.Body)
if err != nil {
return "", err
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
if len(bodyBytes) > 1024 {
respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)",
resp.Status, time.Since(st), bodyBytes[:1024], len(bodyBytes))
return respMsg, nil
}
if len(bodyBytes) > 0 {
respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes)
return respMsg, nil
}
respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st))
return respMsg, nil
}
// curlLogTransport print curl log transport
type curlLogTransport struct {
maskKeys map[string]string
}
// RoundTrip curlLog Transport
func (t *curlLogTransport) Transport(next http.RoundTripper) http.RoundTripper {
f := func(req *http.Request) (*http.Response, error) {
logger := slog.With("request_id", req.Header.Get("X-Request-Id"))
st := time.Now()
// 记录请求
rbody, err := t.reqToCurl(req)
if err != nil {
return nil, err
}
logger.Info("curl", "REQ", rbody)
resp, err := next.RoundTrip(req)
if err != nil {
logger.Error("curl", "RESP: [err]", err)
return nil, err
}
// 记录返回
respBody, err := t.respToCurl(resp, st)
if err != nil {
return nil, err
}
logger.Info("curl", "RESP", respBody)
return resp, nil
}
return RoundTripFunc(f)
}
// CurlLogTransport make a new curl log transport, default transport can be nil
func CurlLogTransport(maskkeys map[string]string) func(http.RoundTripper) http.RoundTripper {
t := curlLogTransport{map[string]string{}}
return t.Transport
}

View File

@ -0,0 +1,30 @@
package transport
import (
"math/rand/v2"
"net/http"
)
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
// genRequestID generates a random string of a given length.
func genRequestID() string {
b := make([]byte, 32)
for i := range b {
b[i] = charset[rand.IntN(len(charset))]
}
return string(b)
}
// RequestID ...
func RequestID(next http.RoundTripper) http.RoundTripper {
f := func(req *http.Request) (*http.Response, error) {
requestID := genRequestID()
ctx := req.Context()
req.Header.Set("X-Request-Id", requestID)
req = req.WithContext(ctx)
return next.RoundTrip(req)
}
return RoundTripFunc(f)
}

1
rest/transport/retry.go Normal file
View File

@ -0,0 +1 @@
package transport

View File

@ -0,0 +1,62 @@
package transport
import "net/http"
// RoundTripFunc, similar to http.HandlerFunc, is an adapter
// to allow the use of ordinary functions as http.RoundTrippers.
type RoundTripFunc func(r *http.Request) (*http.Response, error)
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
// Chain wraps given base RoundTripper, which is used to make HTTP requests
// (e.g. http.DefaultTransport) with RoundTripper middlewares.
func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripper) *chain {
if base == nil {
base = http.DefaultTransport
}
// Filter out nil transports.
mws := []func(http.RoundTripper) http.RoundTripper{}
for _, fn := range mw {
if fn != nil {
mws = append(mws, fn)
}
}
if c, ok := base.(*chain); ok {
c.middlewares = append(c.middlewares, mws...)
return c
}
return &chain{
baseTransport: base,
middlewares: mws,
}
}
type chain struct {
baseTransport http.RoundTripper
middlewares []func(http.RoundTripper) http.RoundTripper
}
func (c *chain) RoundTrip(req *http.Request) (*http.Response, error) {
rt := c.baseTransport
// Apply middlewares in reversed order so the first middleware becomes
// the innermost onion layer and the last becomes the outermost. Example:
// Given
// [Auth, VCTraceID, Debug],
// the middlewares are applied in this order:
// rt = Debug(rt)
// rt = VCTraceID(rt)
// rt = Auth(rt)
// The Auth and VCTraceID are called before the Debug middleware,
// which can then see the final request headers, as seen by http.DefaultTransport.
for i := len(c.middlewares) - 1; i >= 0; i-- {
rt = c.middlewares[i](rt)
}
return rt.RoundTrip(req)
}