From ec4a0ac231e7e081fc7e2255a108ae1ab761885e Mon Sep 17 00:00:00 2001 From: Ankit Patial Date: Sat, 16 Aug 2025 19:25:00 +0530 Subject: [PATCH] middleware stacking bug fix --- mux.go | 25 +++++++++++++++---------- mux_test.go | 2 +- resource.go | 13 ++----------- stack.go | 2 +- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/mux.go b/mux.go index 13a9187..1b66beb 100644 --- a/mux.go +++ b/mux.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "slices" "strings" "sync/atomic" ) @@ -102,25 +103,19 @@ func (m *Mux) handle(method, pattern string, h http.HandlerFunc, mw ...func(http path := fmt.Sprintf("%s %s", method, pattern) if len(mw) > 0 { - mws := make([]func(http.Handler) http.Handler, len(m.middlewares)+len(mw)) - copy(mws, m.middlewares) - mws = append(mws, mw...) + m.mux.Handle(path, stack(h, copyMW(m.middlewares, mw))) } else { - m.mux.Handle(path, stack(m.middlewares, h)) + m.mux.Handle(path, stack(h, m.middlewares)) } m.routes.Add(path) } // With adds inline middlewares for an endpoint handler. -func (m *Mux) With(middleware ...func(http.Handler) http.Handler) *Mux { - mws := make([]func(http.Handler) http.Handler, len(m.middlewares)+len(middleware)) - copy(mws, m.middlewares) - mws = append(mws, middleware...) - +func (m *Mux) With(mw ...func(http.Handler) http.Handler) *Mux { im := &Mux{ mux: m.mux, - middlewares: mws, + middlewares: copyMW(m.middlewares, mw), routes: m.routes, } @@ -160,3 +155,13 @@ func (m *Mux) PrintRoutes(w io.Writer) { func (m *Mux) RouteList() []string { return m.routes.All() } + +func copyMW(a []func(http.Handler) http.Handler, b []func(http.Handler) http.Handler) []func(http.Handler) http.Handler { + if len(b) > 0 { + return slices.Concat(a, b) + } + + mws := make([]func(http.Handler) http.Handler, len(a)) + copy(mws, a) + return mws +} diff --git a/mux_test.go b/mux_test.go index 78ebb34..14fb071 100644 --- a/mux_test.go +++ b/mux_test.go @@ -256,7 +256,7 @@ func TestStack(t *testing.T) { }) middlewares := []func(http.Handler) http.Handler{middleware1, middleware2} - stacked := stack(middlewares, handler) + stacked := stack(handler, middlewares) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/resource.go b/resource.go index 7b3e655..ae718fa 100644 --- a/resource.go +++ b/resource.go @@ -34,19 +34,10 @@ func (m *Mux) Resource(pattern string, fn func(res *Resource), mw ...func(http.H panic("mux: Resource() requires callback") } - // Copy root middlewares. - mws := make([]func(http.Handler) http.Handler, len(m.middlewares)+len(mw)) - copy(mws, m.middlewares) - - // Append inline middlewares. - if len(mw) > 0 { - mws = append(mws, mw...) - } - fn(&Resource{ mux: m.mux, pattern: pattern, - middlewares: mws, + middlewares: copyMW(m.middlewares, mw), routes: m.routes, }) } @@ -130,7 +121,7 @@ func (res *Resource) handlerFunc(method, pattern string, h http.HandlerFunc) { } path := fmt.Sprintf("%s %s", method, pattern) - res.mux.Handle(path, stack(res.middlewares, h)) + res.mux.Handle(path, stack(h, res.middlewares)) } // Use will register middleware(s) on Router stack. diff --git a/stack.go b/stack.go index 633e78f..9eac189 100644 --- a/stack.go +++ b/stack.go @@ -3,7 +3,7 @@ package mux import "net/http" // stack middlewares(http handler) in order they are passed (FIFO) -func stack(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { +func stack(endpoint http.Handler, middlewares []func(http.Handler) http.Handler) http.Handler { // Return ahead of time if there aren't any middlewares for the chain if len(middlewares) == 0 { return endpoint