From cf6d36d7cbfa9873cf9a699b60d07ce6d9d5aaf5 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Mon, 13 Apr 2026 23:06:50 +0000 Subject: [PATCH] Initial commit Signed-off-by: Carl Pearson --- README.md | 46 +++++++++++ go.mod | 18 +++++ go.sum | 31 ++++++++ ratelimit.go | 186 +++++++++++++++++++++++++++++++++++++++++++ ratelimit_test.go | 198 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 479 insertions(+) create mode 100644 README.md create mode 100644 go.mod create mode 100644 go.sum create mode 100644 ratelimit.go create mode 100644 ratelimit_test.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..218733d --- /dev/null +++ b/README.md @@ -0,0 +1,46 @@ +# echo-ratelimit + +Small in-memory rate-limiting middleware for [Echo](https://github.com/labstack/echo), keyed by client IP. + +## Install + +```bash +go get github.com/labstack/echo/v4 +``` + +## Usage + +```go +package main + +import ( + "net/http" + "time" + + "github.com/labstack/echo/v4" + + ratelimit "git.carlpearson.net/cwpearson/echo-ratelimit" +) + +func main() { + e := echo.New() + e.Use(ratelimit.Middleware(60, time.Minute)) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + e.Logger.Fatal(e.Start(":8080")) +} +``` + +## Custom Configuration + +```go +e.Use(ratelimit.WithConfig(ratelimit.Config{ + Limit: 10, + Window: time.Minute, +})) +``` + +Rate-limit headers are emitted by default. diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6dea583 --- /dev/null +++ b/go.mod @@ -0,0 +1,18 @@ +module git.carlpearson.net/cwpearson/echo-ratelimit + +go 1.26.0 + +require github.com/labstack/echo/v4 v4.15.1 + +require ( + github.com/labstack/gommon v0.4.2 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect + golang.org/x/crypto v0.46.0 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect + golang.org/x/time v0.14.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9e7e68d --- /dev/null +++ b/go.sum @@ -0,0 +1,31 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/labstack/echo/v4 v4.15.1 h1:S9keusg26gZpjMmPqB5hOEvNKnmd1lNmcHrbbH2lnFs= +github.com/labstack/echo/v4 v4.15.1/go.mod h1:xmw1clThob0BSVRX1CRQkGQ/vjwcpOMjQZSZa9fKA/c= +github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= +github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..616c02e --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,186 @@ +package ratelimit + +import ( + "errors" + "math" + "net/http" + "strconv" + "sync" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +// ErrInvalidConfig reports invalid middleware configuration. +var ErrInvalidConfig = errors.New("ratelimit: invalid config") + +// Decision contains the current request result and window metadata. +type Decision struct { + Allowed bool + Limit int + Remaining int + ResetAt time.Time + RetryAfter time.Duration +} + +// Config controls middleware behavior. +type Config struct { + Skipper middleware.Skipper + Limit int + Window time.Duration + SetHeaders bool + ErrorHandler func(echo.Context, error) error + DenyHandler func(echo.Context, Decision) error + + now func() time.Time +} + +type clientState struct { + count int + resetAt time.Time + lastSeen time.Time +} + +type memoryLimiter struct { + limit int + window time.Duration + clients map[string]*clientState + cleanupEvery uint64 + requests uint64 + mu sync.Mutex +} + +// Middleware returns a middleware with sensible defaults. +func Middleware(limit int, window time.Duration) echo.MiddlewareFunc { + return WithConfig(Config{ + Limit: limit, + Window: window, + }) +} + +// WithConfig creates an in-memory IP-based rate limiter middleware for Echo. +func WithConfig(config Config) echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = middleware.DefaultSkipper + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(_ echo.Context, err error) error { + return err + } + } + if config.DenyHandler == nil { + config.DenyHandler = func(c echo.Context, _ Decision) error { + return echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + } + } + if config.now == nil { + config.now = time.Now + } + if !config.SetHeaders { + config.SetHeaders = true + } + if config.Limit <= 0 { + panic(ErrInvalidConfig) + } + if config.Window <= 0 { + panic(ErrInvalidConfig) + } + + limiter := &memoryLimiter{ + limit: config.Limit, + window: config.Window, + clients: make(map[string]*clientState), + cleanupEvery: 256, + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + key := c.RealIP() + if key == "" { + err := errors.New("ratelimit: unable to determine client IP") + return config.ErrorHandler(c, err) + } + + now := config.now() + decision := limiter.allow(key, now) + if config.SetHeaders { + writeHeaders(c.Response().Header(), decision) + } + if !decision.Allowed { + return config.DenyHandler(c, decision) + } + + return next(c) + } + } +} + +func (l *memoryLimiter) allow(key string, now time.Time) Decision { + l.mu.Lock() + defer l.mu.Unlock() + + l.requests++ + if l.requests%l.cleanupEvery == 0 { + l.cleanup(now) + } + + state, ok := l.clients[key] + if !ok || !now.Before(state.resetAt) { + state = &clientState{ + count: 0, + resetAt: now.Add(l.window), + lastSeen: now, + } + l.clients[key] = state + } + + state.lastSeen = now + remaining := l.limit - state.count - 1 + if state.count >= l.limit { + retryAfter := state.resetAt.Sub(now) + if retryAfter < 0 { + retryAfter = 0 + } + return Decision{ + Allowed: false, + Limit: l.limit, + Remaining: 0, + ResetAt: state.resetAt, + RetryAfter: retryAfter, + } + } + + state.count++ + if remaining < 0 { + remaining = 0 + } + + return Decision{ + Allowed: true, + Limit: l.limit, + Remaining: remaining, + ResetAt: state.resetAt, + } +} + +func (l *memoryLimiter) cleanup(now time.Time) { + for key, state := range l.clients { + if now.Sub(state.lastSeen) >= l.window && !now.Before(state.resetAt) { + delete(l.clients, key) + } + } +} + +func writeHeaders(header http.Header, decision Decision) { + header.Set("X-RateLimit-Limit", strconv.Itoa(decision.Limit)) + header.Set("X-RateLimit-Remaining", strconv.Itoa(decision.Remaining)) + header.Set("X-RateLimit-Reset", strconv.FormatInt(decision.ResetAt.Unix(), 10)) + if !decision.Allowed { + header.Set("Retry-After", strconv.Itoa(int(math.Ceil(decision.RetryAfter.Seconds())))) + } +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..42054f7 --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,198 @@ +package ratelimit + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/labstack/echo/v4" +) + +func TestAllowsRequestsWithinWindow(t *testing.T) { + e := echo.New() + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + + mw := WithConfig(Config{ + Limit: 2, + Window: time.Minute, + SetHeaders: true, + now: func() time.Time { + return now + }, + }) + + handler := mw(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + rec1 := performRequest(e, handler, "203.0.113.1") + if rec1.Code != http.StatusOK { + t.Fatalf("first request status = %d", rec1.Code) + } + if got := rec1.Header().Get("X-RateLimit-Remaining"); got != "1" { + t.Fatalf("first remaining = %q", got) + } + + rec2 := performRequest(e, handler, "203.0.113.1") + if rec2.Code != http.StatusOK { + t.Fatalf("second request status = %d", rec2.Code) + } + if got := rec2.Header().Get("X-RateLimit-Remaining"); got != "0" { + t.Fatalf("second remaining = %q", got) + } +} + +func TestHeadersEnabledByDefault(t *testing.T) { + e := echo.New() + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + + mw := WithConfig(Config{ + Limit: 1, + Window: time.Minute, + now: func() time.Time { + return now + }, + }) + + handler := mw(func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + rec := performRequest(e, handler, "203.0.113.20") + if got := rec.Header().Get("X-RateLimit-Limit"); got != "1" { + t.Fatalf("limit header = %q", got) + } +} + +func TestDeniesRequestsOverLimit(t *testing.T) { + e := echo.New() + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + + mw := WithConfig(Config{ + Limit: 1, + Window: 30 * time.Second, + SetHeaders: true, + now: func() time.Time { + return now + }, + }) + + handler := mw(func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + _ = performRequest(e, handler, "203.0.113.10") + rec := performRequest(e, handler, "203.0.113.10") + + if rec.Code != http.StatusTooManyRequests { + t.Fatalf("status = %d", rec.Code) + } + if got := rec.Header().Get("Retry-After"); got != "30" { + t.Fatalf("retry-after = %q", got) + } +} + +func TestWindowResets(t *testing.T) { + e := echo.New() + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + + mw := WithConfig(Config{ + Limit: 1, + Window: 10 * time.Second, + now: func() time.Time { + return now + }, + }) + + handler := mw(func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + rec1 := performRequest(e, handler, "198.51.100.1") + if rec1.Code != http.StatusOK { + t.Fatalf("first status = %d", rec1.Code) + } + + rec2 := performRequest(e, handler, "198.51.100.1") + if rec2.Code != http.StatusTooManyRequests { + t.Fatalf("second status = %d", rec2.Code) + } + + now = now.Add(11 * time.Second) + rec3 := performRequest(e, handler, "198.51.100.1") + if rec3.Code != http.StatusOK { + t.Fatalf("third status = %d", rec3.Code) + } +} + +func TestSkipperBypassesLimiter(t *testing.T) { + e := echo.New() + mw := WithConfig(Config{ + Limit: 1, + Window: time.Minute, + Skipper: func(c echo.Context) bool { + return c.Path() == "/health" + }, + }) + + handler := mw(func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + rec1 := performRequestWithPath(e, handler, "203.0.113.5", "/health") + rec2 := performRequestWithPath(e, handler, "203.0.113.5", "/health") + if rec1.Code != http.StatusOK || rec2.Code != http.StatusOK { + t.Fatalf("skipped requests should both pass: %d %d", rec1.Code, rec2.Code) + } +} + +func TestRealIPErrorsPropagate(t *testing.T) { + e := echo.New() + expected := errors.New("boom") + + mw := WithConfig(Config{ + Limit: 1, + Window: time.Minute, + ErrorHandler: func(_ echo.Context, err error) error { + if err == nil { + t.Fatal("expected error") + } + return expected + }, + }) + + handler := mw(func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = ":1234" + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + if !errors.Is(err, expected) { + t.Fatalf("error = %v", err) + } +} + +func performRequest(e *echo.Echo, h echo.HandlerFunc, ip string) *httptest.ResponseRecorder { + return performRequestWithPath(e, h, ip, "/") +} + +func performRequestWithPath(e *echo.Echo, h echo.HandlerFunc, ip, path string) *httptest.ResponseRecorder { + req := httptest.NewRequest(http.MethodGet, path, nil) + req.RemoteAddr = ip + ":1234" + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPath(path) + + err := h(c) + if err != nil { + e.HTTPErrorHandler(err, c) + } + + return rec +}