5 Commits

16 changed files with 753 additions and 137 deletions

2
.gitignore vendored
View File

@@ -1,3 +1,5 @@
.claude
# Profiling files # Profiling files
.prof .prof

View File

@@ -7,25 +7,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
No unreleased changes.
## [1.0.0] - 2024-12-19
### Changed ### Changed
- Renamed `HandleGET`, `HandlePOST`, `HandlePUT`, `HandlePATCH`, `HandleDELETE` to `MemberGET`, `MemberPOST`, `MemberPUT`, `MemberPATCH`, `MemberDELETE` for better clarity - **BREAKING**: Renamed `HandleGET`, `HandlePOST`, `HandlePUT`, `HandlePATCH`, `HandleDELETE` to `MemberGET`, `MemberPOST`, `MemberPUT`, `MemberPATCH`, `MemberDELETE` for better clarity
- Member routes now explicitly operate on `/pattern/{id}/action` endpoints - Member routes now explicitly operate on `/pattern/{id}/action` endpoints
- Optimized struct field alignment for better memory usage
### Added ### Added
- Collection-level custom route methods: `GET`, `POST`, `PUT`, `PATCH`, `DELETE` for `/pattern/action` endpoints - Collection-level custom route methods: `GET`, `POST`, `PUT`, `PATCH`, `DELETE` for `/pattern/action` endpoints
- Comprehensive README with detailed examples and usage patterns - Comprehensive README with detailed examples and usage patterns
- CONTRIBUTING.md with code quality standards and guidelines - CONTRIBUTING.md with code quality standards and guidelines
- QUICKSTART.md for new users
- DOCS.md as documentation index
- SUMMARY.md documenting all changes
- `.cursorrules` file for AI coding assistants - `.cursorrules` file for AI coding assistants
- GitHub Actions CI/CD workflow - GitHub Actions CI/CD workflow
- Makefile for common development tasks - Makefile for common development tasks
- `check.sh` script for running all quality checks
- golangci-lint configuration - golangci-lint configuration
- Field alignment requirements and checks - Field alignment requirements and checks
- Go 1.25+ requirement enforcement
### Documentation ### Documentation
- Improved README with table of contents and comprehensive examples - Improved README with table of contents and comprehensive examples
- Added distinction between collection and member routes in Resource documentation - Added distinction between collection and member routes in Resource documentation
- Added performance and testing guidelines - Added performance and testing guidelines
- Added examples for all major features - Added examples for all major features
- Added quick start guide
- Added contribution guidelines with code quality standards
### Quality
- All code passes go vet, staticcheck, and fieldalignment
- All tests pass with race detector
- Memory optimized struct layouts
## [0.2.0] - Previous Release ## [0.2.0] - Previous Release
@@ -60,6 +77,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Route conflict detection and panics - Route conflict detection and panics
- Context-aware shutdown signaling - Context-aware shutdown signaling
[Unreleased]: https://github.com/yourusername/mux/compare/v0.2.0...HEAD [Unreleased]: https://code.patial.tech/go/mux/compare/v1.0.0...HEAD
[0.2.0]: https://github.com/yourusername/mux/compare/v0.1.0...v0.2.0 [1.0.0]: https://code.patial.tech/go/mux/compare/v0.7.1...v1.0.0
[0.1.0]: https://github.com/yourusername/mux/releases/tag/v0.1.0 [0.2.0]: https://code.patial.tech/go/mux/compare/v0.1.0...v0.2.0
[0.1.0]: https://code.patial.tech/go/mux/releases/tag/v0.1.0

2
go.mod
View File

@@ -1,3 +1,3 @@
module code.patial.tech/go/mux module code.patial.tech/go/mux
go 1.25 go 1.26

View File

@@ -1,4 +1,4 @@
go 1.25.1 go 1.26
use . use .

407
middleware/compress.go Normal file
View File

@@ -0,0 +1,407 @@
// Originally from: https://github.com/go-chi/chi/blob/master/mw/compress.go
// Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc.
// MIT License
package middleware
import (
"bufio"
"compress/flate"
"compress/gzip"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
)
var defaultCompressibleContentTypes = []string{
"text/html",
"text/css",
"text/plain",
"text/javascript",
"application/javascript",
"application/x-javascript",
"application/json",
"application/atom+xml",
"application/rss+xml",
"image/svg+xml",
}
// Compress is a middleware that compresses response
// body of a given content types to a data format based
// on Accept-Encoding request header. It uses a given
// compression level.
//
// NOTE: make sure to set the Content-Type header on your response
// otherwise this middleware will not compress the response body. For ex, in
// your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody))
// or set it manually.
//
// Passing a compression level of 5 is sensible value
func Compress(level int, types ...string) func(next http.Handler) http.Handler {
compressor := NewCompressor(level, types...)
return compressor.Handler
}
// Compressor represents a set of encoding configurations.
type Compressor struct {
// The mapping of encoder names to encoder functions.
encoders map[string]EncoderFunc
// The mapping of pooled encoders to pools.
pooledEncoders map[string]*sync.Pool
// The set of content types allowed to be compressed.
allowedTypes map[string]struct{}
allowedWildcards map[string]struct{}
// The list of encoders in order of decreasing precedence.
encodingPrecedence []string
level int // The compression level.
}
// NewCompressor creates a new Compressor that will handle encoding responses.
//
// The level should be one of the ones defined in the flate package.
// The types are the content types that are allowed to be compressed.
func NewCompressor(level int, types ...string) *Compressor {
// If types are provided, set those as the allowed types. If none are
// provided, use the default list.
allowedTypes := make(map[string]struct{})
allowedWildcards := make(map[string]struct{})
if len(types) > 0 {
for _, t := range types {
if strings.Contains(strings.TrimSuffix(t, "/*"), "*") {
panic(fmt.Sprintf("mw/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t))
}
if strings.HasSuffix(t, "/*") {
allowedWildcards[strings.TrimSuffix(t, "/*")] = struct{}{}
} else {
allowedTypes[t] = struct{}{}
}
}
} else {
for _, t := range defaultCompressibleContentTypes {
allowedTypes[t] = struct{}{}
}
}
c := &Compressor{
level: level,
encoders: make(map[string]EncoderFunc),
pooledEncoders: make(map[string]*sync.Pool),
allowedTypes: allowedTypes,
allowedWildcards: allowedWildcards,
}
// Set the default encoders. The precedence order uses the reverse
// ordering that the encoders were added. This means adding new encoders
// will move them to the front of the order.
//
// TODO:
// lzma: Opera.
// sdch: Chrome, Android. Gzip output + dictionary header.
// br: Brotli, see https://github.com/go-chi/chi/pull/326
// HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951)
// wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32
// checksum compared to CRC-32 used in "gzip" and thus is faster.
//
// But.. some old browsers (MSIE, Safari 5.1) incorrectly expect
// raw DEFLATE data only, without the mentioned zlib wrapper.
// Because of this major confusion, most modern browsers try it
// both ways, first looking for zlib headers.
// Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548
//
// The list of browsers having problems is quite big, see:
// http://zoompf.com/blog/2012/02/lose-the-wait-http-compression
// https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results
//
// That's why we prefer gzip over deflate. It's just more reliable
// and not significantly slower than deflate.
c.SetEncoder("deflate", encoderDeflate)
// TODO: Exception for old MSIE browsers that can't handle non-HTML?
// https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
c.SetEncoder("gzip", encoderGzip)
// NOTE: Not implemented, intentionally:
// case "compress": // LZW. Deprecated.
// case "bzip2": // Too slow on-the-fly.
// case "zopfli": // Too slow on-the-fly.
// case "xz": // Too slow on-the-fly.
return c
}
// SetEncoder can be used to set the implementation of a compression algorithm.
//
// The encoding should be a standardised identifier. See:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
//
// For example, add the Brotli algorithm:
//
// import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc"
//
// compressor := middleware.NewCompressor(5, "text/html")
// compressor.SetEncoder("br", func(w io.Writer, level int) io.Writer {
// params := brotli_enc.NewBrotliParams()
// params.SetQuality(level)
// return brotli_enc.NewBrotliWriter(params, w)
// })
func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
encoding = strings.ToLower(encoding)
if encoding == "" {
panic("the encoding can not be empty")
}
if fn == nil {
panic("attempted to set a nil encoder function")
}
// If we are adding a new encoder that is already registered, we have to
// clear that one out first.
delete(c.pooledEncoders, encoding)
delete(c.encoders, encoding)
// If the encoder supports Resetting (IoReseterWriter), then it can be pooled.
encoder := fn(io.Discard, c.level)
if _, ok := encoder.(ioResetterWriter); ok {
pool := &sync.Pool{
New: func() any {
return fn(io.Discard, c.level)
},
}
c.pooledEncoders[encoding] = pool
}
// If the encoder is not in the pooledEncoders, add it to the normal encoders.
if _, ok := c.pooledEncoders[encoding]; !ok {
c.encoders[encoding] = fn
}
for i, v := range c.encodingPrecedence {
if v == encoding {
c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
}
}
c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
}
// Handler returns a new middleware that will compress the response based on the
// current Compressor.
func (c *Compressor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip compression for WebSocket upgrades
if r.Header.Get("Upgrade") == "websocket" {
next.ServeHTTP(w, r)
return
}
encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
cw := &compressResponseWriter{
ResponseWriter: w,
w: w,
contentTypes: c.allowedTypes,
contentWildcards: c.allowedWildcards,
encoding: encoding,
compressible: false, // determined in post-handler
}
if encoder != nil {
cw.w = encoder
}
// Re-add the encoder to the pool if applicable.
defer cleanup()
defer cw.Close()
next.ServeHTTP(cw, r)
})
}
// selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
header := h.Get("Accept-Encoding")
// Parse the names of all accepted algorithms from the header.
accepted := strings.Split(strings.ToLower(header), ",")
// Find supported encoder by accepted list by precedence
for _, name := range c.encodingPrecedence {
if matchAcceptEncoding(accepted, name) {
if pool, ok := c.pooledEncoders[name]; ok {
encoder := pool.Get().(ioResetterWriter)
cleanup := func() {
pool.Put(encoder)
}
encoder.Reset(w)
return encoder, name, cleanup
}
if fn, ok := c.encoders[name]; ok {
return fn(w, c.level), name, func() {}
}
}
}
// No encoder found to match the accepted encoding
return nil, "", func() {}
}
func matchAcceptEncoding(accepted []string, encoding string) bool {
for _, v := range accepted {
v = strings.TrimSpace(v)
// Handle quality values like "gzip;q=0.8"
if idx := strings.Index(v, ";"); idx != -1 {
v = strings.TrimSpace(v[:idx])
}
if v == encoding {
return true
}
}
return false
}
// An EncoderFunc is a function that wraps the provided io.Writer with a
// streaming compression algorithm and returns it.
//
// In case of failure, the function should return nil.
type EncoderFunc func(w io.Writer, level int) io.Writer
// Interface for types that allow resetting io.Writers.
type ioResetterWriter interface {
io.Writer
Reset(w io.Writer)
}
type compressResponseWriter struct {
http.ResponseWriter
// The streaming encoder writer to be used if there is one. Otherwise,
// this is just the normal writer.
w io.Writer
contentTypes map[string]struct{}
contentWildcards map[string]struct{}
encoding string
wroteHeader bool
compressible bool
}
func (cw *compressResponseWriter) isCompressible() bool {
// Parse the first part of the Content-Type response header.
contentType := cw.Header().Get("Content-Type")
contentType, _, _ = strings.Cut(contentType, ";")
// Is the content type compressible?
if _, ok := cw.contentTypes[contentType]; ok {
return true
}
if contentType, _, hadSlash := strings.Cut(contentType, "/"); hadSlash {
_, ok := cw.contentWildcards[contentType]
return ok
}
return false
}
func (cw *compressResponseWriter) WriteHeader(code int) {
if cw.wroteHeader {
cw.ResponseWriter.WriteHeader(code) // Allow multiple calls to propagate.
return
}
cw.wroteHeader = true
defer cw.ResponseWriter.WriteHeader(code)
// Already compressed data?
if cw.Header().Get("Content-Encoding") != "" {
return
}
if !cw.isCompressible() {
cw.compressible = false
return
}
if cw.encoding != "" {
cw.compressible = true
cw.Header().Set("Content-Encoding", cw.encoding)
cw.Header().Add("Vary", "Accept-Encoding")
// The content-length after compression is unknown
cw.Header().Del("Content-Length")
}
}
func (cw *compressResponseWriter) Write(p []byte) (int, error) {
if !cw.wroteHeader {
cw.WriteHeader(http.StatusOK)
}
return cw.writer().Write(p)
}
func (cw *compressResponseWriter) writer() io.Writer {
if cw.compressible {
return cw.w
}
return cw.ResponseWriter
}
type compressFlusher interface {
Flush() error
}
func (cw *compressResponseWriter) Flush() {
if f, ok := cw.writer().(http.Flusher); ok {
f.Flush()
}
// If the underlying writer has a compression flush signature,
// call this Flush() method instead
if f, ok := cw.writer().(compressFlusher); ok {
f.Flush()
// Also flush the underlying response writer
if f, ok := cw.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
}
func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := cw.writer().(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, errors.New("mw/compress: http.Hijacker is unavailable on the writer")
}
func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
if ps, ok := cw.writer().(http.Pusher); ok {
return ps.Push(target, opts)
}
return errors.New("mw/compress: http.Pusher is unavailable on the writer")
}
func (cw *compressResponseWriter) Close() error {
if c, ok := cw.writer().(io.WriteCloser); ok {
return c.Close()
}
return errors.New("mw/compress: io.WriteCloser is unavailable on the writer")
}
func (cw *compressResponseWriter) Unwrap() http.ResponseWriter {
return cw.ResponseWriter
}
func encoderGzip(w io.Writer, level int) io.Writer {
gw, err := gzip.NewWriterLevel(w, level)
if err != nil {
return nil
}
return gw
}
func encoderDeflate(w io.Writer, level int) io.Writer {
dw, err := flate.NewWriter(w, level)
if err != nil {
return nil
}
return dw
}

View File

@@ -109,7 +109,6 @@ func CORS(opts CORSOption) func(http.Handler) http.Handler {
ch.setAllowedMethods(opts.AllowedMethods) ch.setAllowedMethods(opts.AllowedMethods)
ch.setExposedHeaders(opts.ExposedHeaders) ch.setExposedHeaders(opts.ExposedHeaders)
ch.setMaxAge(opts.MaxAge) ch.setMaxAge(opts.MaxAge)
ch.maxAge = opts.MaxAge
ch.allowCredentials = opts.AllowCredentials ch.allowCredentials = opts.AllowCredentials
return ch return ch

View File

@@ -56,7 +56,7 @@ type (
} }
TransportSecurity struct { TransportSecurity struct {
// Age in seconts // Age in seconds
MaxAge uint MaxAge uint
IncludeSubDomains bool IncludeSubDomains bool
Preload bool Preload bool
@@ -111,108 +111,124 @@ const (
// Helmet headers to secure server response // Helmet headers to secure server response
func Helmet(opt HelmetOption) func(http.Handler) http.Handler { func Helmet(opt HelmetOption) func(http.Handler) http.Handler {
// Precompute all static header values once at middleware creation time.
headers := buildHelmetHeaders(opt)
return func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Security-Policy", opt.ContentSecurityPolicy.value()) for _, kv := range headers {
w.Header().Add(kv.key, kv.value)
}
w.Header().Del("X-Powered-By")
h.ServeHTTP(w, r)
})
}
}
// Opener-Policy type headerKV struct {
if opt.CrossOriginOpenerPolicy == "" { key string
w.Header().Add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin)) value string
} else { }
w.Header().Add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy))
func buildHelmetHeaders(opt HelmetOption) []headerKV {
var headers []headerKV
add := func(key, value string) {
headers = append(headers, headerKV{key: key, value: value})
} }
// Resource-Policy // Content-Security-Policy
if opt.CrossOriginResourcePolicy == "" { add("Content-Security-Policy", opt.ContentSecurityPolicy.value())
w.Header().Add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin))
// Cross-Origin-Opener-Policy
if opt.CrossOriginOpenerPolicy == "" {
add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin))
} else { } else {
w.Header().Add("Cross-Origin-Resource-Policy", string(opt.CrossOriginResourcePolicy)) add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy))
}
// Cross-Origin-Resource-Policy
if opt.CrossOriginResourcePolicy == "" {
add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin))
} else {
add("Cross-Origin-Resource-Policy", string(opt.CrossOriginResourcePolicy))
} }
// Referrer-Policy // Referrer-Policy
rpCount := len(opt.ReferrerPolicy) if len(opt.ReferrerPolicy) > 0 {
if rpCount > 0 { refP := make([]string, len(opt.ReferrerPolicy))
refP := make([]string, rpCount)
for i, r := range opt.ReferrerPolicy { for i, r := range opt.ReferrerPolicy {
refP[i] = string(r) refP[i] = string(r)
} }
w.Header().Add("Referrer-Policy", string(NoReferrer)) add("Referrer-Policy", strings.Join(refP, ","))
} else { } else {
// default no referer add("Referrer-Policy", string(NoReferrer))
w.Header().Add("Referrer-Policy", string(NoReferrer))
} }
// Origin-Agent-Cluster // Origin-Agent-Cluster
if opt.OriginAgentCluster { if opt.OriginAgentCluster {
w.Header().Add("Origin-Agent-Cluster", "?1") add("Origin-Agent-Cluster", "?1")
} }
// Strict-Transport-Security // Strict-Transport-Security
if opt.StrictTransportSecurity != nil { if opt.StrictTransportSecurity != nil {
var sb strings.Builder var sb strings.Builder
if opt.StrictTransportSecurity.MaxAge == 0 { maxAge := opt.StrictTransportSecurity.MaxAge
opt.StrictTransportSecurity.MaxAge = YearDuration if maxAge == 0 {
maxAge = YearDuration
} }
sb.WriteString(fmt.Sprintf("max-age=%d", opt.StrictTransportSecurity.MaxAge)) sb.WriteString(fmt.Sprintf("max-age=%d", maxAge))
if opt.StrictTransportSecurity.IncludeSubDomains { if opt.StrictTransportSecurity.IncludeSubDomains {
sb.WriteString("; includeSubDomains") sb.WriteString("; includeSubDomains")
} }
if opt.StrictTransportSecurity.Preload { if opt.StrictTransportSecurity.Preload {
sb.WriteString("; preload") sb.WriteString("; preload")
} }
w.Header().Add("Strict-Transport-Security", sb.String()) add("Strict-Transport-Security", sb.String())
} }
// X-Content-Type-Options
if !opt.DisableSniffMimeType { if !opt.DisableSniffMimeType {
// MIME types advertised in the Content-Current headers should be followed and not be changed add("X-Content-Type-Options", "nosniff")
w.Header().Add("X-Content-Type-Options", "nosniff")
} }
// X-DNS-Prefetch-Control
if opt.DisableDNSPrefetch { if opt.DisableDNSPrefetch {
w.Header().Add("X-DNS-Prefetch-Control", "off") add("X-DNS-Prefetch-Control", "off")
} else { } else {
w.Header().Add("X-DNS-Prefetch-Control", "on") add("X-DNS-Prefetch-Control", "on")
} }
// X-Download-Options
if !opt.DisableXDownload { if !opt.DisableXDownload {
// Instructs Internet Explorer not to open the file directly but to offer it for download first. add("X-Download-Options", "noopen")
w.Header().Add("X-Download-Options", "noopen")
} }
// indicate whether a browser should be allowed to render a page in iframe | frame | embed | object // X-Frame-Options
if opt.XFrameOption == "" { if opt.XFrameOption == "" {
w.Header().Add("X-Frame-Options", string(XFrameSameOrigin)) add("X-Frame-Options", string(XFrameSameOrigin))
} else { } else {
w.Header().Add("X-Frame-Options", string(opt.XFrameOption)) add("X-Frame-Options", string(opt.XFrameOption))
} }
// X-Permitted-Cross-Domain-Policies
if opt.CrossDomainPolicies == "" { if opt.CrossDomainPolicies == "" {
w.Header().Add("X-Permitted-Cross-Domain-Policies", string(CDPNone)) add("X-Permitted-Cross-Domain-Policies", string(CDPNone))
} else { } else {
w.Header().Add("X-Permitted-Cross-Domain-Policies", string(opt.CrossDomainPolicies)) add("X-Permitted-Cross-Domain-Policies", string(opt.CrossDomainPolicies))
} }
w.Header().Del("X-Powered-By") // X-Xss-Protection
if opt.XssProtection { if opt.XssProtection {
// feature of IE, Chrome and Safari that stops pages from loading when they detect reflected add("X-Xss-Protection", "1; mode=block")
// cross-site scripting (XSS) attacks.
w.Header().Add("X-Xss-Protection", "1; mode=block")
} else { } else {
// Following a decision by Google Chrome developers to disable Auditor, add("X-Xss-Protection", "0")
// developers should be able to disable the auditor for older browsers and set it to 0.
// The X-XSS-PROTECTION header was found to have a multitude of issues, instead of helping the
// developers protect their application.
w.Header().Add("X-Xss-Protection", "0")
} }
h.ServeHTTP(w, r) return headers
})
}
} }
func (csp *CSP) value() string { func (csp *CSP) value() string {

56
middleware/real_ip.go Normal file
View File

@@ -0,0 +1,56 @@
package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"net"
"net/http"
"strings"
)
var trueClientIP = http.CanonicalHeaderKey("True-Client-IP")
var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
var xRealIP = http.CanonicalHeaderKey("X-Real-IP")
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers
// (in that order).
//
// This middleware should be inserted fairly early in the middleware stack to
// ensure that subsequent layers (e.g., request loggers) which examine the
// RemoteAddr will see the intended value.
//
// You should only use this middleware if you can trust the headers passed to
// you (in particular, the three headers this middleware uses), for example
// because you have placed a reverse proxy like HAProxy or nginx in front of
// chi. If your reverse proxies are configured to pass along arbitrary header
// values from the client, or if you use this middleware without a reverse
// proxy, malicious clients will be able to make you very sad (or, depending on
// how you're using RemoteAddr, vulnerable to an attack of some sort).
func RealIP(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := realIP(r); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func realIP(r *http.Request) string {
var ip string
if tcip := r.Header.Get(trueClientIP); tcip != "" {
ip = tcip
} else if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
ip, _, _ = strings.Cut(xff, ",")
}
if ip == "" || net.ParseIP(ip) == nil {
return ""
}
return ip
}

102
middleware/request_id.go Normal file
View File

@@ -0,0 +1,102 @@
package middleware
// Ported from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
)
// Key to use when setting the request ID.
type ctxKeyRequestID int
// RequestIDKey is the key that holds the unique request ID in a request context.
const RequestIDKey ctxKeyRequestID = 0
// RequestIDHeader is the name of the HTTP Header which contains the request id.
// Exported so that it can be changed by developers
var RequestIDHeader = "X-Request-Id"
var prefix string
var reqid atomic.Uint64
// A quick note on the statistics here: we're trying to calculate the chance that
// two randomly generated base62 prefixes will collide. We use the formula from
// http://en.wikipedia.org/wiki/Birthday_problem
//
// P[m, n] \approx 1 - e^{-m^2/2n}
//
// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server
// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$
//
// For a $k$ character base-62 identifier, we have $n(k) = 62^k$
//
// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for
// our purposes, and is surely more than anyone would ever need in practice -- a
// process that is rebooted a handful of times a day for a hundred years has less
// than a millionth of a percent chance of generating two colliding IDs.
func init() {
hostname, err := os.Hostname()
if hostname == "" || err != nil {
hostname = "localhost"
}
var buf [12]byte
var b64 string
for len(b64) < 10 {
rand.Read(buf[:])
b64 = base64.StdEncoding.EncodeToString(buf[:])
b64 = strings.NewReplacer("+", "", "/", "").Replace(b64)
}
prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10])
}
// RequestID is a middleware that injects a request ID into the context of each
// request. A request ID is a string of the form "host.example.com/random-0001",
// where "random" is a base62 random string that uniquely identifies this go
// process, and where the last number is an atomically incremented request
// counter.
// maxRequestIDLen is the maximum length of an incoming request ID header
// to prevent log injection or memory abuse from malicious clients.
const maxRequestIDLen = 200
func RequestID(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := r.Header.Get(RequestIDHeader)
if requestID == "" {
myid := reqid.Add(1)
requestID = fmt.Sprintf("%s-%06d", prefix, myid)
} else if len(requestID) > maxRequestIDLen {
requestID = requestID[:maxRequestIDLen]
}
ctx = context.WithValue(ctx, RequestIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// GetReqID returns a request ID from the given context if one is present.
// Returns the empty string if a request ID cannot be found.
func GetReqID(ctx context.Context) string {
if ctx == nil {
return ""
}
if reqID, ok := ctx.Value(RequestIDKey).(string); ok {
return reqID
}
return ""
}
// NextRequestID generates the next request ID in the sequence.
func NextRequestID() uint64 {
return reqid.Add(1)
}

View File

@@ -0,0 +1,18 @@
package middleware
import (
"net/http"
)
// RequestSize is a middleware that will limit request sizes to a specified
// number of bytes. It uses MaxBytesReader to do so.
func RequestSize(bytes int64) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, bytes)
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
return f
}

20
mux.go
View File

@@ -23,6 +23,22 @@ func New() *Mux {
mux: http.NewServeMux(), mux: http.NewServeMux(),
routes: new(RouteList), routes: new(RouteList),
} }
// Catch-all OPTIONS handler.
// Pass it through all middlewares and drain oversized bodies.
m.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "0")
if r.ContentLength != 0 {
// Read up to 4KB of OPTIONS body (as mentioned in the
// spec as being reserved for future use), but anything
// over that is considered a waste of server resources
// (or an attack) and we abort and close the connection,
// courtesy of MaxBytesReader's EOF behavior.
mb := http.MaxBytesReader(w, r.Body, 4<<10)
_, _ = io.Copy(io.Discard, mb)
}
})
return m return m
} }
@@ -147,8 +163,8 @@ func (m *Mux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
func (m *Mux) PrintRoutes(w io.Writer) { func (m *Mux) PrintRoutes(w io.Writer) {
for _, route := range m.routes.All() { for _, route := range m.routes.All() {
w.Write([]byte(route)) _, _ = w.Write([]byte(route))
w.Write([]byte("\n")) _, _ = w.Write([]byte("\n"))
} }
} }

View File

@@ -306,10 +306,8 @@ func BenchmarkRouterSimple(b *testing.B) {
source := rand.NewSource(time.Now().UnixNano()) source := rand.NewSource(time.Now().UnixNano())
r := rand.New(source) r := rand.New(source)
// Generate a random integer between 0 and 99 (inclusive)
rn := r.Intn(10000)
for b.Loop() { for b.Loop() {
rn := r.Intn(10000)
req, _ := http.NewRequest(http.MethodGet, "/"+strconv.Itoa(rn), nil) req, _ := http.NewRequest(http.MethodGet, "/"+strconv.Itoa(rn), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
m.ServeHTTP(w, req) m.ServeHTTP(w, req)

View File

@@ -1,6 +1,6 @@
module code.patial.tech/go/mux/playground module code.patial.tech/go/mux/playground
go 1.24 go 1.26
require ( require (
code.patial.tech/go/mux v0.7.1 code.patial.tech/go/mux v0.7.1

View File

@@ -27,7 +27,7 @@ func (m *Mux) Resource(pattern string, fn func(res *Resource), mw ...func(http.H
} }
if strings.TrimSpace(pattern) == "" { if strings.TrimSpace(pattern) == "" {
panic("mux: Resource() requires a patter to work") panic("mux: Resource() requires a pattern to work")
} }
if fn == nil { if fn == nil {
@@ -215,11 +215,9 @@ func (res *Resource) Use(middlewares ...func(http.Handler) http.Handler) {
} }
func suffixIt(str, suffix string) string { func suffixIt(str, suffix string) string {
var p strings.Builder if strings.HasSuffix(str, "/") {
p.WriteString(str) return str + suffix
if !strings.HasSuffix(str, "/") {
p.WriteString("/")
} }
p.WriteString(suffix)
return p.String() return str + "/" + suffix
} }

View File

@@ -33,7 +33,7 @@ func (s *RouteList) Get(index int) (string, error) {
defer s.mu.RUnlock() defer s.mu.RUnlock()
if index < 0 || index >= len(s.routes) { if index < 0 || index >= len(s.routes) {
return "0", fmt.Errorf("index out of bounds") return "", fmt.Errorf("index out of bounds")
} }
return s.routes[index], nil return s.routes[index], nil
} }
@@ -47,5 +47,7 @@ func (s *RouteList) All() []string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.routes out := make([]string, len(s.routes))
copy(out, s.routes)
return out
} }

View File

@@ -3,7 +3,6 @@ package mux
import ( import (
"context" "context"
"errors" "errors"
"io"
"log" "log"
"log/slog" "log/slog"
"net" "net"
@@ -26,21 +25,6 @@ func (m *Mux) Serve(cb ServeCB) {
rootCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) rootCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop() defer stop()
// catch all options
// lets get it thorugh all middlewares
m.OPTIONS("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "0")
if r.ContentLength != 0 {
// Read up to 4KB of OPTIONS body (as mentioned in the
// spec as being reserved for future use), but anything
// over that is considered a waste of server resources
// (or an attack) and we abort and close the connection,
// courtesy of MaxBytesReader's EOF behavior.
mb := http.MaxBytesReader(w, r.Body, 4<<10)
io.Copy(io.Discard, mb)
}
})
srvCtx, cancelSrvCtx := context.WithCancel(context.Background()) srvCtx, cancelSrvCtx := context.WithCancel(context.Background())
srv := &http.Server{ srv := &http.Server{
Handler: m, Handler: m,
@@ -51,7 +35,7 @@ func (m *Mux) Serve(cb ServeCB) {
go func() { go func() {
if err := cb(srv); !errors.Is(err, http.ErrServerClosed) { if err := cb(srv); !errors.Is(err, http.ErrServerClosed) {
panic(err) log.Fatalf("server error: %v", err)
} }
}() }()
@@ -60,7 +44,7 @@ func (m *Mux) Serve(cb ServeCB) {
stop() stop()
m.IsShuttingDown.Store(true) m.IsShuttingDown.Store(true)
slog.Info("received interrupt singal, shutting down") slog.Info("received interrupt signal, shutting down")
time.Sleep(drainDelay) time.Sleep(drainDelay)
slog.Info("readiness check propagated, now waiting for ongoing requests to finish.") slog.Info("readiness check propagated, now waiting for ongoing requests to finish.")
@@ -74,5 +58,5 @@ func (m *Mux) Serve(cb ServeCB) {
time.Sleep(shutdownHardDelay) time.Sleep(shutdownHardDelay)
} }
slog.Info("seerver shut down gracefully") slog.Info("server shut down gracefully")
} }