From 7861f56eb5b48438523c170eff5ef85b69d0ba95 Mon Sep 17 00:00:00 2001 From: joelei Date: Thu, 2 Oct 2025 22:17:46 +0800 Subject: [PATCH] add rest --- .drone.yml | 20 --- .envrc | 2 + .vscode/settings.json | 6 +- go.mod | 4 +- rest/client.go | 276 +++++++++++++++++++++++++++++++++++ rest/client_test.go | 116 +++++++++++++++ rest/codec/codec.go | 133 +++++++++++++++++ rest/codec/form.go | 77 ++++++++++ rest/codec/header.go | 59 ++++++++ rest/codec/json.go | 61 ++++++++ rest/codec/path.go | 57 ++++++++ rest/codec/query.go | 58 ++++++++ rest/codec/tag.go | 62 ++++++++ rest/codec/value_parser.go | 195 +++++++++++++++++++++++++ rest/generic.go | 142 ++++++++++++++++++ rest/init.go | 7 + rest/metrics.go | 76 ++++++++++ rest/request.go | 62 ++++++++ rest/request_test.go | 134 +++++++++++++++++ rest/response.go | 63 ++++++++ rest/transport/curl.go | 126 ++++++++++++++++ rest/transport/request_id.go | 30 ++++ rest/transport/retry.go | 1 + rest/transport/transport.go | 62 ++++++++ 24 files changed, 1805 insertions(+), 24 deletions(-) delete mode 100644 .drone.yml create mode 100644 rest/client.go create mode 100644 rest/client_test.go create mode 100644 rest/codec/codec.go create mode 100644 rest/codec/form.go create mode 100644 rest/codec/header.go create mode 100644 rest/codec/json.go create mode 100644 rest/codec/path.go create mode 100644 rest/codec/query.go create mode 100644 rest/codec/tag.go create mode 100644 rest/codec/value_parser.go create mode 100644 rest/generic.go create mode 100644 rest/init.go create mode 100644 rest/metrics.go create mode 100644 rest/request.go create mode 100644 rest/request_test.go create mode 100644 rest/response.go create mode 100644 rest/transport/curl.go create mode 100644 rest/transport/request_id.go create mode 100644 rest/transport/retry.go create mode 100644 rest/transport/transport.go diff --git a/.drone.yml b/.drone.yml deleted file mode 100644 index e9a16ba..0000000 --- a/.drone.yml +++ /dev/null @@ -1,20 +0,0 @@ ---- -kind: pipeline -type: docker -name: test - -steps: - - name: fetch - image: alpine/git - commands: - - git fetch --tags - - - name: test - image: golang:1.20-alpine - commands: - - apk --update --no-cache add make tzdata git - - make test - - make benchmark - -node: - host: devcloud diff --git a/.envrc b/.envrc index e69de29..5383e68 100644 --- a/.envrc +++ b/.envrc @@ -0,0 +1,2 @@ +export PATH=/opt/go/sdk/go1/bin:$PATH +export GOEXPERIMENT=jsonv2 diff --git a/.vscode/settings.json b/.vscode/settings.json index 7efe88f..cdfb17e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,6 @@ { - "go.goroot": "/opt/go/sdk/go", - "go.gopath": "/root/.go" + "go.goroot": "/opt/go/sdk/go1", + "go.toolsEnvVars": { + "GOEXPERIMENT": "jsonv2" + } } \ No newline at end of file diff --git a/go.mod b/go.mod index bf8899a..32d6f16 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module git.ifooth.com/common/pkg -go 1.21 +go 1.25 require ( github.com/dustin/go-humanize v1.0.1 @@ -18,6 +18,7 @@ require ( github.com/samber/lo v1.47.0 github.com/stretchr/testify v1.9.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0 + golang.org/x/net v0.25.0 ) require ( @@ -50,7 +51,6 @@ require ( go.opentelemetry.io/otel/metric v1.27.0 // indirect go.opentelemetry.io/otel/trace v1.27.0 // indirect golang.org/x/crypto v0.23.0 // indirect - golang.org/x/net v0.25.0 // indirect golang.org/x/sys v0.20.0 // indirect golang.org/x/text v0.16.0 // indirect google.golang.org/protobuf v1.34.1 // indirect diff --git a/rest/client.go b/rest/client.go new file mode 100644 index 0000000..06febac --- /dev/null +++ b/rest/client.go @@ -0,0 +1,276 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json/v2" + "encoding/xml" + "fmt" + "io" + "maps" + "net/http" + "net/url" + "reflect" + "strings" + "time" + + "golang.org/x/net/html" +) + +type Client[T any] struct { + httpClient *http.Client + commonReq *http.Request + commonFormData url.Values + commonQueryParams url.Values + commonCookies []*http.Cookie +} + +func NewClient[T any]() *Client[T] { + r := &Client[T]{ + commonReq: &http.Request{Header: make(http.Header)}, + httpClient: &http.Client{}, + commonQueryParams: url.Values{}, + } + return r +} + +func (c *Client[T]) SetHeader(header, value string) *Client[T] { + c.commonReq.Header.Set(header, value) + return c +} + +func (c *Client[T]) SetQueryParam(param, value string) *Client[T] { + c.commonQueryParams.Set(param, value) + return c +} + +func (c *Client[T]) SetTransport(transport http.RoundTripper) *Client[T] { + if transport != nil { + c.httpClient.Transport = transport + } + + return c +} + +func (c *Client[T]) R() *Request[T] { + req := &http.Request{ + Method: http.MethodGet, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + req = req.WithContext(context.Background()) + + r := &Request[T]{ + req: req, + client: c, + queryParams: url.Values{}, + } + return r +} + +func (c *Client[T]) execute(req *http.Request) (*http.Response, error) { + // change header + maps.Copy(req.Header, c.commonReq.Header) + + // change query + query := req.URL.Query() + maps.Copy(query, c.commonQueryParams) + req.URL.RawQuery = query.Encode() + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + return resp, nil +} + +type Request[T any] struct { + req *http.Request + client *Client[T] + queryParams url.Values +} + +func (r *Request[T]) Clone() *Request[T] { + newReq := r.client.R() + newReq.req = r.req.Clone(r.req.Context()) + newReq.client = r.client + newReq.queryParams = r.queryParams + return newReq +} + +func (r *Request[T]) SetHeader(header, value string) *Request[T] { + r.req.Header.Set(header, value) + return r +} + +func (r *Request[T]) SetQueryParam(param, value string) *Request[T] { + r.queryParams.Set(param, value) + return r +} + +func (r *Request[T]) SetContext(ctx context.Context) *Request[T] { + r.req = r.req.WithContext(ctx) + return r +} + +func (r *Request[T]) Get(url string) (*T, error) { + return r.Execute(http.MethodGet, url) +} + +func (r *Request[T]) Execute(method string, url string) (*T, error) { + out := new(T) + rt := reflect.TypeFor[T]() + + newReq, err := http.NewRequestWithContext(r.req.Context(), method, url, r.req.Body) + if err != nil { + return nil, err + } + + // change a req header + maps.Copy(newReq.Header, r.req.Header) + + // change a req query + query := newReq.URL.Query() + maps.Copy(query, r.queryParams) + newReq.URL.RawQuery = query.Encode() + + resp, err := r.client.execute(newReq) + if err != nil { + return nil, err + } + + // io.ReadCloser 需要自己调用Close + switch any(out).(type) { + case *http.Response: + return any(resp).(*T), nil + case *io.ReadCloser: + return any(&resp.Body).(*T), nil + case *Response: // for addition trace info, reuse body an so on + r := Response{resp: resp, trace: &TraceInfo{}} + return any(&r).(*T), nil + } + + // body 自动close + defer resp.Body.Close() + + switch any(out).(type) { + case *[]byte: + // 需要读取整个body + body, rErr := io.ReadAll(resp.Body) + if rErr != nil { + return nil, rErr + } + return any(&body).(*T), nil + case *string: + // 需要读取整个body + body, rErr := io.ReadAll(resp.Body) + if rErr != nil { + return nil, rErr + } + v := string(body) + return any(&v).(*T), nil + case *io.Reader: + // 需要读取整个body + body, rErr := io.ReadAll(resp.Body) + if rErr != nil { + return nil, rErr + } + v := io.Reader(bytes.NewBuffer(body)) + return any(&v).(*T), nil + case *html.Node: + node, rErr := html.Parse(resp.Body) + if rErr != nil { + return nil, rErr + } + return any(node).(*T), nil + } + + if rt.Kind() != reflect.Struct { + return nil, fmt.Errorf("unsupported type: %v", rt) + } + + // 按类型反序列化 + contentType := resp.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + if err = json.UnmarshalRead(resp.Body, out); err != nil { + return nil, err + } + return out, nil + } + if strings.HasPrefix(contentType, "application/xml") { + // 需要读取整个body + body, rErr := io.ReadAll(resp.Body) + if rErr != nil { + return nil, rErr + } + if err = xml.Unmarshal(body, out); err != nil { + return nil, err + } + return out, nil + } + + return nil, fmt.Errorf("unsupported content-type: %s", contentType) +} + +// TraceInfo struct is used to provide request trace info such as DNS lookup +// duration, Connection obtain duration, Server processing duration, etc. +type TraceInfo struct { + // DNSLookup is the duration that transport took to perform + // DNS lookup. + DNSLookup time.Duration `json:"dns_lookup_time"` + + // ConnTime is the duration it took to obtain a successful connection. + ConnTime time.Duration `json:"connection_time"` + + // TCPConnTime is the duration it took to obtain the TCP connection. + TCPConnTime time.Duration `json:"tcp_connection_time"` + + // TLSHandshake is the duration of the TLS handshake. + TLSHandshake time.Duration `json:"tls_handshake_time"` + + // ServerTime is the server's duration for responding to the first byte. + ServerTime time.Duration `json:"server_time"` + + // ResponseTime is the duration since the first response byte from the server to + // request completion. + ResponseTime time.Duration `json:"response_time"` + + // TotalTime is the duration of the total time request taken end-to-end. + TotalTime time.Duration `json:"total_time"` + + // IsConnReused is whether this connection has been previously + // used for another HTTP request. + IsConnReused bool `json:"is_connection_reused"` + + // IsConnWasIdle is whether this connection was obtained from an + // idle pool. + IsConnWasIdle bool `json:"is_connection_was_idle"` + + // ConnIdleTime is the duration how long the connection that was previously + // idle, if IsConnWasIdle is true. + ConnIdleTime time.Duration `json:"connection_idle_time"` + + // RequestAttempt is to represent the request attempt made during a Resty + // request execution flow, including retry count. + RequestAttempt int `json:"request_attempt"` + + // RemoteAddr returns the remote network address. + RemoteAddr string `json:"remote_address"` +} + +type Response struct { + trace *TraceInfo + resp *http.Response + receivedAt time.Time +} + +func (r *Response) Duration() time.Duration { + return r.trace.TotalTime +} + +// ReceivedAt method returns the time we received a response from the server for the request. +func (r *Response) ReceivedAt() time.Time { + return r.receivedAt +} diff --git a/rest/client_test.go b/rest/client_test.go new file mode 100644 index 0000000..a3b7fdf --- /dev/null +++ b/rest/client_test.go @@ -0,0 +1,116 @@ +package rest + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + + "git.ifooth.com/common/pkg/rest/transport" + "github.com/stretchr/testify/assert" + "golang.org/x/net/html" +) + +type Resp1 struct { + UUID string `json:"uuid"` +} + +type BkResult[T any] struct { + Code int + Data *T +} + +func getTitle(n *html.Node) string { + if n.Type == html.ElementNode && n.Data == "title" { + if n.FirstChild != nil { + return n.FirstChild.Data + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + if title := getTitle(c); title != "" { + return title + } + } + return "" +} + +func TestGet(t *testing.T) { + resp, err := http.Get("http://127.0.0.1:38033/uuid") + assert.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + + fmt.Println(string(body)) + + r := new(Resp1) + err = json.Unmarshal(body, r) + assert.NoError(t, err) + + assert.True(t, len(r.UUID) == 36) +} + +func CustomC[T any]() *Client[T] { + c := NewClient[T]() + c.SetHeader("abc", "abc") + c.SetQueryParam("name", "ko") + return c +} + +func TestGetRest(t *testing.T) { + r, err := NewClient[Resp1]().R().Get("http://127.0.0.1:38033/uuid") + assert.NoError(t, err) + + fmt.Println(r) + assert.True(t, len(r.UUID) == 36) + + rb, err := NewClient[[]byte]().R().Get("http://127.0.0.1:38033/uuid") + assert.NoError(t, err) + + fmt.Println("abc", string(*rb)) + + r3, err := NewClient[io.Reader]().R().Get("http://127.0.0.1:38033/uuid") + assert.NoError(t, err) + r31, err := io.ReadAll(*r3) + assert.NoError(t, err) + + fmt.Println("r31", string(r31)) + + r4, err := NewClient[io.ReadCloser]().R().Get("http://127.0.0.1:38033/uuid") + assert.NoError(t, err) + r41, err := io.ReadAll(*r4) + assert.NoError(t, err) + (*r4).Close() + + fmt.Println("r41", string(r41)) + + ts := transport.Chain(http.DefaultTransport, + transport.RequestID, + transport.CurlLogTransport(map[string]string{}), + ) + + r5, err := CustomC[http.Response]().SetTransport(ts).R().SetHeader("ko", "abc"). + SetQueryParam("age", "12"). + Get("http://127.0.0.1:38033/uuid?dd=abc") + assert.NoError(t, err) + r51, err := io.ReadAll(r5.Body) + assert.NoError(t, err) + r5.Body.Close() + + fmt.Println("r51", string(r51)) + + r6, err := NewClient[html.Node]().R().Get("http://127.0.0.1:38033") + assert.NoError(t, err) + title := getTitle(r6) + assert.Equal(t, "httpbin.org", title) + +} + +func TestGetRestError(t *testing.T) { + r, err := NewClient[int]().R().Get("http://127.0.0.1:38033/uuid") + assert.Error(t, err) + fmt.Println(err) + assert.Nil(t, r) +} diff --git a/rest/codec/codec.go b/rest/codec/codec.go new file mode 100644 index 0000000..47e2d32 --- /dev/null +++ b/rest/codec/codec.go @@ -0,0 +1,133 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +// Package codec provides encoding and decoding utilities across various formats +package codec + +import ( + "errors" + "fmt" + "net/http" + "reflect" +) + +func decodeTo(r *http.Request, val any) error { + rt := reflect.TypeOf(val).Elem() + rv := reflect.ValueOf(val).Elem() + + // json 整个解析 + jsonCodec := NewJsonCodec(r) + if err := jsonCodec.Decode(val); err != nil { + return err + } + + formCodec, err := NewFormCodec(r) + if err != nil { + return err + } + + pathCodec := NewPathCodec(r) + queryCodec := NewQueryCodec(r) + headerCodec := NewHeaderCodec(r) + + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + + // 非导出需要跳过, 无法设置值 + if !field.IsExported() { + continue + } + + tagStr := field.Tag.Get(tagName) + if tagStr == "" { + continue + } + tag, err := parseTag(tagStr) + if err != nil { + return err + } + + fv := rv.Field(i) + if err := formCodec.Decode(field, fv, tag); err != nil { + return err + } + if err := queryCodec.Decode(field, fv, tag); err != nil { + return err + } + if err := headerCodec.Decode(field, fv, tag); err != nil { + return err + } + if err := pathCodec.Decode(field, fv, tag); err != nil { + return err + } + } + + return nil +} + +// Decode 按结构体反序列化Request +func Decode[T any](r *http.Request) (*T, error) { + rt := reflect.TypeFor[T]() + if rt.Kind() != reflect.Struct { + return nil, errors.New("generic type T must be a struct type") + } + + t := new(T) + err := decodeTo(r, t) + if err != nil { + return nil, fmt.Errorf("codec decode: %w", err) + } + + return t, nil +} + +// getFieldValue 获取字段值 +func getFieldValue(field reflect.Type, tag *Tag, values []string) (reflect.Value, error) { + // 指针类型 + if field.Kind() == reflect.Pointer { + typ := field.Elem() + rv, err := getFieldValue(typ, tag, values) + if err != nil { + return reflect.Value{}, err + } + + newPtr := reflect.New(typ) + newPtr.Elem().Set(rv) + return newPtr, nil + } + + // slice类型 + if field.Kind() == reflect.Slice { + typ := field.Elem() + + // []byte 特殊处理 + if typ == byteType { + return ParseValue(field, values[0], tag.Option) + } + + val := reflect.MakeSlice(field, 0, len(values)) + for _, v := range values { + rv, err := getFieldValue(typ, tag, []string{v}) + if err != nil { + return reflect.Value{}, err + } + val = reflect.Append(val, rv) + } + return val, nil + } + + return ParseValue(field, values[0], tag.Option) +} diff --git a/rest/codec/form.go b/rest/codec/form.go new file mode 100644 index 0000000..cd1f47a --- /dev/null +++ b/rest/codec/form.go @@ -0,0 +1,77 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package codec + +import ( + "fmt" + "net/http" + "net/url" + "reflect" + "strings" +) + +type formCodec struct { + values url.Values + isForm bool +} + +// NewFormCodec ... +func NewFormCodec(r *http.Request) (*formCodec, error) { + isForm := false + + contentType := r.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") { + if err := r.ParseForm(); err != nil { + return nil, err + } + isForm = true + } + + c := &formCodec{ + values: r.PostForm, + isForm: isForm, + } + return c, nil +} + +// Decode ... +func (c *formCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error { + if !c.isForm { + return nil + } + + formTag, ok := tag.Option["form"] + if !ok { + return nil + } + + v := c.values[formTag] + if len(v) == 0 { + return nil + } + + rv, err := getFieldValue(field.Type, tag, v) + if err != nil { + return err + } + if !rv.IsValid() { + return fmt.Errorf("%s not valid", rv) + } + + fv.Set(rv) + return nil +} diff --git a/rest/codec/header.go b/rest/codec/header.go new file mode 100644 index 0000000..251e72f --- /dev/null +++ b/rest/codec/header.go @@ -0,0 +1,59 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package codec + +import ( + "fmt" + "net/http" + "reflect" +) + +type headerCodec struct { + values http.Header +} + +// NewHeaderCodec ... +func NewHeaderCodec(r *http.Request) *headerCodec { + c := &headerCodec{values: r.Header} + return c +} + +// Decode ... +func (c *headerCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error { + headTag, ok := tag.Option["header"] + if !ok { + return nil + } + + // header统一格式 + key := http.CanonicalHeaderKey(headTag) + v := c.values[key] + if len(v) == 0 { + return nil + } + + rv, err := getFieldValue(field.Type, tag, v) + if err != nil { + return err + } + if !rv.IsValid() { + return fmt.Errorf("%s not valid", rv) + } + + fv.Set(rv) + return nil +} diff --git a/rest/codec/json.go b/rest/codec/json.go new file mode 100644 index 0000000..fcfc336 --- /dev/null +++ b/rest/codec/json.go @@ -0,0 +1,61 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package codec + +import ( + "encoding/json/v2" + "fmt" + "io" + "net/http" + "strings" +) + +type jsonCodec struct { + isJson bool + req *http.Request +} + +// NewJsonCodec ... +func NewJsonCodec(r *http.Request) *jsonCodec { + isJson := false + + contentType := r.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + isJson = true + } + + return &jsonCodec{req: r, isJson: isJson} +} + +// Decode ... +func (j *jsonCodec) Decode(val any) error { + if !j.isJson { + return nil + } + + body, err := io.ReadAll(j.req.Body) + if err != nil { + return err + } + if len(body) == 0 { + return fmt.Errorf("json body is empty") + } + if err := json.Unmarshal(body, val); err != nil { + return fmt.Errorf("unmarshal json body: %s", err) + } + return nil +} diff --git a/rest/codec/path.go b/rest/codec/path.go new file mode 100644 index 0000000..d423714 --- /dev/null +++ b/rest/codec/path.go @@ -0,0 +1,57 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package codec + +import ( + "fmt" + "net/http" + "reflect" +) + +type pathCodec struct { + req *http.Request +} + +// NewPathCodec ... +func NewPathCodec(r *http.Request) *pathCodec { + c := &pathCodec{req: r} + return c +} + +// Decode ... +func (c *pathCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error { + pathTag, ok := tag.Option["path"] + if !ok { + return nil + } + + pv := c.req.PathValue(pathTag) + if pv == "" { + return nil + } + + rv, err := getFieldValue(field.Type, tag, []string{pv}) + if err != nil { + return err + } + if !rv.IsValid() { + return fmt.Errorf("%s not valid", rv) + } + + fv.Set(rv) + return nil +} diff --git a/rest/codec/query.go b/rest/codec/query.go new file mode 100644 index 0000000..ed5d773 --- /dev/null +++ b/rest/codec/query.go @@ -0,0 +1,58 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package codec + +import ( + "fmt" + "net/http" + "net/url" + "reflect" +) + +type queryCodec struct { + values url.Values +} + +// NewQueryCodec ... +func NewQueryCodec(r *http.Request) *queryCodec { + c := &queryCodec{values: r.URL.Query()} + return c +} + +// Decode ... +func (c *queryCodec) Decode(field reflect.StructField, fv reflect.Value, tag *Tag) error { + queryTag, ok := tag.Option["query"] + if !ok { + return nil + } + + v := c.values[queryTag] + if len(v) == 0 { + return nil + } + + rv, err := getFieldValue(field.Type, tag, v) + if err != nil { + return err + } + if !rv.IsValid() { + return fmt.Errorf("%s not valid", rv) + } + + fv.Set(rv) + return nil +} diff --git a/rest/codec/tag.go b/rest/codec/tag.go new file mode 100644 index 0000000..540c2d2 --- /dev/null +++ b/rest/codec/tag.go @@ -0,0 +1,62 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package codec + +import ( + "fmt" + "strings" +) + +const ( + // tagName 结构体tag名称 + // 格式参考 https://pkg.go.dev/encoding/json/v2#example-package-FormatFlags + tagName = "req" +) + +// Tag is a struct tag +type Tag struct { + Option map[string]string +} + +func parseTag(tagStr string) (*Tag, error) { + tagStr = strings.TrimSpace(tagStr) + if tagStr == "" { + return nil, fmt.Errorf("tag is empty") + } + + parts := strings.Split(tagStr, ",") + t := &Tag{ + Option: map[string]string{}, + } + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + return nil, fmt.Errorf("tag option not valid") + } + + opt := strings.SplitN(part, ":", 2) + key := opt[0] + val := "" + if len(opt) == 2 { + val = opt[1] + } + t.Option[key] = val + } + + return t, nil +} diff --git a/rest/codec/value_parser.go b/rest/codec/value_parser.go new file mode 100644 index 0000000..60cbd28 --- /dev/null +++ b/rest/codec/value_parser.go @@ -0,0 +1,195 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package codec + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "time" +) + +var ( + // ErrUnsupportedType unsupported parse type + ErrUnsupportedType = errors.New("unsupported type") + + parserRegistry = make(map[reflect.Type]Parser) + byteType = reflect.TypeFor[byte]() +) + +// Parser defines the interface for converting a string to a reflect.Value +// Implementations should parse a string representation into a Go value +type Parser interface { + Parse(s string) (reflect.Value, error) +} + +// NewParser defines a factory interface for creating Parser instances +// Implementations should return a new Parser configured with the provided options +type NewParser interface { + // New creates a new Parser instance with the given configuration options + // The options map contains key-value pairs for parser configuration + New(opt map[string]string) Parser +} + +// ParserFunc is a function type that implements the Parser interface +// Allows regular functions to be used as Parser implementations +type ParserFunc func(s string) (reflect.Value, error) + +// Parse implements the Parser interface for ParserFunc +// This adapter allows any function with the signature func(string) (reflect.Value, error) +// to be used as a Parser without defining a new type +func (p ParserFunc) Parse(s string) (reflect.Value, error) { + return p(s) +} + +// RegisterParser registers a parser implementation for a specific type T +func RegisterParser[T any](p Parser) { + parserRegistry[reflect.TypeFor[T]()] = p +} + +// ParseValue converts a string to a value of the specified type using registered parsers +// Returns an error if parsing fails or the value type is unsupported +func ParseValue(rt reflect.Type, s string, opt map[string]string) (reflect.Value, error) { + parser, ok := parserRegistry[rt] + if !ok { + return reflect.Value{}, fmt.Errorf("%w: %v", ErrUnsupportedType, rt) + } + + // 实现自定义初始化 + if v, ok := parser.(NewParser); ok { + parser = v.New(opt) + } + + return parser.Parse(s) +} + +// StringParser ... +func StringParser(s string) (reflect.Value, error) { + return reflect.ValueOf(s), nil +} + +// BoolParser ... +func BoolParser(s string) (reflect.Value, error) { + v, err := strconv.ParseBool(s) + if err != nil { + return reflect.Value{}, err + } + + return reflect.ValueOf(v), nil +} + +// Int generic int parser +type Int[T int | int8 | int16 | int32 | int64] struct { + bitSize int +} + +// Parse int parser +func (i Int[T]) Parse(s string) (reflect.Value, error) { + v, err := strconv.ParseInt(s, 10, i.bitSize) + if err != nil { + return reflect.Value{}, err + } + + return reflect.ValueOf(T(v)), nil +} + +// Uint generic uint parser +type Uint[T uint | uint8 | uint16 | uint32 | uint64] struct { + bitSize int +} + +// Parse uint parser +func (i Uint[T]) Parse(s string) (reflect.Value, error) { + v, err := strconv.ParseUint(s, 10, i.bitSize) + if err != nil { + return reflect.Value{}, err + } + + return reflect.ValueOf(T(v)), nil +} + +// Float generic float parser +type Float[T float32 | float64] struct { + bitSize int +} + +// Parse float parser +func (f Float[T]) Parse(s string) (reflect.Value, error) { + v, err := strconv.ParseFloat(s, f.bitSize) + if err != nil { + return reflect.Value{}, err + } + + return reflect.ValueOf(T(v)), nil +} + +// ByteSlice is a wrapper of []byte to implement Parser +type ByteSlice []byte + +// Parse ... +func (bs ByteSlice) Parse(s string) (reflect.Value, error) { + v := []byte(s) + + return reflect.ValueOf(v), nil +} + +// Time is a parser with format option +type Time struct { + option map[string]string +} + +// Parse ... +func (t Time) Parse(s string) (reflect.Value, error) { + format := t.option["format"] + if format == "" { + format = time.DateTime + } + + v, err := time.Parse(format, s) + if err != nil { + return reflect.Value{}, err + } + + return reflect.ValueOf(v), nil +} + +// New ... +func (t *Time) New(opt map[string]string) Parser { + newParser := &Time{option: opt} + return newParser +} + +func init() { + // buildin parser + RegisterParser[string](ParserFunc(StringParser)) + RegisterParser[bool](ParserFunc(BoolParser)) + RegisterParser[int](Int[int]{0}) + RegisterParser[int8](Int[int8]{8}) + RegisterParser[int16](Int[int16]{16}) + RegisterParser[int32](Int[int32]{32}) + RegisterParser[int64](Int[int64]{64}) + RegisterParser[uint](Uint[uint]{0}) + RegisterParser[uint8](Uint[uint8]{8}) + RegisterParser[uint16](Uint[uint16]{16}) + RegisterParser[uint32](Uint[uint32]{32}) + RegisterParser[uint64](Uint[uint64]{64}) + RegisterParser[float32](Float[float32]{32}) + RegisterParser[float64](Float[float64]{64}) + RegisterParser[[]byte](ByteSlice{}) + RegisterParser[time.Time](&Time{}) +} diff --git a/rest/generic.go b/rest/generic.go new file mode 100644 index 0000000..ca526d8 --- /dev/null +++ b/rest/generic.go @@ -0,0 +1,142 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +// Package rest framework +package rest + +import ( + "context" + "log/slog" + "net/http" + "time" +) + +// UnaryFunc Unary or ClientStreaming handle function +type UnaryFunc[Req, Resp any] func(context.Context, *Req) (*Resp, error) + +// StreamingServer server or bidi streaming server +type StreamingServer interface { + http.ResponseWriter + Context() context.Context +} + +// StreamFunc ServerStreaming or BidiStreaming handle function +type StreamFunc[Req any] func(*Req, StreamingServer) error + +// Handle Composable HTTP Handlers using generics +func Handle[Req, Resp any](fn UnaryFunc[Req, Resp]) func(w http.ResponseWriter, r *http.Request) { + handleName := getHandleName(fn) + + f := func(w http.ResponseWriter, r *http.Request) { + st := time.Now() + var err error + defer func() { + collectHandleMetrics(handleName, r.Method, st, err) + }() + + // 反序列化 + in, err := decodeReq[Req](r) + if err != nil { + slog.Error("handle decode request failed", "err", err) + _ = APIError(err).Render(w, r) + return + } + + // 参数校验 + if err = validateReq(r.Context(), in); err != nil { + slog.Error("validate req failed", "err", err) + _ = APIError(err).Render(w, r) + return + } + + out, err := fn(r.Context(), in) + if err != nil { + _ = APIError(err).Render(w, r) + return + } + _ = APIOK(out).Render(w, r) + } + return f +} + +type streamingServer struct { + http.ResponseWriter + *http.ResponseController + ctx context.Context +} + +// Context return svr's context +func (s *streamingServer) Context() context.Context { + return s.ctx +} + +// Stream Composable HTTP Handlers using generics +func Stream[Req any](fn StreamFunc[Req]) func(w http.ResponseWriter, r *http.Request) { + handleName := getHandleName(fn) + + f := func(w http.ResponseWriter, r *http.Request) { + st := time.Now() + var err error + defer func() { + collectHandleMetrics(handleName, r.Method, st, err) + }() + + // 反序列化 + in, err := decodeReq[Req](r) + if err != nil { + slog.Error("handle decode stream request failed", "err", err) + _ = APIError(err).Render(w, r) + return + } + + // 参数校验 + if err = validateReq(r.Context(), in); err != nil { + slog.Error("validate stream req failed", "err", err) + _ = APIError(err).Render(w, r) + return + } + + svr := &streamingServer{ + ResponseWriter: w, + ResponseController: http.NewResponseController(w), + ctx: r.Context(), + } + + err = fn(in, svr) + if err != nil { + _ = APIError(err).Render(w, r) + } + } + return f +} + +// EmptyReq 空的请求 +type EmptyReq struct{} + +// EmptyResp 空的返回 +type EmptyResp struct{} + +// PaginationReq 分页接口通用请求 +type PaginationReq struct { + Offset int `json:"offset" in:"query=offset" validate:"gte=0"` + Limit int `json:"limit" in:"query=limit" validate:"gte=0"` +} + +// PaginationResp 分页接口通用返回 +type PaginationResp[T any] struct { + Count int64 `json:"count"` + Items []T `json:"items"` +} diff --git a/rest/init.go b/rest/init.go new file mode 100644 index 0000000..2296cf9 --- /dev/null +++ b/rest/init.go @@ -0,0 +1,7 @@ +package rest + +import "git.ifooth.com/common/pkg/logger" + +func init() { + logger.Init() +} diff --git a/rest/metrics.go b/rest/metrics.go new file mode 100644 index 0000000..e106e64 --- /dev/null +++ b/rest/metrics.go @@ -0,0 +1,76 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package rest + +import ( + "reflect" + "runtime" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +var ( + requestCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "Number of get requests.", + }, + []string{"handler", "method", "code"}, + ) + responseTimeDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_duration_seconds", + Help: "Histogram of response time for HTTP requests.", + Buckets: []float64{0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60}, + }, + []string{"handler", "method", "code"}, + ) +) + +// getHandleName 获取FuncHandle/StreamHandle函数名 +func getHandleName(fn any) string { + fullName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + if fullName == "" { + panic("get func name is empty") + } + + parts := strings.Split(fullName, ".") + lastPart := parts[len(parts)-1] + name := strings.TrimSuffix(lastPart, "-fm") + return name +} + +// collectHandleMetrics api指标数据 +func collectHandleMetrics(funcName, method string, st time.Time, err error) { + code := 200 + if err != nil { + code = APIError(err).(*APIResponse).HTTPCode + } + + codeStr := strconv.Itoa(code) + requestCounter.WithLabelValues(funcName, method, codeStr).Inc() + duration := time.Since(st).Seconds() + responseTimeDuration.WithLabelValues(funcName, method, codeStr).Observe(duration) +} + +func init() { + prometheus.MustRegister(requestCounter) + prometheus.MustRegister(responseTimeDuration) +} diff --git a/rest/request.go b/rest/request.go new file mode 100644 index 0000000..3f6b862 --- /dev/null +++ b/rest/request.go @@ -0,0 +1,62 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package rest + +import ( + "context" + "net/http" + + "git.ifooth.com/common/pkg/rest/codec" + "git.ifooth.com/common/pkg/validator" +) + +// decodeReq ... +func decodeReq[T any](r *http.Request) (*T, error) { + in := new(T) + + // http.Request 直接返回 + if _, ok := any(in).(*http.Request); ok { + return any(r).(*T), nil + } + + // 空值不需要反序列化 + if _, ok := any(in).(*EmptyReq); ok { + return in, nil + } + + in, err := codec.Decode[T](r) + if err != nil { + return nil, err + } + + return in, nil +} + +// validate 参数校验 +func validateReq(ctx context.Context, req any) error { + // http.Request 直接返回 + if _, ok := req.(*http.Request); ok { + return nil + } + + // 空值不需要校验 + if _, ok := req.(*EmptyReq); ok { + return nil + } + + return validator.Struct(ctx, req) +} diff --git a/rest/request_test.go b/rest/request_test.go new file mode 100644 index 0000000..a8cdef6 --- /dev/null +++ b/rest/request_test.go @@ -0,0 +1,134 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package rest + +import ( + "bytes" + "io" + "net/http" + "net/url" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "git.ifooth.com/common/pkg/rest/codec" +) + +// reqStruct for rest +type reqStruct struct { + Org string `json:"org" req:"path:org" in:"path=org"` + Name string `json:"name" req:"query:name" in:"query=name"` + Age int32 `json:"age" req:"form:age" in:"form=age"` + Bool bool `json:"bool" req:"query:bool" in:"query=bool"` + AgePtr *int `json:"agePtr" req:"header:age_ptr" in:"header=age_ptr"` + SliceStr []string `json:"sliceStr" req:"query:slice_str" in:"query=slice_str"` + Page int64 `json:"page"` +} + +// newMockRequest creates a new mock request. +func newMockRequest(t testing.TB, method string, header map[string]string, body io.Reader) *http.Request { + req, err := http.NewRequest(method, "/{org}/vm/xxx?name=alices&age=20&age_ptr=21&slice_str=1&slice_str=2", body) + req.SetPathValue("org", "myOrg") + for k, v := range header { + req.Header.Set(k, v) + } + req.Body = io.NopCloser(body) + + require.NoError(t, err) + return req +} + +func TestDecode(t *testing.T) { + header := map[string]string{ + "age_ptr": "21", + } + r := newMockRequest(t, http.MethodGet, header, nil) + + req, err := decodeReq[reqStruct](r) + assert.NoError(t, err) + assert.Equal(t, "myOrg", req.Org) + assert.Equal(t, "alices", req.Name) + assert.Equal(t, lo.ToPtr(21), req.AgePtr) + assert.Equal(t, []string{"1", "2"}, req.SliceStr) +} + +func TestFormDecode(t *testing.T) { + header := map[string]string{ + "age_ptr": "21", + "Content-Type": "application/x-www-form-urlencoded", + } + + formData := url.Values{} + formData.Set("age", "20") + + r := newMockRequest(t, http.MethodPost, header, bytes.NewBufferString(formData.Encode())) + + req, err := decodeReq[reqStruct](r) + assert.NoError(t, err) + assert.Equal(t, "myOrg", req.Org) + assert.Equal(t, "alices", req.Name) + assert.Equal(t, int32(20), req.Age) + assert.Equal(t, lo.ToPtr(21), req.AgePtr) + assert.Equal(t, []string{"1", "2"}, req.SliceStr) +} + +func TestJsonDecode(t *testing.T) { + header := map[string]string{ + "age_ptr": "21", + "Content-Type": "application/json", + } + + jsonData := `{"page": 64}` + r := newMockRequest(t, http.MethodPost, header, bytes.NewBufferString(jsonData)) + + req, err := decodeReq[reqStruct](r) + assert.NoError(t, err) + assert.Equal(t, "myOrg", req.Org) + assert.Equal(t, "alices", req.Name) + assert.Equal(t, lo.ToPtr(21), req.AgePtr) + assert.Equal(t, []string{"1", "2"}, req.SliceStr) + assert.Equal(t, int64(64), req.Page) +} + +func TestDecodeErr(t *testing.T) { + header := map[string]string{ + "age_ptr": "21", + } + + // array not support + type Req2 struct { + SliceStr [1]string `json:"sliceStr" req:"query:slice_str"` + } + r := newMockRequest(t, http.MethodGet, header, nil) + _, err := decodeReq[Req2](r) + assert.ErrorIs(t, err, codec.ErrUnsupportedType) +} + +func BenchmarkDecodeReq(b *testing.B) { + for b.Loop() { + r := newMockRequest(b, http.MethodGet, nil, nil) + req, err := decodeReq[reqStruct](r) + if err != nil { + b.Fatal(err) + } + if req.Name != "alices" { + b.Fatal("name not equal") + } + } +} diff --git a/rest/response.go b/rest/response.go new file mode 100644 index 0000000..390e40c --- /dev/null +++ b/rest/response.go @@ -0,0 +1,63 @@ +/* + * Tencent is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package rest + +import ( + "encoding/json/v2" + "net/http" +) + +// Renderer interface for managing response payloads. +type Renderer interface { + Render(w http.ResponseWriter, r *http.Request) error +} + +// APIResponse response for api request +type APIResponse struct { + HTTPCode int `json:"-"` // http response status code + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data"` +} + +// Render chi render interface implementation +func (e *APIResponse) Render(w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + + w.WriteHeader(e.HTTPCode) + + return json.MarshalWrite(w, e) +} + +// APIOK 正常返回 +func APIOK(data any) Renderer { + return &APIResponse{ + Message: "request OK", + HTTPCode: http.StatusOK, + Code: 0, + Data: data, + } +} + +// APIError 错误返回 +func APIError(err error) Renderer { + return &APIResponse{ + Message: err.Error(), + HTTPCode: http.StatusBadRequest, + Code: 40000, + } +} diff --git a/rest/transport/curl.go b/rest/transport/curl.go new file mode 100644 index 0000000..70d292c --- /dev/null +++ b/rest/transport/curl.go @@ -0,0 +1,126 @@ +package transport + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "time" +) + +// reqToCurl curl 格式的请求日志 +func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) { + // 过滤掉敏感信息, header 和 query + headers := "" + for key, values := range r.Header { + for _, value := range values { + if _, ok := c.maskKeys[key]; ok { + value = "***" + } + headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value)) + } + } + + rawURL := *r.URL + queryValue := rawURL.Query() + for key := range queryValue { + if _, ok := c.maskKeys[key]; ok { + queryValue.Set(key, "") + } + } + rawURL.RawQuery = queryValue.Encode() + + reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers) + if r.Body != nil { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return "", err + } + r.Body.Close() + + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + if len(bodyBytes) > 1024 { + reqMsg += fmt.Sprintf(" -d '%s...(total %dB)'", bodyBytes[:1024], len(bodyBytes)) + } else { + reqMsg += fmt.Sprintf(" -d '%s'", bodyBytes) + } + + } + + return reqMsg, nil +} + +// respToCurl 返回日志 +func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) { + var ( + bodyBytes []byte + err error + ) + if resp.Body != nil { + bodyBytes, err = io.ReadAll(resp.Body) + if err != nil { + return "", err + } + resp.Body.Close() + + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + if len(bodyBytes) > 1024 { + respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)", + resp.Status, time.Since(st), bodyBytes[:1024], len(bodyBytes)) + return respMsg, nil + } + + if len(bodyBytes) > 0 { + respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes) + return respMsg, nil + } + + respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st)) + return respMsg, nil +} + +// curlLogTransport print curl log transport +type curlLogTransport struct { + maskKeys map[string]string +} + +// RoundTrip curlLog Transport +func (t *curlLogTransport) Transport(next http.RoundTripper) http.RoundTripper { + f := func(req *http.Request) (*http.Response, error) { + logger := slog.With("request_id", req.Header.Get("X-Request-Id")) + st := time.Now() + + // 记录请求 + rbody, err := t.reqToCurl(req) + if err != nil { + return nil, err + } + logger.Info("curl", "REQ", rbody) + + resp, err := next.RoundTrip(req) + if err != nil { + logger.Error("curl", "RESP: [err]", err) + return nil, err + } + + // 记录返回 + respBody, err := t.respToCurl(resp, st) + if err != nil { + return nil, err + } + logger.Info("curl", "RESP", respBody) + + return resp, nil + } + return RoundTripFunc(f) + +} + +// CurlLogTransport make a new curl log transport, default transport can be nil +func CurlLogTransport(maskkeys map[string]string) func(http.RoundTripper) http.RoundTripper { + t := curlLogTransport{map[string]string{}} + return t.Transport +} diff --git a/rest/transport/request_id.go b/rest/transport/request_id.go new file mode 100644 index 0000000..1b40103 --- /dev/null +++ b/rest/transport/request_id.go @@ -0,0 +1,30 @@ +package transport + +import ( + "math/rand/v2" + "net/http" +) + +const charset = "abcdefghijklmnopqrstuvwxyz0123456789" + +// genRequestID generates a random string of a given length. +func genRequestID() string { + b := make([]byte, 32) + for i := range b { + b[i] = charset[rand.IntN(len(charset))] + } + return string(b) +} + +// RequestID ... +func RequestID(next http.RoundTripper) http.RoundTripper { + f := func(req *http.Request) (*http.Response, error) { + requestID := genRequestID() + ctx := req.Context() + req.Header.Set("X-Request-Id", requestID) + req = req.WithContext(ctx) + return next.RoundTrip(req) + } + + return RoundTripFunc(f) +} diff --git a/rest/transport/retry.go b/rest/transport/retry.go new file mode 100644 index 0000000..d11d0be --- /dev/null +++ b/rest/transport/retry.go @@ -0,0 +1 @@ +package transport diff --git a/rest/transport/transport.go b/rest/transport/transport.go new file mode 100644 index 0000000..aace24c --- /dev/null +++ b/rest/transport/transport.go @@ -0,0 +1,62 @@ +package transport + +import "net/http" + +// RoundTripFunc, similar to http.HandlerFunc, is an adapter +// to allow the use of ordinary functions as http.RoundTrippers. +type RoundTripFunc func(r *http.Request) (*http.Response, error) + +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// Chain wraps given base RoundTripper, which is used to make HTTP requests +// (e.g. http.DefaultTransport) with RoundTripper middlewares. +func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripper) *chain { + if base == nil { + base = http.DefaultTransport + } + + // Filter out nil transports. + mws := []func(http.RoundTripper) http.RoundTripper{} + for _, fn := range mw { + if fn != nil { + mws = append(mws, fn) + } + } + + if c, ok := base.(*chain); ok { + c.middlewares = append(c.middlewares, mws...) + return c + } + + return &chain{ + baseTransport: base, + middlewares: mws, + } +} + +type chain struct { + baseTransport http.RoundTripper + middlewares []func(http.RoundTripper) http.RoundTripper +} + +func (c *chain) RoundTrip(req *http.Request) (*http.Response, error) { + rt := c.baseTransport + + // Apply middlewares in reversed order so the first middleware becomes + // the innermost onion layer and the last becomes the outermost. Example: + // Given + // [Auth, VCTraceID, Debug], + // the middlewares are applied in this order: + // rt = Debug(rt) + // rt = VCTraceID(rt) + // rt = Auth(rt) + // The Auth and VCTraceID are called before the Debug middleware, + // which can then see the final request headers, as seen by http.DefaultTransport. + for i := len(c.middlewares) - 1; i >= 0; i-- { + rt = c.middlewares[i](rt) + } + + return rt.RoundTrip(req) +}