92 lines
1.9 KiB
Go
92 lines
1.9 KiB
Go
// Package ratelimit provides a simple in-memory per-IP sliding-window rate limiter.
|
|
package ratelimit
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Limiter tracks request counts per IP within a sliding window.
|
|
type Limiter struct {
|
|
mu sync.Mutex
|
|
entries map[string][]time.Time
|
|
window time.Duration
|
|
max int
|
|
}
|
|
|
|
// New creates a Limiter that allows at most max requests per window per IP.
|
|
func New(window time.Duration, max int) *Limiter {
|
|
l := &Limiter{
|
|
entries: make(map[string][]time.Time),
|
|
window: window,
|
|
max: max,
|
|
}
|
|
go l.cleanup()
|
|
return l
|
|
}
|
|
|
|
// Allow returns true if the IP is within its rate limit, recording the attempt.
|
|
func (l *Limiter) Allow(ip string) bool {
|
|
now := time.Now()
|
|
cutoff := now.Add(-l.window)
|
|
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
times := l.entries[ip]
|
|
recent := times[:0]
|
|
for _, t := range times {
|
|
if t.After(cutoff) {
|
|
recent = append(recent, t)
|
|
}
|
|
}
|
|
|
|
if len(recent) >= l.max {
|
|
l.entries[ip] = recent
|
|
return false
|
|
}
|
|
|
|
l.entries[ip] = append(recent, now)
|
|
return true
|
|
}
|
|
|
|
// Middleware wraps a handler, rejecting over-limit requests with 429.
|
|
func (l *Limiter) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
ip = r.RemoteAddr
|
|
}
|
|
if !l.Allow(ip) {
|
|
http.Error(w, "Too many requests. Please wait and try again.", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// cleanup periodically removes expired entries to prevent unbounded growth.
|
|
func (l *Limiter) cleanup() {
|
|
for {
|
|
time.Sleep(5 * time.Minute)
|
|
cutoff := time.Now().Add(-l.window)
|
|
l.mu.Lock()
|
|
for ip, times := range l.entries {
|
|
recent := times[:0]
|
|
for _, t := range times {
|
|
if t.After(cutoff) {
|
|
recent = append(recent, t)
|
|
}
|
|
}
|
|
if len(recent) == 0 {
|
|
delete(l.entries, ip)
|
|
} else {
|
|
l.entries[ip] = recent
|
|
}
|
|
}
|
|
l.mu.Unlock()
|
|
}
|
|
}
|