diff --git a/rest/init.go b/rest/init.go deleted file mode 100644 index 2296cf9..0000000 --- a/rest/init.go +++ /dev/null @@ -1,7 +0,0 @@ -package rest - -import "git.ifooth.com/common/pkg/logger" - -func init() { - logger.Init() -} diff --git a/rest/middleware/middleware.go b/rest/middleware/middleware.go new file mode 100644 index 0000000..5898b68 --- /dev/null +++ b/rest/middleware/middleware.go @@ -0,0 +1,18 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +// Package middleware provides common HTTP server middleware components. +package middleware diff --git a/rest/middleware/profile.go b/rest/middleware/profile.go new file mode 100644 index 0000000..956b275 --- /dev/null +++ b/rest/middleware/profile.go @@ -0,0 +1,48 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package middleware + +import ( + "net/http" + "net/http/pprof" + + "git.ifooth.com/common/pkg/rest" +) + +// Profiler is a convenient subrouter used for mounting net/http/pprof +func Profiler() http.Handler { + r := rest.NewRouter() + + r.Get("/{$}", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, r.RequestURI+"/pprof/", http.StatusMovedPermanently) + }) + + r.HandleFunc("/pprof/{_...}", pprof.Index) + r.HandleFunc("/pprof/cmdline", pprof.Cmdline) + r.HandleFunc("/pprof/profile", pprof.Profile) + r.HandleFunc("/pprof/symbol", pprof.Symbol) + r.HandleFunc("/pprof/trace", pprof.Trace) + + r.Handle("/pprof/goroutine", pprof.Handler("goroutine")) + r.Handle("/pprof/threadcreate", pprof.Handler("threadcreate")) + r.Handle("/pprof/mutex", pprof.Handler("mutex")) + r.Handle("/pprof/heap", pprof.Handler("heap")) + r.Handle("/pprof/block", pprof.Handler("block")) + r.Handle("/pprof/allocs", pprof.Handler("allocs")) + + return r +} diff --git a/rest/middleware/recover.go b/rest/middleware/recover.go new file mode 100644 index 0000000..b973119 --- /dev/null +++ b/rest/middleware/recover.go @@ -0,0 +1,40 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package middleware + +import ( + "fmt" + "log/slog" + "net/http" + "runtime/debug" +) + +// Recoverer is a middleware that recovers from panics, logs the panic and stack trace +func Recoverer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if fatalErr := recover(); fatalErr != nil { + w.WriteHeader(http.StatusInternalServerError) + + msg := fmt.Sprintf("panic err: %v", fatalErr) + slog.ErrorContext(r.Context(), msg, "stack_trace", debug.Stack()) + } + }() + + next.ServeHTTP(w, r) + }) +} diff --git a/rest/middleware/wrape_writer.go b/rest/middleware/wrape_writer.go new file mode 100644 index 0000000..f7baf0f --- /dev/null +++ b/rest/middleware/wrape_writer.go @@ -0,0 +1,88 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package middleware + +import ( + "bufio" + "net" + "net/http" +) + +// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook +// into various parts of the response process. +type WrapResponseWriter interface { + http.ResponseWriter + Flush() error + Hijack() (net.Conn, *bufio.ReadWriter, error) + Status() int + BytesWritten() int +} + +// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to +// hook into various parts of the response process. +func NewWrapResponseWriter(w http.ResponseWriter) WrapResponseWriter { + bw := basicWriter{ + ResponseWriter: w, + ResponseController: http.NewResponseController(w), + } + return &bw +} + +// basicWriter wraps a http.ResponseWriter that implements the minimal +// http.ResponseWriter interface. +type basicWriter struct { + http.ResponseWriter + *http.ResponseController + code int + bytes int +} + +// Header implement the http.ResponseWriter interface Header method +func (bw *basicWriter) Header() http.Header { + return bw.ResponseWriter.Header() +} + +// WriteHeader implement the http.ResponseWriter interface WriteHeader method +func (bw *basicWriter) WriteHeader(code int) { + if bw.code != 0 { + return + } + + bw.code = code + bw.ResponseWriter.WriteHeader(code) +} + +// Write implement the http.ResponseWriter interface Write method +func (bw *basicWriter) Write(buf []byte) (int, error) { + if bw.code == 0 { + bw.WriteHeader(http.StatusOK) + } + + n, err := bw.ResponseWriter.Write(buf) + bw.bytes += n + return n, err +} + +// Status return the status code of the response +func (bw *basicWriter) Status() int { + return bw.code +} + +// BytesWritten return the bytes written +func (bw *basicWriter) BytesWritten() int { + return bw.bytes +} diff --git a/rest/mux.go b/rest/mux.go new file mode 100644 index 0000000..3efd8e0 --- /dev/null +++ b/rest/mux.go @@ -0,0 +1,362 @@ +/* + * TencentBlueKing is pleased to support the open source community by making + * 蓝鲸智云 - 配置平台 (BlueKing - CMDB) available. + * Copyright (C) 2025 Tencent. All rights reserved. + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at http://opensource.org/licenses/MIT + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * We undertake not to change the open source license (MIT license) applicable + * to the current version of the project delivered to anyone in the future. + */ + +package rest + +import ( + "bytes" + "context" + "fmt" + "net/http" + "path" + "strings" +) + +type ctxKey string + +var ( + patternCtxKey = ctxKey("mux.pattern") +) + +// Router consisting of the core routing methods using only the standard net/http handler +type Router interface { + http.Handler + + // Use appends one or more middlewares onto the Router stack. + // Note: This should be used only before any routes are added to the mux. + Use(middlewares ...func(http.Handler) http.Handler) + + // With adds inline middlewares for an endpoint handler. + With(middlewares ...func(http.Handler) http.Handler) Router + + // Group adds a new inline-Router along the current routing + // path, with a fresh middleware stack for the inline-Router. + Group(fn func(r Router)) Router + + // Route mounts a sub-Router along a `pattern` string. + Route(pattern string, fn func(r Router)) Router + + // Mount attaches another http.Handler along `pattern/` string + Mount(pattern string, h http.Handler) + + // Handle and HandleFunc adds routes for `pattern{$}` that matches all HTTP methods. + Handle(pattern string, h http.Handler) + HandleFunc(pattern string, h http.HandlerFunc) + + // HTTP-method routing along `pattern{$}` + Get(pattern string, h http.HandlerFunc) + Post(pattern string, h http.HandlerFunc) + Put(pattern string, h http.HandlerFunc) + Patch(pattern string, h http.HandlerFunc) + Delete(pattern string, h http.HandlerFunc) + + // NotFound defines a handler to respond whenever a route could not be found. + NotFound(h http.HandlerFunc) + + // MethodNotAllowed defines a handler to respond whenever a method is not allowed. + MethodNotAllowed(h http.HandlerFunc) +} + +// router is a HTTP route multiplexer that parses a request path, base the standard net/http mux +type router struct { + // The underlying mux to register the routes to + mux *http.ServeMux + // The Subrouter base path + basePath string + // The middleware stack + middlewares []func(http.Handler) http.Handler + // Custom route not found handler + notFoundHandler http.HandlerFunc + // Custom method not allowed handler + methodNotAllowedHandler http.HandlerFunc +} + +// NewRouter returns a new mux object that implements the Router interface. +func NewRouter() Router { + r := &router{ + mux: http.NewServeMux(), + middlewares: make([]func(http.Handler) http.Handler, 0), + } + return r +} + +// Use appends a middleware handler to the mux middleware stack. +func (r *router) Use(middlewares ...func(http.Handler) http.Handler) { + r.middlewares = append(r.middlewares, middlewares...) +} + +// With adds inline middlewares for an endpoint handler. +func (r *router) With(middlewares ...func(http.Handler) http.Handler) Router { + newRouter := r.clone() + newRouter.middlewares = append(newRouter.middlewares, middlewares...) + return newRouter +} + +// Group creates a new inline-mux with a copy of middleware stack. It's useful +// for a group of handlers along the same routing path that use an additional +// set of middlewares. +func (r *router) Group(fn func(r Router)) Router { + newRouter := r.clone() + if fn != nil { + fn(newRouter) + } + + return newRouter +} + +// Route creates a new mux and mounts it along the `pattern` as a subrouter. +func (r *router) Route(pattern string, fn func(r Router)) Router { + newRouter := r.clone() + newRouter.basePath = path.Join(r.basePath, pattern) + if fn != nil { + fn(newRouter) + } + + return newRouter +} + +// Mount attaches another http.Handler or mux Router as a subrouter along a routing +// path. It's very useful to split up a large API as many independent routers and +// compose them as a single service using Mount +func (r *router) Mount(pattern string, handler http.Handler) { + if pattern == "" || pattern[0] != '/' { + panic(fmt.Errorf("pattern must begin with /")) + } + + basePattern := strings.TrimRight(pattern, "/") + r.mux.Handle(basePattern+"/", r.stripMountPrefix(basePattern, r.chain(handler))) +} + +// stripMountPrefix trims the mount prefix or {var} from the url if present. +func (r *router) stripMountPrefix(pattern string, handler http.Handler) http.Handler { + if pattern == "" { + return handler + } + + // add a next / prefix + count := strings.Count(pattern, "/") + 1 + + f := func(w http.ResponseWriter, req *http.Request) { + // 忽略路径中的前缀 + skipCount := count + offset := strings.IndexFunc(req.URL.Path, func(r rune) bool { + if r == '/' { + skipCount-- + } + if skipCount == 0 { + return true + } + return false + }) + req.URL.Path = req.URL.Path[offset:] + + // the mount route pattern + rp := strings.TrimSuffix(RoutePattern(req), "/") + ctx := context.WithValue(req.Context(), patternCtxKey, rp) + req = req.WithContext(ctx) + + handler.ServeHTTP(w, req) + } + + return http.HandlerFunc(f) +} + +// Handle adds the route `pattern` that matches any http method to +// execute the `handler` http.Handler. +func (r *router) Handle(pattern string, handler http.Handler) { + r.register("", pattern, handler) +} + +// HandleFunc adds the route `pattern` that matches any http method to +// execute the `handlerFn` http.HandlerFunc. +func (r *router) HandleFunc(pattern string, handler http.HandlerFunc) { + r.register("", pattern, handler) +} + +// Get adds the route `pattern` that matches a GET http method to +// execute the `handlerFn` http.HandlerFunc. +func (r *router) Get(pattern string, handler http.HandlerFunc) { + r.register(http.MethodGet, pattern, handler) +} + +// Post adds the route `pattern` that matches a POST http method to +// execute the `handlerFn` http.HandlerFunc. +func (r *router) Post(pattern string, handler http.HandlerFunc) { + r.register(http.MethodPost, pattern, handler) +} + +// Put adds the route `pattern` that matches a Put http method to +// execute the `handlerFn` http.HandlerFunc. +func (r *router) Put(pattern string, handler http.HandlerFunc) { + r.register(http.MethodPut, pattern, handler) +} + +// Patch adds the route `pattern` that matches a Patch http method to +// execute the `handlerFn` http.HandlerFunc. +func (r *router) Patch(pattern string, handler http.HandlerFunc) { + r.register(http.MethodPatch, pattern, handler) +} + +// Delete adds the route `pattern` that matches a Delete http method to +// execute the `handlerFn` http.HandlerFunc. +func (r *router) Delete(pattern string, handler http.HandlerFunc) { + r.register(http.MethodDelete, pattern, handler) +} + +// NotFound sets a custom http.HandlerFunc for routing paths that could +// not be found. The default 404 handler is `http.NotFound`. +func (r *router) NotFound(h http.HandlerFunc) { + r.notFoundHandler = h +} + +// MethodNotAllowed sets a custom http.HandlerFunc for routing paths where the +// method is unresolved. The default handler returns a 405 with an empty body. +func (r *router) MethodNotAllowed(h http.HandlerFunc) { + r.methodNotAllowedHandler = h +} + +// ServeHTTP is the single method of the http.Handler interface that makes +// Mux interoperable with the standard library +func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // 没有自定义404和405处理函数,直接调用mux的ServeHTTP, 减少一次路由匹配 + if r.notFoundHandler == nil && r.methodNotAllowedHandler == nil { + r.mux.ServeHTTP(w, req) + return + } + + h, pattern := r.mux.Handler(req) + + // handle custom 404 and 405 route + if pattern == "" { + h = r.serveNotFound(h) + } else { + h = r.mux + } + + h.ServeHTTP(w, req) +} + +func (r *router) serveNotFound(handler http.Handler) http.Handler { + f := func(w http.ResponseWriter, req *http.Request) { + wrapper := ¬FoundWrapper{ + header: http.Header{}, + buf: bytes.NewBuffer(nil), + } + handler.ServeHTTP(wrapper, req) + + // 404 + if r.notFoundHandler != nil && wrapper.code == http.StatusNotFound { + r.notFoundHandler(w, req) + return + } + + // 405 + if r.methodNotAllowedHandler != nil && wrapper.code == http.StatusMethodNotAllowed { + w.Header().Set("Allow", wrapper.header.Get("Allow")) + r.methodNotAllowedHandler(w, req) + return + } + + wrapper.ServeHTTP(w, req) + } + + return http.HandlerFunc(f) +} + +func (r *router) register(method string, pattern string, handler http.Handler) { + if r.basePath != "" { + pattern = path.Join(r.basePath, pattern) + } + + if strings.HasSuffix(pattern, "/") { + pattern = pattern + "{$}" // 除开mount, 其他路径都精准匹配 + } + + if method != "" { + r.mux.Handle(method+" "+pattern, r.chain(handler)) + } else { + r.mux.Handle(pattern, r.chain(handler)) + } +} + +// chain builds a http.Handler composed of an inline middleware stack and endpoint +// handler in the order they are passed. +func (r *router) chain(endpoint http.Handler) http.Handler { + for i := len(r.middlewares) - 1; i >= 0; i-- { + endpoint = r.middlewares[i](endpoint) + } + + return endpoint +} + +func (r *router) clone() *router { + newMiddlewares := make([]func(http.Handler) http.Handler, len(r.middlewares)) + copy(newMiddlewares, r.middlewares) + + newRouter := &router{ + mux: r.mux, + basePath: r.basePath, + middlewares: newMiddlewares, + } + return newRouter +} + +// notFoundWrapper that implements the minimal http.ResponseWriter interface. +type notFoundWrapper struct { + code int + header http.Header + buf *bytes.Buffer +} + +// Header implement the http.ResponseWriter interface Header method +func (nf *notFoundWrapper) Header() http.Header { + return nf.header +} + +// WriteHeader implement the http.ResponseWriter interface WriteHeader method +func (nf *notFoundWrapper) WriteHeader(code int) { + nf.code = code +} + +// Write implement the http.ResponseWriter interface Write method +func (nf *notFoundWrapper) Write(buf []byte) (int, error) { + return nf.buf.Write(buf) +} + +// ServeHTTP implement the http.Handler interface ServeHTTP method +func (nf *notFoundWrapper) ServeHTTP(w http.ResponseWriter, req *http.Request) { + for k, v := range nf.header { + for i := range v { + w.Header().Add(k, v[i]) + } + } + + w.WriteHeader(nf.code) + w.Write(nf.buf.Bytes()) +} + +// RoutePattern returns the matched route pattern include mount prefix, but ignore method and host +func RoutePattern(req *http.Request) string { + pattern, _ := req.Context().Value(patternCtxKey).(string) + + // ignore method and host + i := strings.IndexByte(req.Pattern, '/') + if i < 0 { + return pattern + } + + return pattern + req.Pattern[i:] +} diff --git a/rest/mux_test.go b/rest/mux_test.go new file mode 100644 index 0000000..ca279b2 --- /dev/null +++ b/rest/mux_test.go @@ -0,0 +1,509 @@ +package rest + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testRequest(t *testing.T, ts *httptest.Server, method, urlPath string, body io.Reader) (*http.Response, string) { + req, err := http.NewRequest(method, ts.URL+urlPath, body) + if err != nil { + t.Fatal(err) + return nil, "" + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + return nil, "" + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + return nil, "" + } + defer resp.Body.Close() + + return resp, string(respBody) +} + +func TestDefaultRouter(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, body, "index") + + resp, body = testRequest(t, ts, "GET", "/xxx", nil) + assert.Equal(t, resp.StatusCode, http.StatusNotFound) + assert.Equal(t, "404 page not found\n", body) + + resp, body = testRequest(t, ts, "POST", "/", nil) + assert.Equal(t, resp.StatusCode, http.StatusMethodNotAllowed) + assert.Equal(t, "GET, HEAD", resp.Header.Get("Allow")) + assert.Equal(t, "Method Not Allowed\n", body) +} + +func TestConflictsRouter(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + err, ok := r.(error) + require.True(t, ok) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with pattern") + }() + + r := NewRouter() + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + r.Mount("/api", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + })) + + r.Get("/{bizID}/user", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() +} + +func TestInvalidURLRouter(t *testing.T) { + r := NewRouter() + + checkPanic := func(errMsg string) { + r := recover() + require.NotNil(t, r) + err, ok := r.(error) + require.True(t, ok) + assert.Error(t, err) + assert.Contains(t, err.Error(), errMsg) + } + + t.Run("empty url", func(t *testing.T) { + defer checkPanic("host/path missing /") + + r.Get("", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + }) + + t.Run("not start with /", func(t *testing.T) { + defer checkPanic("host/path missing /") + + r.Get("abc", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + }) + + t.Run("mount not start with /", func(t *testing.T) { + defer checkPanic("pattern must begin with /") + + r.Mount("api", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + })) + }) + + ts := httptest.NewServer(r) + defer ts.Close() +} + +func TestParamsRouter(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + r.Get("/{bizID}", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.PathValue("bizID"))) + }) + + r.Get("/{bizID}/module/{moduleID}", func(w http.ResponseWriter, r *http.Request) { + body := r.PathValue("bizID") + ":" + r.PathValue("moduleID") + w.Write([]byte(body)) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, body, "index") + + resp, body = testRequest(t, ts, "GET", "/xxx", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "xxx", body) + + resp, body = testRequest(t, ts, "GET", "/xxx/module/abc", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "xxx:abc", body) + + resp, body = testRequest(t, ts, "GET", "/xxx/module/abc/xxx", nil) + assert.Equal(t, resp.StatusCode, http.StatusNotFound) + assert.Equal(t, "404 page not found\n", body) +} + +func TestMountRouter(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + r.Mount("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := "mount index:" + r.URL.Path + w.Write([]byte(body)) + })) + + subR := NewRouter() + subR.Get("/user", func(w http.ResponseWriter, r *http.Request) { + body := "api index:" + r.URL.Path + w.Write([]byte(body)) + }) + r.Mount("/api", subR) + + r.Get("/biz/{bizID}", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.PathValue("bizID"))) + }) + + r.Get("/biz/{bizID}/module/{moduleID}", func(w http.ResponseWriter, r *http.Request) { + body := r.PathValue("bizID") + ":" + r.PathValue("moduleID") + w.Write([]byte(body)) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, body, "index") + + resp, body = testRequest(t, ts, "GET", "/xxx", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "mount index:/xxx", body) + + resp, body = testRequest(t, ts, "GET", "/biz/xxx/module/abc", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "xxx:abc", body) + + resp, body = testRequest(t, ts, "GET", "/biz/xxx/module/abc/xxx", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "mount index:/biz/xxx/module/abc/xxx", body) + + resp, body = testRequest(t, ts, "GET", "/api/user", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "api index:/user", body) + +} + +func TestMuxPattern(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + r.Route("/org/{org}", func(r Router) { + r.Get("/user/pattern", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.Pattern)) + }) + }) + + subR := NewRouter() + subR.Get("/user/{user}/pattern", func(w http.ResponseWriter, r *http.Request) { + // pattern := chi.RouteContext(r.Context()).RoutePattern() + w.Write([]byte(r.Pattern)) + }) + + subR.Get("/user/{user}/pattern/ctx", func(w http.ResponseWriter, r *http.Request) { + // pattern := chi.RouteContext(r.Context()).RoutePattern() + w.Write([]byte(RoutePattern(r))) + }) + + subR2 := NewRouter() + subR2.Mount("/pattern/{pattern}", subR) + + r.Mount("/", subR) + r.Mount("/org", subR) + r.Mount("/mount/{org}", subR) + r.Mount("/mount1/{org}", subR2) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/org/1/user/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /org/{org}/user/pattern", body) + + resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /user/{user}/pattern", body) + + resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern/ctx", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/mount/{org}/user/{user}/pattern/ctx", body) + + resp, body = testRequest(t, ts, "GET", "/mount1/1/pattern/11/user/2/pattern/ctx", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/mount1/{org}/pattern/{pattern}/user/{user}/pattern/ctx", body) + + resp, body = testRequest(t, ts, "GET", "/user/2/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /user/{user}/pattern", body) + + resp, body = testRequest(t, ts, "GET", "/org/user/2/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /user/{user}/pattern", body) +} + +func TestRoutePattern(t *testing.T) { + + r := NewRouter() + // r := chi.NewRouter() + + r.Route("/org/{org}", func(r Router) { + r.Get("/user/pattern", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.Pattern)) + }) + }) + + subR := NewRouter() + subR.Get("/user/{user}/pattern", func(w http.ResponseWriter, r *http.Request) { + // pattern := chi.RouteContext(r.Context()).RoutePattern() + w.Write([]byte(r.Pattern)) + }) + + subR.Get("/user/{user}/pattern/ctx", func(w http.ResponseWriter, r *http.Request) { + // pattern := chi.RouteContext(r.Context()).RoutePattern() + w.Write([]byte(RoutePattern(r))) + }) + + r.Mount("/mount/{org}", subR) + r.Mount("/", subR) + + r.Mount("/org", subR) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/org/1/user/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /org/{org}/user/pattern", body) + + resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /user/{user}/pattern", body) + + resp, body = testRequest(t, ts, "GET", "/mount/1/user/2/pattern/ctx", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "/mount/{org}/user/{user}/pattern/ctx", body) + + resp, body = testRequest(t, ts, "GET", "/user/2/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /user/{user}/pattern", body) + + resp, body = testRequest(t, ts, "GET", "/org/user/2/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /user/{user}/pattern", body) + + resp, body = testRequest(t, ts, "GET", "/org/user/2/pattern", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET /user/{user}/pattern", body) +} +func TestNotFoundRouter(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("404")) + })) + + r.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("not allowed")) + })) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "index", body) + + resp, body = testRequest(t, ts, "GET", "/xxx", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "404", body) + + resp, body = testRequest(t, ts, "POST", "/", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "GET, HEAD", resp.Header.Get("Allow")) + assert.Equal(t, "not allowed", body) +} + +func TestSubRouter(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + r.Route("/api", func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("api index")) + }) + + r.Get("/user", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("api user")) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "index", body) + + resp, body = testRequest(t, ts, "GET", "/xxx", nil) + assert.Equal(t, resp.StatusCode, http.StatusNotFound) + assert.Equal(t, "404 page not found\n", body) + + resp, body = testRequest(t, ts, "GET", "/api", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "api index", body) + + resp, body = testRequest(t, ts, "GET", "/api/user", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "api user", body) +} + +func TestGroupRouter(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + r.Group(func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("group index")) + }) + + r.Get("/user", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("group user")) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "group index", body) + + resp, body = testRequest(t, ts, "GET", "/xxx", nil) + assert.Equal(t, resp.StatusCode, http.StatusNotFound) + assert.Equal(t, "404 page not found\n", body) + + resp, body = testRequest(t, ts, "GET", "/user", nil) + assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, "group user", body) +} + +func TestGroupConflictsRouter(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + err, ok := r.(error) + require.True(t, ok) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conflicts with pattern") + }() + + r := NewRouter() + // r := chi.NewRouter() + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + + r.Group(func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("group index")) + }) + + r.Get("/user", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("group user")) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() +} + +func TestUseRouter(t *testing.T) { + r := NewRouter() + // r := chi.NewRouter() + + h := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + } + + m := func(name string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(name + ":")) + next.ServeHTTP(w, r) + }) + } + } + + r.Use(m("m1")) + r.Use(m("m2")) + + r.With(m("m3")).Get("/user", h) + r.Get("/", h) + + r.Group(func(r Router) { + r.Use(m("m4")) + r.Get("/user2", h) + }) + + r.Route("/api", func(r Router) { + r.Use(m("m5")) + r.Get("/user3", h) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, body := testRequest(t, ts, "GET", "/", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "m1:m2:index", body) + + resp, body = testRequest(t, ts, "GET", "/user", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "m1:m2:m3:index", body) + + resp, body = testRequest(t, ts, "GET", "/api/user3", nil) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "m1:m2:m5:index", body) +}