package web import ( "context" "log/slog" "net" "net/http" "sync" "time" "golang.org/x/time/rate" ) // securityHeaders are applied to every response. var securityHeaders = map[string]string{ "Referrer-Policy": "strict-origin-when-cross-origin", "X-Content-Type-Options": "nosniff", "X-Frame-Options": "SAMEORIGIN", "Strict-Transport-Security": "max-age=31536000; includeSubDomains", "Permissions-Policy": "camera=(), microphone=(), geolocation=(), interest-cohort=()", "Content-Security-Policy": "default-src 'self'; " + "style-src 'self' 'unsafe-inline'; " + "font-src 'self'; " + "script-src 'self'; " + "img-src 'self' data:; " + "connect-src 'self'; " + "frame-ancestors 'none'; " + "base-uri 'self'; " + "form-action 'self'", } func applySecurityHeaders(w http.ResponseWriter) { for k, v := range securityHeaders { if w.Header().Get(k) == "" { w.Header().Set(k, v) } } } // ---- per-IP rate limiter ---- type ipLimiter struct { limiter *rate.Limiter lastSeen time.Time } // RateLimiter tracks per-IP request rates. type RateLimiter struct { mu sync.Mutex limiters map[string]*ipLimiter r rate.Limit burst int } // NewRateLimiter creates a RateLimiter with the given sustained rate and burst. func NewRateLimiter(r rate.Limit, burst int) *RateLimiter { rl := &RateLimiter{ limiters: make(map[string]*ipLimiter), r: r, burst: burst, } go rl.cleanup() return rl } func (rl *RateLimiter) get(ip string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() entry, ok := rl.limiters[ip] if !ok { entry = &ipLimiter{limiter: rate.NewLimiter(rl.r, rl.burst)} rl.limiters[ip] = entry } entry.lastSeen = time.Now() return entry.limiter } func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { rl.mu.Lock() for ip, e := range rl.limiters { if time.Since(e.lastSeen) > 10*time.Minute { delete(rl.limiters, ip) } } rl.mu.Unlock() } } // ---- status recorder for logging ---- type statusRecorder struct { http.ResponseWriter status int wroteHeader bool } func (sr *statusRecorder) WriteHeader(code int) { if !sr.wroteHeader { sr.status = code sr.wroteHeader = true } sr.ResponseWriter.WriteHeader(code) } func (sr *statusRecorder) Write(b []byte) (int, error) { if !sr.wroteHeader { sr.status = http.StatusOK sr.wroteHeader = true } return sr.ResponseWriter.Write(b) } // BuildMiddleware wraps mux with: logging → timeout → rate-limit → security headers. func BuildMiddleware(mux http.Handler, rl *RateLimiter, timeout time.Duration) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() r = r.WithContext(ctx) ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { ip = r.RemoteAddr } if !rl.get(ip).Allow() { http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests) slog.Info("rate limited", "ip", ip, "path", r.URL.Path) return } applySecurityHeaders(w) sr := &statusRecorder{ResponseWriter: w} mux.ServeHTTP(sr, r) slog.Info("request", "method", r.Method, "path", r.URL.Path, "status", sr.status, "ip", ip, "duration", time.Since(start).String(), ) }) }