pkg/rest/transport/curl.go

127 lines
2.8 KiB
Go

package transport
import (
"bytes"
"fmt"
"io"
"log/slog"
"net/http"
"time"
)
// reqToCurl curl 格式的请求日志
func (c *curlLogTransport) reqToCurl(r *http.Request) (string, error) {
// 过滤掉敏感信息, header 和 query
headers := ""
for key, values := range r.Header {
for _, value := range values {
if _, ok := c.maskKeys[key]; ok {
value = "***"
}
headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value))
}
}
rawURL := *r.URL
queryValue := rawURL.Query()
for key := range queryValue {
if _, ok := c.maskKeys[key]; ok {
queryValue.Set(key, "<masked>")
}
}
rawURL.RawQuery = queryValue.Encode()
reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers)
if r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return "", err
}
r.Body.Close()
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if len(bodyBytes) > 1024 {
reqMsg += fmt.Sprintf(" -d '%s...(total %dB)'", bodyBytes[:1024], len(bodyBytes))
} else {
reqMsg += fmt.Sprintf(" -d '%s'", bodyBytes)
}
}
return reqMsg, nil
}
// respToCurl 返回日志
func (t *curlLogTransport) respToCurl(resp *http.Response, st time.Time) (string, error) {
var (
bodyBytes []byte
err error
)
if resp.Body != nil {
bodyBytes, err = io.ReadAll(resp.Body)
if err != nil {
return "", err
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
if len(bodyBytes) > 1024 {
respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)",
resp.Status, time.Since(st), bodyBytes[:1024], len(bodyBytes))
return respMsg, nil
}
if len(bodyBytes) > 0 {
respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes)
return respMsg, nil
}
respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st))
return respMsg, nil
}
// curlLogTransport print curl log transport
type curlLogTransport struct {
maskKeys map[string]string
}
// RoundTrip curlLog Transport
func (t *curlLogTransport) Transport(next http.RoundTripper) http.RoundTripper {
f := func(req *http.Request) (*http.Response, error) {
logger := slog.With("request_id", req.Header.Get("X-Request-Id"))
st := time.Now()
// 记录请求
rbody, err := t.reqToCurl(req)
if err != nil {
return nil, err
}
logger.Info("curl", "REQ", rbody)
resp, err := next.RoundTrip(req)
if err != nil {
logger.Error("curl", "RESP: [err]", err)
return nil, err
}
// 记录返回
respBody, err := t.respToCurl(resp, st)
if err != nil {
return nil, err
}
logger.Info("curl", "RESP", respBody)
return resp, nil
}
return RoundTripFunc(f)
}
// CurlLogTransport make a new curl log transport, default transport can be nil
func CurlLogTransport(maskkeys map[string]string) func(http.RoundTripper) http.RoundTripper {
t := curlLogTransport{map[string]string{}}
return t.Transport
}