package rest import ( "context" "encoding/json" "fmt" "io" "log/slog" "net/http" "time" "github.com/ggicci/httpin" httpin_integration "github.com/ggicci/httpin/integration" "github.com/go-chi/chi/v5" "github.com/go-chi/render" "git.ifooth.com/common/pkg/i18n" ) // UnaryFunc Unary or ClientStreaming handle function type UnaryFunc[In, Out any] func(context.Context, *In) (*Out, error) // StreamingServer server or bidi streaming server type StreamingServer interface { http.ResponseWriter Context() context.Context } // StreamFunc ServerStreaming or BidiStreaming handle function type StreamFunc[In any] func(*In, StreamingServer) error // Handle Composable HTTP Handlers using generics func Handle[In any, Out any](fn UnaryFunc[In, Out]) func(w http.ResponseWriter, r *http.Request) { handleName := getHandleName(fn) f := func(w http.ResponseWriter, r *http.Request) { r.Header.Set("Content-Type", "application/json") st := time.Now() var err error defer func() { collectHandleMetrics(handleName, r.Method, st, err) }() in, err := decodeReq[In](r) if err != nil { slog.Error("handle decode request failed", "err", err) _ = render.Render(w, r, APIError(err)) return } // 设置语言 ctx := i18n.SetLang(r.Context(), r.Header.Get("Accept-Language")) ctx = context.WithValue(ctx, reqCtxKey, r) out, err := fn(ctx, in) if err != nil { _ = render.Render(w, r, APIError(err)) return } _ = render.Render(w, r, APIOK(out)) } return f } type streamingServer struct { http.ResponseWriter *http.ResponseController ctx context.Context } // Context return svr's context func (s *streamingServer) Context() context.Context { return s.ctx } // Stream Composable HTTP Handlers using generics func Stream[In any](fn StreamFunc[In]) func(w http.ResponseWriter, r *http.Request) { handleName := getHandleName(fn) f := func(w http.ResponseWriter, r *http.Request) { st := time.Now() var err error defer func() { collectHandleMetrics(handleName, r.Method, st, err) }() in, err := decodeReq[In](r) if err != nil { slog.Error("handle decode request failed ", "err", err) _ = render.Render(w, r, APIError(err)) return } // 设置语言 ctx := i18n.SetLang(r.Context(), r.Header.Get("Accept-Language")) ctx = context.WithValue(ctx, reqCtxKey, r) svr := &streamingServer{ ResponseWriter: w, ResponseController: http.NewResponseController(w), ctx: ctx, } err = fn(in, svr) if err != nil { _ = render.Render(w, r, APIError(err)) } } return f } // decodeReq ... func decodeReq[T any](r *http.Request) (*T, error) { in := new(T) var err error // http.Request 直接返回 if _, ok := any(in).(*http.Request); ok { return any(r).(*T), nil } // 空值不需要反序列化 if _, ok := any(in).(*EmptyReq); ok { return in, nil } in, err = httpin.Decode[T](r) if err != nil { return nil, err } // Get/Delete 请求, 请求参数从url中获取 if r.Method == http.MethodGet || r.Method == http.MethodDelete { return in, nil } // Post 请求等, 从body中获取 body, err := io.ReadAll(r.Body) if err != nil { return nil, err } if err = json.Unmarshal(body, in); err != nil { return nil, fmt.Errorf("unmarshal json body: %s", err) } return in, nil } // EmptyReq 空的请求 type EmptyReq struct{} // EmptyResp 空的返回 type EmptyResp struct{} // PaginationReq 分页接口通用请求 type PaginationReq struct { Offset int `json:"offset" in:"query=offset" validate:"gte=0"` Limit int `json:"limit" in:"query=limit" validate:"gte=0"` } // PaginationResp 分页接口通用返回 type PaginationResp[T any] struct { Count int64 `json:"count"` Items []*T `json:"items"` } func init() { httpin_integration.UseGochiURLParam("path", chi.URLParam) }