middleware helmet changes.
router check and panic message change. README enhancement
This commit is contained in:
		
							
								
								
									
										234
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										234
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,115 +1,153 @@ | ||||
| # Mux | ||||
| # Mux - A Lightweight HTTP Router for Go | ||||
|  | ||||
| Tiny wrapper around Go's builtin http.ServeMux with easy routing methods. | ||||
| Mux is a simple, lightweight HTTP router for Go that wraps around the standard `http.ServeMux` to provide additional functionality and a more ergonomic API. | ||||
|  | ||||
| ## Example | ||||
| ## Features | ||||
|  | ||||
| - HTTP method-specific routing (GET, POST, PUT, DELETE, etc.) | ||||
| - Middleware support with flexible stacking | ||||
| - Route grouping for organization and shared middleware | ||||
| - RESTful resource routing | ||||
| - URL parameter extraction | ||||
| - Graceful shutdown support | ||||
| - Minimal dependencies (only uses Go standard library) | ||||
|  | ||||
| ## Installation | ||||
|  | ||||
| ```bash | ||||
| go get code.patial.tech/go/mux | ||||
| ``` | ||||
|  | ||||
| ## Basic Usage | ||||
|  | ||||
| ```go | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"log/slog" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"gitserver.in/patialtech/mux" | ||||
| 	"code.patial.tech/go/mux" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	// create a new router | ||||
| 	r := mux.NewRouter() | ||||
| 	// Create a new router | ||||
| 	router := mux.NewRouter() | ||||
|  | ||||
| 	// you can use any middleware that is: "func(http.Handler) http.Handler" | ||||
| 	// so you can use any of it | ||||
| 	// - https://github.com/gorilla/handlers | ||||
| 	// - https://github.com/go-chi/chi/tree/master/middleware | ||||
|  | ||||
| 	// add some root level middlewares, these will apply to all routes after it | ||||
| 	r.Use(middleware1, middleware2) | ||||
|  | ||||
| 	// let's add a route | ||||
| 	r.GET("/hello", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.Write([]byte("i am route /hello")) | ||||
| 	}) | ||||
| 	// r.Post(pattern string, h http.HandlerFunc) | ||||
| 	// r.Put(pattern string, h http.HandlerFunc) | ||||
| 	// ... | ||||
|  | ||||
| 	// you can inline middleware(s) to a route | ||||
| 	r. | ||||
| 		With(mwInline). | ||||
| 		GET("/hello-2", func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.Write([]byte("i am route /hello-2 with my own middleware")) | ||||
| 		}) | ||||
|  | ||||
| 	// define a resource | ||||
| 	r.Resource("/photos", func(resource *mux.Resource) { | ||||
| 		// rails style resource routes | ||||
| 		// GET       /photos | ||||
| 		// GET       /photos/new | ||||
| 		// POST      /photos | ||||
| 		// GET       /photos/:id | ||||
| 		// GET       /photos/:id/edit | ||||
| 		// PUT 		 /photos/:id | ||||
| 		// PATCH 	 /photos/:id | ||||
| 		// DELETE    /photos/:id | ||||
| 		resource.Index(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.Write([]byte("all photos")) | ||||
| 		}) | ||||
|  | ||||
| 		resource.New(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.Write([]byte("upload a new pohoto")) | ||||
| 		}) | ||||
| 	// Define a simple route | ||||
| 	router.GET("/", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		fmt.Fprint(w, "Hello, World!") | ||||
| 	}) | ||||
|  | ||||
| 	// create a group of few routes with their own middlewares | ||||
| 	r.Group(func(grp *mux.Router) { | ||||
| 		grp.Use(mwGroup) | ||||
| 		grp.GET("/group", func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.Write([]byte("i am route /group")) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	// catches all | ||||
| 	r.GET("/", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.Write([]byte("hello there")) | ||||
| 	}) | ||||
|  | ||||
| 	// Serve allows graceful shutdown, you can use it | ||||
| 	r.Serve(func(srv *http.Server) error { | ||||
| 		srv.Addr = ":3001" | ||||
| 		// srv.ReadTimeout = time.Minute | ||||
| 		// srv.WriteTimeout = time.Minute | ||||
|  | ||||
| 		slog.Info("listening on http://localhost" + srv.Addr) | ||||
| 		return srv.ListenAndServe() | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func middleware1(h http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		slog.Info("i am middleware 1") | ||||
| 		h.ServeHTTP(w, r) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func middleware2(h http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		slog.Info("i am middleware 2") | ||||
| 		h.ServeHTTP(w, r) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func mwInline(h http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		slog.Info("i am inline middleware") | ||||
| 		h.ServeHTTP(w, r) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func mwGroup(h http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		slog.Info("i am group middleware") | ||||
| 		h.ServeHTTP(w, r) | ||||
| 	}) | ||||
| 	// Start the server | ||||
| 	http.ListenAndServe(":8080", router) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## Routing | ||||
|  | ||||
| Mux supports all HTTP methods defined in the Go standard library: | ||||
|  | ||||
| ```go | ||||
| router.GET("/users", listUsers) | ||||
| router.POST("/users", createUser) | ||||
| router.PUT("/users/{id}", updateUser) | ||||
| router.DELETE("/users/{id}", deleteUser) | ||||
| router.PATCH("/users/{id}", partialUpdateUser) | ||||
| router.HEAD("/users", headUsers) | ||||
| router.OPTIONS("/users", optionsUsers) | ||||
| router.TRACE("/users", traceUsers) | ||||
| router.CONNECT("/users", connectUsers) | ||||
| ``` | ||||
|  | ||||
| ## URL Parameters | ||||
|  | ||||
| Mux supports URL parameters using curly braces: | ||||
|  | ||||
| ```go | ||||
| router.GET("/users/{id}", func(w http.ResponseWriter, r *http.Request) { | ||||
| 	id := r.PathValue("id") | ||||
| 	fmt.Fprintf(w, "User ID: %s", id) | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| ## Middleware | ||||
|  | ||||
| Middleware functions take an `http.Handler` and return an `http.Handler`. You can add global middleware to all routes: | ||||
|  | ||||
| ```go | ||||
| // Logging middleware | ||||
| func loggingMiddleware(next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		fmt.Printf("[%s] %s\n", r.Method, r.URL.Path) | ||||
| 		next.ServeHTTP(w, r) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // Add middleware to all routes | ||||
| router.Use(loggingMiddleware) | ||||
| ``` | ||||
|  | ||||
| ## Route Groups | ||||
|  | ||||
| Group related routes and apply middleware to specific groups: | ||||
|  | ||||
| ```go | ||||
| // API routes group | ||||
| router.Group(func(api *mux.Router) { | ||||
| 	// Middleware only for API routes | ||||
| 	api.Use(authMiddleware) | ||||
|  | ||||
| 	// API routes | ||||
| 	api.GET("/api/users", listUsers) | ||||
| 	api.POST("/api/users", createUser) | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| ## RESTful Resources | ||||
|  | ||||
| Easily define RESTful resources: | ||||
|  | ||||
| ```go | ||||
| router.Resource("/posts", func(r *mux.Resource) { | ||||
| 	r.Index(listPosts)    // GET /posts | ||||
| 	r.Show(showPost)      // GET /posts/{id} | ||||
| 	r.Create(createPost)  // POST /posts | ||||
| 	r.Update(updatePost)  // PUT /posts/{id} | ||||
| 	r.Destroy(deletePost) // DELETE /posts/{id} | ||||
| 	r.New(newPostForm)    // GET /posts/new | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| ## Graceful Shutdown | ||||
|  | ||||
| Use the built-in graceful shutdown functionality: | ||||
|  | ||||
| ```go | ||||
| router.Serve(func(srv *http.Server) error { | ||||
| 	srv.Addr = ":8080" | ||||
| 	return srv.ListenAndServe() | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| ## Custom 404 Handler | ||||
|  | ||||
| can be tried like this | ||||
|  | ||||
| ```go | ||||
| router.GET("/", func(writer http.ResponseWriter, request *http.Request) { | ||||
|     if request.URL.Path != "/" { | ||||
|         writer.WriteHeader(404) | ||||
|         writer.Write([]byte(`not found, da xiong dei !!!`)) | ||||
|         return | ||||
|     } | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| ## Full Example | ||||
|  | ||||
| See the [examples directory](./example) for complete working examples. | ||||
|  | ||||
| ## License | ||||
|  | ||||
| This project is licensed under the MIT License - see the [LICENSE](./LICENSE) file for details. | ||||
|   | ||||
| @@ -12,8 +12,22 @@ func main() { | ||||
| 	// create a new router | ||||
| 	r := mux.NewRouter() | ||||
| 	r.Use(middleware.CORS(middleware.CORSOption{ | ||||
| 		AllowedOrigins: []string{"*"}, | ||||
| 		MaxAge:         60, | ||||
| 		AllowedOrigins:   []string{"*"}, | ||||
| 		AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, | ||||
| 		AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-CSRF-AccessToken", "X-Real-IP"}, | ||||
| 		ExposedHeaders:   []string{"Link"}, | ||||
| 		AllowCredentials: true, | ||||
| 		MaxAge:           300, | ||||
| 	})) | ||||
|  | ||||
| 	r.Use(middleware.Helmet(middleware.HelmetOption{ | ||||
| 		StrictTransportSecurity: &middleware.TransportSecurity{ | ||||
| 			MaxAge:            31536000, | ||||
| 			IncludeSubDomains: true, | ||||
| 			Preload:           true, | ||||
| 		}, | ||||
| 		XssProtection: true, | ||||
| 		XFrameOption:  middleware.XFrameDeny, | ||||
| 	})) | ||||
|  | ||||
| 	// you can use any middleware that is: "func(http.Handler) http.Handler" | ||||
|   | ||||
| @@ -1,3 +1,7 @@ | ||||
| // Author: Ankit Patial | ||||
| // inspired from Helmet.js | ||||
| // https://github.com/helmetjs/helmet/tree/main | ||||
|  | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| @@ -6,9 +10,6 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // inspired from Helmet.js | ||||
| // https://github.com/helmetjs/helmet/tree/main | ||||
|  | ||||
| type ( | ||||
| 	HelmetOption struct { | ||||
| 		ContentSecurityPolicy CSP | ||||
| @@ -101,16 +102,16 @@ type ( | ||||
| const ( | ||||
| 	YearDuration = 365 * 24 * 60 * 60 | ||||
|  | ||||
| 	// EmbedderDefault default value will be "require-corp" | ||||
| 	EmbedderRequireCorp    Embedder = "require-corp" | ||||
| 	EmbedderCredentialLess Embedder = "credentialless" | ||||
| 	EmbedderUnsafeNone     Embedder = "unsafe-none" | ||||
|  | ||||
| 	// OpenerDefault default value will be "same-origin" | ||||
| 	// OpenerSameOrigin is default if no value supplied | ||||
| 	OpenerSameOrigin            Opener = "same-origin" | ||||
| 	OpenerSameOriginAllowPopups Opener = "same-origin-allow-popups" | ||||
| 	OpenerUnsafeNone            Opener = "unsafe-none" | ||||
|  | ||||
| 	// EmbedderDefault is default if no value supplied | ||||
| 	EmbedderRequireCorp    Embedder = "require-corp" | ||||
| 	EmbedderCredentialLess Embedder = "credentialless" | ||||
| 	EmbedderUnsafeNone     Embedder = "unsafe-none" | ||||
|  | ||||
| 	// ResourceDefault default value will be "same-origin" | ||||
| 	ResourceSameOrigin  Resource = "same-origin" | ||||
| 	ResourceSameSite    Resource = "same-site" | ||||
| @@ -125,13 +126,13 @@ const ( | ||||
| 	StrictOriginWhenCrossOrigin Referrer = "strict-origin-when-cross-origin" | ||||
| 	UnsafeUrl                   Referrer = "unsafe-url" | ||||
|  | ||||
| 	// CDPDefault default value is  "none" | ||||
| 	// CDPNone is default if no value supplied | ||||
| 	CDPNone          CDP = "none" | ||||
| 	CDPMasterOnly    CDP = "master-only" | ||||
| 	CDPByContentType CDP = "by-content-type" | ||||
| 	CDPAll           CDP = "all" | ||||
|  | ||||
| 	// XFrameDefault default value will be "sameorigin" | ||||
| 	// XFrameSameOrigin is default if no value supplied | ||||
| 	XFrameSameOrigin XFrame = "sameorigin" | ||||
| 	XFrameDeny       XFrame = "deny" | ||||
| ) | ||||
| @@ -142,21 +143,14 @@ func Helmet(opt HelmetOption) func(http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.Header().Add("Content-Security-Policy", opt.ContentSecurityPolicy.value()) | ||||
|  | ||||
| 			// Cross-Origin-Embedder-Policy, if nil set default | ||||
| 			if opt.CrossOriginEmbedderPolicy == "" { | ||||
| 				w.Header().Add("Cross-Origin-Embedder-Policy", string(EmbedderRequireCorp)) | ||||
| 			} else { | ||||
| 				w.Header().Add("Cross-Origin-Embedder-Policy", string(opt.CrossOriginEmbedderPolicy)) | ||||
| 			} | ||||
|  | ||||
| 			// Cross-Origin-Opener-Policy, if nil set default | ||||
| 			// Opener-Policy | ||||
| 			if opt.CrossOriginOpenerPolicy == "" { | ||||
| 				w.Header().Add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin)) | ||||
| 			} else { | ||||
| 				w.Header().Add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy)) | ||||
| 			} | ||||
|  | ||||
| 			// Cross-Origin-Resource-Policy, if nil set default | ||||
| 			// Resource-Policy | ||||
| 			if opt.CrossOriginResourcePolicy == "" { | ||||
| 				w.Header().Add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin)) | ||||
| 			} else { | ||||
|   | ||||
							
								
								
									
										38
									
								
								router.go
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								router.go
									
									
									
									
									
								
							| @@ -30,55 +30,65 @@ func (r *Router) Use(h ...func(http.Handler) http.Handler) { | ||||
|  | ||||
| // GET method route | ||||
| func (r *Router) GET(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodGet, pattern, h) | ||||
| 	r.handle(http.MethodGet, pattern, h) | ||||
| } | ||||
|  | ||||
| // HEAD method route | ||||
| func (r *Router) HEAD(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodHead, pattern, h) | ||||
| 	r.handle(http.MethodHead, pattern, h) | ||||
| } | ||||
|  | ||||
| // POST method route | ||||
| func (r *Router) POST(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodPost, pattern, h) | ||||
| 	r.handle(http.MethodPost, pattern, h) | ||||
| } | ||||
|  | ||||
| // PUT method route | ||||
| func (r *Router) PUT(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodPut, pattern, h) | ||||
| 	r.handle(http.MethodPut, pattern, h) | ||||
| } | ||||
|  | ||||
| // PATCH method route | ||||
| func (r *Router) PATCH(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodPatch, pattern, h) | ||||
| 	r.handle(http.MethodPatch, pattern, h) | ||||
| } | ||||
|  | ||||
| // DELETE method route | ||||
| func (r *Router) DELETE(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodDelete, pattern, h) | ||||
| 	r.handle(http.MethodDelete, pattern, h) | ||||
| } | ||||
|  | ||||
| // CONNECT method route | ||||
| func (r *Router) CONNECT(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodConnect, pattern, h) | ||||
| } // OPTIONS method route | ||||
| 	r.handle(http.MethodConnect, pattern, h) | ||||
| } | ||||
|  | ||||
| // OPTIONS method route | ||||
| func (r *Router) OPTIONS(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodOptions, pattern, h) | ||||
| 	r.handle(http.MethodOptions, pattern, h) | ||||
| } | ||||
|  | ||||
| // TRACE method route | ||||
| func (r *Router) TRACE(pattern string, h http.HandlerFunc) { | ||||
| 	r.handlerFunc(http.MethodTrace, pattern, h) | ||||
| 	r.handle(http.MethodTrace, pattern, h) | ||||
| } | ||||
|  | ||||
| // HandleFunc registers the handler function for the given pattern. | ||||
| // handle registers the handler for the given pattern. | ||||
| // If the given pattern conflicts, with one that is already registered, HandleFunc | ||||
| // panics. | ||||
| func (r *Router) handlerFunc(method, pattern string, h http.HandlerFunc) { | ||||
| func (r *Router) handle(method, pattern string, h http.HandlerFunc) { | ||||
| 	if r == nil { | ||||
| 		panic("mux: func Handle() was called on nil") | ||||
| 	} | ||||
|  | ||||
| 	if strings.TrimSpace(pattern) == "" { | ||||
| 		panic("mux: pattern cannot be empty") | ||||
| 	} | ||||
|  | ||||
| 	if !strings.HasPrefix(pattern, "/") { | ||||
| 		pattern = "/" + pattern | ||||
| 	} | ||||
|  | ||||
| 	path := fmt.Sprintf("%s %s", method, pattern) | ||||
| 	r.mux.Handle(path, stack(r.middlewares, h)) | ||||
| } | ||||
| @@ -101,7 +111,7 @@ func (r *Router) With(middleware ...func(http.Handler) http.Handler) *Router { | ||||
| // path, with a fresh middleware stack for the inline-Router. | ||||
| func (r *Router) Group(fn func(grp *Router)) { | ||||
| 	if r == nil { | ||||
| 		panic("mux: Resource() called on nil") | ||||
| 		panic("mux: Group() called on nil") | ||||
| 	} | ||||
|  | ||||
| 	if fn == nil { | ||||
| @@ -143,7 +153,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { | ||||
| 	r.mux.ServeHTTP(w, req) | ||||
| } | ||||
|  | ||||
| // TODO: proxy for aws lambda | ||||
| // TODO: proxy for aws lambda and other serverless platforms | ||||
|  | ||||
| // stack middlewares(http handler) in order they are passed (FIFO) | ||||
| func stack(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { | ||||
|   | ||||
							
								
								
									
										330
									
								
								router_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										330
									
								
								router_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,330 @@ | ||||
| package mux | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestRouterGET(t *testing.T) { | ||||
| 	r := NewRouter() | ||||
| 	r.GET("/test", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		fmt.Fprint(w, "GET test") | ||||
| 	}) | ||||
|  | ||||
| 	ts := httptest.NewServer(r) | ||||
| 	defer ts.Close() | ||||
|  | ||||
| 	resp, err := http.Get(ts.URL + "/test") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error making GET request: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		t.Errorf("Expected status OK; got %v", resp.Status) | ||||
| 	} | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error reading response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	expected := "GET test" | ||||
| 	if string(body) != expected { | ||||
| 		t.Errorf("Expected body %q; got %q", expected, string(body)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRouterPOST(t *testing.T) { | ||||
| 	r := NewRouter() | ||||
| 	r.POST("/test", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		fmt.Fprint(w, "POST test") | ||||
| 	}) | ||||
|  | ||||
| 	ts := httptest.NewServer(r) | ||||
| 	defer ts.Close() | ||||
|  | ||||
| 	resp, err := http.Post(ts.URL+"/test", "text/plain", strings.NewReader("test data")) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error making POST request: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		t.Errorf("Expected status OK; got %v", resp.Status) | ||||
| 	} | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error reading response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	expected := "POST test" | ||||
| 	if string(body) != expected { | ||||
| 		t.Errorf("Expected body %q; got %q", expected, string(body)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRouterWith(t *testing.T) { | ||||
| 	r := NewRouter() | ||||
|  | ||||
| 	middleware := func(h http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.Header().Set("X-Test", "middleware") | ||||
| 			h.ServeHTTP(w, r) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	r.With(middleware).GET("/test", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		fmt.Fprint(w, "GET with middleware") | ||||
| 	}) | ||||
|  | ||||
| 	ts := httptest.NewServer(r) | ||||
| 	defer ts.Close() | ||||
|  | ||||
| 	resp, err := http.Get(ts.URL + "/test") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error making GET request: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.Header.Get("X-Test") != "middleware" { | ||||
| 		t.Errorf("Expected header X-Test to be 'middleware'; got %q", resp.Header.Get("X-Test")) | ||||
| 	} | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error reading response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	expected := "GET with middleware" | ||||
| 	if string(body) != expected { | ||||
| 		t.Errorf("Expected body %q; got %q", expected, string(body)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRouterGroup(t *testing.T) { | ||||
| 	r := NewRouter() | ||||
|  | ||||
| 	var groupCalled bool | ||||
|  | ||||
| 	r.Group(func(g *Router) { | ||||
| 		groupCalled = true | ||||
| 		g.GET("/group", func(w http.ResponseWriter, r *http.Request) { | ||||
| 			fmt.Fprint(w, "Group route") | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	if !groupCalled { | ||||
| 		t.Error("Expected Group callback to be called") | ||||
| 	} | ||||
|  | ||||
| 	ts := httptest.NewServer(r) | ||||
| 	defer ts.Close() | ||||
|  | ||||
| 	resp, err := http.Get(ts.URL + "/group") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error making GET request: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		t.Errorf("Expected status OK; got %v", resp.Status) | ||||
| 	} | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error reading response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	expected := "Group route" | ||||
| 	if string(body) != expected { | ||||
| 		t.Errorf("Expected body %q; got %q", expected, string(body)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRouterResource(t *testing.T) { | ||||
| 	r := NewRouter() | ||||
|  | ||||
| 	r.Resource("/users", func(resource *Resource) { | ||||
| 		resource.Index(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			fmt.Fprint(w, "All users") | ||||
| 		}) | ||||
|  | ||||
| 		resource.Show(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			id := r.PathValue("id") | ||||
| 			fmt.Fprintf(w, "User %s", id) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	ts := httptest.NewServer(r) | ||||
| 	defer ts.Close() | ||||
|  | ||||
| 	// Test Index | ||||
| 	resp, err := http.Get(ts.URL + "/users") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error making GET request: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error reading response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	expected := "All users" | ||||
| 	if string(body) != expected { | ||||
| 		t.Errorf("Expected body %q; got %q", expected, string(body)) | ||||
| 	} | ||||
|  | ||||
| 	// Test Show | ||||
| 	resp, err = http.Get(ts.URL + "/users/123") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error making GET request: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	body, err = io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error reading response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	expected = "User 123" | ||||
| 	if string(body) != expected { | ||||
| 		t.Errorf("Expected body %q; got %q", expected, string(body)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestGetParams(t *testing.T) { | ||||
| 	r := NewRouter() | ||||
|  | ||||
| 	r.GET("/users/{id}/posts/{post_id}", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		userId := r.PathValue("id") | ||||
| 		postId := r.PathValue("post_id") | ||||
| 		fmt.Fprintf(w, "User: %s, Post: %s", userId, postId) | ||||
| 	}) | ||||
|  | ||||
| 	ts := httptest.NewServer(r) | ||||
| 	defer ts.Close() | ||||
|  | ||||
| 	resp, err := http.Get(ts.URL + "/users/123/posts/456") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error making GET request: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error reading response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	expected := "User: 123, Post: 456" | ||||
| 	if string(body) != expected { | ||||
| 		t.Errorf("Expected body %q; got %q", expected, string(body)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestStack(t *testing.T) { | ||||
| 	var calls []string | ||||
|  | ||||
| 	middleware1 := func(h http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			calls = append(calls, "middleware1 before") | ||||
| 			h.ServeHTTP(w, r) | ||||
| 			calls = append(calls, "middleware1 after") | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	middleware2 := func(h http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			calls = append(calls, "middleware2 before") | ||||
| 			h.ServeHTTP(w, r) | ||||
| 			calls = append(calls, "middleware2 after") | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		calls = append(calls, "handler") | ||||
| 	}) | ||||
|  | ||||
| 	middlewares := []func(http.Handler) http.Handler{middleware1, middleware2} | ||||
| 	stacked := stack(middlewares, handler) | ||||
|  | ||||
| 	w := httptest.NewRecorder() | ||||
| 	r := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
|  | ||||
| 	stacked.ServeHTTP(w, r) | ||||
|  | ||||
| 	expected := []string{ | ||||
| 		"middleware1 before", | ||||
| 		"middleware2 before", | ||||
| 		"handler", | ||||
| 		"middleware2 after", | ||||
| 		"middleware1 after", | ||||
| 	} | ||||
|  | ||||
| 	if len(calls) != len(expected) { | ||||
| 		t.Errorf("Expected %d calls; got %d", len(expected), len(calls)) | ||||
| 	} | ||||
|  | ||||
| 	for i, call := range calls { | ||||
| 		if i < len(expected) && call != expected[i] { | ||||
| 			t.Errorf("Expected call %d to be %q; got %q", i, expected[i], call) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRouterPanic(t *testing.T) { | ||||
| 	defer func() { | ||||
| 		if r := recover(); r == nil { | ||||
| 			t.Error("Expected panic but none occurred") | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	var r *Router | ||||
| 	r.GET("/", func(w http.ResponseWriter, r *http.Request) {}) | ||||
| } | ||||
|  | ||||
| func BenchmarkRouterSimple(b *testing.B) { | ||||
| 	r := NewRouter() | ||||
|  | ||||
| 	r.GET("/", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		fmt.Fprint(w, "Hello") | ||||
| 	}) | ||||
|  | ||||
| 	req, _ := http.NewRequest(http.MethodGet, "/", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		r.ServeHTTP(w, req) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkRouterWithMiddleware(b *testing.B) { | ||||
| 	r := NewRouter() | ||||
|  | ||||
| 	middleware := func(h http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			h.ServeHTTP(w, r) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	r.Use(middleware, middleware) | ||||
|  | ||||
| 	r.GET("/", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		fmt.Fprint(w, "Hello") | ||||
| 	}) | ||||
|  | ||||
| 	req, _ := http.NewRequest(http.MethodGet, "/", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		r.ServeHTTP(w, req) | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user