package mux import ( "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" ) func TestRouterGET(t *testing.T) { r := NewRouter() r.GET("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "GET test") }) ts := httptest.NewServer(r) defer ts.Close() resp, err := http.Get(ts.URL + "/test") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK; got %v", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "GET test" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterPOST(t *testing.T) { r := NewRouter() r.POST("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "POST test") }) ts := httptest.NewServer(r) defer ts.Close() resp, err := http.Post(ts.URL+"/test", "text/plain", strings.NewReader("test data")) if err != nil { t.Fatalf("Error making POST request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK; got %v", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "POST test" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterWith(t *testing.T) { r := NewRouter() middleware := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Test", "middleware") h.ServeHTTP(w, r) }) } r.With(middleware).GET("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "GET with middleware") }) ts := httptest.NewServer(r) defer ts.Close() resp, err := http.Get(ts.URL + "/test") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() if resp.Header.Get("X-Test") != "middleware" { t.Errorf("Expected header X-Test to be 'middleware'; got %q", resp.Header.Get("X-Test")) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "GET with middleware" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterGroup(t *testing.T) { r := NewRouter() var groupCalled bool r.Group(func(g *Router) { groupCalled = true g.GET("/group", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Group route") }) }) if !groupCalled { t.Error("Expected Group callback to be called") } ts := httptest.NewServer(r) defer ts.Close() resp, err := http.Get(ts.URL + "/group") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK; got %v", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "Group route" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterResource(t *testing.T) { r := NewRouter() r.Resource("/users", func(resource *Resource) { resource.Index(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "All users") }) resource.Show(func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") fmt.Fprintf(w, "User %s", id) }) }) ts := httptest.NewServer(r) defer ts.Close() // Test Index resp, err := http.Get(ts.URL + "/users") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "All users" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } // Test Show resp, err = http.Get(ts.URL + "/users/123") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected = "User 123" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestGetParams(t *testing.T) { r := NewRouter() r.GET("/users/{id}/posts/{post_id}", func(w http.ResponseWriter, r *http.Request) { userId := r.PathValue("id") postId := r.PathValue("post_id") fmt.Fprintf(w, "User: %s, Post: %s", userId, postId) }) ts := httptest.NewServer(r) defer ts.Close() resp, err := http.Get(ts.URL + "/users/123/posts/456") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "User: 123, Post: 456" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestStack(t *testing.T) { var calls []string middleware1 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls = append(calls, "middleware1 before") h.ServeHTTP(w, r) calls = append(calls, "middleware1 after") }) } middleware2 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls = append(calls, "middleware2 before") h.ServeHTTP(w, r) calls = append(calls, "middleware2 after") }) } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls = append(calls, "handler") }) middlewares := []func(http.Handler) http.Handler{middleware1, middleware2} stacked := stack(middlewares, handler) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) stacked.ServeHTTP(w, r) expected := []string{ "middleware1 before", "middleware2 before", "handler", "middleware2 after", "middleware1 after", } if len(calls) != len(expected) { t.Errorf("Expected %d calls; got %d", len(expected), len(calls)) } for i, call := range calls { if i < len(expected) && call != expected[i] { t.Errorf("Expected call %d to be %q; got %q", i, expected[i], call) } } } func TestRouterPanic(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected panic but none occurred") } }() var r *Router r.GET("/", func(w http.ResponseWriter, r *http.Request) {}) } func BenchmarkRouterSimple(b *testing.B) { r := NewRouter() r.GET("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello") }) req, _ := http.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() b.ResetTimer() for i := 0; i < b.N; i++ { r.ServeHTTP(w, req) } } func BenchmarkRouterWithMiddleware(b *testing.B) { r := NewRouter() middleware := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h.ServeHTTP(w, r) }) } r.Use(middleware, middleware) r.GET("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello") }) req, _ := http.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() b.ResetTimer() for i := 0; i < b.N; i++ { r.ServeHTTP(w, req) } }