diff --git a/.gitignore b/.gitignore index 445d045..1c14916 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.claude + # Profiling files .prof diff --git a/middleware/cors.go b/middleware/cors.go index 6e194cc..a615302 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -109,7 +109,6 @@ func CORS(opts CORSOption) func(http.Handler) http.Handler { ch.setAllowedMethods(opts.AllowedMethods) ch.setExposedHeaders(opts.ExposedHeaders) ch.setMaxAge(opts.MaxAge) - ch.maxAge = opts.MaxAge ch.allowCredentials = opts.AllowCredentials return ch diff --git a/middleware/helmet.go b/middleware/helmet.go index ce3ab6e..de120f7 100644 --- a/middleware/helmet.go +++ b/middleware/helmet.go @@ -56,7 +56,7 @@ type ( } TransportSecurity struct { - // Age in seconts + // Age in seconds MaxAge uint IncludeSubDomains bool Preload bool @@ -111,110 +111,126 @@ const ( // Helmet headers to secure server response func Helmet(opt HelmetOption) func(http.Handler) http.Handler { + // Precompute all static header values once at middleware creation time. + headers := buildHelmetHeaders(opt) + return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Security-Policy", opt.ContentSecurityPolicy.value()) - - // Opener-Policy - if opt.CrossOriginOpenerPolicy == "" { - w.Header().Add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin)) - } else { - w.Header().Add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy)) + for _, kv := range headers { + w.Header().Add(kv.key, kv.value) } - - // Resource-Policy - if opt.CrossOriginResourcePolicy == "" { - w.Header().Add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin)) - } else { - w.Header().Add("Cross-Origin-Resource-Policy", string(opt.CrossOriginResourcePolicy)) - } - - // Referrer-Policy - rpCount := len(opt.ReferrerPolicy) - if rpCount > 0 { - refP := make([]string, rpCount) - for i, r := range opt.ReferrerPolicy { - refP[i] = string(r) - } - w.Header().Add("Referrer-Policy", string(NoReferrer)) - } else { - // default no referer - w.Header().Add("Referrer-Policy", string(NoReferrer)) - } - - // Origin-Agent-Cluster - if opt.OriginAgentCluster { - w.Header().Add("Origin-Agent-Cluster", "?1") - } - - // Strict-Transport-Security - if opt.StrictTransportSecurity != nil { - var sb strings.Builder - - if opt.StrictTransportSecurity.MaxAge == 0 { - opt.StrictTransportSecurity.MaxAge = YearDuration - } - - sb.WriteString(fmt.Sprintf("max-age=%d", opt.StrictTransportSecurity.MaxAge)) - if opt.StrictTransportSecurity.IncludeSubDomains { - sb.WriteString("; includeSubDomains") - } - - if opt.StrictTransportSecurity.Preload { - sb.WriteString("; preload") - } - - w.Header().Add("Strict-Transport-Security", sb.String()) - } - - if !opt.DisableSniffMimeType { - // MIME types advertised in the Content-Current headers should be followed and not be changed - w.Header().Add("X-Content-Type-Options", "nosniff") - } - - if opt.DisableDNSPrefetch { - w.Header().Add("X-DNS-Prefetch-Control", "off") - } else { - w.Header().Add("X-DNS-Prefetch-Control", "on") - } - - if !opt.DisableXDownload { - // Instructs Internet Explorer not to open the file directly but to offer it for download first. - w.Header().Add("X-Download-Options", "noopen") - } - - // indicate whether a browser should be allowed to render a page in iframe | frame | embed | object - if opt.XFrameOption == "" { - w.Header().Add("X-Frame-Options", string(XFrameSameOrigin)) - } else { - w.Header().Add("X-Frame-Options", string(opt.XFrameOption)) - } - - if opt.CrossDomainPolicies == "" { - w.Header().Add("X-Permitted-Cross-Domain-Policies", string(CDPNone)) - } else { - w.Header().Add("X-Permitted-Cross-Domain-Policies", string(opt.CrossDomainPolicies)) - } - w.Header().Del("X-Powered-By") - - if opt.XssProtection { - // feature of IE, Chrome and Safari that stops pages from loading when they detect reflected - // cross-site scripting (XSS) attacks. - w.Header().Add("X-Xss-Protection", "1; mode=block") - } else { - // Following a decision by Google Chrome developers to disable Auditor, - // developers should be able to disable the auditor for older browsers and set it to 0. - // The X-XSS-PROTECTION header was found to have a multitude of issues, instead of helping the - // developers protect their application. - w.Header().Add("X-Xss-Protection", "0") - } - h.ServeHTTP(w, r) }) } } +type headerKV struct { + key string + value string +} + +func buildHelmetHeaders(opt HelmetOption) []headerKV { + var headers []headerKV + + add := func(key, value string) { + headers = append(headers, headerKV{key: key, value: value}) + } + + // Content-Security-Policy + add("Content-Security-Policy", opt.ContentSecurityPolicy.value()) + + // Cross-Origin-Opener-Policy + if opt.CrossOriginOpenerPolicy == "" { + add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin)) + } else { + add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy)) + } + + // Cross-Origin-Resource-Policy + if opt.CrossOriginResourcePolicy == "" { + add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin)) + } else { + add("Cross-Origin-Resource-Policy", string(opt.CrossOriginResourcePolicy)) + } + + // Referrer-Policy + if len(opt.ReferrerPolicy) > 0 { + refP := make([]string, len(opt.ReferrerPolicy)) + for i, r := range opt.ReferrerPolicy { + refP[i] = string(r) + } + add("Referrer-Policy", strings.Join(refP, ",")) + } else { + add("Referrer-Policy", string(NoReferrer)) + } + + // Origin-Agent-Cluster + if opt.OriginAgentCluster { + add("Origin-Agent-Cluster", "?1") + } + + // Strict-Transport-Security + if opt.StrictTransportSecurity != nil { + var sb strings.Builder + + maxAge := opt.StrictTransportSecurity.MaxAge + if maxAge == 0 { + maxAge = YearDuration + } + + sb.WriteString(fmt.Sprintf("max-age=%d", maxAge)) + if opt.StrictTransportSecurity.IncludeSubDomains { + sb.WriteString("; includeSubDomains") + } + if opt.StrictTransportSecurity.Preload { + sb.WriteString("; preload") + } + + add("Strict-Transport-Security", sb.String()) + } + + // X-Content-Type-Options + if !opt.DisableSniffMimeType { + add("X-Content-Type-Options", "nosniff") + } + + // X-DNS-Prefetch-Control + if opt.DisableDNSPrefetch { + add("X-DNS-Prefetch-Control", "off") + } else { + add("X-DNS-Prefetch-Control", "on") + } + + // X-Download-Options + if !opt.DisableXDownload { + add("X-Download-Options", "noopen") + } + + // X-Frame-Options + if opt.XFrameOption == "" { + add("X-Frame-Options", string(XFrameSameOrigin)) + } else { + add("X-Frame-Options", string(opt.XFrameOption)) + } + + // X-Permitted-Cross-Domain-Policies + if opt.CrossDomainPolicies == "" { + add("X-Permitted-Cross-Domain-Policies", string(CDPNone)) + } else { + add("X-Permitted-Cross-Domain-Policies", string(opt.CrossDomainPolicies)) + } + + // X-Xss-Protection + if opt.XssProtection { + add("X-Xss-Protection", "1; mode=block") + } else { + add("X-Xss-Protection", "0") + } + + return headers +} + func (csp *CSP) value() string { var sb strings.Builder diff --git a/middleware/request_id.go b/middleware/request_id.go index e1d4ccb..8f059be 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -64,6 +64,10 @@ func init() { // where "random" is a base62 random string that uniquely identifies this go // process, and where the last number is an atomically incremented request // counter. +// maxRequestIDLen is the maximum length of an incoming request ID header +// to prevent log injection or memory abuse from malicious clients. +const maxRequestIDLen = 200 + func RequestID(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -71,6 +75,8 @@ func RequestID(next http.Handler) http.Handler { if requestID == "" { myid := reqid.Add(1) requestID = fmt.Sprintf("%s-%06d", prefix, myid) + } else if len(requestID) > maxRequestIDLen { + requestID = requestID[:maxRequestIDLen] } ctx = context.WithValue(ctx, RequestIDKey, requestID) next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/mux.go b/mux.go index e5a1f15..0df4b28 100644 --- a/mux.go +++ b/mux.go @@ -23,6 +23,22 @@ func New() *Mux { mux: http.NewServeMux(), routes: new(RouteList), } + + // Catch-all OPTIONS handler. + // Pass it through all middlewares and drain oversized bodies. + m.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "0") + if r.ContentLength != 0 { + // Read up to 4KB of OPTIONS body (as mentioned in the + // spec as being reserved for future use), but anything + // over that is considered a waste of server resources + // (or an attack) and we abort and close the connection, + // courtesy of MaxBytesReader's EOF behavior. + mb := http.MaxBytesReader(w, r.Body, 4<<10) + _, _ = io.Copy(io.Discard, mb) + } + }) + return m } @@ -147,8 +163,8 @@ func (m *Mux) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (m *Mux) PrintRoutes(w io.Writer) { for _, route := range m.routes.All() { - w.Write([]byte(route)) - w.Write([]byte("\n")) + _, _ = w.Write([]byte(route)) + _, _ = w.Write([]byte("\n")) } } diff --git a/mux_test.go b/mux_test.go index 14fb071..05b2c16 100644 --- a/mux_test.go +++ b/mux_test.go @@ -306,10 +306,8 @@ func BenchmarkRouterSimple(b *testing.B) { source := rand.NewSource(time.Now().UnixNano()) r := rand.New(source) - // Generate a random integer between 0 and 99 (inclusive) - rn := r.Intn(10000) - for b.Loop() { + rn := r.Intn(10000) req, _ := http.NewRequest(http.MethodGet, "/"+strconv.Itoa(rn), nil) w := httptest.NewRecorder() m.ServeHTTP(w, req) diff --git a/resource.go b/resource.go index caecc87..9e116d0 100644 --- a/resource.go +++ b/resource.go @@ -27,7 +27,7 @@ func (m *Mux) Resource(pattern string, fn func(res *Resource), mw ...func(http.H } if strings.TrimSpace(pattern) == "" { - panic("mux: Resource() requires a patter to work") + panic("mux: Resource() requires a pattern to work") } if fn == nil { @@ -215,11 +215,9 @@ func (res *Resource) Use(middlewares ...func(http.Handler) http.Handler) { } func suffixIt(str, suffix string) string { - var p strings.Builder - p.WriteString(str) - if !strings.HasSuffix(str, "/") { - p.WriteString("/") + if strings.HasSuffix(str, "/") { + return str + suffix } - p.WriteString(suffix) - return p.String() + + return str + "/" + suffix } diff --git a/route.go b/route.go index 20b0347..5104612 100644 --- a/route.go +++ b/route.go @@ -33,7 +33,7 @@ func (s *RouteList) Get(index int) (string, error) { defer s.mu.RUnlock() if index < 0 || index >= len(s.routes) { - return "0", fmt.Errorf("index out of bounds") + return "", fmt.Errorf("index out of bounds") } return s.routes[index], nil } @@ -47,5 +47,7 @@ func (s *RouteList) All() []string { s.mu.RLock() defer s.mu.RUnlock() - return s.routes + out := make([]string, len(s.routes)) + copy(out, s.routes) + return out } diff --git a/serve.go b/serve.go index 372b8c2..8d3f672 100644 --- a/serve.go +++ b/serve.go @@ -3,7 +3,6 @@ package mux import ( "context" "errors" - "io" "log" "log/slog" "net" @@ -26,21 +25,6 @@ func (m *Mux) Serve(cb ServeCB) { rootCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - // catch all options - // lets get it thorugh all middlewares - m.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "0") - if r.ContentLength != 0 { - // Read up to 4KB of OPTIONS body (as mentioned in the - // spec as being reserved for future use), but anything - // over that is considered a waste of server resources - // (or an attack) and we abort and close the connection, - // courtesy of MaxBytesReader's EOF behavior. - mb := http.MaxBytesReader(w, r.Body, 4<<10) - io.Copy(io.Discard, mb) - } - }) - srvCtx, cancelSrvCtx := context.WithCancel(context.Background()) srv := &http.Server{ Handler: m, @@ -51,7 +35,7 @@ func (m *Mux) Serve(cb ServeCB) { go func() { if err := cb(srv); !errors.Is(err, http.ErrServerClosed) { - panic(err) + log.Fatalf("server error: %v", err) } }() @@ -60,7 +44,7 @@ func (m *Mux) Serve(cb ServeCB) { stop() m.IsShuttingDown.Store(true) - slog.Info("received interrupt singal, shutting down") + slog.Info("received interrupt signal, shutting down") time.Sleep(drainDelay) slog.Info("readiness check propagated, now waiting for ongoing requests to finish.") @@ -74,5 +58,5 @@ func (m *Mux) Serve(cb ServeCB) { time.Sleep(shutdownHardDelay) } - slog.Info("seerver shut down gracefully") + slog.Info("server shut down gracefully") }