3 Commits

Author SHA1 Message Date
f1c5b9587b claude code review changes 2026-02-20 17:05:34 +05:30
136957d75d use go version 1.26 2026-02-20 16:16:43 +05:30
43da615326 feat: add RealIP, RequestID, and RequestSize middleware 2025-11-20 23:13:26 +05:30
14 changed files with 324 additions and 133 deletions

2
.gitignore vendored
View File

@@ -1,3 +1,5 @@
.claude
# Profiling files # Profiling files
.prof .prof

2
go.mod
View File

@@ -1,3 +1,3 @@
module code.patial.tech/go/mux module code.patial.tech/go/mux
go 1.25 go 1.26

View File

@@ -1,4 +1,4 @@
go 1.25.1 go 1.26
use . use .

View File

@@ -109,7 +109,6 @@ func CORS(opts CORSOption) func(http.Handler) http.Handler {
ch.setAllowedMethods(opts.AllowedMethods) ch.setAllowedMethods(opts.AllowedMethods)
ch.setExposedHeaders(opts.ExposedHeaders) ch.setExposedHeaders(opts.ExposedHeaders)
ch.setMaxAge(opts.MaxAge) ch.setMaxAge(opts.MaxAge)
ch.maxAge = opts.MaxAge
ch.allowCredentials = opts.AllowCredentials ch.allowCredentials = opts.AllowCredentials
return ch return ch

View File

@@ -56,7 +56,7 @@ type (
} }
TransportSecurity struct { TransportSecurity struct {
// Age in seconts // Age in seconds
MaxAge uint MaxAge uint
IncludeSubDomains bool IncludeSubDomains bool
Preload bool Preload bool
@@ -111,108 +111,124 @@ const (
// Helmet headers to secure server response // Helmet headers to secure server response
func Helmet(opt HelmetOption) func(http.Handler) http.Handler { func Helmet(opt HelmetOption) func(http.Handler) http.Handler {
// Precompute all static header values once at middleware creation time.
headers := buildHelmetHeaders(opt)
return func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Security-Policy", opt.ContentSecurityPolicy.value()) for _, kv := range headers {
w.Header().Add(kv.key, kv.value)
}
w.Header().Del("X-Powered-By")
h.ServeHTTP(w, r)
})
}
}
// Opener-Policy type headerKV struct {
if opt.CrossOriginOpenerPolicy == "" { key string
w.Header().Add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin)) value string
} else { }
w.Header().Add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy))
func buildHelmetHeaders(opt HelmetOption) []headerKV {
var headers []headerKV
add := func(key, value string) {
headers = append(headers, headerKV{key: key, value: value})
} }
// Resource-Policy // Content-Security-Policy
if opt.CrossOriginResourcePolicy == "" { add("Content-Security-Policy", opt.ContentSecurityPolicy.value())
w.Header().Add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin))
// Cross-Origin-Opener-Policy
if opt.CrossOriginOpenerPolicy == "" {
add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin))
} else { } else {
w.Header().Add("Cross-Origin-Resource-Policy", string(opt.CrossOriginResourcePolicy)) add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy))
}
// Cross-Origin-Resource-Policy
if opt.CrossOriginResourcePolicy == "" {
add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin))
} else {
add("Cross-Origin-Resource-Policy", string(opt.CrossOriginResourcePolicy))
} }
// Referrer-Policy // Referrer-Policy
rpCount := len(opt.ReferrerPolicy) if len(opt.ReferrerPolicy) > 0 {
if rpCount > 0 { refP := make([]string, len(opt.ReferrerPolicy))
refP := make([]string, rpCount)
for i, r := range opt.ReferrerPolicy { for i, r := range opt.ReferrerPolicy {
refP[i] = string(r) refP[i] = string(r)
} }
w.Header().Add("Referrer-Policy", string(NoReferrer)) add("Referrer-Policy", strings.Join(refP, ","))
} else { } else {
// default no referer add("Referrer-Policy", string(NoReferrer))
w.Header().Add("Referrer-Policy", string(NoReferrer))
} }
// Origin-Agent-Cluster // Origin-Agent-Cluster
if opt.OriginAgentCluster { if opt.OriginAgentCluster {
w.Header().Add("Origin-Agent-Cluster", "?1") add("Origin-Agent-Cluster", "?1")
} }
// Strict-Transport-Security // Strict-Transport-Security
if opt.StrictTransportSecurity != nil { if opt.StrictTransportSecurity != nil {
var sb strings.Builder var sb strings.Builder
if opt.StrictTransportSecurity.MaxAge == 0 { maxAge := opt.StrictTransportSecurity.MaxAge
opt.StrictTransportSecurity.MaxAge = YearDuration if maxAge == 0 {
maxAge = YearDuration
} }
sb.WriteString(fmt.Sprintf("max-age=%d", opt.StrictTransportSecurity.MaxAge)) sb.WriteString(fmt.Sprintf("max-age=%d", maxAge))
if opt.StrictTransportSecurity.IncludeSubDomains { if opt.StrictTransportSecurity.IncludeSubDomains {
sb.WriteString("; includeSubDomains") sb.WriteString("; includeSubDomains")
} }
if opt.StrictTransportSecurity.Preload { if opt.StrictTransportSecurity.Preload {
sb.WriteString("; preload") sb.WriteString("; preload")
} }
w.Header().Add("Strict-Transport-Security", sb.String()) add("Strict-Transport-Security", sb.String())
} }
// X-Content-Type-Options
if !opt.DisableSniffMimeType { if !opt.DisableSniffMimeType {
// MIME types advertised in the Content-Current headers should be followed and not be changed add("X-Content-Type-Options", "nosniff")
w.Header().Add("X-Content-Type-Options", "nosniff")
} }
// X-DNS-Prefetch-Control
if opt.DisableDNSPrefetch { if opt.DisableDNSPrefetch {
w.Header().Add("X-DNS-Prefetch-Control", "off") add("X-DNS-Prefetch-Control", "off")
} else { } else {
w.Header().Add("X-DNS-Prefetch-Control", "on") add("X-DNS-Prefetch-Control", "on")
} }
// X-Download-Options
if !opt.DisableXDownload { if !opt.DisableXDownload {
// Instructs Internet Explorer not to open the file directly but to offer it for download first. add("X-Download-Options", "noopen")
w.Header().Add("X-Download-Options", "noopen")
} }
// indicate whether a browser should be allowed to render a page in iframe | frame | embed | object // X-Frame-Options
if opt.XFrameOption == "" { if opt.XFrameOption == "" {
w.Header().Add("X-Frame-Options", string(XFrameSameOrigin)) add("X-Frame-Options", string(XFrameSameOrigin))
} else { } else {
w.Header().Add("X-Frame-Options", string(opt.XFrameOption)) add("X-Frame-Options", string(opt.XFrameOption))
} }
// X-Permitted-Cross-Domain-Policies
if opt.CrossDomainPolicies == "" { if opt.CrossDomainPolicies == "" {
w.Header().Add("X-Permitted-Cross-Domain-Policies", string(CDPNone)) add("X-Permitted-Cross-Domain-Policies", string(CDPNone))
} else { } else {
w.Header().Add("X-Permitted-Cross-Domain-Policies", string(opt.CrossDomainPolicies)) add("X-Permitted-Cross-Domain-Policies", string(opt.CrossDomainPolicies))
} }
w.Header().Del("X-Powered-By") // X-Xss-Protection
if opt.XssProtection { if opt.XssProtection {
// feature of IE, Chrome and Safari that stops pages from loading when they detect reflected add("X-Xss-Protection", "1; mode=block")
// cross-site scripting (XSS) attacks.
w.Header().Add("X-Xss-Protection", "1; mode=block")
} else { } else {
// Following a decision by Google Chrome developers to disable Auditor, add("X-Xss-Protection", "0")
// developers should be able to disable the auditor for older browsers and set it to 0.
// The X-XSS-PROTECTION header was found to have a multitude of issues, instead of helping the
// developers protect their application.
w.Header().Add("X-Xss-Protection", "0")
} }
h.ServeHTTP(w, r) return headers
})
}
} }
func (csp *CSP) value() string { func (csp *CSP) value() string {

56
middleware/real_ip.go Normal file
View File

@@ -0,0 +1,56 @@
package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"net"
"net/http"
"strings"
)
var trueClientIP = http.CanonicalHeaderKey("True-Client-IP")
var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
var xRealIP = http.CanonicalHeaderKey("X-Real-IP")
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers
// (in that order).
//
// This middleware should be inserted fairly early in the middleware stack to
// ensure that subsequent layers (e.g., request loggers) which examine the
// RemoteAddr will see the intended value.
//
// You should only use this middleware if you can trust the headers passed to
// you (in particular, the three headers this middleware uses), for example
// because you have placed a reverse proxy like HAProxy or nginx in front of
// chi. If your reverse proxies are configured to pass along arbitrary header
// values from the client, or if you use this middleware without a reverse
// proxy, malicious clients will be able to make you very sad (or, depending on
// how you're using RemoteAddr, vulnerable to an attack of some sort).
func RealIP(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := realIP(r); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func realIP(r *http.Request) string {
var ip string
if tcip := r.Header.Get(trueClientIP); tcip != "" {
ip = tcip
} else if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
ip, _, _ = strings.Cut(xff, ",")
}
if ip == "" || net.ParseIP(ip) == nil {
return ""
}
return ip
}

102
middleware/request_id.go Normal file
View File

@@ -0,0 +1,102 @@
package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
)
// Key to use when setting the request ID.
type ctxKeyRequestID int
// RequestIDKey is the key that holds the unique request ID in a request context.
const RequestIDKey ctxKeyRequestID = 0
// RequestIDHeader is the name of the HTTP Header which contains the request id.
// Exported so that it can be changed by developers
var RequestIDHeader = "X-Request-Id"
var prefix string
var reqid atomic.Uint64
// A quick note on the statistics here: we're trying to calculate the chance that
// two randomly generated base62 prefixes will collide. We use the formula from
// http://en.wikipedia.org/wiki/Birthday_problem
//
// P[m, n] \approx 1 - e^{-m^2/2n}
//
// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server
// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$
//
// For a $k$ character base-62 identifier, we have $n(k) = 62^k$
//
// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for
// our purposes, and is surely more than anyone would ever need in practice -- a
// process that is rebooted a handful of times a day for a hundred years has less
// than a millionth of a percent chance of generating two colliding IDs.
func init() {
hostname, err := os.Hostname()
if hostname == "" || err != nil {
hostname = "localhost"
}
var buf [12]byte
var b64 string
for len(b64) < 10 {
rand.Read(buf[:])
b64 = base64.StdEncoding.EncodeToString(buf[:])
b64 = strings.NewReplacer("+", "", "/", "").Replace(b64)
}
prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10])
}
// RequestID is a middleware that injects a request ID into the context of each
// request. A request ID is a string of the form "host.example.com/random-0001",
// where "random" is a base62 random string that uniquely identifies this go
// process, and where the last number is an atomically incremented request
// counter.
// maxRequestIDLen is the maximum length of an incoming request ID header
// to prevent log injection or memory abuse from malicious clients.
const maxRequestIDLen = 200
func RequestID(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := r.Header.Get(RequestIDHeader)
if requestID == "" {
myid := reqid.Add(1)
requestID = fmt.Sprintf("%s-%06d", prefix, myid)
} else if len(requestID) > maxRequestIDLen {
requestID = requestID[:maxRequestIDLen]
}
ctx = context.WithValue(ctx, RequestIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// GetReqID returns a request ID from the given context if one is present.
// Returns the empty string if a request ID cannot be found.
func GetReqID(ctx context.Context) string {
if ctx == nil {
return ""
}
if reqID, ok := ctx.Value(RequestIDKey).(string); ok {
return reqID
}
return ""
}
// NextRequestID generates the next request ID in the sequence.
func NextRequestID() uint64 {
return reqid.Add(1)
}

View File

@@ -0,0 +1,18 @@
package middleware
import (
"net/http"
)
// RequestSize is a middleware that will limit request sizes to a specified
// number of bytes. It uses MaxBytesReader to do so.
func RequestSize(bytes int64) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, bytes)
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
return f
}

20
mux.go
View File

@@ -23,6 +23,22 @@ func New() *Mux {
mux: http.NewServeMux(), mux: http.NewServeMux(),
routes: new(RouteList), routes: new(RouteList),
} }
// Catch-all OPTIONS handler.
// Pass it through all middlewares and drain oversized bodies.
m.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "0")
if r.ContentLength != 0 {
// Read up to 4KB of OPTIONS body (as mentioned in the
// spec as being reserved for future use), but anything
// over that is considered a waste of server resources
// (or an attack) and we abort and close the connection,
// courtesy of MaxBytesReader's EOF behavior.
mb := http.MaxBytesReader(w, r.Body, 4<<10)
_, _ = io.Copy(io.Discard, mb)
}
})
return m return m
} }
@@ -147,8 +163,8 @@ func (m *Mux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
func (m *Mux) PrintRoutes(w io.Writer) { func (m *Mux) PrintRoutes(w io.Writer) {
for _, route := range m.routes.All() { for _, route := range m.routes.All() {
w.Write([]byte(route)) _, _ = w.Write([]byte(route))
w.Write([]byte("\n")) _, _ = w.Write([]byte("\n"))
} }
} }

View File

@@ -306,10 +306,8 @@ func BenchmarkRouterSimple(b *testing.B) {
source := rand.NewSource(time.Now().UnixNano()) source := rand.NewSource(time.Now().UnixNano())
r := rand.New(source) r := rand.New(source)
// Generate a random integer between 0 and 99 (inclusive)
rn := r.Intn(10000)
for b.Loop() { for b.Loop() {
rn := r.Intn(10000)
req, _ := http.NewRequest(http.MethodGet, "/"+strconv.Itoa(rn), nil) req, _ := http.NewRequest(http.MethodGet, "/"+strconv.Itoa(rn), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
m.ServeHTTP(w, req) m.ServeHTTP(w, req)

View File

@@ -1,6 +1,6 @@
module code.patial.tech/go/mux/playground module code.patial.tech/go/mux/playground
go 1.24 go 1.26
require ( require (
code.patial.tech/go/mux v0.7.1 code.patial.tech/go/mux v0.7.1

View File

@@ -27,7 +27,7 @@ func (m *Mux) Resource(pattern string, fn func(res *Resource), mw ...func(http.H
} }
if strings.TrimSpace(pattern) == "" { if strings.TrimSpace(pattern) == "" {
panic("mux: Resource() requires a patter to work") panic("mux: Resource() requires a pattern to work")
} }
if fn == nil { if fn == nil {
@@ -215,11 +215,9 @@ func (res *Resource) Use(middlewares ...func(http.Handler) http.Handler) {
} }
func suffixIt(str, suffix string) string { func suffixIt(str, suffix string) string {
var p strings.Builder if strings.HasSuffix(str, "/") {
p.WriteString(str) return str + suffix
if !strings.HasSuffix(str, "/") {
p.WriteString("/")
} }
p.WriteString(suffix)
return p.String() return str + "/" + suffix
} }

View File

@@ -33,7 +33,7 @@ func (s *RouteList) Get(index int) (string, error) {
defer s.mu.RUnlock() defer s.mu.RUnlock()
if index < 0 || index >= len(s.routes) { if index < 0 || index >= len(s.routes) {
return "0", fmt.Errorf("index out of bounds") return "", fmt.Errorf("index out of bounds")
} }
return s.routes[index], nil return s.routes[index], nil
} }
@@ -47,5 +47,7 @@ func (s *RouteList) All() []string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.routes out := make([]string, len(s.routes))
copy(out, s.routes)
return out
} }

View File

@@ -3,7 +3,6 @@ package mux
import ( import (
"context" "context"
"errors" "errors"
"io"
"log" "log"
"log/slog" "log/slog"
"net" "net"
@@ -26,21 +25,6 @@ func (m *Mux) Serve(cb ServeCB) {
rootCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) rootCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop() defer stop()
// catch all options
// lets get it thorugh all middlewares
m.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "0")
if r.ContentLength != 0 {
// Read up to 4KB of OPTIONS body (as mentioned in the
// spec as being reserved for future use), but anything
// over that is considered a waste of server resources
// (or an attack) and we abort and close the connection,
// courtesy of MaxBytesReader's EOF behavior.
mb := http.MaxBytesReader(w, r.Body, 4<<10)
io.Copy(io.Discard, mb)
}
})
srvCtx, cancelSrvCtx := context.WithCancel(context.Background()) srvCtx, cancelSrvCtx := context.WithCancel(context.Background())
srv := &http.Server{ srv := &http.Server{
Handler: m, Handler: m,
@@ -51,7 +35,7 @@ func (m *Mux) Serve(cb ServeCB) {
go func() { go func() {
if err := cb(srv); !errors.Is(err, http.ErrServerClosed) { if err := cb(srv); !errors.Is(err, http.ErrServerClosed) {
panic(err) log.Fatalf("server error: %v", err)
} }
}() }()
@@ -60,7 +44,7 @@ func (m *Mux) Serve(cb ServeCB) {
stop() stop()
m.IsShuttingDown.Store(true) m.IsShuttingDown.Store(true)
slog.Info("received interrupt singal, shutting down") slog.Info("received interrupt signal, shutting down")
time.Sleep(drainDelay) time.Sleep(drainDelay)
slog.Info("readiness check propagated, now waiting for ongoing requests to finish.") slog.Info("readiness check propagated, now waiting for ongoing requests to finish.")
@@ -74,5 +58,5 @@ func (m *Mux) Serve(cb ServeCB) {
time.Sleep(shutdownHardDelay) time.Sleep(shutdownHardDelay)
} }
slog.Info("seerver shut down gracefully") slog.Info("server shut down gracefully")
} }