pkg/rest/generic.go

143 lines
3.7 KiB
Go

/*
* Tencent is pleased to support the open source community by making
* 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
// Package rest framework
package rest
import (
"context"
"log/slog"
"net/http"
"time"
)
// UnaryFunc Unary or ClientStreaming handle function
type UnaryFunc[Req, Resp any] func(context.Context, *Req) (*Resp, error)
// StreamingServer server or bidi streaming server
type StreamingServer interface {
http.ResponseWriter
Context() context.Context
}
// StreamFunc ServerStreaming or BidiStreaming handle function
type StreamFunc[Req any] func(*Req, StreamingServer) error
// Handle Composable HTTP Handlers using generics
func Handle[Req, Resp any](fn UnaryFunc[Req, Resp]) func(w http.ResponseWriter, r *http.Request) {
handleName := getHandleName(fn)
f := func(w http.ResponseWriter, r *http.Request) {
st := time.Now()
var err error
defer func() {
collectHandleMetrics(handleName, r.Method, st, err)
}()
// 反序列化
in, err := decodeReq[Req](r)
if err != nil {
slog.Error("handle decode request failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
// 参数校验
if err = validateReq(r.Context(), in); err != nil {
slog.Error("validate req failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
out, err := fn(r.Context(), in)
if err != nil {
_ = APIError(err).Render(w, r)
return
}
_ = APIOK(out).Render(w, r)
}
return f
}
type streamingServer struct {
http.ResponseWriter
*http.ResponseController
ctx context.Context
}
// Context return svr's context
func (s *streamingServer) Context() context.Context {
return s.ctx
}
// Stream Composable HTTP Handlers using generics
func Stream[Req any](fn StreamFunc[Req]) func(w http.ResponseWriter, r *http.Request) {
handleName := getHandleName(fn)
f := func(w http.ResponseWriter, r *http.Request) {
st := time.Now()
var err error
defer func() {
collectHandleMetrics(handleName, r.Method, st, err)
}()
// 反序列化
in, err := decodeReq[Req](r)
if err != nil {
slog.Error("handle decode stream request failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
// 参数校验
if err = validateReq(r.Context(), in); err != nil {
slog.Error("validate stream req failed", "err", err)
_ = APIError(err).Render(w, r)
return
}
svr := &streamingServer{
ResponseWriter: w,
ResponseController: http.NewResponseController(w),
ctx: r.Context(),
}
err = fn(in, svr)
if err != nil {
_ = APIError(err).Render(w, r)
}
}
return f
}
// EmptyReq 空的请求
type EmptyReq struct{}
// EmptyResp 空的返回
type EmptyResp struct{}
// PaginationReq 分页接口通用请求
type PaginationReq struct {
Offset int `json:"offset" in:"query=offset" validate:"gte=0"`
Limit int `json:"limit" in:"query=limit" validate:"gte=0"`
}
// PaginationResp 分页接口通用返回
type PaginationResp[T any] struct {
Count int64 `json:"count"`
Items []T `json:"items"`
}