46
README.md
Normal file
46
README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# echo-ratelimit
|
||||
|
||||
Small in-memory rate-limiting middleware for [Echo](https://github.com/labstack/echo), keyed by client IP.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
go get github.com/labstack/echo/v4
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
ratelimit "git.carlpearson.net/cwpearson/echo-ratelimit"
|
||||
)
|
||||
|
||||
func main() {
|
||||
e := echo.New()
|
||||
e.Use(ratelimit.Middleware(60, time.Minute))
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e.Logger.Fatal(e.Start(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Configuration
|
||||
|
||||
```go
|
||||
e.Use(ratelimit.WithConfig(ratelimit.Config{
|
||||
Limit: 10,
|
||||
Window: time.Minute,
|
||||
}))
|
||||
```
|
||||
|
||||
Rate-limit headers are emitted by default.
|
||||
18
go.mod
Normal file
18
go.mod
Normal file
@@ -0,0 +1,18 @@
|
||||
module git.carlpearson.net/cwpearson/echo-ratelimit
|
||||
|
||||
go 1.26.0
|
||||
|
||||
require github.com/labstack/echo/v4 v4.15.1
|
||||
|
||||
require (
|
||||
github.com/labstack/gommon v0.4.2 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
golang.org/x/crypto v0.46.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
)
|
||||
31
go.sum
Normal file
31
go.sum
Normal file
@@ -0,0 +1,31 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/labstack/echo/v4 v4.15.1 h1:S9keusg26gZpjMmPqB5hOEvNKnmd1lNmcHrbbH2lnFs=
|
||||
github.com/labstack/echo/v4 v4.15.1/go.mod h1:xmw1clThob0BSVRX1CRQkGQ/vjwcpOMjQZSZa9fKA/c=
|
||||
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
|
||||
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
|
||||
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
186
ratelimit.go
Normal file
186
ratelimit.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
)
|
||||
|
||||
// ErrInvalidConfig reports invalid middleware configuration.
|
||||
var ErrInvalidConfig = errors.New("ratelimit: invalid config")
|
||||
|
||||
// Decision contains the current request result and window metadata.
|
||||
type Decision struct {
|
||||
Allowed bool
|
||||
Limit int
|
||||
Remaining int
|
||||
ResetAt time.Time
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
// Config controls middleware behavior.
|
||||
type Config struct {
|
||||
Skipper middleware.Skipper
|
||||
Limit int
|
||||
Window time.Duration
|
||||
SetHeaders bool
|
||||
ErrorHandler func(echo.Context, error) error
|
||||
DenyHandler func(echo.Context, Decision) error
|
||||
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type clientState struct {
|
||||
count int
|
||||
resetAt time.Time
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
type memoryLimiter struct {
|
||||
limit int
|
||||
window time.Duration
|
||||
clients map[string]*clientState
|
||||
cleanupEvery uint64
|
||||
requests uint64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Middleware returns a middleware with sensible defaults.
|
||||
func Middleware(limit int, window time.Duration) echo.MiddlewareFunc {
|
||||
return WithConfig(Config{
|
||||
Limit: limit,
|
||||
Window: window,
|
||||
})
|
||||
}
|
||||
|
||||
// WithConfig creates an in-memory IP-based rate limiter middleware for Echo.
|
||||
func WithConfig(config Config) echo.MiddlewareFunc {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = middleware.DefaultSkipper
|
||||
}
|
||||
if config.ErrorHandler == nil {
|
||||
config.ErrorHandler = func(_ echo.Context, err error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if config.DenyHandler == nil {
|
||||
config.DenyHandler = func(c echo.Context, _ Decision) error {
|
||||
return echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
if config.now == nil {
|
||||
config.now = time.Now
|
||||
}
|
||||
if !config.SetHeaders {
|
||||
config.SetHeaders = true
|
||||
}
|
||||
if config.Limit <= 0 {
|
||||
panic(ErrInvalidConfig)
|
||||
}
|
||||
if config.Window <= 0 {
|
||||
panic(ErrInvalidConfig)
|
||||
}
|
||||
|
||||
limiter := &memoryLimiter{
|
||||
limit: config.Limit,
|
||||
window: config.Window,
|
||||
clients: make(map[string]*clientState),
|
||||
cleanupEvery: 256,
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
key := c.RealIP()
|
||||
if key == "" {
|
||||
err := errors.New("ratelimit: unable to determine client IP")
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
now := config.now()
|
||||
decision := limiter.allow(key, now)
|
||||
if config.SetHeaders {
|
||||
writeHeaders(c.Response().Header(), decision)
|
||||
}
|
||||
if !decision.Allowed {
|
||||
return config.DenyHandler(c, decision)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *memoryLimiter) allow(key string, now time.Time) Decision {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.requests++
|
||||
if l.requests%l.cleanupEvery == 0 {
|
||||
l.cleanup(now)
|
||||
}
|
||||
|
||||
state, ok := l.clients[key]
|
||||
if !ok || !now.Before(state.resetAt) {
|
||||
state = &clientState{
|
||||
count: 0,
|
||||
resetAt: now.Add(l.window),
|
||||
lastSeen: now,
|
||||
}
|
||||
l.clients[key] = state
|
||||
}
|
||||
|
||||
state.lastSeen = now
|
||||
remaining := l.limit - state.count - 1
|
||||
if state.count >= l.limit {
|
||||
retryAfter := state.resetAt.Sub(now)
|
||||
if retryAfter < 0 {
|
||||
retryAfter = 0
|
||||
}
|
||||
return Decision{
|
||||
Allowed: false,
|
||||
Limit: l.limit,
|
||||
Remaining: 0,
|
||||
ResetAt: state.resetAt,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
state.count++
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
return Decision{
|
||||
Allowed: true,
|
||||
Limit: l.limit,
|
||||
Remaining: remaining,
|
||||
ResetAt: state.resetAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *memoryLimiter) cleanup(now time.Time) {
|
||||
for key, state := range l.clients {
|
||||
if now.Sub(state.lastSeen) >= l.window && !now.Before(state.resetAt) {
|
||||
delete(l.clients, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeHeaders(header http.Header, decision Decision) {
|
||||
header.Set("X-RateLimit-Limit", strconv.Itoa(decision.Limit))
|
||||
header.Set("X-RateLimit-Remaining", strconv.Itoa(decision.Remaining))
|
||||
header.Set("X-RateLimit-Reset", strconv.FormatInt(decision.ResetAt.Unix(), 10))
|
||||
if !decision.Allowed {
|
||||
header.Set("Retry-After", strconv.Itoa(int(math.Ceil(decision.RetryAfter.Seconds()))))
|
||||
}
|
||||
}
|
||||
198
ratelimit_test.go
Normal file
198
ratelimit_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func TestAllowsRequestsWithinWindow(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 2,
|
||||
Window: time.Minute,
|
||||
SetHeaders: true,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
rec1 := performRequest(e, handler, "203.0.113.1")
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first request status = %d", rec1.Code)
|
||||
}
|
||||
if got := rec1.Header().Get("X-RateLimit-Remaining"); got != "1" {
|
||||
t.Fatalf("first remaining = %q", got)
|
||||
}
|
||||
|
||||
rec2 := performRequest(e, handler, "203.0.113.1")
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("second request status = %d", rec2.Code)
|
||||
}
|
||||
if got := rec2.Header().Get("X-RateLimit-Remaining"); got != "0" {
|
||||
t.Fatalf("second remaining = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadersEnabledByDefault(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: time.Minute,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
rec := performRequest(e, handler, "203.0.113.20")
|
||||
if got := rec.Header().Get("X-RateLimit-Limit"); got != "1" {
|
||||
t.Fatalf("limit header = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeniesRequestsOverLimit(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: 30 * time.Second,
|
||||
SetHeaders: true,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
_ = performRequest(e, handler, "203.0.113.10")
|
||||
rec := performRequest(e, handler, "203.0.113.10")
|
||||
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("status = %d", rec.Code)
|
||||
}
|
||||
if got := rec.Header().Get("Retry-After"); got != "30" {
|
||||
t.Fatalf("retry-after = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowResets(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: 10 * time.Second,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
rec1 := performRequest(e, handler, "198.51.100.1")
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d", rec1.Code)
|
||||
}
|
||||
|
||||
rec2 := performRequest(e, handler, "198.51.100.1")
|
||||
if rec2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second status = %d", rec2.Code)
|
||||
}
|
||||
|
||||
now = now.Add(11 * time.Second)
|
||||
rec3 := performRequest(e, handler, "198.51.100.1")
|
||||
if rec3.Code != http.StatusOK {
|
||||
t.Fatalf("third status = %d", rec3.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipperBypassesLimiter(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: time.Minute,
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return c.Path() == "/health"
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
rec1 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
|
||||
rec2 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
|
||||
if rec1.Code != http.StatusOK || rec2.Code != http.StatusOK {
|
||||
t.Fatalf("skipped requests should both pass: %d %d", rec1.Code, rec2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPErrorsPropagate(t *testing.T) {
|
||||
e := echo.New()
|
||||
expected := errors.New("boom")
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: time.Minute,
|
||||
ErrorHandler: func(_ echo.Context, err error) error {
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
return expected
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = ":1234"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func performRequest(e *echo.Echo, h echo.HandlerFunc, ip string) *httptest.ResponseRecorder {
|
||||
return performRequestWithPath(e, h, ip, "/")
|
||||
}
|
||||
|
||||
func performRequestWithPath(e *echo.Echo, h echo.HandlerFunc, ip, path string) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
req.RemoteAddr = ip + ":1234"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetPath(path)
|
||||
|
||||
err := h(c)
|
||||
if err != nil {
|
||||
e.HTTPErrorHandler(err, c)
|
||||
}
|
||||
|
||||
return rec
|
||||
}
|
||||
Reference in New Issue
Block a user