// Author: Ankit Patial // inspired from Helmet.js // https://github.com/helmetjs/helmet/tree/main package middleware import ( "fmt" "net/http" "strings" ) type ( HelmetOption struct { StrictTransportSecurity *TransportSecurity XFrameOption XFrame CrossOriginEmbedderPolicy Embedder CrossOriginOpenerPolicy Opener CrossOriginResourcePolicy Resource CrossDomainPolicies CDP ReferrerPolicy []Referrer ContentSecurityPolicy CSP OriginAgentCluster bool DisableXDownload bool DisableDNSPrefetch bool DisableSniffMimeType bool XssProtection bool } // CSP is Content-Security-Policy settings // // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/Sources CSP struct { // default-src, default value will be 'self' DefaultSrc []string // script-src, default value will be 'self' ScriptSrc []string // script-src-attr, default value will be 'none' ScriptSrcAttr []string // style-src, default value will be 'self' https: 'unsafe-inline' StyleSrc []string // img-src, default value will be 'self' data: ImgSrc []string // object-src, default value will be 'none' ObjectSrc []string // base-uri, default value will be 'self' BaseUri []string // font-src, default value will be 'self' https: data: FontSrc []string // form-action, default value will be 'self' FormAction []string // frame-ancestors, default value will be 'self' FrameAncestors []string UpgradeInsecureRequests bool } TransportSecurity struct { // Age in seconds MaxAge uint IncludeSubDomains bool Preload bool } Embedder string Opener string Resource string Referrer string // CDP Cross-Domain-Policy CDP string XFrame string ) const ( YearDuration = 365 * 24 * 60 * 60 // OpenerSameOrigin is default if no value supplied OpenerSameOrigin Opener = "same-origin" OpenerSameOriginAllowPopups Opener = "same-origin-allow-popups" OpenerUnsafeNone Opener = "unsafe-none" // EmbedderDefault is default if no value supplied EmbedderRequireCorp Embedder = "require-corp" EmbedderCredentialLess Embedder = "credentialless" EmbedderUnsafeNone Embedder = "unsafe-none" // ResourceDefault default value will be "same-origin" ResourceSameOrigin Resource = "same-origin" ResourceSameSite Resource = "same-site" ResourceCrossOrigin Resource = "cross-origin" NoReferrer Referrer = "no-referrer" NoReferrerWhenDowngrade Referrer = "no-referrer-when-downgrade" SameOrigin Referrer = "same-origin" Origin Referrer = "origin" StrictOrigin Referrer = "strict-origin" OriginWhenCrossOrigin Referrer = "origin-when-cross-origin" StrictOriginWhenCrossOrigin Referrer = "strict-origin-when-cross-origin" UnsafeUrl Referrer = "unsafe-url" // CDPNone is default if no value supplied CDPNone CDP = "none" CDPMasterOnly CDP = "master-only" CDPByContentType CDP = "by-content-type" CDPAll CDP = "all" // XFrameSameOrigin is default if no value supplied XFrameSameOrigin XFrame = "sameorigin" XFrameDeny XFrame = "deny" ) // Helmet headers to secure server response func Helmet(opt HelmetOption) func(http.Handler) http.Handler { // Precompute all static header values once at middleware creation time. headers := buildHelmetHeaders(opt) return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for _, kv := range headers { w.Header().Add(kv.key, kv.value) } w.Header().Del("X-Powered-By") h.ServeHTTP(w, r) }) } } type headerKV struct { key string value string } func buildHelmetHeaders(opt HelmetOption) []headerKV { var headers []headerKV add := func(key, value string) { headers = append(headers, headerKV{key: key, value: value}) } // Content-Security-Policy add("Content-Security-Policy", opt.ContentSecurityPolicy.value()) // Cross-Origin-Opener-Policy if opt.CrossOriginOpenerPolicy == "" { add("Cross-Origin-Opener-Policy", string(OpenerSameOrigin)) } else { add("Cross-Origin-Opener-Policy", string(opt.CrossOriginOpenerPolicy)) } // Cross-Origin-Resource-Policy if opt.CrossOriginResourcePolicy == "" { add("Cross-Origin-Resource-Policy", string(ResourceSameOrigin)) } else { add("Cross-Origin-Resource-Policy", string(opt.CrossOriginResourcePolicy)) } // Referrer-Policy if len(opt.ReferrerPolicy) > 0 { refP := make([]string, len(opt.ReferrerPolicy)) for i, r := range opt.ReferrerPolicy { refP[i] = string(r) } add("Referrer-Policy", strings.Join(refP, ",")) } else { add("Referrer-Policy", string(NoReferrer)) } // Origin-Agent-Cluster if opt.OriginAgentCluster { add("Origin-Agent-Cluster", "?1") } // Strict-Transport-Security if opt.StrictTransportSecurity != nil { var sb strings.Builder maxAge := opt.StrictTransportSecurity.MaxAge if maxAge == 0 { maxAge = YearDuration } sb.WriteString(fmt.Sprintf("max-age=%d", maxAge)) if opt.StrictTransportSecurity.IncludeSubDomains { sb.WriteString("; includeSubDomains") } if opt.StrictTransportSecurity.Preload { sb.WriteString("; preload") } add("Strict-Transport-Security", sb.String()) } // X-Content-Type-Options if !opt.DisableSniffMimeType { add("X-Content-Type-Options", "nosniff") } // X-DNS-Prefetch-Control if opt.DisableDNSPrefetch { add("X-DNS-Prefetch-Control", "off") } else { add("X-DNS-Prefetch-Control", "on") } // X-Download-Options if !opt.DisableXDownload { add("X-Download-Options", "noopen") } // X-Frame-Options if opt.XFrameOption == "" { add("X-Frame-Options", string(XFrameSameOrigin)) } else { add("X-Frame-Options", string(opt.XFrameOption)) } // X-Permitted-Cross-Domain-Policies if opt.CrossDomainPolicies == "" { add("X-Permitted-Cross-Domain-Policies", string(CDPNone)) } else { add("X-Permitted-Cross-Domain-Policies", string(opt.CrossDomainPolicies)) } // X-Xss-Protection if opt.XssProtection { add("X-Xss-Protection", "1; mode=block") } else { add("X-Xss-Protection", "0") } return headers } func (csp *CSP) value() string { var sb strings.Builder // should be the first thing if csp.UpgradeInsecureRequests { sb.WriteString("upgrade-insecure-requests;") } sb.WriteString(fmt.Sprintf( "default-src %s; ", cspNormalised(csp.DefaultSrc, []string{"self"}), )) sb.WriteString(fmt.Sprintf( "script-src %s; ", cspNormalised(csp.ScriptSrc, []string{"self"}), )) sb.WriteString(fmt.Sprintf( "script-src-attr %s; ", cspNormalised(csp.ScriptSrcAttr, []string{"none"}), )) sb.WriteString(fmt.Sprintf( "style-src %s; ", cspNormalised(csp.StyleSrc, []string{"self", "unsafe-inline"}), )) sb.WriteString(fmt.Sprintf( "img-src %s; ", cspNormalised(csp.ImgSrc, []string{"self", "data:"}), )) sb.WriteString(fmt.Sprintf( "object-src %s; ", cspNormalised(csp.ObjectSrc, []string{"none"}), )) sb.WriteString(fmt.Sprintf( "base-uri %s; ", cspNormalised(csp.BaseUri, []string{"self"}), )) sb.WriteString(fmt.Sprintf( "font-src %s; ", cspNormalised(csp.FontSrc, []string{"self", "data:"}), )) sb.WriteString(fmt.Sprintf( "form-action %s; ", cspNormalised(csp.FormAction, []string{"self"}), )) sb.WriteString(fmt.Sprintf( "frame-ancestors %s; ", cspNormalised(csp.FrameAncestors, []string{"self"}), )) return sb.String() } func cspNormalised(v, defaultVal []string) string { if len(v) == 0 { v = defaultVal } var sb strings.Builder for _, val := range v { val = strings.TrimSpace(val) if val == "" { continue } sb.WriteString(" " + cspQuoted(val)) } return strings.TrimSpace(sb.String()) } func cspQuoted(v string) string { switch v { case "none", "self", "strict-dynamic", "report-sample", "inline-speculation-rules", "unsafe-inline", "unsafe-eval", "unsafe-hashes", "wasm-unsafe-eval": return fmt.Sprintf("'%s'", v) default: return v } }