pkg/trace/middleware.go

79 lines
1.9 KiB
Go

package trace
import (
"context"
"log/slog"
"net/http"
"strings"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
)
const (
// requestIDHeader is the name of the HTTP Header which contains the request id.
requestIDHeader = "X-Request-Id"
)
// requestIDContext 自定义从X-Request-Id统一转换为traceID
type requestIDContext struct{}
// Inject implement otel/propagation.TextMapPropagator interface
func (r requestIDContext) Inject(ctx context.Context, carrier propagation.TextMapCarrier) {
spanCtx := trace.SpanContextFromContext(ctx)
if !spanCtx.HasTraceID() {
return
}
carrier.Set(requestIDHeader, spanCtx.TraceID().String())
}
// Extract implement otel/propagation.TextMapPropagator interface
// 如果没有traceparent,且X-Request-Id是合法的,生成对应的traceID
func (r requestIDContext) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
spanCtx := trace.SpanContextFromContext(ctx)
if spanCtx.HasTraceID() {
return ctx
}
rid := carrier.Get(requestIDHeader)
if rid == "" {
return ctx
}
// 兼容uuid格式
newRid := strings.ReplaceAll(rid, "-", "")
traceID, err := trace.TraceIDFromHex(newRid)
if err != nil {
slog.ErrorContext(ctx, "conv rid to traceID failed", "rid", rid, "err", err)
return ctx
}
spanCtx = trace.NewSpanContext(trace.SpanContextConfig{
TraceID: traceID,
})
return trace.ContextWithRemoteSpanContext(ctx, spanCtx)
}
// Fields implement otel/propagation.TextMapPropagator interface
func (r requestIDContext) Fields() []string {
return []string{requestIDHeader}
}
// Middleware 扩展 otelhttp trace 中间件, 支持traceparent / X-Request-Id
func Middleware(next http.Handler) http.Handler {
f := func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
}
otelHandler := otelhttp.NewHandler(
http.HandlerFunc(f),
"trace",
)
return otelHandler
}