From f8cdf3a5115fe172c214a54ce1dffa1a03e6bcf4 Mon Sep 17 00:00:00 2001 From: Ankit Patial Date: Sat, 17 May 2025 18:55:15 +0530 Subject: [PATCH] middleware helmet changes. router check and panic message change. README enhancement --- README.md | 234 +++++++++++++++++------------- example/main.go | 18 ++- middleware/helmet.go | 34 ++--- router.go | 38 +++-- router_test.go | 330 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 520 insertions(+), 134 deletions(-) create mode 100644 router_test.go diff --git a/README.md b/README.md index b5d490e..c7268d9 100644 --- a/README.md +++ b/README.md @@ -1,115 +1,153 @@ -# Mux +# Mux - A Lightweight HTTP Router for Go -Tiny wrapper around Go's builtin http.ServeMux with easy routing methods. +Mux is a simple, lightweight HTTP router for Go that wraps around the standard `http.ServeMux` to provide additional functionality and a more ergonomic API. -## Example +## Features + +- HTTP method-specific routing (GET, POST, PUT, DELETE, etc.) +- Middleware support with flexible stacking +- Route grouping for organization and shared middleware +- RESTful resource routing +- URL parameter extraction +- Graceful shutdown support +- Minimal dependencies (only uses Go standard library) + +## Installation + +```bash +go get code.patial.tech/go/mux +``` + +## Basic Usage ```go package main import ( - "log/slog" + "fmt" "net/http" - "gitserver.in/patialtech/mux" + "code.patial.tech/go/mux" ) func main() { - // create a new router - r := mux.NewRouter() + // Create a new router + router := mux.NewRouter() - // you can use any middleware that is: "func(http.Handler) http.Handler" - // so you can use any of it - // - https://github.com/gorilla/handlers - // - https://github.com/go-chi/chi/tree/master/middleware - - // add some root level middlewares, these will apply to all routes after it - r.Use(middleware1, middleware2) - - // let's add a route - r.GET("/hello", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("i am route /hello")) - }) - // r.Post(pattern string, h http.HandlerFunc) - // r.Put(pattern string, h http.HandlerFunc) - // ... - - // you can inline middleware(s) to a route - r. - With(mwInline). - GET("/hello-2", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("i am route /hello-2 with my own middleware")) - }) - - // define a resource - r.Resource("/photos", func(resource *mux.Resource) { - // rails style resource routes - // GET /photos - // GET /photos/new - // POST /photos - // GET /photos/:id - // GET /photos/:id/edit - // PUT /photos/:id - // PATCH /photos/:id - // DELETE /photos/:id - resource.Index(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("all photos")) - }) - - resource.New(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("upload a new pohoto")) - }) + // Define a simple route + router.GET("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello, World!") }) - // create a group of few routes with their own middlewares - r.Group(func(grp *mux.Router) { - grp.Use(mwGroup) - grp.GET("/group", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("i am route /group")) - }) - }) - - // catches all - r.GET("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hello there")) - }) - - // Serve allows graceful shutdown, you can use it - r.Serve(func(srv *http.Server) error { - srv.Addr = ":3001" - // srv.ReadTimeout = time.Minute - // srv.WriteTimeout = time.Minute - - slog.Info("listening on http://localhost" + srv.Addr) - return srv.ListenAndServe() - }) -} - -func middleware1(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slog.Info("i am middleware 1") - h.ServeHTTP(w, r) - }) -} - -func middleware2(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slog.Info("i am middleware 2") - h.ServeHTTP(w, r) - }) -} - -func mwInline(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slog.Info("i am inline middleware") - h.ServeHTTP(w, r) - }) -} - -func mwGroup(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slog.Info("i am group middleware") - h.ServeHTTP(w, r) - }) + // Start the server + http.ListenAndServe(":8080", router) } ``` + +## Routing + +Mux supports all HTTP methods defined in the Go standard library: + +```go +router.GET("/users", listUsers) +router.POST("/users", createUser) +router.PUT("/users/{id}", updateUser) +router.DELETE("/users/{id}", deleteUser) +router.PATCH("/users/{id}", partialUpdateUser) +router.HEAD("/users", headUsers) +router.OPTIONS("/users", optionsUsers) +router.TRACE("/users", traceUsers) +router.CONNECT("/users", connectUsers) +``` + +## URL Parameters + +Mux supports URL parameters using curly braces: + +```go +router.GET("/users/{id}", func(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + fmt.Fprintf(w, "User ID: %s", id) +}) +``` + +## Middleware + +Middleware functions take an `http.Handler` and return an `http.Handler`. You can add global middleware to all routes: + +```go +// Logging middleware +func loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("[%s] %s\n", r.Method, r.URL.Path) + next.ServeHTTP(w, r) + }) +} + +// Add middleware to all routes +router.Use(loggingMiddleware) +``` + +## Route Groups + +Group related routes and apply middleware to specific groups: + +```go +// API routes group +router.Group(func(api *mux.Router) { + // Middleware only for API routes + api.Use(authMiddleware) + + // API routes + api.GET("/api/users", listUsers) + api.POST("/api/users", createUser) +}) +``` + +## RESTful Resources + +Easily define RESTful resources: + +```go +router.Resource("/posts", func(r *mux.Resource) { + r.Index(listPosts) // GET /posts + r.Show(showPost) // GET /posts/{id} + r.Create(createPost) // POST /posts + r.Update(updatePost) // PUT /posts/{id} + r.Destroy(deletePost) // DELETE /posts/{id} + r.New(newPostForm) // GET /posts/new +}) +``` + +## Graceful Shutdown + +Use the built-in graceful shutdown functionality: + +```go +router.Serve(func(srv *http.Server) error { + srv.Addr = ":8080" + return srv.ListenAndServe() +}) +``` + +## Custom 404 Handler + +can be tried like this + +```go +router.GET("/", func(writer http.ResponseWriter, request *http.Request) { + if request.URL.Path != "/" { + writer.WriteHeader(404) + writer.Write([]byte(`not found, da xiong dei !!!`)) + return + } +}) +``` + +## Full Example + +See the [examples directory](./example) for complete working examples. + +## License + +This project is licensed under the MIT License - see the [LICENSE](./LICENSE) file for details. diff --git a/example/main.go b/example/main.go index 4cfe9b2..50025c5 100644 --- a/example/main.go +++ b/example/main.go @@ -12,8 +12,22 @@ func main() { // create a new router r := mux.NewRouter() r.Use(middleware.CORS(middleware.CORSOption{ - AllowedOrigins: []string{"*"}, - MaxAge: 60, + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-AccessToken", "X-Real-IP"}, + ExposedHeaders: []string{"Link"}, + AllowCredentials: true, + MaxAge: 300, + })) + + r.Use(middleware.Helmet(middleware.HelmetOption{ + StrictTransportSecurity: &middleware.TransportSecurity{ + MaxAge: 31536000, + IncludeSubDomains: true, + Preload: true, + }, + XssProtection: true, + XFrameOption: middleware.XFrameDeny, })) // you can use any middleware that is: "func(http.Handler) http.Handler" diff --git a/middleware/helmet.go b/middleware/helmet.go index 672d1ec..d8267ba 100644 --- a/middleware/helmet.go +++ b/middleware/helmet.go @@ -1,3 +1,7 @@ +// Author: Ankit Patial +// inspired from Helmet.js +// https://github.com/helmetjs/helmet/tree/main + package middleware import ( @@ -6,9 +10,6 @@ import ( "strings" ) -// inspired from Helmet.js -// https://github.com/helmetjs/helmet/tree/main - type ( HelmetOption struct { ContentSecurityPolicy CSP @@ -101,16 +102,16 @@ type ( const ( YearDuration = 365 * 24 * 60 * 60 - // EmbedderDefault default value will be "require-corp" - EmbedderRequireCorp Embedder = "require-corp" - EmbedderCredentialLess Embedder = "credentialless" - EmbedderUnsafeNone Embedder = "unsafe-none" - - // OpenerDefault default value will be "same-origin" + // OpenerSameOrigin is default if no value supplied OpenerSameOrigin Opener = "same-origin" OpenerSameOriginAllowPopups Opener = "same-origin-allow-popups" OpenerUnsafeNone Opener = "unsafe-none" + // EmbedderDefault is default if no value supplied + EmbedderRequireCorp Embedder = "require-corp" + EmbedderCredentialLess Embedder = "credentialless" + EmbedderUnsafeNone Embedder = "unsafe-none" + // ResourceDefault default value will be "same-origin" ResourceSameOrigin Resource = "same-origin" ResourceSameSite Resource = "same-site" @@ -125,13 +126,13 @@ const ( StrictOriginWhenCrossOrigin Referrer = "strict-origin-when-cross-origin" UnsafeUrl Referrer = "unsafe-url" - // CDPDefault default value is "none" + // CDPNone is default if no value supplied CDPNone CDP = "none" CDPMasterOnly CDP = "master-only" CDPByContentType CDP = "by-content-type" CDPAll CDP = "all" - // XFrameDefault default value will be "sameorigin" + // XFrameSameOrigin is default if no value supplied XFrameSameOrigin XFrame = "sameorigin" XFrameDeny XFrame = "deny" ) @@ -142,21 +143,14 @@ func Helmet(opt HelmetOption) func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Security-Policy", opt.ContentSecurityPolicy.value()) - // Cross-Origin-Embedder-Policy, if nil set default - if opt.CrossOriginEmbedderPolicy == "" { - w.Header().Add("Cross-Origin-Embedder-Policy", string(EmbedderRequireCorp)) - } else { - w.Header().Add("Cross-Origin-Embedder-Policy", string(opt.CrossOriginEmbedderPolicy)) - } - - // Cross-Origin-Opener-Policy, if nil set default + // Opener-Policy if opt.CrossOriginOpenerPolicy == "" { w.Header().Add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin)) } else { w.Header().Add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy)) } - // Cross-Origin-Resource-Policy, if nil set default + // Resource-Policy if opt.CrossOriginResourcePolicy == "" { w.Header().Add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin)) } else { diff --git a/router.go b/router.go index f8b05f4..d9bd8cc 100644 --- a/router.go +++ b/router.go @@ -30,55 +30,65 @@ func (r *Router) Use(h ...func(http.Handler) http.Handler) { // GET method route func (r *Router) GET(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodGet, pattern, h) + r.handle(http.MethodGet, pattern, h) } // HEAD method route func (r *Router) HEAD(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodHead, pattern, h) + r.handle(http.MethodHead, pattern, h) } // POST method route func (r *Router) POST(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodPost, pattern, h) + r.handle(http.MethodPost, pattern, h) } // PUT method route func (r *Router) PUT(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodPut, pattern, h) + r.handle(http.MethodPut, pattern, h) } // PATCH method route func (r *Router) PATCH(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodPatch, pattern, h) + r.handle(http.MethodPatch, pattern, h) } // DELETE method route func (r *Router) DELETE(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodDelete, pattern, h) + r.handle(http.MethodDelete, pattern, h) } // CONNECT method route func (r *Router) CONNECT(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodConnect, pattern, h) -} // OPTIONS method route + r.handle(http.MethodConnect, pattern, h) +} + +// OPTIONS method route func (r *Router) OPTIONS(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodOptions, pattern, h) + r.handle(http.MethodOptions, pattern, h) } // TRACE method route func (r *Router) TRACE(pattern string, h http.HandlerFunc) { - r.handlerFunc(http.MethodTrace, pattern, h) + r.handle(http.MethodTrace, pattern, h) } -// HandleFunc registers the handler function for the given pattern. +// handle registers the handler for the given pattern. // If the given pattern conflicts, with one that is already registered, HandleFunc // panics. -func (r *Router) handlerFunc(method, pattern string, h http.HandlerFunc) { +func (r *Router) handle(method, pattern string, h http.HandlerFunc) { if r == nil { panic("mux: func Handle() was called on nil") } + if strings.TrimSpace(pattern) == "" { + panic("mux: pattern cannot be empty") + } + + if !strings.HasPrefix(pattern, "/") { + pattern = "/" + pattern + } + path := fmt.Sprintf("%s %s", method, pattern) r.mux.Handle(path, stack(r.middlewares, h)) } @@ -101,7 +111,7 @@ func (r *Router) With(middleware ...func(http.Handler) http.Handler) *Router { // path, with a fresh middleware stack for the inline-Router. func (r *Router) Group(fn func(grp *Router)) { if r == nil { - panic("mux: Resource() called on nil") + panic("mux: Group() called on nil") } if fn == nil { @@ -143,7 +153,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.mux.ServeHTTP(w, req) } -// TODO: proxy for aws lambda +// TODO: proxy for aws lambda and other serverless platforms // stack middlewares(http handler) in order they are passed (FIFO) func stack(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { diff --git a/router_test.go b/router_test.go new file mode 100644 index 0000000..95631a8 --- /dev/null +++ b/router_test.go @@ -0,0 +1,330 @@ +package mux + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRouterGET(t *testing.T) { + r := NewRouter() + r.GET("/test", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "GET test") + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/test") + if err != nil { + t.Fatalf("Error making GET request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status OK; got %v", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error reading response body: %v", err) + } + + expected := "GET test" + if string(body) != expected { + t.Errorf("Expected body %q; got %q", expected, string(body)) + } +} + +func TestRouterPOST(t *testing.T) { + r := NewRouter() + r.POST("/test", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "POST test") + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, err := http.Post(ts.URL+"/test", "text/plain", strings.NewReader("test data")) + if err != nil { + t.Fatalf("Error making POST request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status OK; got %v", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error reading response body: %v", err) + } + + expected := "POST test" + if string(body) != expected { + t.Errorf("Expected body %q; got %q", expected, string(body)) + } +} + +func TestRouterWith(t *testing.T) { + r := NewRouter() + + middleware := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "middleware") + h.ServeHTTP(w, r) + }) + } + + r.With(middleware).GET("/test", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "GET with middleware") + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/test") + if err != nil { + t.Fatalf("Error making GET request: %v", err) + } + defer resp.Body.Close() + + if resp.Header.Get("X-Test") != "middleware" { + t.Errorf("Expected header X-Test to be 'middleware'; got %q", resp.Header.Get("X-Test")) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error reading response body: %v", err) + } + + expected := "GET with middleware" + if string(body) != expected { + t.Errorf("Expected body %q; got %q", expected, string(body)) + } +} + +func TestRouterGroup(t *testing.T) { + r := NewRouter() + + var groupCalled bool + + r.Group(func(g *Router) { + groupCalled = true + g.GET("/group", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Group route") + }) + }) + + if !groupCalled { + t.Error("Expected Group callback to be called") + } + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/group") + if err != nil { + t.Fatalf("Error making GET request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status OK; got %v", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error reading response body: %v", err) + } + + expected := "Group route" + if string(body) != expected { + t.Errorf("Expected body %q; got %q", expected, string(body)) + } +} + +func TestRouterResource(t *testing.T) { + r := NewRouter() + + r.Resource("/users", func(resource *Resource) { + resource.Index(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "All users") + }) + + resource.Show(func(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + fmt.Fprintf(w, "User %s", id) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + // Test Index + resp, err := http.Get(ts.URL + "/users") + if err != nil { + t.Fatalf("Error making GET request: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error reading response body: %v", err) + } + + expected := "All users" + if string(body) != expected { + t.Errorf("Expected body %q; got %q", expected, string(body)) + } + + // Test Show + resp, err = http.Get(ts.URL + "/users/123") + if err != nil { + t.Fatalf("Error making GET request: %v", err) + } + defer resp.Body.Close() + + body, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error reading response body: %v", err) + } + + expected = "User 123" + if string(body) != expected { + t.Errorf("Expected body %q; got %q", expected, string(body)) + } +} + +func TestGetParams(t *testing.T) { + r := NewRouter() + + r.GET("/users/{id}/posts/{post_id}", func(w http.ResponseWriter, r *http.Request) { + userId := r.PathValue("id") + postId := r.PathValue("post_id") + fmt.Fprintf(w, "User: %s, Post: %s", userId, postId) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/users/123/posts/456") + if err != nil { + t.Fatalf("Error making GET request: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error reading response body: %v", err) + } + + expected := "User: 123, Post: 456" + if string(body) != expected { + t.Errorf("Expected body %q; got %q", expected, string(body)) + } +} + +func TestStack(t *testing.T) { + var calls []string + + middleware1 := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, "middleware1 before") + h.ServeHTTP(w, r) + calls = append(calls, "middleware1 after") + }) + } + + middleware2 := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, "middleware2 before") + h.ServeHTTP(w, r) + calls = append(calls, "middleware2 after") + }) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, "handler") + }) + + middlewares := []func(http.Handler) http.Handler{middleware1, middleware2} + stacked := stack(middlewares, handler) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + + stacked.ServeHTTP(w, r) + + expected := []string{ + "middleware1 before", + "middleware2 before", + "handler", + "middleware2 after", + "middleware1 after", + } + + if len(calls) != len(expected) { + t.Errorf("Expected %d calls; got %d", len(expected), len(calls)) + } + + for i, call := range calls { + if i < len(expected) && call != expected[i] { + t.Errorf("Expected call %d to be %q; got %q", i, expected[i], call) + } + } +} + +func TestRouterPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic but none occurred") + } + }() + + var r *Router + r.GET("/", func(w http.ResponseWriter, r *http.Request) {}) +} + +func BenchmarkRouterSimple(b *testing.B) { + r := NewRouter() + + r.GET("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello") + }) + + req, _ := http.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.ServeHTTP(w, req) + } +} + +func BenchmarkRouterWithMiddleware(b *testing.B) { + r := NewRouter() + + middleware := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) + } + + r.Use(middleware, middleware) + + r.GET("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello") + }) + + req, _ := http.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.ServeHTTP(w, req) + } +}