package mux import ( "fmt" "io" "math/rand" "net/http" "net/http/httptest" "strconv" "strings" "testing" "time" ) func TestRouterGET(t *testing.T) { m := New() m.GET("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "GET test") }) ts := httptest.NewServer(m) defer ts.Close() resp, err := http.Get(ts.URL + "/test") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK; got %v", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "GET test" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterPOST(t *testing.T) { m := New() m.POST("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "POST test") }) ts := httptest.NewServer(m) defer ts.Close() resp, err := http.Post(ts.URL+"/test", "text/plain", strings.NewReader("test data")) if err != nil { t.Fatalf("Error making POST request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK; got %v", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "POST test" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterWith(t *testing.T) { m := New() middleware := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Test", "middleware") h.ServeHTTP(w, r) }) } m.With(middleware).GET("/test", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "GET with middleware") }) ts := httptest.NewServer(m) defer ts.Close() resp, err := http.Get(ts.URL + "/test") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() if resp.Header.Get("X-Test") != "middleware" { t.Errorf("Expected header X-Test to be 'middleware'; got %q", resp.Header.Get("X-Test")) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "GET with middleware" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterGroup(t *testing.T) { r := New() var groupCalled bool r.Group(func(g *Mux) { groupCalled = true g.GET("/group", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Group route") }) }) if !groupCalled { t.Error("Expected Group callback to be called") } ts := httptest.NewServer(r) defer ts.Close() resp, err := http.Get(ts.URL + "/group") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK; got %v", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "Group route" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestRouterResource(t *testing.T) { r := New() r.Resource("/users", func(res *Resource) { res.Index(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "All users") }) res.View(func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") fmt.Fprintf(w, "User %s", id) }) }) ts := httptest.NewServer(r) defer ts.Close() // Test Index resp, err := http.Get(ts.URL + "/users") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "All users" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } // Test Show resp, err = http.Get(ts.URL + "/users/123") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected = "User 123" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestGetParams(t *testing.T) { r := New() r.GET("/users/{id}/posts/{post_id}", func(w http.ResponseWriter, r *http.Request) { userId := r.PathValue("id") postId := r.PathValue("post_id") fmt.Fprintf(w, "User: %s, Post: %s", userId, postId) }) ts := httptest.NewServer(r) defer ts.Close() resp, err := http.Get(ts.URL + "/users/123/posts/456") if err != nil { t.Fatalf("Error making GET request: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Error reading response body: %v", err) } expected := "User: 123, Post: 456" if string(body) != expected { t.Errorf("Expected body %q; got %q", expected, string(body)) } } func TestStack(t *testing.T) { var calls []string middleware1 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls = append(calls, "middleware1 before") h.ServeHTTP(w, r) calls = append(calls, "middleware1 after") }) } middleware2 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls = append(calls, "middleware2 before") h.ServeHTTP(w, r) calls = append(calls, "middleware2 after") }) } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls = append(calls, "handler") }) middlewares := []func(http.Handler) http.Handler{middleware1, middleware2} stacked := stack(middlewares, handler) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) stacked.ServeHTTP(w, r) expected := []string{ "middleware1 before", "middleware2 before", "handler", "middleware2 after", "middleware1 after", } if len(calls) != len(expected) { t.Errorf("Expected %d calls; got %d", len(expected), len(calls)) } for i, call := range calls { if i < len(expected) && call != expected[i] { t.Errorf("Expected call %d to be %q; got %q", i, expected[i], call) } } } func TestRouterPanic(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected panic but none occurred") } }() var r *Mux r.GET("/", func(w http.ResponseWriter, r *http.Request) {}) } // BenchmarkRouterSimple-12 1125854 1058 ns/op 1568 B/op 17 allocs/op func BenchmarkRouterSimple(b *testing.B) { m := New() for i := range 10000 { m.GET("/"+strconv.Itoa(i), func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello from "+strconv.Itoa(i)) }) } source := rand.NewSource(time.Now().UnixNano()) r := rand.New(source) // Generate a random integer between 0 and 99 (inclusive) rn := r.Intn(10000) for b.Loop() { req, _ := http.NewRequest(http.MethodGet, "/"+strconv.Itoa(rn), nil) w := httptest.NewRecorder() m.ServeHTTP(w, req) } } // BenchmarkRouterWithMiddleware-12 14761327 68.70 ns/op 18 B/op 0 allocs/op func BenchmarkRouterWithMiddleware(b *testing.B) { m := New() middleware := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h.ServeHTTP(w, r) }) } m.Use(middleware, middleware) m.GET("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello") }) req, _ := http.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() for b.Loop() { m.ServeHTTP(w, req) } }