diff --git a/rest/generic.go b/rest/generic.go index fcfdb48..1b799d5 100644 --- a/rest/generic.go +++ b/rest/generic.go @@ -24,6 +24,10 @@ import ( "time" ) +var ( + reqCtxKey = ctxKey("generic.HTTPRequest") +) + // UnaryFunc Unary or ClientStreaming handle function type UnaryFunc[Req, Resp any] func(context.Context, *Req) (*Resp, error) @@ -55,14 +59,17 @@ func Handle[Req, Resp any](fn UnaryFunc[Req, Resp]) func(w http.ResponseWriter, return } + ctx := r.Context() + ctx = context.WithValue(ctx, reqCtxKey, r) + // 参数校验 - if err = validateReq(r.Context(), in); err != nil { + if err = validateReq(ctx, in); err != nil { slog.Error("validate req failed", "err", err) APIError(err).Render(w, r) return } - out, err := fn(r.Context(), in) + out, err := fn(ctx, in) if err != nil { APIError(err).Render(w, r) return @@ -102,8 +109,11 @@ func Stream[Req any](fn StreamFunc[Req]) func(w http.ResponseWriter, r *http.Req return } + ctx := r.Context() + ctx = context.WithValue(ctx, reqCtxKey, r) + // 参数校验 - if err = validateReq(r.Context(), in); err != nil { + if err = validateReq(ctx, in); err != nil { slog.Error("validate stream req failed", "err", err) APIError(err).Render(w, r) return @@ -112,7 +122,7 @@ func Stream[Req any](fn StreamFunc[Req]) func(w http.ResponseWriter, r *http.Req svr := &streamingServer{ ResponseWriter: w, ResponseController: http.NewResponseController(w), - ctx: r.Context(), + ctx: ctx, } err = fn(in, svr) @@ -123,6 +133,15 @@ func Stream[Req any](fn StreamFunc[Req]) func(w http.ResponseWriter, r *http.Req return f } +// HTTPRequest return svr's request +func HTTPRequest(ctx context.Context) *http.Request { + val, ok := ctx.Value(reqCtxKey).(*http.Request) + if !ok { + panic("missing request in context") + } + return val +} + // EmptyReq 空的请求 type EmptyReq struct{}