add client

main
git 2025-12-09 21:04:52 +08:00
parent c0a27b8d1f
commit f2322befd3
Signed by: git
GPG Key ID: 3F65EFFA44207ADD
7 changed files with 1065 additions and 7 deletions

View File

@ -1,7 +0,0 @@
package rest
import "git.ifooth.com/common/pkg/logger"
func init() {
logger.Init()
}

View File

@ -0,0 +1,18 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
// Package middleware provides common HTTP server middleware components.
package middleware

View File

@ -0,0 +1,48 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package middleware
import (
"net/http"
"net/http/pprof"
"git.ifooth.com/common/pkg/rest"
)
// Profiler is a convenient subrouter used for mounting net/http/pprof
func Profiler() http.Handler {
r := rest.NewRouter()
r.Get("/{$}", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, r.RequestURI+"/pprof/", http.StatusMovedPermanently)
})
r.HandleFunc("/pprof/{_...}", pprof.Index)
r.HandleFunc("/pprof/cmdline", pprof.Cmdline)
r.HandleFunc("/pprof/profile", pprof.Profile)
r.HandleFunc("/pprof/symbol", pprof.Symbol)
r.HandleFunc("/pprof/trace", pprof.Trace)
r.Handle("/pprof/goroutine", pprof.Handler("goroutine"))
r.Handle("/pprof/threadcreate", pprof.Handler("threadcreate"))
r.Handle("/pprof/mutex", pprof.Handler("mutex"))
r.Handle("/pprof/heap", pprof.Handler("heap"))
r.Handle("/pprof/block", pprof.Handler("block"))
r.Handle("/pprof/allocs", pprof.Handler("allocs"))
return r
}

View File

@ -0,0 +1,40 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package middleware
import (
"fmt"
"log/slog"
"net/http"
"runtime/debug"
)
// Recoverer is a middleware that recovers from panics, logs the panic and stack trace
func Recoverer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if fatalErr := recover(); fatalErr != nil {
w.WriteHeader(http.StatusInternalServerError)
msg := fmt.Sprintf("panic err: %v", fatalErr)
slog.ErrorContext(r.Context(), msg, "stack_trace", debug.Stack())
}
}()
next.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,88 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package middleware
import (
"bufio"
"net"
"net/http"
)
// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook
// into various parts of the response process.
type WrapResponseWriter interface {
http.ResponseWriter
Flush() error
Hijack() (net.Conn, *bufio.ReadWriter, error)
Status() int
BytesWritten() int
}
// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to
// hook into various parts of the response process.
func NewWrapResponseWriter(w http.ResponseWriter) WrapResponseWriter {
bw := basicWriter{
ResponseWriter: w,
ResponseController: http.NewResponseController(w),
}
return &bw
}
// basicWriter wraps a http.ResponseWriter that implements the minimal
// http.ResponseWriter interface.
type basicWriter struct {
http.ResponseWriter
*http.ResponseController
code int
bytes int
}
// Header implement the http.ResponseWriter interface Header method
func (bw *basicWriter) Header() http.Header {
return bw.ResponseWriter.Header()
}
// WriteHeader implement the http.ResponseWriter interface WriteHeader method
func (bw *basicWriter) WriteHeader(code int) {
if bw.code != 0 {
return
}
bw.code = code
bw.ResponseWriter.WriteHeader(code)
}
// Write implement the http.ResponseWriter interface Write method
func (bw *basicWriter) Write(buf []byte) (int, error) {
if bw.code == 0 {
bw.WriteHeader(http.StatusOK)
}
n, err := bw.ResponseWriter.Write(buf)
bw.bytes += n
return n, err
}
// Status return the status code of the response
func (bw *basicWriter) Status() int {
return bw.code
}
// BytesWritten return the bytes written
func (bw *basicWriter) BytesWritten() int {
return bw.bytes
}

362
rest/mux.go Normal file
View File

@ -0,0 +1,362 @@
/*
* TencentBlueKing is pleased to support the open source community by making
* - (BlueKing - CMDB) available.
* Copyright (C) 2025 Tencent. All rights reserved.
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://opensource.org/licenses/MIT
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
* We undertake not to change the open source license (MIT license) applicable
* to the current version of the project delivered to anyone in the future.
*/
package rest
import (
"bytes"
"context"
"fmt"
"net/http"
"path"
"strings"
)
type ctxKey string
var (
patternCtxKey = ctxKey("mux.pattern")
)
// Router consisting of the core routing methods using only the standard net/http handler
type Router interface {
http.Handler
// Use appends one or more middlewares onto the Router stack.
// Note: This should be used only before any routes are added to the mux.
Use(middlewares ...func(http.Handler) http.Handler)
// With adds inline middlewares for an endpoint handler.
With(middlewares ...func(http.Handler) http.Handler) Router
// Group adds a new inline-Router along the current routing
// path, with a fresh middleware stack for the inline-Router.
Group(fn func(r Router)) Router
// Route mounts a sub-Router along a `pattern` string.
Route(pattern string, fn func(r Router)) Router
// Mount attaches another http.Handler along `pattern/` string
Mount(pattern string, h http.Handler)
// Handle and HandleFunc adds routes for `pattern{$}` that matches all HTTP methods.
Handle(pattern string, h http.Handler)
HandleFunc(pattern string, h http.HandlerFunc)
// HTTP-method routing along `pattern{$}`
Get(pattern string, h http.HandlerFunc)
Post(pattern string, h http.HandlerFunc)
Put(pattern string, h http.HandlerFunc)
Patch(pattern string, h http.HandlerFunc)
Delete(pattern string, h http.HandlerFunc)
// NotFound defines a handler to respond whenever a route could not be found.
NotFound(h http.HandlerFunc)
// MethodNotAllowed defines a handler to respond whenever a method is not allowed.
MethodNotAllowed(h http.HandlerFunc)
}
// router is a HTTP route multiplexer that parses a request path, base the standard net/http mux
type router struct {
// The underlying mux to register the routes to
mux *http.ServeMux
// The Subrouter base path
basePath string
// The middleware stack
middlewares []func(http.Handler) http.Handler
// Custom route not found handler
notFoundHandler http.HandlerFunc
// Custom method not allowed handler
methodNotAllowedHandler http.HandlerFunc
}
// NewRouter returns a new mux object that implements the Router interface.
func NewRouter() Router {
r := &router{
mux: http.NewServeMux(),
middlewares: make([]func(http.Handler) http.Handler, 0),
}
return r
}
// Use appends a middleware handler to the mux middleware stack.
func (r *router) Use(middlewares ...func(http.Handler) http.Handler) {
r.middlewares = append(r.middlewares, middlewares...)
}
// With adds inline middlewares for an endpoint handler.
func (r *router) With(middlewares ...func(http.Handler) http.Handler) Router {
newRouter := r.clone()
newRouter.middlewares = append(newRouter.middlewares, middlewares...)
return newRouter
}
// Group creates a new inline-mux with a copy of middleware stack. It's useful
// for a group of handlers along the same routing path that use an additional
// set of middlewares.
func (r *router) Group(fn func(r Router)) Router {
newRouter := r.clone()
if fn != nil {
fn(newRouter)
}
return newRouter
}
// Route creates a new mux and mounts it along the `pattern` as a subrouter.
func (r *router) Route(pattern string, fn func(r Router)) Router {
newRouter := r.clone()
newRouter.basePath = path.Join(r.basePath, pattern)
if fn != nil {
fn(newRouter)
}
return newRouter
}
// Mount attaches another http.Handler or mux Router as a subrouter along a routing
// path. It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount
func (r *router) Mount(pattern string, handler http.Handler) {
if pattern == "" || pattern[0] != '/' {
panic(fmt.Errorf("pattern must begin with /"))
}
basePattern := strings.TrimRight(pattern, "/")
r.mux.Handle(basePattern+"/", r.stripMountPrefix(basePattern, r.chain(handler)))
}
// stripMountPrefix trims the mount prefix or {var} from the url if present.
func (r *router) stripMountPrefix(pattern string, handler http.Handler) http.Handler {
if pattern == "" {
return handler
}
// add a next / prefix
count := strings.Count(pattern, "/") + 1
f := func(w http.ResponseWriter, req *http.Request) {
// 忽略路径中的前缀
skipCount := count
offset := strings.IndexFunc(req.URL.Path, func(r rune) bool {
if r == '/' {
skipCount--
}
if skipCount == 0 {
return true
}
return false
})
req.URL.Path = req.URL.Path[offset:]
// the mount route pattern
rp := strings.TrimSuffix(RoutePattern(req), "/")
ctx := context.WithValue(req.Context(), patternCtxKey, rp)
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
}
return http.HandlerFunc(f)
}
// Handle adds the route `pattern` that matches any http method to
// execute the `handler` http.Handler.
func (r *router) Handle(pattern string, handler http.Handler) {
r.register("", pattern, handler)
}
// HandleFunc adds the route `pattern` that matches any http method to
// execute the `handlerFn` http.HandlerFunc.
func (r *router) HandleFunc(pattern string, handler http.HandlerFunc) {
r.register("", pattern, handler)
}
// Get adds the route `pattern` that matches a GET http method to
// execute the `handlerFn` http.HandlerFunc.
func (r *router) Get(pattern string, handler http.HandlerFunc) {
r.register(http.MethodGet, pattern, handler)
}
// Post adds the route `pattern` that matches a POST http method to
// execute the `handlerFn` http.HandlerFunc.
func (r *router) Post(pattern string, handler http.HandlerFunc) {
r.register(http.MethodPost, pattern, handler)
}
// Put adds the route `pattern` that matches a Put http method to
// execute the `handlerFn` http.HandlerFunc.
func (r *router) Put(pattern string, handler http.HandlerFunc) {
r.register(http.MethodPut, pattern, handler)
}
// Patch adds the route `pattern` that matches a Patch http method to
// execute the `handlerFn` http.HandlerFunc.
func (r *router) Patch(pattern string, handler http.HandlerFunc) {
r.register(http.MethodPatch, pattern, handler)
}
// Delete adds the route `pattern` that matches a Delete http method to
// execute the `handlerFn` http.HandlerFunc.
func (r *router) Delete(pattern string, handler http.HandlerFunc) {
r.register(http.MethodDelete, pattern, handler)
}
// NotFound sets a custom http.HandlerFunc for routing paths that could
// not be found. The default 404 handler is `http.NotFound`.
func (r *router) NotFound(h http.HandlerFunc) {
r.notFoundHandler = h
}
// MethodNotAllowed sets a custom http.HandlerFunc for routing paths where the
// method is unresolved. The default handler returns a 405 with an empty body.
func (r *router) MethodNotAllowed(h http.HandlerFunc) {
r.methodNotAllowedHandler = h
}
// ServeHTTP is the single method of the http.Handler interface that makes
// Mux interoperable with the standard library
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 没有自定义404和405处理函数直接调用mux的ServeHTTP, 减少一次路由匹配
if r.notFoundHandler == nil && r.methodNotAllowedHandler == nil {
r.mux.ServeHTTP(w, req)
return
}
h, pattern := r.mux.Handler(req)
// handle custom 404 and 405 route
if pattern == "" {
h = r.serveNotFound(h)
} else {
h = r.mux
}
h.ServeHTTP(w, req)
}
func (r *router) serveNotFound(handler http.Handler) http.Handler {
f := func(w http.ResponseWriter, req *http.Request) {
wrapper := &notFoundWrapper{
header: http.Header{},
buf: bytes.NewBuffer(nil),
}
handler.ServeHTTP(wrapper, req)
// 404
if r.notFoundHandler != nil && wrapper.code == http.StatusNotFound {
r.notFoundHandler(w, req)
return
}
// 405
if r.methodNotAllowedHandler != nil && wrapper.code == http.StatusMethodNotAllowed {
w.Header().Set("Allow", wrapper.header.Get("Allow"))
r.methodNotAllowedHandler(w, req)
return
}
wrapper.ServeHTTP(w, req)
}
return http.HandlerFunc(f)
}
func (r *router) register(method string, pattern string, handler http.Handler) {
if r.basePath != "" {
pattern = path.Join(r.basePath, pattern)
}
if strings.HasSuffix(pattern, "/") {
pattern = pattern + "{$}" // 除开mount, 其他路径都精准匹配
}
if method != "" {
r.mux.Handle(method+" "+pattern, r.chain(handler))
} else {
r.mux.Handle(pattern, r.chain(handler))
}
}
// chain builds a http.Handler composed of an inline middleware stack and endpoint
// handler in the order they are passed.
func (r *router) chain(endpoint http.Handler) http.Handler {
for i := len(r.middlewares) - 1; i >= 0; i-- {
endpoint = r.middlewares[i](endpoint)
}
return endpoint
}
func (r *router) clone() *router {
newMiddlewares := make([]func(http.Handler) http.Handler, len(r.middlewares))
copy(newMiddlewares, r.middlewares)
newRouter := &router{
mux: r.mux,
basePath: r.basePath,
middlewares: newMiddlewares,
}
return newRouter
}
// notFoundWrapper that implements the minimal http.ResponseWriter interface.
type notFoundWrapper struct {
code int
header http.Header
buf *bytes.Buffer
}
// Header implement the http.ResponseWriter interface Header method
func (nf *notFoundWrapper) Header() http.Header {
return nf.header
}
// WriteHeader implement the http.ResponseWriter interface WriteHeader method
func (nf *notFoundWrapper) WriteHeader(code int) {
nf.code = code
}
// Write implement the http.ResponseWriter interface Write method
func (nf *notFoundWrapper) Write(buf []byte) (int, error) {
return nf.buf.Write(buf)
}
// ServeHTTP implement the http.Handler interface ServeHTTP method
func (nf *notFoundWrapper) ServeHTTP(w http.ResponseWriter, req *http.Request) {
for k, v := range nf.header {
for i := range v {
w.Header().Add(k, v[i])
}
}
w.WriteHeader(nf.code)
w.Write(nf.buf.Bytes())
}
// RoutePattern returns the matched route pattern include mount prefix, but ignore method and host
func RoutePattern(req *http.Request) string {
pattern, _ := req.Context().Value(patternCtxKey).(string)
// ignore method and host
i := strings.IndexByte(req.Pattern, '/')
if i < 0 {
return pattern
}
return pattern + req.Pattern[i:]
}

509
rest/mux_test.go Normal file
View File

@ -0,0 +1,509 @@
package rest
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testRequest(t *testing.T, ts *httptest.Server, method, urlPath string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+urlPath, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
return resp, string(respBody)
}
func TestDefaultRouter(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, body, "index")
resp, body = testRequest(t, ts, "GET", "/xxx", nil)
assert.Equal(t, resp.StatusCode, http.StatusNotFound)
assert.Equal(t, "404 page not found\n", body)
resp, body = testRequest(t, ts, "POST", "/", nil)
assert.Equal(t, resp.StatusCode, http.StatusMethodNotAllowed)
assert.Equal(t, "GET, HEAD", resp.Header.Get("Allow"))
assert.Equal(t, "Method Not Allowed\n", body)
}
func TestConflictsRouter(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
err, ok := r.(error)
require.True(t, ok)
assert.Error(t, err)
assert.Contains(t, err.Error(), "conflicts with pattern")
}()
r := NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
r.Mount("/api", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
}))
r.Get("/{bizID}/user", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
ts := httptest.NewServer(r)
defer ts.Close()
}
func TestInvalidURLRouter(t *testing.T) {
r := NewRouter()
checkPanic := func(errMsg string) {
r := recover()
require.NotNil(t, r)
err, ok := r.(error)
require.True(t, ok)
assert.Error(t, err)
assert.Contains(t, err.Error(), errMsg)
}
t.Run("empty url", func(t *testing.T) {
defer checkPanic("host/path missing /")
r.Get("", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
})
t.Run("not start with /", func(t *testing.T) {
defer checkPanic("host/path missing /")
r.Get("abc", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
})
t.Run("mount not start with /", func(t *testing.T) {
defer checkPanic("pattern must begin with /")
r.Mount("api", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
}))
})
ts := httptest.NewServer(r)
defer ts.Close()
}
func TestParamsRouter(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
r.Get("/{bizID}", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.PathValue("bizID")))
})
r.Get("/{bizID}/module/{moduleID}", func(w http.ResponseWriter, r *http.Request) {
body := r.PathValue("bizID") + ":" + r.PathValue("moduleID")
w.Write([]byte(body))
})
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, body, "index")
resp, body = testRequest(t, ts, "GET", "/xxx", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "xxx", body)
resp, body = testRequest(t, ts, "GET", "/xxx/module/abc", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "xxx:abc", body)
resp, body = testRequest(t, ts, "GET", "/xxx/module/abc/xxx", nil)
assert.Equal(t, resp.StatusCode, http.StatusNotFound)
assert.Equal(t, "404 page not found\n", body)
}
func TestMountRouter(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
r.Mount("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := "mount index:" + r.URL.Path
w.Write([]byte(body))
}))
subR := NewRouter()
subR.Get("/user", func(w http.ResponseWriter, r *http.Request) {
body := "api index:" + r.URL.Path
w.Write([]byte(body))
})
r.Mount("/api", subR)
r.Get("/biz/{bizID}", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.PathValue("bizID")))
})
r.Get("/biz/{bizID}/module/{moduleID}", func(w http.ResponseWriter, r *http.Request) {
body := r.PathValue("bizID") + ":" + r.PathValue("moduleID")
w.Write([]byte(body))
})
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, body, "index")
resp, body = testRequest(t, ts, "GET", "/xxx", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "mount index:/xxx", body)
resp, body = testRequest(t, ts, "GET", "/biz/xxx/module/abc", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "xxx:abc", body)
resp, body = testRequest(t, ts, "GET", "/biz/xxx/module/abc/xxx", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "mount index:/biz/xxx/module/abc/xxx", body)
resp, body = testRequest(t, ts, "GET", "/api/user", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "api index:/user", body)
}
func TestMuxPattern(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Route("/org/{org}", func(r Router) {
r.Get("/user/pattern", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.Pattern))
})
})
subR := NewRouter()
subR.Get("/user/{user}/pattern", func(w http.ResponseWriter, r *http.Request) {
// pattern := chi.RouteContext(r.Context()).RoutePattern()
w.Write([]byte(r.Pattern))
})
subR.Get("/user/{user}/pattern/ctx", func(w http.ResponseWriter, r *http.Request) {
// pattern := chi.RouteContext(r.Context()).RoutePattern()
w.Write([]byte(RoutePattern(r)))
})
subR2 := NewRouter()
subR2.Mount("/pattern/{pattern}", subR)
r.Mount("/", subR)
r.Mount("/org", subR)
r.Mount("/mount/{org}", subR)
r.Mount("/mount1/{org}", subR2)
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/org/1/user/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /org/{org}/user/pattern", body)
resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /user/{user}/pattern", body)
resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern/ctx", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "/mount/{org}/user/{user}/pattern/ctx", body)
resp, body = testRequest(t, ts, "GET", "/mount1/1/pattern/11/user/2/pattern/ctx", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "/mount1/{org}/pattern/{pattern}/user/{user}/pattern/ctx", body)
resp, body = testRequest(t, ts, "GET", "/user/2/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /user/{user}/pattern", body)
resp, body = testRequest(t, ts, "GET", "/org/user/2/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /user/{user}/pattern", body)
}
func TestRoutePattern(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Route("/org/{org}", func(r Router) {
r.Get("/user/pattern", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.Pattern))
})
})
subR := NewRouter()
subR.Get("/user/{user}/pattern", func(w http.ResponseWriter, r *http.Request) {
// pattern := chi.RouteContext(r.Context()).RoutePattern()
w.Write([]byte(r.Pattern))
})
subR.Get("/user/{user}/pattern/ctx", func(w http.ResponseWriter, r *http.Request) {
// pattern := chi.RouteContext(r.Context()).RoutePattern()
w.Write([]byte(RoutePattern(r)))
})
r.Mount("/mount/{org}", subR)
r.Mount("/", subR)
r.Mount("/org", subR)
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/org/1/user/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /org/{org}/user/pattern", body)
resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /user/{user}/pattern", body)
resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern/ctx", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "/mount/{org}/user/{user}/pattern/ctx", body)
resp, body = testRequest(t, ts, "GET", "/user/2/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /user/{user}/pattern", body)
resp, body = testRequest(t, ts, "GET", "/org/user/2/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /user/{user}/pattern", body)
resp, body = testRequest(t, ts, "GET", "/org/user/2/pattern", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET /user/{user}/pattern", body)
}
func TestNotFoundRouter(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("404"))
}))
r.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("not allowed"))
}))
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "index", body)
resp, body = testRequest(t, ts, "GET", "/xxx", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "404", body)
resp, body = testRequest(t, ts, "POST", "/", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "GET, HEAD", resp.Header.Get("Allow"))
assert.Equal(t, "not allowed", body)
}
func TestSubRouter(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
r.Route("/api", func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("api index"))
})
r.Get("/user", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("api user"))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "index", body)
resp, body = testRequest(t, ts, "GET", "/xxx", nil)
assert.Equal(t, resp.StatusCode, http.StatusNotFound)
assert.Equal(t, "404 page not found\n", body)
resp, body = testRequest(t, ts, "GET", "/api", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "api index", body)
resp, body = testRequest(t, ts, "GET", "/api/user", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "api user", body)
}
func TestGroupRouter(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
r.Group(func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("group index"))
})
r.Get("/user", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("group user"))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "group index", body)
resp, body = testRequest(t, ts, "GET", "/xxx", nil)
assert.Equal(t, resp.StatusCode, http.StatusNotFound)
assert.Equal(t, "404 page not found\n", body)
resp, body = testRequest(t, ts, "GET", "/user", nil)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "group user", body)
}
func TestGroupConflictsRouter(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
err, ok := r.(error)
require.True(t, ok)
assert.Error(t, err)
assert.Contains(t, err.Error(), "conflicts with pattern")
}()
r := NewRouter()
// r := chi.NewRouter()
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
})
r.Group(func(r Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("group index"))
})
r.Get("/user", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("group user"))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
}
func TestUseRouter(t *testing.T) {
r := NewRouter()
// r := chi.NewRouter()
h := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("index"))
}
m := func(name string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(name + ":"))
next.ServeHTTP(w, r)
})
}
}
r.Use(m("m1"))
r.Use(m("m2"))
r.With(m("m3")).Get("/user", h)
r.Get("/", h)
r.Group(func(r Router) {
r.Use(m("m4"))
r.Get("/user2", h)
})
r.Route("/api", func(r Router) {
r.Use(m("m5"))
r.Get("/user3", h)
})
ts := httptest.NewServer(r)
defer ts.Close()
resp, body := testRequest(t, ts, "GET", "/", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "m1:m2:index", body)
resp, body = testRequest(t, ts, "GET", "/user", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "m1:m2:m3:index", body)
resp, body = testRequest(t, ts, "GET", "/api/user3", nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "m1:m2:m5:index", body)
}