Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
ec4a0ac231 | |||
5885b42816 | |||
216fe93a55 |
53
mux.go
53
mux.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
@@ -35,37 +36,38 @@ func (m *Mux) Use(h ...func(http.Handler) http.Handler) {
|
|||||||
if m == nil {
|
if m == nil {
|
||||||
panic("mux: func Use was called on nil")
|
panic("mux: func Use was called on nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.middlewares = append(m.middlewares, h...)
|
m.middlewares = append(m.middlewares, h...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GET method route
|
// GET method route
|
||||||
func (m *Mux) GET(pattern string, h http.HandlerFunc) {
|
func (m *Mux) GET(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
|
||||||
m.handle(http.MethodGet, pattern, h)
|
m.handle(http.MethodGet, pattern, h, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HEAD method route
|
// HEAD method route
|
||||||
func (m *Mux) HEAD(pattern string, h http.HandlerFunc) {
|
func (m *Mux) HEAD(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
|
||||||
m.handle(http.MethodHead, pattern, h)
|
m.handle(http.MethodHead, pattern, h, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// POST method route
|
// POST method route
|
||||||
func (m *Mux) POST(pattern string, h http.HandlerFunc) {
|
func (m *Mux) POST(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
|
||||||
m.handle(http.MethodPost, pattern, h)
|
m.handle(http.MethodPost, pattern, h, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PUT method route
|
// PUT method route
|
||||||
func (m *Mux) PUT(pattern string, h http.HandlerFunc) {
|
func (m *Mux) PUT(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
|
||||||
m.handle(http.MethodPut, pattern, h)
|
m.handle(http.MethodPut, pattern, h, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PATCH method route
|
// PATCH method route
|
||||||
func (m *Mux) PATCH(pattern string, h http.HandlerFunc) {
|
func (m *Mux) PATCH(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
|
||||||
m.handle(http.MethodPatch, pattern, h)
|
m.handle(http.MethodPatch, pattern, h, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DELETE method route
|
// DELETE method route
|
||||||
func (m *Mux) DELETE(pattern string, h http.HandlerFunc) {
|
func (m *Mux) DELETE(pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
|
||||||
m.handle(http.MethodDelete, pattern, h)
|
m.handle(http.MethodDelete, pattern, h, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CONNECT method route
|
// CONNECT method route
|
||||||
@@ -86,7 +88,7 @@ func (m *Mux) TRACE(pattern string, h http.HandlerFunc) {
|
|||||||
// handle registers the handler for the given pattern.
|
// handle registers the handler for the given pattern.
|
||||||
// If the given pattern conflicts, with one that is already registered, HandleFunc
|
// If the given pattern conflicts, with one that is already registered, HandleFunc
|
||||||
// panics.
|
// panics.
|
||||||
func (m *Mux) handle(method, pattern string, h http.HandlerFunc) {
|
func (m *Mux) handle(method, pattern string, h http.HandlerFunc, mw ...func(http.Handler) http.Handler) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
panic("mux: func Handle() was called on nil")
|
panic("mux: func Handle() was called on nil")
|
||||||
}
|
}
|
||||||
@@ -100,19 +102,20 @@ func (m *Mux) handle(method, pattern string, h http.HandlerFunc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
path := fmt.Sprintf("%s %s", method, pattern)
|
path := fmt.Sprintf("%s %s", method, pattern)
|
||||||
m.mux.Handle(path, stack(m.middlewares, h))
|
if len(mw) > 0 {
|
||||||
|
m.mux.Handle(path, stack(h, copyMW(m.middlewares, mw)))
|
||||||
|
} else {
|
||||||
|
m.mux.Handle(path, stack(h, m.middlewares))
|
||||||
|
}
|
||||||
|
|
||||||
m.routes.Add(path)
|
m.routes.Add(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// With adds inline middlewares for an endpoint handler.
|
// With adds inline middlewares for an endpoint handler.
|
||||||
func (m *Mux) With(middleware ...func(http.Handler) http.Handler) *Mux {
|
func (m *Mux) With(mw ...func(http.Handler) http.Handler) *Mux {
|
||||||
mws := make([]func(http.Handler) http.Handler, len(m.middlewares))
|
|
||||||
copy(mws, m.middlewares)
|
|
||||||
mws = append(mws, middleware...)
|
|
||||||
|
|
||||||
im := &Mux{
|
im := &Mux{
|
||||||
mux: m.mux,
|
mux: m.mux,
|
||||||
middlewares: mws,
|
middlewares: copyMW(m.middlewares, mw),
|
||||||
routes: m.routes,
|
routes: m.routes,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,3 +155,13 @@ func (m *Mux) PrintRoutes(w io.Writer) {
|
|||||||
func (m *Mux) RouteList() []string {
|
func (m *Mux) RouteList() []string {
|
||||||
return m.routes.All()
|
return m.routes.All()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyMW(a []func(http.Handler) http.Handler, b []func(http.Handler) http.Handler) []func(http.Handler) http.Handler {
|
||||||
|
if len(b) > 0 {
|
||||||
|
return slices.Concat(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
mws := make([]func(http.Handler) http.Handler, len(a))
|
||||||
|
copy(mws, a)
|
||||||
|
return mws
|
||||||
|
}
|
||||||
|
19
mux_test.go
19
mux_test.go
@@ -153,16 +153,17 @@ func TestRouterGroup(t *testing.T) {
|
|||||||
func TestRouterResource(t *testing.T) {
|
func TestRouterResource(t *testing.T) {
|
||||||
r := New()
|
r := New()
|
||||||
|
|
||||||
r.Resource("/users", func(res *Resource) {
|
r.Resource("/users",
|
||||||
res.Index(func(w http.ResponseWriter, r *http.Request) {
|
func(res *Resource) {
|
||||||
fmt.Fprint(w, "All users")
|
res.Index(func(w http.ResponseWriter, r *http.Request) {
|
||||||
})
|
fmt.Fprint(w, "All users")
|
||||||
|
})
|
||||||
|
|
||||||
res.View(func(w http.ResponseWriter, r *http.Request) {
|
res.View(func(w http.ResponseWriter, r *http.Request) {
|
||||||
id := r.PathValue("id")
|
id := r.PathValue("id")
|
||||||
fmt.Fprintf(w, "User %s", id)
|
fmt.Fprintf(w, "User %s", id)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
|
||||||
|
|
||||||
ts := httptest.NewServer(r)
|
ts := httptest.NewServer(r)
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
@@ -255,7 +256,7 @@ func TestStack(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
middlewares := []func(http.Handler) http.Handler{middleware1, middleware2}
|
middlewares := []func(http.Handler) http.Handler{middleware1, middleware2}
|
||||||
stacked := stack(middlewares, handler)
|
stacked := stack(handler, middlewares)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
10
resource.go
10
resource.go
@@ -21,7 +21,7 @@ type Resource struct {
|
|||||||
// - PUT /pattern/:id update a resource
|
// - PUT /pattern/:id update a resource
|
||||||
// - PATCH /pattern/:id partial update a resource
|
// - PATCH /pattern/:id partial update a resource
|
||||||
// - DELETE /resource/:id delete a resource
|
// - DELETE /resource/:id delete a resource
|
||||||
func (m *Mux) Resource(pattern string, fn func(res *Resource)) {
|
func (m *Mux) Resource(pattern string, fn func(res *Resource), mw ...func(http.Handler) http.Handler) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
panic("mux: Resource() called on nil")
|
panic("mux: Resource() called on nil")
|
||||||
}
|
}
|
||||||
@@ -34,14 +34,10 @@ func (m *Mux) Resource(pattern string, fn func(res *Resource)) {
|
|||||||
panic("mux: Resource() requires callback")
|
panic("mux: Resource() requires callback")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy root middlewares.
|
|
||||||
mws := make([]func(http.Handler) http.Handler, len(m.middlewares))
|
|
||||||
copy(mws, m.middlewares)
|
|
||||||
|
|
||||||
fn(&Resource{
|
fn(&Resource{
|
||||||
mux: m.mux,
|
mux: m.mux,
|
||||||
pattern: pattern,
|
pattern: pattern,
|
||||||
middlewares: mws,
|
middlewares: copyMW(m.middlewares, mw),
|
||||||
routes: m.routes,
|
routes: m.routes,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -125,7 +121,7 @@ func (res *Resource) handlerFunc(method, pattern string, h http.HandlerFunc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
path := fmt.Sprintf("%s %s", method, pattern)
|
path := fmt.Sprintf("%s %s", method, pattern)
|
||||||
res.mux.Handle(path, stack(res.middlewares, h))
|
res.mux.Handle(path, stack(h, res.middlewares))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use will register middleware(s) on Router stack.
|
// Use will register middleware(s) on Router stack.
|
||||||
|
2
stack.go
2
stack.go
@@ -3,7 +3,7 @@ package mux
|
|||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
||||||
// stack middlewares(http handler) in order they are passed (FIFO)
|
// stack middlewares(http handler) in order they are passed (FIFO)
|
||||||
func stack(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler {
|
func stack(endpoint http.Handler, middlewares []func(http.Handler) http.Handler) http.Handler {
|
||||||
// Return ahead of time if there aren't any middlewares for the chain
|
// Return ahead of time if there aren't any middlewares for the chain
|
||||||
if len(middlewares) == 0 {
|
if len(middlewares) == 0 {
|
||||||
return endpoint
|
return endpoint
|
||||||
|
Reference in New Issue
Block a user