186
ratelimit.go
Normal file
186
ratelimit.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
)
|
||||
|
||||
// ErrInvalidConfig reports invalid middleware configuration.
|
||||
var ErrInvalidConfig = errors.New("ratelimit: invalid config")
|
||||
|
||||
// Decision contains the current request result and window metadata.
|
||||
type Decision struct {
|
||||
Allowed bool
|
||||
Limit int
|
||||
Remaining int
|
||||
ResetAt time.Time
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
// Config controls middleware behavior.
|
||||
type Config struct {
|
||||
Skipper middleware.Skipper
|
||||
Limit int
|
||||
Window time.Duration
|
||||
SetHeaders bool
|
||||
ErrorHandler func(echo.Context, error) error
|
||||
DenyHandler func(echo.Context, Decision) error
|
||||
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type clientState struct {
|
||||
count int
|
||||
resetAt time.Time
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
type memoryLimiter struct {
|
||||
limit int
|
||||
window time.Duration
|
||||
clients map[string]*clientState
|
||||
cleanupEvery uint64
|
||||
requests uint64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Middleware returns a middleware with sensible defaults.
|
||||
func Middleware(limit int, window time.Duration) echo.MiddlewareFunc {
|
||||
return WithConfig(Config{
|
||||
Limit: limit,
|
||||
Window: window,
|
||||
})
|
||||
}
|
||||
|
||||
// WithConfig creates an in-memory IP-based rate limiter middleware for Echo.
|
||||
func WithConfig(config Config) echo.MiddlewareFunc {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = middleware.DefaultSkipper
|
||||
}
|
||||
if config.ErrorHandler == nil {
|
||||
config.ErrorHandler = func(_ echo.Context, err error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if config.DenyHandler == nil {
|
||||
config.DenyHandler = func(c echo.Context, _ Decision) error {
|
||||
return echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
if config.now == nil {
|
||||
config.now = time.Now
|
||||
}
|
||||
if !config.SetHeaders {
|
||||
config.SetHeaders = true
|
||||
}
|
||||
if config.Limit <= 0 {
|
||||
panic(ErrInvalidConfig)
|
||||
}
|
||||
if config.Window <= 0 {
|
||||
panic(ErrInvalidConfig)
|
||||
}
|
||||
|
||||
limiter := &memoryLimiter{
|
||||
limit: config.Limit,
|
||||
window: config.Window,
|
||||
clients: make(map[string]*clientState),
|
||||
cleanupEvery: 256,
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
key := c.RealIP()
|
||||
if key == "" {
|
||||
err := errors.New("ratelimit: unable to determine client IP")
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
now := config.now()
|
||||
decision := limiter.allow(key, now)
|
||||
if config.SetHeaders {
|
||||
writeHeaders(c.Response().Header(), decision)
|
||||
}
|
||||
if !decision.Allowed {
|
||||
return config.DenyHandler(c, decision)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *memoryLimiter) allow(key string, now time.Time) Decision {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.requests++
|
||||
if l.requests%l.cleanupEvery == 0 {
|
||||
l.cleanup(now)
|
||||
}
|
||||
|
||||
state, ok := l.clients[key]
|
||||
if !ok || !now.Before(state.resetAt) {
|
||||
state = &clientState{
|
||||
count: 0,
|
||||
resetAt: now.Add(l.window),
|
||||
lastSeen: now,
|
||||
}
|
||||
l.clients[key] = state
|
||||
}
|
||||
|
||||
state.lastSeen = now
|
||||
remaining := l.limit - state.count - 1
|
||||
if state.count >= l.limit {
|
||||
retryAfter := state.resetAt.Sub(now)
|
||||
if retryAfter < 0 {
|
||||
retryAfter = 0
|
||||
}
|
||||
return Decision{
|
||||
Allowed: false,
|
||||
Limit: l.limit,
|
||||
Remaining: 0,
|
||||
ResetAt: state.resetAt,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
state.count++
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
return Decision{
|
||||
Allowed: true,
|
||||
Limit: l.limit,
|
||||
Remaining: remaining,
|
||||
ResetAt: state.resetAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *memoryLimiter) cleanup(now time.Time) {
|
||||
for key, state := range l.clients {
|
||||
if now.Sub(state.lastSeen) >= l.window && !now.Before(state.resetAt) {
|
||||
delete(l.clients, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeHeaders(header http.Header, decision Decision) {
|
||||
header.Set("X-RateLimit-Limit", strconv.Itoa(decision.Limit))
|
||||
header.Set("X-RateLimit-Remaining", strconv.Itoa(decision.Remaining))
|
||||
header.Set("X-RateLimit-Reset", strconv.FormatInt(decision.ResetAt.Unix(), 10))
|
||||
if !decision.Allowed {
|
||||
header.Set("Retry-After", strconv.Itoa(int(math.Ceil(decision.RetryAfter.Seconds()))))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user