Files
echo-ratelimit/ratelimit.go
Carl Pearson cf6d36d7cb Initial commit
Signed-off-by: Carl Pearson <me@carlpearson.net>
2026-04-13 23:06:50 +00:00

187 lines
4.0 KiB
Go

package ratelimit
import (
"errors"
"math"
"net/http"
"strconv"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
// ErrInvalidConfig reports invalid middleware configuration.
var ErrInvalidConfig = errors.New("ratelimit: invalid config")
// Decision contains the current request result and window metadata.
type Decision struct {
Allowed bool
Limit int
Remaining int
ResetAt time.Time
RetryAfter time.Duration
}
// Config controls middleware behavior.
type Config struct {
Skipper middleware.Skipper
Limit int
Window time.Duration
SetHeaders bool
ErrorHandler func(echo.Context, error) error
DenyHandler func(echo.Context, Decision) error
now func() time.Time
}
type clientState struct {
count int
resetAt time.Time
lastSeen time.Time
}
type memoryLimiter struct {
limit int
window time.Duration
clients map[string]*clientState
cleanupEvery uint64
requests uint64
mu sync.Mutex
}
// Middleware returns a middleware with sensible defaults.
func Middleware(limit int, window time.Duration) echo.MiddlewareFunc {
return WithConfig(Config{
Limit: limit,
Window: window,
})
}
// WithConfig creates an in-memory IP-based rate limiter middleware for Echo.
func WithConfig(config Config) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = middleware.DefaultSkipper
}
if config.ErrorHandler == nil {
config.ErrorHandler = func(_ echo.Context, err error) error {
return err
}
}
if config.DenyHandler == nil {
config.DenyHandler = func(c echo.Context, _ Decision) error {
return echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
}
}
if config.now == nil {
config.now = time.Now
}
if !config.SetHeaders {
config.SetHeaders = true
}
if config.Limit <= 0 {
panic(ErrInvalidConfig)
}
if config.Window <= 0 {
panic(ErrInvalidConfig)
}
limiter := &memoryLimiter{
limit: config.Limit,
window: config.Window,
clients: make(map[string]*clientState),
cleanupEvery: 256,
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
key := c.RealIP()
if key == "" {
err := errors.New("ratelimit: unable to determine client IP")
return config.ErrorHandler(c, err)
}
now := config.now()
decision := limiter.allow(key, now)
if config.SetHeaders {
writeHeaders(c.Response().Header(), decision)
}
if !decision.Allowed {
return config.DenyHandler(c, decision)
}
return next(c)
}
}
}
func (l *memoryLimiter) allow(key string, now time.Time) Decision {
l.mu.Lock()
defer l.mu.Unlock()
l.requests++
if l.requests%l.cleanupEvery == 0 {
l.cleanup(now)
}
state, ok := l.clients[key]
if !ok || !now.Before(state.resetAt) {
state = &clientState{
count: 0,
resetAt: now.Add(l.window),
lastSeen: now,
}
l.clients[key] = state
}
state.lastSeen = now
remaining := l.limit - state.count - 1
if state.count >= l.limit {
retryAfter := state.resetAt.Sub(now)
if retryAfter < 0 {
retryAfter = 0
}
return Decision{
Allowed: false,
Limit: l.limit,
Remaining: 0,
ResetAt: state.resetAt,
RetryAfter: retryAfter,
}
}
state.count++
if remaining < 0 {
remaining = 0
}
return Decision{
Allowed: true,
Limit: l.limit,
Remaining: remaining,
ResetAt: state.resetAt,
}
}
func (l *memoryLimiter) cleanup(now time.Time) {
for key, state := range l.clients {
if now.Sub(state.lastSeen) >= l.window && !now.Before(state.resetAt) {
delete(l.clients, key)
}
}
}
func writeHeaders(header http.Header, decision Decision) {
header.Set("X-RateLimit-Limit", strconv.Itoa(decision.Limit))
header.Set("X-RateLimit-Remaining", strconv.Itoa(decision.Remaining))
header.Set("X-RateLimit-Reset", strconv.FormatInt(decision.ResetAt.Unix(), 10))
if !decision.Allowed {
header.Set("Retry-After", strconv.Itoa(int(math.Ceil(decision.RetryAfter.Seconds()))))
}
}