diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..b8f21dc --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,85 @@ +# Code generated by gen-lint. DO NOT EDIT. + +run: + timeout: 5m + +issues: + # 显示所有 issue + max-issues-per-linter: 0 + max-same-issues: 0 + exclude-use-default: false + +linters: + disable-all: true + enable: + # enable by default + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - unused + + # custom + - gci + - goconst + - gofmt + - goheader + - goimports + - gosec + - misspell + - nakedret + - revive + - unconvert + - unparam + +linters-settings: + # 只开启特定的规则 + errcheck: + exclude-functions: + - (*os.File).Close + - (io.Closer).Close + - (net/http.ResponseWriter).Write + - io.Copy + - os.RemoveAll + govet: + enable: + - shadow + goimports: + local-prefixes: git.ifooth.com/common/pkg + gci: + sections: + - standard + - default + - prefix(git.ifooth.com/common/pkg) + gosec: + includes: + - G201 # SQL query construction using format string + - G202 # SQL query construction using string concatenation + - G101 # Look for hard coded credentials + - G401 # Detect the usage of DES, RC4, MD5 or SHA1 + - G402 # Look for bad TLS connection settings + - G403 # Ensure minimum RSA key length of 2048 bits + - G404 # Insecure random number source (rand) + - G504 # Import blocklist: net/http/cgi + misspell: + locale: US + revive: + rules: + - name: line-length-limit + arguments: + - 160 + - name: function-length + arguments: + - 80 # statements + - 120 # lines + - name: cyclomatic + arguments: + - 30 + - name: use-any + - name: early-return + - name: exported + arguments: + - checkPrivateReceivers + - sayRepetitiveInsteadOfStutters + - name: package-comments diff --git a/apis/logger.go b/apis/logger.go deleted file mode 100644 index 2826595..0000000 --- a/apis/logger.go +++ /dev/null @@ -1,62 +0,0 @@ -package apis - -import ( - "log/slog" - "net" - "net/http" - "strings" - "time" - - "github.com/go-chi/chi/v5/middleware" -) - -// ClientIP getIP returns the ip address from the http request -func ClientIP(r *http.Request) string { - xForwardedFor := r.Header.Get("X-Forwarded-For") - ip := strings.TrimSpace(strings.Split(xForwardedFor, ",")[0]) - if ip != "" { - return ip - } - - ip = strings.TrimSpace(r.Header.Get("X-Real-Ip")) - if ip != "" { - return ip - } - - if ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)); err == nil { - return ip - } - - return "" -} - -// Logger returns a `func(http.Handler) http.Handler` (middleware) that logs requests using slog. -func Logger(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - st := time.Now() - ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - - defer func() { - msg := r.Method + " " + r.RequestURI + " " + r.Proto - - ip := ClientIP(r) - status := ww.Status() - if status == 0 { - status = http.StatusOK - } - - attrs := []slog.Attr{ - slog.String("ip", ip), - slog.String("id", middleware.GetReqID(r.Context())), - slog.Int("status", status), - slog.Duration("latency", time.Since(st)), - slog.Int("length", ww.BytesWritten()), - } - slog.LogAttrs(r.Context(), slog.LevelInfo, msg, attrs...) - }() - - next.ServeHTTP(ww, r) - } - - return http.HandlerFunc(fn) -} diff --git a/apis/middleware.go b/apis/middleware.go deleted file mode 100644 index cfe0293..0000000 --- a/apis/middleware.go +++ /dev/null @@ -1,80 +0,0 @@ -// Package apis for http -package apis - -import ( - "context" - "net/http" - "path/filepath" - "strings" - - "github.com/go-chi/chi/v5/middleware" - "github.com/google/uuid" - - "git.ifooth.com/common/pkg/components" -) - -type contextKey struct { - name string -} - -var ( - reqCtxKey = &contextKey{"HTTPRequest"} -) - -// RequestIdGenerator request_id -func RequestIdGenerator() string { - uid := uuid.New().String() - requestId := strings.Replace(uid, "-", "", -1) - return requestId -} - -// RequestID reuqest_id -func RequestID(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - requestID := r.Header.Get(components.RequestIDHeaderKey) - if requestID == "" { - requestID = RequestIdGenerator() - } - - ctx = components.WithRequestIDValue(ctx, requestID) - ctx = context.WithValue(ctx, middleware.RequestIDKey, requestID) - - w.Header().Set(components.RequestIDHeaderKey, requestID) - next.ServeHTTP(w, r.WithContext(ctx)) - } - return http.HandlerFunc(fn) -} - -// AuthRequired API类型, 兼容多种鉴权模式 -func AuthRequired(next http.Handler) http.Handler { - ignoreExtMap := map[string]struct{}{ - ".js": {}, - ".css": {}, - ".map": {}, - ".png": {}, - } - - fn := func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions { - next.ServeHTTP(w, r) - return - } - - // 静态资源过滤, 注意不会带鉴权信息 - fileExt := filepath.Ext(r.URL.Path) - if _, ok := ignoreExtMap[fileExt]; ok { - next.ServeHTTP(w, r) - return - } - - // switch { - // default: - // render.Render(w, r, rest.AbortWithUnauthorizedError(rest.UnauthorizedError)) - // return - // } - - next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) -} diff --git a/http/httpserver/io.go b/http/httpserver/io.go new file mode 100644 index 0000000..cef9e75 --- /dev/null +++ b/http/httpserver/io.go @@ -0,0 +1,88 @@ +package httpserver + +import ( + "bytes" + "fmt" + "io" +) + +// LimitBuffer buf with limit +type LimitBuffer interface { + io.Writer + String() string + Remain() int +} + +// limitBuffer buf with limit +type limitBuffer struct { + buf *bytes.Buffer + limit int + remain int + length int +} + +// NewLimitBuffer .. +func NewLimitBuffer(limit int) LimitBuffer { + buf := bytes.NewBuffer(make([]byte, 0, limit)) + return &limitBuffer{buf: buf, limit: limit, remain: limit} +} + +func (b *limitBuffer) Write(p []byte) (n int, err error) { + defer func() { + b.length += n + }() + + // discord + if b.remain <= 0 { + return len(p), nil + } + + // write remain + if len(p) > b.remain { + n, err = b.buf.Write(p[:b.remain]) + b.remain -= n + + // alway return all writed length + n = len(p) + return + } + + // write all + n, err = b.buf.Write(p) + b.remain -= n + return +} + +// String .. +func (b *limitBuffer) String() string { + if b.length > b.limit { + return b.buf.String() + fmt.Sprintf("...(total %dB)", b.length) + } + + return b.buf.String() +} + +// Remain 剩余多少个字节没有写满 +func (b *limitBuffer) Remain() int { + return b.remain +} + +type teeReadCloser struct { + tee io.Reader + r io.Closer +} + +// Read implement io.Read interface +func (t *teeReadCloser) Read(p []byte) (n int, err error) { + return t.tee.Read(p) +} + +// Close implement io.Close interface +func (t *teeReadCloser) Close() error { + return t.r.Close() +} + +// TeeReadCloser TeeReader with limit +func TeeReadCloser(r io.ReadCloser, w io.Writer) io.ReadCloser { + return &teeReadCloser{r: r, tee: io.TeeReader(r, w)} +} diff --git a/http/httpserver/io_test.go b/http/httpserver/io_test.go new file mode 100644 index 0000000..aca9655 --- /dev/null +++ b/http/httpserver/io_test.go @@ -0,0 +1,49 @@ +package httpserver + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLimitBufer(t *testing.T) { + buf := NewLimitBuffer(10) + n, err := buf.Write([]byte("1")) + assert.NoError(t, err) + assert.Equal(t, 1, n) + assert.Equal(t, "1", buf.String()) + assert.Equal(t, 9, buf.Remain()) + + n, err = buf.Write([]byte("TestLimitBufer")) + // 设置limit但是不会返回 EOF + assert.NoError(t, err) + assert.Equal(t, 14, n) + assert.Equal(t, "1TestLimit...(total 15B)", buf.String()) + assert.Equal(t, 0, buf.Remain()) + + n, err = buf.Write([]byte("TestLimitBufer")) + // 设置limit但是不会返回 EOF + assert.NoError(t, err) + assert.Equal(t, 14, n) + assert.Equal(t, "1TestLimit...(total 29B)", buf.String()) + assert.Equal(t, 0, buf.Remain()) +} + +func TestTeeReader(t *testing.T) { + a := bytes.NewBuffer([]byte("TestLimitBufer")) + buf := NewLimitBuffer(10) + tee := TeeReadCloser(io.NopCloser(a), buf) + + _, _ = io.CopyN(io.Discard, tee, 1) + assert.Equal(t, "T", buf.String()) + assert.Equal(t, 9, buf.Remain()) + + io.Copy(io.Discard, io.LimitReader(tee, int64(buf.Remain()))) + assert.Equal(t, "TestLimitB", buf.String()) + assert.Equal(t, 0, buf.Remain()) + + io.Copy(io.Discard, tee) + assert.Equal(t, "TestLimitB...(total 14B)", buf.String()) +} diff --git a/apis/generic.go b/http/rest/generic.go similarity index 99% rename from apis/generic.go rename to http/rest/generic.go index 0a4b3f2..816af23 100644 --- a/apis/generic.go +++ b/http/rest/generic.go @@ -1,4 +1,4 @@ -package apis +package rest import ( "context" diff --git a/apis/generic_test.go b/http/rest/generic_test.go similarity index 99% rename from apis/generic_test.go rename to http/rest/generic_test.go index 7190b9e..ddbd803 100644 --- a/apis/generic_test.go +++ b/http/rest/generic_test.go @@ -1,4 +1,4 @@ -package apis +package rest import ( "bytes" diff --git a/apis/handler.go b/http/rest/handler.go similarity index 98% rename from apis/handler.go rename to http/rest/handler.go index fd71723..fcf0150 100644 --- a/apis/handler.go +++ b/http/rest/handler.go @@ -1,4 +1,4 @@ -package apis +package rest import ( "net/http" diff --git a/apis/metrics.go b/http/rest/metrics.go similarity index 99% rename from apis/metrics.go rename to http/rest/metrics.go index ca6dbae..0e4f57f 100644 --- a/apis/metrics.go +++ b/http/rest/metrics.go @@ -1,4 +1,4 @@ -package apis +package rest import ( "reflect" diff --git a/http/rest/middleware.go b/http/rest/middleware.go new file mode 100644 index 0000000..af13c32 --- /dev/null +++ b/http/rest/middleware.go @@ -0,0 +1,84 @@ +// Package apis for http +package rest + +import ( + "fmt" + "log/slog" + "net/http" + "path/filepath" + "time" + + "github.com/go-chi/chi/v5/middleware" + + "git.ifooth.com/common/pkg/http/httpserver" + "git.ifooth.com/common/pkg/http/restyclient" +) + +// HandleLogger 记录请求日志 +func HandleLogger(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + st := time.Now() + + // 优先使用蓝鲸网关的request_id + reqId := r.Header.Get("X-Bkapi-Request-ID") + if reqId == "" { + reqId = r.Header.Get("X-Request-Id") + } + + ctx := restyclient.WithRequestID(r.Context(), reqId) + r = r.WithContext(ctx) + + limit := 2048 + reqBuf := httpserver.NewLimitBuffer(limit) + r.Body = httpserver.TeeReadCloser(r.Body, reqBuf) + + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + respBuf := httpserver.NewLimitBuffer(limit) + ww.Tee(respBuf) + + next.ServeHTTP(ww, r) + + // 保证能读取前1K字符 + // if reqBuf.Remain() > 0 { + // io.Copy(io.Discard, io.LimitReader(r.Body, int64(reqBuf.Remain()))) + // } + + msg := fmt.Sprintf("Handle %s %s From %s", r.Method, r.RequestURI, r.RemoteAddr) + slog.Info(msg, "req_id", reqId, "status", ww.Status(), "duration", time.Since(st), "req", reqBuf.String(), "resp", respBuf.String()) + } + + return http.HandlerFunc(fn) +} + +// AuthRequired API类型, 兼容多种鉴权模式 +func AuthRequired(next http.Handler) http.Handler { + ignoreExtMap := map[string]struct{}{ + ".js": {}, + ".css": {}, + ".map": {}, + ".png": {}, + } + + fn := func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + return + } + + // 静态资源过滤, 注意不会带鉴权信息 + fileExt := filepath.Ext(r.URL.Path) + if _, ok := ignoreExtMap[fileExt]; ok { + next.ServeHTTP(w, r) + return + } + + // switch { + // default: + // render.Render(w, r, rest.AbortWithUnauthorizedError(rest.UnauthorizedError)) + // return + // } + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} diff --git a/http/rest/request.go b/http/rest/request.go new file mode 100644 index 0000000..b3f0a9f --- /dev/null +++ b/http/rest/request.go @@ -0,0 +1,54 @@ +package rest + +import ( + "context" + "net/http" + + "github.com/google/uuid" +) + +type contextKey struct { + name string +} + +var ( + reqCtxKey = &contextKey{"HTTPRequest"} +) + +type ctxKey int + +const ( + requestIDCtxKey = ctxKey(1) + requestIDHeaderKey = "X-Request-Id" +) + +// HTTPRequest return svr's request +func HTTPRequest(ctx context.Context) *http.Request { // nolint + val, ok := ctx.Value(reqCtxKey).(*http.Request) + if !ok { + panic("missing request in context") + } + return val +} + +// GenRequestID 生产 request_id +func GenRequestID() string { + id, _ := uuid.NewRandom() + return id.String() +} + +// RequestIDValue 获取 RequestId 值 +func RequestIDValue(ctx context.Context) string { + v, ok := ctx.Value(requestIDCtxKey).(string) + if !ok { + return "" + } + + return v +} + +// WithRequestID 设置 request_id +func WithRequestID(ctx context.Context, id string) context.Context { + newCtx := context.WithValue(ctx, requestIDCtxKey, id) + return newCtx +} diff --git a/apis/response.go b/http/rest/response.go similarity index 74% rename from apis/response.go rename to http/rest/response.go index c1ae75c..1707519 100644 --- a/apis/response.go +++ b/http/rest/response.go @@ -1,12 +1,10 @@ -package apis +package rest import ( "net/http" "github.com/go-chi/render" "github.com/pkg/errors" - - "git.ifooth.com/common/pkg/components" ) var ( @@ -14,28 +12,6 @@ var ( UnauthorizedError = errors.New("用户未登入") ) -// HandlerFunc -type RestHandlerFunc func(r *http.Request) (interface{}, error) - -func (h RestHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { - data, err := h(r) - if err != nil { - // handle returned error here. - render.Render(w, r, AbortWithBadRequestError(err)) - return - } - - switch v := data.(type) { - case render.Renderer: - if err := v.Render(w, r); err != nil { - render.Render(w, r, AbortWithBadRequestError(err)) - return - } - default: - render.JSON(w, r, data) - } -} - // APIResponse 返回的标准结构 type APIResponse struct { Err error `json:"-"` // low-level runtime error @@ -53,7 +29,7 @@ func (res *APIResponse) Render(w http.ResponseWriter, r *http.Request) error { statusCode = http.StatusOK } - res.RequestId = components.RequestIDValue(r.Context()) + res.RequestId = RequestIDValue(r.Context()) render.Status(r, statusCode) return nil diff --git a/http/restyclient/client.go b/http/restyclient/client.go new file mode 100644 index 0000000..c91d36f --- /dev/null +++ b/http/restyclient/client.go @@ -0,0 +1,101 @@ +// Package restyclient for http client +package restyclient + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "sync" + "time" + + "github.com/go-resty/resty/v2" + "github.com/google/uuid" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +const ( + timeout = time.Second * 30 +) + +var ( + clientOnce sync.Once + silentClientOnce sync.Once + globalClient *resty.Client + globalSilentClient *resty.Client +) + +var dialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, +} + +// defaultTransport default transport +var defaultTransport http.RoundTripper = &http.Transport{ + DialContext: dialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + // NOCC:gas/tls(设计如此) + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint +} + +// New : 新建 Client, 设置公共参数,tracing 等; 每次新建,cookies不复用 +func New() *resty.Client { + if globalClient == nil { + clientOnce.Do(func() { + globalClient = resty.New(). + SetTransport(otelhttp.NewTransport(defaultTransport)). + SetTimeout(timeout). + SetCookieJar(nil). + SetDebugBodyLimit(1024). + OnAfterResponse(restyAfterResponseHook). + SetPreRequestHook(restyBeforeRequestHook). + OnError(restyErrHook). + SetHeader("User-Agent", "envmgr-restyclient") + }) + } + return globalClient +} + +// silentNew 安静模式,只打印错误日志 +func silentNew() *resty.Client { + if globalSilentClient == nil { + silentClientOnce.Do(func() { + globalSilentClient = resty.New(). + SetTransport(otelhttp.NewTransport(defaultTransport)). + SetTimeout(timeout). + SetCookieJar(nil). + SetDebugBodyLimit(1024). + // OnAfterResponse(restyAfterResponseHook). + // SetPreRequestHook(restyBeforeRequestHook). + OnError(restyErrHook). + SetHeader("User-Agent", "envmgr-restyclient") + }) + } + return globalSilentClient +} + +// R : New().R() 快捷方式, 已设置公共参数,tracing 等 +func R() *resty.Request { + return New().R() +} + +// SilentR : 安静模式,只打印错误日志 已设置公共参数,tracing 等, 只打印错误日志 +func SilentR() *resty.Request { + return silentNew().R() +} + +// GenRequestID 生产 request_id +func GenRequestID() string { + id, _ := uuid.NewRandom() + return id.String() +} + +// WithRequestID 设置 request_id +func WithRequestID(ctx context.Context, id string) context.Context { + newCtx := context.WithValue(ctx, requestIDCtxKey, id) + return newCtx +} diff --git a/http/restyclient/helper.go b/http/restyclient/helper.go new file mode 100644 index 0000000..108f009 --- /dev/null +++ b/http/restyclient/helper.go @@ -0,0 +1,74 @@ +package restyclient + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + + resty "github.com/go-resty/resty/v2" +) + +// CodeNotZeroErr ... +var CodeNotZeroErr = errors.New("resp code != 0") + +// BKResult 蓝鲸返回规范的结构体 +type BKResult[T any] struct { + Result bool `json:"result"` // 部分蓝鲸接口有, 按需校验 + Code any `json:"code"` + Message string `json:"message"` + Data *T `json:"data"` +} + +// NewBKResult create NewBKResult by resp +func NewBKResult[T any](resp *resty.Response) (*BKResult[T], error) { + if !resp.IsSuccess() { + return nil, fmt.Errorf("request failed, status: %s, message: %s", resp.Status(), resp.Body()) + } + + bkResult := new(BKResult[T]) + if err := json.Unmarshal(resp.Body(), bkResult); err != nil { + return nil, err + } + + if err := bkResult.ValidateCode(); err != nil { + return nil, err + } + + return bkResult, nil +} + +// NewBKData only create data by resp +func NewBKData[T any](resp *resty.Response) (*T, error) { + bkResult, err := NewBKResult[T](resp) + if err != nil { + return nil, err + } + + return bkResult.Data, nil +} + +// ValidateCode 返回结果是否OK +func (r *BKResult[T]) ValidateCode() error { + var resultCode int + + switch code := r.Code.(type) { + case int: + resultCode = code + case float64: + resultCode = int(code) + case string: + c, err := strconv.Atoi(code) + if err != nil { + return err + } + resultCode = c + default: + return fmt.Errorf("conversion to int from %T not supported", code) + } + + if resultCode != 0 { + return fmt.Errorf("%w, code=%d, message=%s", CodeNotZeroErr, resultCode, r.Message) + } + return nil +} diff --git a/http/restyclient/hook.go b/http/restyclient/hook.go new file mode 100644 index 0000000..fce5202 --- /dev/null +++ b/http/restyclient/hook.go @@ -0,0 +1,42 @@ +package restyclient + +import ( + "fmt" + "log/slog" + "net/http" + + "github.com/dustin/go-humanize" + "github.com/go-resty/resty/v2" +) + +// restyBeforeRequestHook 请求hook +func restyBeforeRequestHook(c *resty.Client, r *http.Request) error { + rid := getRequestID(r) + r.Header.Set(requestIDHeaderKey, rid) + + rbody, err := reqToCurl(r) + if err != nil { + return err + } + + slog.With("req_id", rid).Info("restyclient REQ", "body", rbody) + return nil +} + +// restyAfterResponseHook 正常返回hook +func restyAfterResponseHook(c *resty.Client, resp *resty.Response) error { + // 最大打印 1024 个字符 + body := string(resp.Body()) + if len(body) > 1024 { + body = fmt.Sprintf("%s...(Total %s)", body[:1024], humanize.Bytes(uint64(len(body)))) + } + + slog.With("req_id", getRequestID(resp.RawResponse.Request)).Info("restyclient RESP", "status", resp.Status(), "duration", resp.Time(), "body", body) + + return nil +} + +// restyErrHook 错误hook +func restyErrHook(r *resty.Request, err error) { + slog.With("req_id", getRequestID(r.RawRequest)).Error("restyclient RESP", "err", err) +} diff --git a/http/restyclient/transport.go b/http/restyclient/transport.go new file mode 100644 index 0000000..6d54a11 --- /dev/null +++ b/http/restyclient/transport.go @@ -0,0 +1,162 @@ +package restyclient + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/dustin/go-humanize" +) + +type ctxKey int + +const ( + requestIDCtxKey = ctxKey(1) + requestIDHeaderKey = "X-Request-Id" +) + +var ( + // maskKeys 敏感参数和头部key + maskKeys = map[string]struct{}{ + "bk_app_secret": {}, + "bk_token": {}, + "Authorization": {}, + "X-Bkapi-Authorization": {}, + } +) + +func getRequestID(r *http.Request) string { + v, ok := r.Context().Value(requestIDCtxKey).(string) + if ok && v != "" { + return v + } + + rid := r.Header.Get(requestIDHeaderKey) + if rid != "" { + return rid + } + + return GenRequestID() +} + +// reqToCurl curl 格式的请求日志 +func reqToCurl(r *http.Request) (string, error) { + // 过滤掉敏感信息, header 和 query + headers := "" + for key, values := range r.Header { + for _, value := range values { + if _, ok := maskKeys[key]; ok { + value = "***" + } + headers += fmt.Sprintf(" -H %q", fmt.Sprintf("%s: %s", key, value)) + } + } + + rawURL := *r.URL + queryValue := rawURL.Query() + for key := range queryValue { + if _, ok := maskKeys[key]; ok { + queryValue.Set(key, "") + } + } + rawURL.RawQuery = queryValue.Encode() + + reqMsg := fmt.Sprintf("curl -X %s '%s'%s", r.Method, rawURL.String(), headers) + if r.Body != nil { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return "", err + } + r.Body.Close() // nolint + + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + if len(bodyBytes) > 1024 { + reqMsg += fmt.Sprintf(" -d '%s...(Total %s)'", bodyBytes[:1024], humanize.Bytes(uint64(len(bodyBytes)))) + } else { + reqMsg += fmt.Sprintf(" -d '%s'", bodyBytes) + } + } + + return reqMsg, nil +} + +// respToCurl 返回日志 +func respToCurl(resp *http.Response, st time.Time) (string, error) { + var ( + bodyBytes []byte + err error + ) + if resp.Body != nil { + bodyBytes, err = io.ReadAll(resp.Body) + if err != nil { + return "", err + } + resp.Body.Close() // nolint + + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + if len(bodyBytes) > 1024 { + respMsg := fmt.Sprintf("[%s] %s %s...(Total %s)", + resp.Status, time.Since(st), bodyBytes[:1024], humanize.Bytes(uint64(len(bodyBytes)))) + return respMsg, nil + } + + if len(bodyBytes) > 0 { + respMsg := fmt.Sprintf("[%s] %s %s", resp.Status, time.Since(st), bodyBytes) + return respMsg, nil + } + + respMsg := fmt.Sprintf("[%s] %s", resp.Status, time.Since(st)) + return respMsg, nil +} + +// curlLogTransport print curl log transport +type curlLogTransport struct { + Transport http.RoundTripper +} + +// RoundTrip curlLog Transport +func (t *curlLogTransport) RoundTrip(req *http.Request) (*http.Response, error) { + st := time.Now() + + rid := getRequestID(req) + req.Header.Set(requestIDHeaderKey, rid) + + // 记录请求 + rbody, err := reqToCurl(req) + if err != nil { + return nil, err + } + slog.With("req_id", rid).Info("restyclient REQ", "body", rbody) + + resp, err := t.transport(req).RoundTrip(req) + if err != nil { + slog.With("req_id", rid).Error("restyclient RESP", "err", err) + return nil, err + } + + // 记录返回 + respBody, err := respToCurl(resp, st) + if err != nil { + return nil, err + } + slog.With("req_id", rid).Info("restyclient REQ", "body", respBody) + + return resp, nil +} + +func (t *curlLogTransport) transport(req *http.Request) http.RoundTripper { //nolint:unparam + if t.Transport != nil { + return t.Transport + } + return http.DefaultTransport +} + +// NewCurlLogTransport make a new curl log transport, default transport can be nil +func NewCurlLogTransport(transport http.RoundTripper) http.RoundTripper { + return &curlLogTransport{Transport: transport} +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..3549c71 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,36 @@ +// Package logger provider std slog logger +package logger + +import ( + "log/slog" + "os" + "path/filepath" + "strconv" +) + +// Init 初始化 slog +func Init() { + textHandler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelInfo, + ReplaceAttr: ReplaceSourceAttr, + }) + + logger := slog.New(textHandler) + slog.SetDefault(logger) +} + +// ReplaceSourceAttr source 格式化为 dir/file:line 格式 +func ReplaceSourceAttr(groups []string, a slog.Attr) slog.Attr { + if a.Key != slog.SourceKey { + return a + } + + src, ok := a.Value.Any().(*slog.Source) + if !ok { + return a + } + + a.Value = slog.StringValue(filepath.Base(src.File) + ":" + strconv.Itoa(src.Line)) + return a +} diff --git a/metrics/metrics.go b/metrics/metrics.go deleted file mode 100644 index e0b60a1..0000000 --- a/metrics/metrics.go +++ /dev/null @@ -1,66 +0,0 @@ -package metrics - -import ( - "net/http" - "strconv" - "time" - - "github.com/go-chi/chi/v5/middleware" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/collectors" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -var ( - reg = prometheus.NewRegistry() - - // http 请求总量 - httpRequestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "http_requests_total", - Help: "Counter of HTTP requests to prime", - }, []string{"handler", "method", "code"}) - - // http 请求耗时 - httpRequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "http_request_duration_seconds", - Help: "Histogram of latencies for HTTP requests to prime.", - Buckets: []float64{0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1}, - }, []string{"handler", "method", "code"}) -) - -func init() { - reg.MustRegister( - collectors.NewGoCollector(), - collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), - ) - - reg.MustRegister(httpRequestsTotal) - reg.MustRegister(httpRequestDuration) -} - -// Handler -func Handler() http.HandlerFunc { - return promhttp.HandlerFor(reg, promhttp.HandlerOpts{Registry: reg}).ServeHTTP -} - -// collectHTTPRequestMetric http metrics 处理 -func collectHTTPRequestMetric(handler, method, code string, duration time.Duration) { - httpRequestsTotal.WithLabelValues(handler, method, code).Inc() - httpRequestDuration.WithLabelValues(handler, method, code).Observe(duration.Seconds()) -} - -// RequestCollect -func RequestCollect(name string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - t1 := time.Now() - defer func() { - collectHTTPRequestMetric(name, r.Method, strconv.Itoa(ww.Status()), time.Since(t1)) - }() - - next.ServeHTTP(ww, r) - } - return http.HandlerFunc(fn) - } -}