187 lines
4.0 KiB
Go
187 lines
4.0 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"errors"
|
|
"math"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
)
|
|
|
|
// ErrInvalidConfig reports invalid middleware configuration.
|
|
var ErrInvalidConfig = errors.New("ratelimit: invalid config")
|
|
|
|
// Decision contains the current request result and window metadata.
|
|
type Decision struct {
|
|
Allowed bool
|
|
Limit int
|
|
Remaining int
|
|
ResetAt time.Time
|
|
RetryAfter time.Duration
|
|
}
|
|
|
|
// Config controls middleware behavior.
|
|
type Config struct {
|
|
Skipper middleware.Skipper
|
|
Limit int
|
|
Window time.Duration
|
|
SetHeaders bool
|
|
ErrorHandler func(echo.Context, error) error
|
|
DenyHandler func(echo.Context, Decision) error
|
|
|
|
now func() time.Time
|
|
}
|
|
|
|
type clientState struct {
|
|
count int
|
|
resetAt time.Time
|
|
lastSeen time.Time
|
|
}
|
|
|
|
type memoryLimiter struct {
|
|
limit int
|
|
window time.Duration
|
|
clients map[string]*clientState
|
|
cleanupEvery uint64
|
|
requests uint64
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// Middleware returns a middleware with sensible defaults.
|
|
func Middleware(limit int, window time.Duration) echo.MiddlewareFunc {
|
|
return WithConfig(Config{
|
|
Limit: limit,
|
|
Window: window,
|
|
})
|
|
}
|
|
|
|
// WithConfig creates an in-memory IP-based rate limiter middleware for Echo.
|
|
func WithConfig(config Config) echo.MiddlewareFunc {
|
|
if config.Skipper == nil {
|
|
config.Skipper = middleware.DefaultSkipper
|
|
}
|
|
if config.ErrorHandler == nil {
|
|
config.ErrorHandler = func(_ echo.Context, err error) error {
|
|
return err
|
|
}
|
|
}
|
|
if config.DenyHandler == nil {
|
|
config.DenyHandler = func(c echo.Context, _ Decision) error {
|
|
return echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
|
|
}
|
|
}
|
|
if config.now == nil {
|
|
config.now = time.Now
|
|
}
|
|
if !config.SetHeaders {
|
|
config.SetHeaders = true
|
|
}
|
|
if config.Limit <= 0 {
|
|
panic(ErrInvalidConfig)
|
|
}
|
|
if config.Window <= 0 {
|
|
panic(ErrInvalidConfig)
|
|
}
|
|
|
|
limiter := &memoryLimiter{
|
|
limit: config.Limit,
|
|
window: config.Window,
|
|
clients: make(map[string]*clientState),
|
|
cleanupEvery: 256,
|
|
}
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if config.Skipper(c) {
|
|
return next(c)
|
|
}
|
|
|
|
key := c.RealIP()
|
|
if key == "" {
|
|
err := errors.New("ratelimit: unable to determine client IP")
|
|
return config.ErrorHandler(c, err)
|
|
}
|
|
|
|
now := config.now()
|
|
decision := limiter.allow(key, now)
|
|
if config.SetHeaders {
|
|
writeHeaders(c.Response().Header(), decision)
|
|
}
|
|
if !decision.Allowed {
|
|
return config.DenyHandler(c, decision)
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *memoryLimiter) allow(key string, now time.Time) Decision {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
l.requests++
|
|
if l.requests%l.cleanupEvery == 0 {
|
|
l.cleanup(now)
|
|
}
|
|
|
|
state, ok := l.clients[key]
|
|
if !ok || !now.Before(state.resetAt) {
|
|
state = &clientState{
|
|
count: 0,
|
|
resetAt: now.Add(l.window),
|
|
lastSeen: now,
|
|
}
|
|
l.clients[key] = state
|
|
}
|
|
|
|
state.lastSeen = now
|
|
remaining := l.limit - state.count - 1
|
|
if state.count >= l.limit {
|
|
retryAfter := state.resetAt.Sub(now)
|
|
if retryAfter < 0 {
|
|
retryAfter = 0
|
|
}
|
|
return Decision{
|
|
Allowed: false,
|
|
Limit: l.limit,
|
|
Remaining: 0,
|
|
ResetAt: state.resetAt,
|
|
RetryAfter: retryAfter,
|
|
}
|
|
}
|
|
|
|
state.count++
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
|
|
return Decision{
|
|
Allowed: true,
|
|
Limit: l.limit,
|
|
Remaining: remaining,
|
|
ResetAt: state.resetAt,
|
|
}
|
|
}
|
|
|
|
func (l *memoryLimiter) cleanup(now time.Time) {
|
|
for key, state := range l.clients {
|
|
if now.Sub(state.lastSeen) >= l.window && !now.Before(state.resetAt) {
|
|
delete(l.clients, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func writeHeaders(header http.Header, decision Decision) {
|
|
header.Set("X-RateLimit-Limit", strconv.Itoa(decision.Limit))
|
|
header.Set("X-RateLimit-Remaining", strconv.Itoa(decision.Remaining))
|
|
header.Set("X-RateLimit-Reset", strconv.FormatInt(decision.ResetAt.Unix(), 10))
|
|
if !decision.Allowed {
|
|
header.Set("Retry-After", strconv.Itoa(int(math.Ceil(decision.RetryAfter.Seconds()))))
|
|
}
|
|
}
|