From 855b82e9df65c3fae3d3ea16ecb12c3f14be2308 Mon Sep 17 00:00:00 2001 From: Ankit Patial Date: Sat, 16 Aug 2025 11:19:45 +0530 Subject: [PATCH] Split code in respective files. Resource method name change. Route list func --- example/main.go | 59 ++++++---- middleware/helmet_test.go | 2 +- mux.go | 154 ++++++++++++++++++++++++++ router_test.go => mux_test.go | 74 +++++++------ resource.go | 200 +++++++++++++++++++++------------- route.go | 51 +++++++++ router.go | 177 ------------------------------ router_serve.go | 58 ---------- serve.go | 78 +++++++++++++ stack.go | 19 ++++ 10 files changed, 506 insertions(+), 366 deletions(-) create mode 100644 mux.go rename router_test.go => mux_test.go (82%) create mode 100644 route.go delete mode 100644 router.go delete mode 100644 router_serve.go create mode 100644 serve.go create mode 100644 stack.go diff --git a/example/main.go b/example/main.go index 50025c5..1318534 100644 --- a/example/main.go +++ b/example/main.go @@ -10,8 +10,8 @@ import ( func main() { // create a new router - r := mux.NewRouter() - r.Use(middleware.CORS(middleware.CORSOption{ + m := mux.New() + m.Use(middleware.CORS(middleware.CORSOption{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-AccessToken", "X-Real-IP"}, @@ -20,7 +20,7 @@ func main() { MaxAge: 300, })) - r.Use(middleware.Helmet(middleware.HelmetOption{ + m.Use(middleware.Helmet(middleware.HelmetOption{ StrictTransportSecurity: &middleware.TransportSecurity{ MaxAge: 31536000, IncludeSubDomains: true, @@ -36,10 +36,10 @@ func main() { // - https://github.com/go-chi/chi/tree/master/middleware // add some root level middlewares, these will apply to all routes after it - r.Use(middleware1, middleware2) + m.Use(middleware1, middleware2) // let's add a route - r.GET("/hello", func(w http.ResponseWriter, r *http.Request) { + m.GET("/hello", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("i am route /hello")) }) // r.Post(pattern string, h http.HandlerFunc) @@ -47,34 +47,45 @@ func main() { // ... // you can inline middleware(s) to a route - r. + m. With(mwInline). GET("/hello-2", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("i am route /hello-2 with my own middleware")) }) // define a resource - r.Resource("/photos", func(resource *mux.Resource) { - // Rails style resource routes - // GET /photos - // GET /photos/new - // POST /photos - // GET /photos/:id - // GET /photos/:id/edit - // PUT /photos/:id - // PATCH /photos/:id - // DELETE /photos/:id - resource.Index(func(w http.ResponseWriter, r *http.Request) { + m.Resource("/photos", func(res *mux.Resource) { + res.Index(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("all photos")) }) - resource.New(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("upload a new pohoto")) + res.CreateView(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("new photo view")) + }) + + res.Create(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("new photo")) + }) + + res.View(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("view photo detail")) + }) + + res.Update(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("update photos")) + }) + + res.UpdatePartial(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("update few of photo fields")) + }) + + res.Delete(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("removed a phot")) }) }) // create a group of few routes with their own middlewares - r.Group(func(grp *mux.Router) { + m.Group(func(grp *mux.Mux) { grp.Use(mwGroup) grp.GET("/group", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("i am route /group")) @@ -82,12 +93,16 @@ func main() { }) // catches all - r.GET("/", func(w http.ResponseWriter, r *http.Request) { + m.GET("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello there")) }) + m.GET("/routes", func(w http.ResponseWriter, r *http.Request) { + m.PrintRoutes(w) + }) + // Serve allows graceful shutdown, you can use it - r.Serve(func(srv *http.Server) error { + m.Serve(func(srv *http.Server) error { srv.Addr = ":3001" // srv.ReadTimeout = time.Minute // srv.WriteTimeout = time.Minute diff --git a/middleware/helmet_test.go b/middleware/helmet_test.go index 940a175..822261a 100644 --- a/middleware/helmet_test.go +++ b/middleware/helmet_test.go @@ -10,7 +10,7 @@ import ( ) func TestHelmet(t *testing.T) { - r := mux.NewRouter() + r := mux.New() r.Use(Helmet(HelmetOption{})) r.GET("/hello", func(writer http.ResponseWriter, request *http.Request) { _, _ = writer.Write([]byte("hello there")) diff --git a/mux.go b/mux.go new file mode 100644 index 0000000..db397de --- /dev/null +++ b/mux.go @@ -0,0 +1,154 @@ +package mux + +import ( + "fmt" + "io" + "net/http" + "strings" + "sync/atomic" +) + +// Mux is a wrapper around the go's standard http.ServeMux. +// It's a lean wrapper with methods to make routing easier +type Mux struct { + mux *http.ServeMux + middlewares []func(http.Handler) http.Handler + routes *RouteList + IsShuttingDown atomic.Bool +} + +func New() *Mux { + m := &Mux{ + mux: http.NewServeMux(), + routes: new(RouteList), + } + return m +} + +// HttpServeMux DO NOT USE it for routing, exposed only for edge cases. +func (m *Mux) HttpServeMux() *http.ServeMux { + return m.mux +} + +// Use will register middleware(s) with router stack +func (m *Mux) Use(h ...func(http.Handler) http.Handler) { + if m == nil { + panic("mux: func Use was called on nil") + } + m.middlewares = append(m.middlewares, h...) +} + +// GET method route +func (m *Mux) GET(pattern string, h http.HandlerFunc) { + m.handle(http.MethodGet, pattern, h) +} + +// HEAD method route +func (m *Mux) HEAD(pattern string, h http.HandlerFunc) { + m.handle(http.MethodHead, pattern, h) +} + +// POST method route +func (m *Mux) POST(pattern string, h http.HandlerFunc) { + m.handle(http.MethodPost, pattern, h) +} + +// PUT method route +func (m *Mux) PUT(pattern string, h http.HandlerFunc) { + m.handle(http.MethodPut, pattern, h) +} + +// PATCH method route +func (m *Mux) PATCH(pattern string, h http.HandlerFunc) { + m.handle(http.MethodPatch, pattern, h) +} + +// DELETE method route +func (m *Mux) DELETE(pattern string, h http.HandlerFunc) { + m.handle(http.MethodDelete, pattern, h) +} + +// CONNECT method route +func (m *Mux) CONNECT(pattern string, h http.HandlerFunc) { + m.handle(http.MethodConnect, pattern, h) +} + +// OPTIONS method route +func (m *Mux) OPTIONS(pattern string, h http.HandlerFunc) { + m.handle(http.MethodOptions, pattern, h) +} + +// TRACE method route +func (m *Mux) TRACE(pattern string, h http.HandlerFunc) { + m.handle(http.MethodTrace, pattern, h) +} + +// handle registers the handler for the given pattern. +// If the given pattern conflicts, with one that is already registered, HandleFunc +// panics. +func (m *Mux) handle(method, pattern string, h http.HandlerFunc) { + if m == nil { + panic("mux: func Handle() was called on nil") + } + + if strings.TrimSpace(pattern) == "" { + panic("mux: pattern cannot be empty") + } + + if !strings.HasPrefix(pattern, "/") { + pattern = "/" + pattern + } + + path := fmt.Sprintf("%s %s", method, pattern) + m.mux.Handle(path, stack(m.middlewares, h)) + m.routes.Add(path) +} + +// With adds inline middlewares for an endpoint handler. +func (m *Mux) With(middleware ...func(http.Handler) http.Handler) *Mux { + mws := make([]func(http.Handler) http.Handler, len(m.middlewares)) + copy(mws, m.middlewares) + mws = append(mws, middleware...) + + im := &Mux{ + mux: m.mux, + middlewares: mws, + routes: m.routes, + } + + return im +} + +// Group adds a new inline-Router along the current routing +// path, with a fresh middleware stack for the inline-Router. +func (m *Mux) Group(fn func(grp *Mux)) { + if m == nil { + panic("mux: Group() called on nil") + } + + if fn == nil { + panic("mux: Group() requires callback") + } + + im := m.With() + fn(im) +} + +func (m *Mux) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if m == nil { + panic("mux: method ServeHTTP called on nil") + } + + m.mux.ServeHTTP(w, req) +} + +func (m *Mux) PrintRoutes(w io.Writer) { + for _, route := range m.routes.All() { + w.Write([]byte(route)) + w.Write([]byte("\n")) + } +} + +func (m *Mux) RouteList() []string { + return m.routes.All() +} diff --git a/router_test.go b/mux_test.go similarity index 82% rename from router_test.go rename to mux_test.go index 95631a8..2808d43 100644 --- a/router_test.go +++ b/mux_test.go @@ -3,19 +3,22 @@ package mux import ( "fmt" "io" + "math/rand" "net/http" "net/http/httptest" + "strconv" "strings" "testing" + "time" ) func TestRouterGET(t *testing.T) { - r := NewRouter() - r.GET("/test", func(w http.ResponseWriter, r *http.Request) { + m := New() + m.GET("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "GET test") }) - ts := httptest.NewServer(r) + ts := httptest.NewServer(m) defer ts.Close() resp, err := http.Get(ts.URL + "/test") @@ -40,12 +43,12 @@ func TestRouterGET(t *testing.T) { } func TestRouterPOST(t *testing.T) { - r := NewRouter() - r.POST("/test", func(w http.ResponseWriter, r *http.Request) { + m := New() + m.POST("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "POST test") }) - ts := httptest.NewServer(r) + ts := httptest.NewServer(m) defer ts.Close() resp, err := http.Post(ts.URL+"/test", "text/plain", strings.NewReader("test data")) @@ -70,7 +73,7 @@ func TestRouterPOST(t *testing.T) { } func TestRouterWith(t *testing.T) { - r := NewRouter() + m := New() middleware := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -79,11 +82,11 @@ func TestRouterWith(t *testing.T) { }) } - r.With(middleware).GET("/test", func(w http.ResponseWriter, r *http.Request) { + m.With(middleware).GET("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "GET with middleware") }) - ts := httptest.NewServer(r) + ts := httptest.NewServer(m) defer ts.Close() resp, err := http.Get(ts.URL + "/test") @@ -108,11 +111,11 @@ func TestRouterWith(t *testing.T) { } func TestRouterGroup(t *testing.T) { - r := NewRouter() + r := New() var groupCalled bool - r.Group(func(g *Router) { + r.Group(func(g *Mux) { groupCalled = true g.GET("/group", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Group route") @@ -148,14 +151,14 @@ func TestRouterGroup(t *testing.T) { } func TestRouterResource(t *testing.T) { - r := NewRouter() + r := New() - r.Resource("/users", func(resource *Resource) { - resource.Index(func(w http.ResponseWriter, r *http.Request) { + r.Resource("/users", func(res *Resource) { + res.Index(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "All users") }) - resource.Show(func(w http.ResponseWriter, r *http.Request) { + res.View(func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") fmt.Fprintf(w, "User %s", id) }) @@ -200,7 +203,7 @@ func TestRouterResource(t *testing.T) { } func TestGetParams(t *testing.T) { - r := NewRouter() + r := New() r.GET("/users/{id}/posts/{post_id}", func(w http.ResponseWriter, r *http.Request) { userId := r.PathValue("id") @@ -285,28 +288,36 @@ func TestRouterPanic(t *testing.T) { } }() - var r *Router + var r *Mux r.GET("/", func(w http.ResponseWriter, r *http.Request) {}) } +// BenchmarkRouterSimple-12 1125854 1058 ns/op 1568 B/op 17 allocs/op func BenchmarkRouterSimple(b *testing.B) { - r := NewRouter() + m := New() - r.GET("/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Hello") - }) + for i := range 10000 { + m.GET("/"+strconv.Itoa(i), func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello from "+strconv.Itoa(i)) + }) + } - req, _ := http.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) - b.ResetTimer() - for i := 0; i < b.N; i++ { - r.ServeHTTP(w, req) + // Generate a random integer between 0 and 99 (inclusive) + rn := r.Intn(10000) + + for b.Loop() { + req, _ := http.NewRequest(http.MethodGet, "/"+strconv.Itoa(rn), nil) + w := httptest.NewRecorder() + m.ServeHTTP(w, req) } } +// BenchmarkRouterWithMiddleware-12 14761327 68.70 ns/op 18 B/op 0 allocs/op func BenchmarkRouterWithMiddleware(b *testing.B) { - r := NewRouter() + m := New() middleware := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -314,17 +325,16 @@ func BenchmarkRouterWithMiddleware(b *testing.B) { }) } - r.Use(middleware, middleware) + m.Use(middleware, middleware) - r.GET("/", func(w http.ResponseWriter, r *http.Request) { + m.GET("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello") }) req, _ := http.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() - b.ResetTimer() - for i := 0; i < b.N; i++ { - r.ServeHTTP(w, req) + for b.Loop() { + m.ServeHTTP(w, req) } } diff --git a/resource.go b/resource.go index 11267d3..fdb16b6 100644 --- a/resource.go +++ b/resource.go @@ -6,28 +6,134 @@ import ( "strings" ) -// Resource is a resourceful route provides a mapping between HTTP verbs and URLs and controller actions. -// By convention, each action also maps to particular CRUD operations in a database. -// A single entry in the routing file, such as -// Index route -// -// GET /resource-name # index route -// -// GET /resource-name/new # create resource page -// -// POST /resource-name # create resource post -// -// GET /resource-name/:id # view resource -// -// GET /resource-name/:id/edit # edit resource -// -// PUT /resource-name/:id # update resource -// -// DELETE /resource-name/:id # delete resource type Resource struct { mux *http.ServeMux pattern string middlewares []func(http.Handler) http.Handler + routes *RouteList +} + +// Resource routes mapping by using HTTP verbs +// - GET /pattern view all resources +// - GET /pattern/create new resource view +// - POST /pattern create a new resource +// - GET /pattern/:id view a resource +// - PUT /pattern/:id update a resource +// - PATCH /pattern/:id partial update a resource +// - DELETE /resource/:id delete a resource +func (m *Mux) Resource(pattern string, fn func(res *Resource)) { + if m == nil { + panic("mux: Resource() called on nil") + } + + if strings.TrimSpace(pattern) == "" { + panic("mux: Resource() requires a patter to work") + } + + if fn == nil { + panic("mux: Resource() requires callback") + } + + // Copy root middlewares. + mws := make([]func(http.Handler) http.Handler, len(m.middlewares)) + copy(mws, m.middlewares) + + fn(&Resource{ + mux: m.mux, + pattern: pattern, + middlewares: mws, + routes: m.routes, + }) +} + +// Index of all resource. +// +// GET /pattern +func (res *Resource) Index(h http.HandlerFunc) { + res.routes.Add(http.MethodGet + " " + res.pattern) + res.handlerFunc(http.MethodGet, res.pattern, h) +} + +// CreateView new resource +// +// GET /pattern/create +func (res *Resource) CreateView(h http.HandlerFunc) { + p := suffixIt(res.pattern, "create") + res.routes.Add(http.MethodGet + " " + p) + res.handlerFunc(http.MethodGet, p, h) +} + +// Create a new resource +// +// POST /pattern/create +func (res *Resource) Create(h http.HandlerFunc) { + res.routes.Add(http.MethodPost + " " + res.pattern) + res.handlerFunc(http.MethodPost, res.pattern, h) +} + +// View a resource +// +// GET /pattern/:id +func (res *Resource) View(h http.HandlerFunc) { + p := suffixIt(res.pattern, "{id}") + res.routes.Add(http.MethodGet + " " + p) + res.handlerFunc(http.MethodGet, p, h) +} + +// Update a resource +// +// PUT /pattern/:id +func (res *Resource) Update(h http.HandlerFunc) { + p := suffixIt(res.pattern, "{id}") + res.routes.Add(http.MethodPut + " " + p) + res.handlerFunc(http.MethodPut, p, h) +} + +// UpdatePartial resource info +// PATCH /pattern/:id +func (res *Resource) UpdatePartial(h http.HandlerFunc) { + p := suffixIt(res.pattern, "{id}") + res.routes.Add(http.MethodPatch + " " + p) + res.handlerFunc(http.MethodPatch, p, h) +} + +// Delete a resource +// +// DELETE /pattern/:id +func (res *Resource) Delete(h http.HandlerFunc) { + p := suffixIt(res.pattern, "{id}") + res.routes.Add(http.MethodDelete + " " + p) + res.handlerFunc(http.MethodDelete, p, h) +} + +func (res *Resource) Handle(pattern string, h http.HandlerFunc) { + p := suffixIt(res.pattern, "{id}") + res.routes.Add(http.MethodDelete + " " + p) + res.handlerFunc(http.MethodDelete, p, h) +} + +// handlerFunc registers the handler function for the given pattern. +// If the given pattern conflicts, with one that is already registered, HandleFunc +// panics. +func (res *Resource) handlerFunc(method, pattern string, h http.HandlerFunc) { + if res == nil { + panic("serve: func handlerFunc() was called on nil") + } + + if res.mux == nil { + panic("serve: router mux is nil") + } + + path := fmt.Sprintf("%s %s", method, pattern) + res.mux.Handle(path, stack(res.middlewares, h)) +} + +// Use will register middleware(s) on Router stack. +func (res *Resource) Use(middlewares ...func(http.Handler) http.Handler) { + if res == nil { + panic("serve: func Use was called on nil") + } + res.middlewares = append(res.middlewares, middlewares...) } func suffixIt(str, suffix string) string { @@ -39,61 +145,3 @@ func suffixIt(str, suffix string) string { p.WriteString(suffix) return p.String() } - -// Index is GET /resource-name -func (r *Resource) Index(h http.HandlerFunc) { - r.handlerFunc(http.MethodGet, r.pattern, h) -} - -// New is GET /resource-name/new -func (r *Resource) New(h http.HandlerFunc) { - r.handlerFunc(http.MethodGet, suffixIt(r.pattern, "new"), h) -} - -// Create is POST /resource-name -func (r *Resource) Create(h http.HandlerFunc) { - r.handlerFunc(http.MethodPost, r.pattern, h) -} - -// Show is GET /resource-name/:id -func (r *Resource) Show(h http.HandlerFunc) { - r.handlerFunc(http.MethodGet, suffixIt(r.pattern, "{id}"), h) -} - -// Update is PUT /resource-name/:id -func (r *Resource) Update(h http.HandlerFunc) { - r.handlerFunc(http.MethodPut, suffixIt(r.pattern, "{id}"), h) -} - -// PartialUpdate is PATCH /resource-name/:id -func (r *Resource) PartialUpdate(h http.HandlerFunc) { - r.handlerFunc(http.MethodPatch, suffixIt(r.pattern, "{id}"), h) -} - -func (r *Resource) Destroy(h http.HandlerFunc) { - r.handlerFunc(http.MethodDelete, suffixIt(r.pattern, "{id}"), h) -} - -// handlerFunc registers the handler function for the given pattern. -// If the given pattern conflicts, with one that is already registered, HandleFunc -// panics. -func (r *Resource) handlerFunc(method, pattern string, h http.HandlerFunc) { - if r == nil { - panic("serve: func handlerFunc() was called on nil") - } - - if r.mux == nil { - panic("serve: router mux is nil") - } - - path := fmt.Sprintf("%s %s", method, pattern) - r.mux.Handle(path, stack(r.middlewares, h)) -} - -// Use will register middleware(s) on Router stack. -func (r *Resource) Use(middlewares ...func(http.Handler) http.Handler) { - if r == nil { - panic("serve: func Use was called on nil") - } - r.middlewares = append(r.middlewares, middlewares...) -} diff --git a/route.go b/route.go new file mode 100644 index 0000000..d642b30 --- /dev/null +++ b/route.go @@ -0,0 +1,51 @@ +package mux + +import ( + "fmt" + "log/slog" + "sync" +) + +type RouteList struct { + mu sync.RWMutex + routes []string +} + +func (s *RouteList) Add(item string) { + if s == nil { + slog.Warn("failed on Add, RouteList is nil") + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.routes = append(s.routes, item) +} + +func (s *RouteList) Get(index int) (string, error) { + if s == nil { + slog.Warn("failed on Get, RouteList is nil") + return "", nil + } + + s.mu.RLock() + defer s.mu.RUnlock() + + if index < 0 || index >= len(s.routes) { + return "0", fmt.Errorf("index out of bounds") + } + return s.routes[index], nil +} + +func (s *RouteList) All() []string { + if s == nil { + slog.Warn("failed on All, RouteList is nil") + return nil + } + + s.mu.RLock() + defer s.mu.RUnlock() + + return s.routes +} diff --git a/router.go b/router.go deleted file mode 100644 index 1aa7837..0000000 --- a/router.go +++ /dev/null @@ -1,177 +0,0 @@ -package mux - -import ( - "fmt" - "net/http" - "strings" -) - -// Router is a wrapper around the go's standard http.ServeMux. -// It's a lean wrapper with methods to make routing easier -type Router struct { - mux *http.ServeMux - middlewares []func(http.Handler) http.Handler -} - -func NewRouter() *Router { - r := &Router{ - mux: http.NewServeMux(), - } - return r -} - -// Mux DO NOT USE it for routing, exposed only for edge cases. -func (r *Router) Mux() *http.ServeMux { - return r.mux -} - -// Use will register middleware(s) with router stack -func (r *Router) Use(h ...func(http.Handler) http.Handler) { - if r == nil { - panic("mux: func Use was called on nil") - } - r.middlewares = append(r.middlewares, h...) -} - -// GET method route -func (r *Router) GET(pattern string, h http.HandlerFunc) { - r.handle(http.MethodGet, pattern, h) -} - -// HEAD method route -func (r *Router) HEAD(pattern string, h http.HandlerFunc) { - r.handle(http.MethodHead, pattern, h) -} - -// POST method route -func (r *Router) POST(pattern string, h http.HandlerFunc) { - r.handle(http.MethodPost, pattern, h) -} - -// PUT method route -func (r *Router) PUT(pattern string, h http.HandlerFunc) { - r.handle(http.MethodPut, pattern, h) -} - -// PATCH method route -func (r *Router) PATCH(pattern string, h http.HandlerFunc) { - r.handle(http.MethodPatch, pattern, h) -} - -// DELETE method route -func (r *Router) DELETE(pattern string, h http.HandlerFunc) { - r.handle(http.MethodDelete, pattern, h) -} - -// CONNECT method route -func (r *Router) CONNECT(pattern string, h http.HandlerFunc) { - r.handle(http.MethodConnect, pattern, h) -} - -// OPTIONS method route -func (r *Router) OPTIONS(pattern string, h http.HandlerFunc) { - r.handle(http.MethodOptions, pattern, h) -} - -// TRACE method route -func (r *Router) TRACE(pattern string, h http.HandlerFunc) { - r.handle(http.MethodTrace, pattern, h) -} - -// handle registers the handler for the given pattern. -// If the given pattern conflicts, with one that is already registered, HandleFunc -// panics. -func (r *Router) handle(method, pattern string, h http.HandlerFunc) { - if r == nil { - panic("mux: func Handle() was called on nil") - } - - if strings.TrimSpace(pattern) == "" { - panic("mux: pattern cannot be empty") - } - - if !strings.HasPrefix(pattern, "/") { - pattern = "/" + pattern - } - - path := fmt.Sprintf("%s %s", method, pattern) - r.mux.Handle(path, stack(r.middlewares, h)) -} - -// With adds inline middlewares for an endpoint handler. -func (r *Router) With(middleware ...func(http.Handler) http.Handler) *Router { - mws := make([]func(http.Handler) http.Handler, len(r.middlewares)) - copy(mws, r.middlewares) - mws = append(mws, middleware...) - - im := &Router{ - mux: r.mux, - middlewares: mws, - } - - return im -} - -// Group adds a new inline-Router along the current routing -// path, with a fresh middleware stack for the inline-Router. -func (r *Router) Group(fn func(grp *Router)) { - if r == nil { - panic("mux: Group() called on nil") - } - - if fn == nil { - panic("mux: Group() requires callback") - } - - im := r.With() - fn(im) -} - -// Resource resourceful route provides a mapping between HTTP verbs for given the pattern -func (r *Router) Resource(pattern string, fn func(resource *Resource)) { - if r == nil { - panic("mux: Resource() called on nil") - } - - if strings.TrimSpace(pattern) == "" { - panic("mux: Resource() requires a patter to work") - } - - if fn == nil { - panic("mux: Resource() requires callback") - } - - mws := make([]func(http.Handler) http.Handler, len(r.middlewares)) - copy(mws, r.middlewares) - fn(&Resource{ - mux: r.mux, - pattern: pattern, - middlewares: mws, - }) -} - -func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if r == nil { - panic("mux: method ServeHTTP called on nil") - } - - r.mux.ServeHTTP(w, req) -} - -// TODO: proxy for aws lambda and other serverless platforms - -// stack middlewares(http handler) in order they are passed (FIFO) -func stack(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { - // Return ahead of time if there aren't any middlewares for the chain - if len(middlewares) == 0 { - return endpoint - } - - // wrap the end handler with the middleware chain - h := middlewares[len(middlewares)-1](endpoint) - for i := len(middlewares) - 2; i >= 0; i-- { - h = middlewares[i](h) - } - - return h -} diff --git a/router_serve.go b/router_serve.go deleted file mode 100644 index 3c579e6..0000000 --- a/router_serve.go +++ /dev/null @@ -1,58 +0,0 @@ -package mux - -import ( - "context" - "errors" - "io" - "log/slog" - "net/http" - "os" - "os/signal" -) - -type ServeCB func(srv *http.Server) error - -// Serve with graceful shutdown -func (r *Router) Serve(cb ServeCB) { - // catch all options - // lets get it thorugh all middlewares - r.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "0") - if r.ContentLength != 0 { - // Read up to 4KB of OPTIONS body (as mentioned in the - // spec as being reserved for future use), but anything - // over that is considered a waste of server resources - // (or an attack) and we abort and close the connection, - // courtesy of MaxBytesReader's EOF behavior. - mb := http.MaxBytesReader(w, r.Body, 4<<10) - io.Copy(io.Discard, mb) - } - }) - - srv := &http.Server{ - Handler: r, - } - - idleConnsClosed := make(chan struct{}) - go func() { - sigint := make(chan os.Signal, 1) - signal.Notify(sigint, os.Interrupt) - <-sigint - - // We received an interrupt signal, shut down. - if err := srv.Shutdown(context.Background()); err != nil { - // Error from closing listeners, or context timeout: - slog.Error("server shutdown error", "error", err) - } else { - slog.Info("server shutdown") - } - close(idleConnsClosed) - }() - - if err := cb(srv); !errors.Is(err, http.ErrServerClosed) { - // Error starting or closing listener: - slog.Error("start server error", "error", err) - } - - <-idleConnsClosed -} diff --git a/serve.go b/serve.go new file mode 100644 index 0000000..372b8c2 --- /dev/null +++ b/serve.go @@ -0,0 +1,78 @@ +package mux + +import ( + "context" + "errors" + "io" + "log" + "log/slog" + "net" + "net/http" + "os/signal" + "syscall" + "time" +) + +type ServeCB func(srv *http.Server) error + +const ( + shutdownDelay = time.Second * 10 + shutdownHardDelay = time.Second * 5 + drainDelay = time.Second +) + +// Serve with graceful shutdown +func (m *Mux) Serve(cb ServeCB) { + rootCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + // catch all options + // lets get it thorugh all middlewares + m.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "0") + if r.ContentLength != 0 { + // Read up to 4KB of OPTIONS body (as mentioned in the + // spec as being reserved for future use), but anything + // over that is considered a waste of server resources + // (or an attack) and we abort and close the connection, + // courtesy of MaxBytesReader's EOF behavior. + mb := http.MaxBytesReader(w, r.Body, 4<<10) + io.Copy(io.Discard, mb) + } + }) + + srvCtx, cancelSrvCtx := context.WithCancel(context.Background()) + srv := &http.Server{ + Handler: m, + BaseContext: func(_ net.Listener) context.Context { + return srvCtx + }, + } + + go func() { + if err := cb(srv); !errors.Is(err, http.ErrServerClosed) { + panic(err) + } + }() + + // Wait for interrupt signal + <-rootCtx.Done() + + stop() + m.IsShuttingDown.Store(true) + slog.Info("received interrupt singal, shutting down") + time.Sleep(drainDelay) + slog.Info("readiness check propagated, now waiting for ongoing requests to finish.") + + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownDelay) + defer cancel() + + err := srv.Shutdown(shutdownCtx) + cancelSrvCtx() + if err != nil { + log.Println("failed to wait for ongoing requests to finish, waiting for forced cancellation") + time.Sleep(shutdownHardDelay) + } + + slog.Info("seerver shut down gracefully") +} diff --git a/stack.go b/stack.go new file mode 100644 index 0000000..633e78f --- /dev/null +++ b/stack.go @@ -0,0 +1,19 @@ +package mux + +import "net/http" + +// stack middlewares(http handler) in order they are passed (FIFO) +func stack(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { + // Return ahead of time if there aren't any middlewares for the chain + if len(middlewares) == 0 { + return endpoint + } + + // wrap the end handler with the middleware chain + h := middlewares[len(middlewares)-1](endpoint) + for i := len(middlewares) - 2; i >= 0; i-- { + h = middlewares[i](h) + } + + return h +}