package ratelimit import ( "errors" "math" "net/http" "strconv" "sync" "time" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" ) // ErrInvalidConfig reports invalid middleware configuration. var ErrInvalidConfig = errors.New("ratelimit: invalid config") // Decision contains the current request result and window metadata. type Decision struct { Allowed bool Limit int Remaining int ResetAt time.Time RetryAfter time.Duration } // Config controls middleware behavior. type Config struct { Skipper middleware.Skipper Limit int Window time.Duration SetHeaders bool ErrorHandler func(echo.Context, error) error DenyHandler func(echo.Context, Decision) error now func() time.Time } type clientState struct { count int resetAt time.Time lastSeen time.Time } type memoryLimiter struct { limit int window time.Duration clients map[string]*clientState cleanupEvery uint64 requests uint64 mu sync.Mutex } // Middleware returns a middleware with sensible defaults. func Middleware(limit int, window time.Duration) echo.MiddlewareFunc { return WithConfig(Config{ Limit: limit, Window: window, }) } // WithConfig creates an in-memory IP-based rate limiter middleware for Echo. func WithConfig(config Config) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = middleware.DefaultSkipper } if config.ErrorHandler == nil { config.ErrorHandler = func(_ echo.Context, err error) error { return err } } if config.DenyHandler == nil { config.DenyHandler = func(c echo.Context, _ Decision) error { return echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") } } if config.now == nil { config.now = time.Now } if !config.SetHeaders { config.SetHeaders = true } if config.Limit <= 0 { panic(ErrInvalidConfig) } if config.Window <= 0 { panic(ErrInvalidConfig) } limiter := &memoryLimiter{ limit: config.Limit, window: config.Window, clients: make(map[string]*clientState), cleanupEvery: 256, } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } key := c.RealIP() if key == "" { err := errors.New("ratelimit: unable to determine client IP") return config.ErrorHandler(c, err) } now := config.now() decision := limiter.allow(key, now) if config.SetHeaders { writeHeaders(c.Response().Header(), decision) } if !decision.Allowed { return config.DenyHandler(c, decision) } return next(c) } } } func (l *memoryLimiter) allow(key string, now time.Time) Decision { l.mu.Lock() defer l.mu.Unlock() l.requests++ if l.requests%l.cleanupEvery == 0 { l.cleanup(now) } state, ok := l.clients[key] if !ok || !now.Before(state.resetAt) { state = &clientState{ count: 0, resetAt: now.Add(l.window), lastSeen: now, } l.clients[key] = state } state.lastSeen = now remaining := l.limit - state.count - 1 if state.count >= l.limit { retryAfter := state.resetAt.Sub(now) if retryAfter < 0 { retryAfter = 0 } return Decision{ Allowed: false, Limit: l.limit, Remaining: 0, ResetAt: state.resetAt, RetryAfter: retryAfter, } } state.count++ if remaining < 0 { remaining = 0 } return Decision{ Allowed: true, Limit: l.limit, Remaining: remaining, ResetAt: state.resetAt, } } func (l *memoryLimiter) cleanup(now time.Time) { for key, state := range l.clients { if now.Sub(state.lastSeen) >= l.window && !now.Before(state.resetAt) { delete(l.clients, key) } } } func writeHeaders(header http.Header, decision Decision) { header.Set("X-RateLimit-Limit", strconv.Itoa(decision.Limit)) header.Set("X-RateLimit-Remaining", strconv.Itoa(decision.Remaining)) header.Set("X-RateLimit-Reset", strconv.FormatInt(decision.ResetAt.Unix(), 10)) if !decision.Allowed { header.Set("Retry-After", strconv.Itoa(int(math.Ceil(decision.RetryAfter.Seconds())))) } }