4 Commits
v0.5.0 ... main

Author SHA1 Message Date
9d0ab3c0f2 resource extra handler 2025-08-17 20:04:48 +05:30
ec4a0ac231 middleware stacking bug fix 2025-08-16 19:25:00 +05:30
5885b42816 copy mw len fix 2025-08-16 18:20:24 +05:30
216fe93a55 inline middlewares 2025-08-16 14:43:41 +05:30
4 changed files with 79 additions and 41 deletions

53
mux.go
View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"slices"
"strings" "strings"
"sync/atomic" "sync/atomic"
) )
@@ -35,37 +36,38 @@ func (m *Mux) Use(h ...func(http.Handler) http.Handler) {
if m == nil { if m == nil {
panic("mux: func Use was called on nil") panic("mux: func Use was called on nil")
} }
m.middlewares = append(m.middlewares, h...) m.middlewares = append(m.middlewares, h...)
} }
// GET method route // GET method route
func (m *Mux) GET(pattern string, h http.HandlerFunc) { func (m *Mux) GET(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
m.handle(http.MethodGet, pattern, h) m.handle(http.MethodGet, pattern, h, mw...)
} }
// HEAD method route // HEAD method route
func (m *Mux) HEAD(pattern string, h http.HandlerFunc) { func (m *Mux) HEAD(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
m.handle(http.MethodHead, pattern, h) m.handle(http.MethodHead, pattern, h, mw...)
} }
// POST method route // POST method route
func (m *Mux) POST(pattern string, h http.HandlerFunc) { func (m *Mux) POST(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
m.handle(http.MethodPost, pattern, h) m.handle(http.MethodPost, pattern, h, mw...)
} }
// PUT method route // PUT method route
func (m *Mux) PUT(pattern string, h http.HandlerFunc) { func (m *Mux) PUT(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
m.handle(http.MethodPut, pattern, h) m.handle(http.MethodPut, pattern, h, mw...)
} }
// PATCH method route // PATCH method route
func (m *Mux) PATCH(pattern string, h http.HandlerFunc) { func (m *Mux) PATCH(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
m.handle(http.MethodPatch, pattern, h) m.handle(http.MethodPatch, pattern, h, mw...)
} }
// DELETE method route // DELETE method route
func (m *Mux) DELETE(pattern string, h http.HandlerFunc) { func (m *Mux) DELETE(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
m.handle(http.MethodDelete, pattern, h) m.handle(http.MethodDelete, pattern, h, mw...)
} }
// CONNECT method route // CONNECT method route
@@ -86,7 +88,7 @@ func (m *Mux) TRACE(pattern string, h http.HandlerFunc) {
// handle registers the handler for the given pattern. // handle registers the handler for the given pattern.
// If the given pattern conflicts, with one that is already registered, HandleFunc // If the given pattern conflicts, with one that is already registered, HandleFunc
// panics. // panics.
func (m *Mux) handle(method, pattern string, h http.HandlerFunc) { func (m *Mux) handle(method, pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
if m == nil { if m == nil {
panic("mux: func Handle() was called on nil") panic("mux: func Handle() was called on nil")
} }
@@ -100,19 +102,20 @@ func (m *Mux) handle(method, pattern string, h http.HandlerFunc) {
} }
path := fmt.Sprintf("%s %s", method, pattern) path := fmt.Sprintf("%s %s", method, pattern)
m.mux.Handle(path, stack(m.middlewares, h)) if len(mw) > 0 {
m.mux.Handle(path, stack(h, copyMW(m.middlewares, mw)))
} else {
m.mux.Handle(path, stack(h, m.middlewares))
}
m.routes.Add(path) m.routes.Add(path)
} }
// With adds inline middlewares for an endpoint handler. // With adds inline middlewares for an endpoint handler.
func (m *Mux) With(middleware ...func(http.Handler) http.Handler) *Mux { func (m *Mux) With(mw ...func(http.Handler) http.Handler) *Mux {
mws := make([]func(http.Handler) http.Handler, len(m.middlewares))
copy(mws, m.middlewares)
mws = append(mws, middleware...)
im := &Mux{ im := &Mux{
mux: m.mux, mux: m.mux,
middlewares: mws, middlewares: copyMW(m.middlewares, mw),
routes: m.routes, routes: m.routes,
} }
@@ -152,3 +155,13 @@ func (m *Mux) PrintRoutes(w io.Writer) {
func (m *Mux) RouteList() []string { func (m *Mux) RouteList() []string {
return m.routes.All() return m.routes.All()
} }
func copyMW(a []func(http.Handler) http.Handler, b []func(http.Handler) http.Handler) []func(http.Handler) http.Handler {
if len(b) > 0 {
return slices.Concat(a, b)
}
mws := make([]func(http.Handler) http.Handler, len(a))
copy(mws, a)
return mws
}

View File

@@ -153,7 +153,8 @@ func TestRouterGroup(t *testing.T) {
func TestRouterResource(t *testing.T) { func TestRouterResource(t *testing.T) {
r := New() r := New()
r.Resource("/users", func(res *Resource) { r.Resource("/users",
func(res *Resource) {
res.Index(func(w http.ResponseWriter, r *http.Request) { res.Index(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "All users") fmt.Fprint(w, "All users")
}) })
@@ -255,7 +256,7 @@ func TestStack(t *testing.T) {
}) })
middlewares := []func(http.Handler) http.Handler{middleware1, middleware2} middlewares := []func(http.Handler) http.Handler{middleware1, middleware2}
stacked := stack(middlewares, handler) stacked := stack(handler, middlewares)
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil) r := httptest.NewRequest(http.MethodGet, "/", nil)

View File

@@ -21,7 +21,7 @@ type Resource struct {
// - PUT /pattern/:id update a resource // - PUT /pattern/:id update a resource
// - PATCH /pattern/:id partial update a resource // - PATCH /pattern/:id partial update a resource
// - DELETE /resource/:id delete a resource // - DELETE /resource/:id delete a resource
func (m *Mux) Resource(pattern string, fn func(res *Resource)) { func (m *Mux) Resource(pattern string, fn func(res *Resource), mw ...func(http.Handler) http.Handler) {
if m == nil { if m == nil {
panic("mux: Resource() called on nil") panic("mux: Resource() called on nil")
} }
@@ -34,14 +34,10 @@ func (m *Mux) Resource(pattern string, fn func(res *Resource)) {
panic("mux: Resource() requires callback") panic("mux: Resource() requires callback")
} }
// Copy root middlewares.
mws := make([]func(http.Handler) http.Handler, len(m.middlewares))
copy(mws, m.middlewares)
fn(&Resource{ fn(&Resource{
mux: m.mux, mux: m.mux,
pattern: pattern, pattern: pattern,
middlewares: mws, middlewares: copyMW(m.middlewares, mw),
routes: m.routes, routes: m.routes,
}) })
} }
@@ -106,10 +102,38 @@ func (res *Resource) Delete(h http.HandlerFunc) {
res.handlerFunc(http.MethodDelete, p, h) res.handlerFunc(http.MethodDelete, p, h)
} }
func (res *Resource) Handle(pattern string, h http.HandlerFunc) { // HandleGET on /group-pattern/:id/pattern
p := suffixIt(res.pattern, "{id}") func (res *Resource) HandleGET(pattern string, h http.HandlerFunc) {
res.routes.Add(http.MethodDelete + " " + p) res.handle(http.MethodGet, pattern, h)
res.handlerFunc(http.MethodDelete, p, h) }
// HandlePOST on /group-pattern/:id/pattern
func (res *Resource) HandlePOST(pattern string, h http.HandlerFunc) {
res.handle(http.MethodPost, pattern, h)
}
// HandlePUT on /group-pattern/:id/pattern
func (res *Resource) HandlePUT(pattern string, h http.HandlerFunc) {
res.handle(http.MethodPut, pattern, h)
}
// HandlePATCH on /group-pattern/:id/pattern
func (res *Resource) HandlePATCH(pattern string, h http.HandlerFunc) {
res.handle(http.MethodPatch, pattern, h)
}
// HandleDELETE on /group-pattern/:id/pattern
func (res *Resource) HandleDELETE(pattern string, h http.HandlerFunc) {
res.handle(http.MethodDelete, pattern, h)
}
func (res *Resource) handle(method string, pattern string, h http.HandlerFunc) {
if !strings.HasPrefix(pattern, "/") {
pattern = "/" + pattern
}
p := suffixIt(res.pattern, "{id}"+pattern)
res.routes.Add(method + " " + p)
res.handlerFunc(method, p, h)
} }
// handlerFunc registers the handler function for the given pattern. // handlerFunc registers the handler function for the given pattern.
@@ -125,7 +149,7 @@ func (res *Resource) handlerFunc(method, pattern string, h http.HandlerFunc) {
} }
path := fmt.Sprintf("%s %s", method, pattern) path := fmt.Sprintf("%s %s", method, pattern)
res.mux.Handle(path, stack(res.middlewares, h)) res.mux.Handle(path, stack(h, res.middlewares))
} }
// Use will register middleware(s) on Router stack. // Use will register middleware(s) on Router stack.

View File

@@ -3,7 +3,7 @@ package mux
import "net/http" import "net/http"
// stack middlewares(http handler) in order they are passed (FIFO) // stack middlewares(http handler) in order they are passed (FIFO)
func stack(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { func stack(endpoint http.Handler, middlewares []func(http.Handler) http.Handler) http.Handler {
// Return ahead of time if there aren't any middlewares for the chain // Return ahead of time if there aren't any middlewares for the chain
if len(middlewares) == 0 { if len(middlewares) == 0 {
return endpoint return endpoint