pkg/http/rest/generic.go

175 lines
3.8 KiB
Go

package rest
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"time"
"github.com/ggicci/httpin"
httpin_integration "github.com/ggicci/httpin/integration"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"git.ifooth.com/common/pkg/i18n"
)
// UnaryFunc Unary or ClientStreaming handle function
type UnaryFunc[In, Out any] func(context.Context, *In) (*Out, error)
// StreamingServer server or bidi streaming server
type StreamingServer interface {
http.ResponseWriter
Context() context.Context
}
// StreamFunc ServerStreaming or BidiStreaming handle function
type StreamFunc[In any] func(*In, StreamingServer) error
// Handle Composable HTTP Handlers using generics
func Handle[In any, Out any](fn UnaryFunc[In, Out]) func(w http.ResponseWriter, r *http.Request) {
handleName := getHandleName(fn)
f := func(w http.ResponseWriter, r *http.Request) {
r.Header.Set("Content-Type", "application/json")
st := time.Now()
var err error
defer func() {
collectHandleMetrics(handleName, r.Method, st, err)
}()
in, err := decodeReq[In](r)
if err != nil {
slog.Error("handle decode request failed", "err", err)
_ = render.Render(w, r, APIError(err))
return
}
// 设置语言
ctx := i18n.SetLang(r.Context(), r.Header.Get("Accept-Language"))
ctx = context.WithValue(ctx, reqCtxKey, r)
out, err := fn(ctx, in)
if err != nil {
_ = render.Render(w, r, APIError(err))
return
}
_ = render.Render(w, r, APIOK(out))
}
f = tracingHandler(handleName, f)
return f
}
type streamingServer struct {
http.ResponseWriter
*http.ResponseController
ctx context.Context
}
// Context return svr's context
func (s *streamingServer) Context() context.Context {
return s.ctx
}
// Stream Composable HTTP Handlers using generics
func Stream[In any](fn StreamFunc[In]) func(w http.ResponseWriter, r *http.Request) {
handleName := getHandleName(fn)
f := func(w http.ResponseWriter, r *http.Request) {
st := time.Now()
var err error
defer func() {
collectHandleMetrics(handleName, r.Method, st, err)
}()
in, err := decodeReq[In](r)
if err != nil {
slog.Error("handle decode request failed ", "err", err)
_ = render.Render(w, r, APIError(err))
return
}
// 设置语言
ctx := i18n.SetLang(r.Context(), r.Header.Get("Accept-Language"))
ctx = context.WithValue(ctx, reqCtxKey, r)
svr := &streamingServer{
ResponseWriter: w,
ResponseController: http.NewResponseController(w),
ctx: ctx,
}
err = fn(in, svr)
if err != nil {
_ = render.Render(w, r, APIError(err))
}
}
f = tracingHandler(handleName, f)
return f
}
// decodeReq ...
func decodeReq[T any](r *http.Request) (*T, error) {
in := new(T)
var err error
// http.Request 直接返回
if _, ok := any(in).(*http.Request); ok {
return any(r).(*T), nil
}
// 空值不需要反序列化
if _, ok := any(in).(*EmptyReq); ok {
return in, nil
}
in, err = httpin.Decode[T](r)
if err != nil {
return nil, err
}
// Get/Delete 请求, 请求参数从url中获取
if r.Method == http.MethodGet || r.Method == http.MethodDelete {
return in, nil
}
// Post 请求等, 从body中获取
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
if err = json.Unmarshal(body, in); err != nil {
return nil, fmt.Errorf("unmarshal json body: %s", err)
}
return in, nil
}
// EmptyReq 空的请求
type EmptyReq struct{}
// EmptyResp 空的返回
type EmptyResp struct{}
// PaginationReq 分页接口通用请求
type PaginationReq struct {
Offset int `json:"offset" in:"query=offset" validate:"gte=0"`
Limit int `json:"limit" in:"query=limit" validate:"gte=0"`
}
// PaginationResp 分页接口通用返回
type PaginationResp[T any] struct {
Count int64 `json:"count"`
Items []*T `json:"items"`
}
func init() {
httpin_integration.UseGochiURLParam("path", chi.URLParam)
}