pkg/http/rest/middleware.go

70 lines
1.7 KiB
Go

// Package apis for http
package rest
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5/middleware"
"git.ifooth.com/common/pkg/http/httpserver"
"git.ifooth.com/common/pkg/http/restyclient"
)
// RequestID reuqest_id
func RequestID(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := r.Header.Get(requestIDHeaderKey)
if requestID == "" {
requestID = GenRequestID()
}
ctx = WithRequestID(ctx, requestID)
ctx = context.WithValue(ctx, middleware.RequestIDKey, requestID)
w.Header().Set(requestIDHeaderKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// HandleLogger 记录请求日志
func HandleLogger(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
st := time.Now()
// 优先使用蓝鲸网关的request_id
reqId := r.Header.Get("X-Bkapi-Request-ID")
if reqId == "" {
reqId = r.Header.Get("X-Request-Id")
}
ctx := restyclient.WithRequestID(r.Context(), reqId)
r = r.WithContext(ctx)
limit := 2048
reqBuf := httpserver.NewLimitBuffer(limit)
r.Body = httpserver.TeeReadCloser(r.Body, reqBuf)
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
respBuf := httpserver.NewLimitBuffer(limit)
ww.Tee(respBuf)
next.ServeHTTP(ww, r)
// 保证能读取前1K字符
// if reqBuf.Remain() > 0 {
// io.Copy(io.Discard, io.LimitReader(r.Body, int64(reqBuf.Remain())))
// }
msg := fmt.Sprintf("Handle %s %s From %s", r.Method, r.RequestURI, r.RemoteAddr)
slog.Info(msg, "req_id", reqId, "status", ww.Status(), "duration", time.Since(st), "req", reqBuf.String(), "resp", respBuf.String())
}
return http.HandlerFunc(fn)
}