163 lines
3.5 KiB
Go
163 lines
3.5 KiB
Go
|
package restyclient
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log/slog"
|
||
|
"net/http"
|
||
|
"time"
|
||
|
|
||
|
"github.com/dustin/go-humanize"
|
||
|
)
|
||
|
|
||
|
type ctxKey int
|
||
|
|
||
|
const (
|
||
|
requestIDCtxKey = ctxKey(1)
|
||
|
requestIDHeaderKey = "X-Request-Id"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// maskKeys 敏感参数和头部key
|
||
|
maskKeys = map[string]struct{}{
|
||
|
"bk_app_secret": {},
|
||
|
"bk_token": {},
|
||
|
"Authorization": {},
|
||
|
"X-Bkapi-Authorization": {},
|
||
|
}
|
||
|
)
|
||
|
|
||
|
func getRequestID(r *http.Request) string {
|
||
|
v, ok := r.Context().Value(requestIDCtxKey).(string)
|
||
|
if ok && v != "" {
|
||
|
return v
|
||
|
}
|
||
|
|
||
|
rid := r.Header.Get(requestIDHeaderKey)
|
||
|
if rid != "" {
|
||
|
return rid
|
||
|
}
|
||
|
|
||
|
return GenRequestID()
|
||
|
}
|
||
|
|
||
|
// reqToCurl curl 格式的请求日志
|
||
|
func reqToCurl(r *http.Request) (string, error) {
|
||
|
// 过滤掉敏感信息, header 和 query
|
||
|
headers := ""
|
||
|
for key, values := range r.Header {
|
||
|
for _, value := range values {
|
||
|
if _, ok := maskKeys[key]; ok {
|
||
|
value = "***"
|
||
|
}
|
||
|
headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
rawURL := *r.URL
|
||
|
queryValue := rawURL.Query()
|
||
|
for key := range queryValue {
|
||
|
if _, ok := maskKeys[key]; ok {
|
||
|
queryValue.Set(key, "<masked>")
|
||
|
}
|
||
|
}
|
||
|
rawURL.RawQuery = queryValue.Encode()
|
||
|
|
||
|
reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers)
|
||
|
if r.Body != nil {
|
||
|
bodyBytes, err := io.ReadAll(r.Body)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
r.Body.Close() // nolint
|
||
|
|
||
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||
|
if len(bodyBytes) > 1024 {
|
||
|
reqMsg += fmt.Sprintf(" -d '%s...(Total %s)'", bodyBytes[:1024], humanize.Bytes(uint64(len(bodyBytes))))
|
||
|
} else {
|
||
|
reqMsg += fmt.Sprintf(" -d '%s'", bodyBytes)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return reqMsg, nil
|
||
|
}
|
||
|
|
||
|
// respToCurl 返回日志
|
||
|
func respToCurl(resp *http.Response, st time.Time) (string, error) {
|
||
|
var (
|
||
|
bodyBytes []byte
|
||
|
err error
|
||
|
)
|
||
|
if resp.Body != nil {
|
||
|
bodyBytes, err = io.ReadAll(resp.Body)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
resp.Body.Close() // nolint
|
||
|
|
||
|
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||
|
}
|
||
|
|
||
|
if len(bodyBytes) > 1024 {
|
||
|
respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)",
|
||
|
resp.Status, time.Since(st), bodyBytes[:1024], humanize.Bytes(uint64(len(bodyBytes))))
|
||
|
return respMsg, nil
|
||
|
}
|
||
|
|
||
|
if len(bodyBytes) > 0 {
|
||
|
respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes)
|
||
|
return respMsg, nil
|
||
|
}
|
||
|
|
||
|
respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st))
|
||
|
return respMsg, nil
|
||
|
}
|
||
|
|
||
|
// curlLogTransport print curl log transport
|
||
|
type curlLogTransport struct {
|
||
|
Transport http.RoundTripper
|
||
|
}
|
||
|
|
||
|
// RoundTrip curlLog Transport
|
||
|
func (t *curlLogTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
|
st := time.Now()
|
||
|
|
||
|
rid := getRequestID(req)
|
||
|
req.Header.Set(requestIDHeaderKey, rid)
|
||
|
|
||
|
// 记录请求
|
||
|
rbody, err := reqToCurl(req)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
slog.With("req_id", rid).Info("restyclient REQ", "body", rbody)
|
||
|
|
||
|
resp, err := t.transport(req).RoundTrip(req)
|
||
|
if err != nil {
|
||
|
slog.With("req_id", rid).Error("restyclient RESP", "err", err)
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// 记录返回
|
||
|
respBody, err := respToCurl(resp, st)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
slog.With("req_id", rid).Info("restyclient REQ", "body", respBody)
|
||
|
|
||
|
return resp, nil
|
||
|
}
|
||
|
|
||
|
func (t *curlLogTransport) transport(req *http.Request) http.RoundTripper { //nolint:unparam
|
||
|
if t.Transport != nil {
|
||
|
return t.Transport
|
||
|
}
|
||
|
return http.DefaultTransport
|
||
|
}
|
||
|
|
||
|
// NewCurlLogTransport make a new curl log transport, default transport can be nil
|
||
|
func NewCurlLogTransport(transport http.RoundTripper) http.RoundTripper {
|
||
|
return &curlLogTransport{Transport: transport}
|
||
|
}
|