Initial commit

Signed-off-by: Carl Pearson <me@carlpearson.net>
This commit is contained in:
2026-04-13 23:06:50 +00:00
commit cf6d36d7cb
5 changed files with 479 additions and 0 deletions

198
ratelimit_test.go Normal file
View File

@@ -0,0 +1,198 @@
package ratelimit
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v4"
)
func TestAllowsRequestsWithinWindow(t *testing.T) {
e := echo.New()
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
mw := WithConfig(Config{
Limit: 2,
Window: time.Minute,
SetHeaders: true,
now: func() time.Time {
return now
},
})
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
rec1 := performRequest(e, handler, "203.0.113.1")
if rec1.Code != http.StatusOK {
t.Fatalf("first request status = %d", rec1.Code)
}
if got := rec1.Header().Get("X-RateLimit-Remaining"); got != "1" {
t.Fatalf("first remaining = %q", got)
}
rec2 := performRequest(e, handler, "203.0.113.1")
if rec2.Code != http.StatusOK {
t.Fatalf("second request status = %d", rec2.Code)
}
if got := rec2.Header().Get("X-RateLimit-Remaining"); got != "0" {
t.Fatalf("second remaining = %q", got)
}
}
func TestHeadersEnabledByDefault(t *testing.T) {
e := echo.New()
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
mw := WithConfig(Config{
Limit: 1,
Window: time.Minute,
now: func() time.Time {
return now
},
})
handler := mw(func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
rec := performRequest(e, handler, "203.0.113.20")
if got := rec.Header().Get("X-RateLimit-Limit"); got != "1" {
t.Fatalf("limit header = %q", got)
}
}
func TestDeniesRequestsOverLimit(t *testing.T) {
e := echo.New()
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
mw := WithConfig(Config{
Limit: 1,
Window: 30 * time.Second,
SetHeaders: true,
now: func() time.Time {
return now
},
})
handler := mw(func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
_ = performRequest(e, handler, "203.0.113.10")
rec := performRequest(e, handler, "203.0.113.10")
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("status = %d", rec.Code)
}
if got := rec.Header().Get("Retry-After"); got != "30" {
t.Fatalf("retry-after = %q", got)
}
}
func TestWindowResets(t *testing.T) {
e := echo.New()
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
mw := WithConfig(Config{
Limit: 1,
Window: 10 * time.Second,
now: func() time.Time {
return now
},
})
handler := mw(func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
rec1 := performRequest(e, handler, "198.51.100.1")
if rec1.Code != http.StatusOK {
t.Fatalf("first status = %d", rec1.Code)
}
rec2 := performRequest(e, handler, "198.51.100.1")
if rec2.Code != http.StatusTooManyRequests {
t.Fatalf("second status = %d", rec2.Code)
}
now = now.Add(11 * time.Second)
rec3 := performRequest(e, handler, "198.51.100.1")
if rec3.Code != http.StatusOK {
t.Fatalf("third status = %d", rec3.Code)
}
}
func TestSkipperBypassesLimiter(t *testing.T) {
e := echo.New()
mw := WithConfig(Config{
Limit: 1,
Window: time.Minute,
Skipper: func(c echo.Context) bool {
return c.Path() == "/health"
},
})
handler := mw(func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
rec1 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
rec2 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
if rec1.Code != http.StatusOK || rec2.Code != http.StatusOK {
t.Fatalf("skipped requests should both pass: %d %d", rec1.Code, rec2.Code)
}
}
func TestRealIPErrorsPropagate(t *testing.T) {
e := echo.New()
expected := errors.New("boom")
mw := WithConfig(Config{
Limit: 1,
Window: time.Minute,
ErrorHandler: func(_ echo.Context, err error) error {
if err == nil {
t.Fatal("expected error")
}
return expected
},
})
handler := mw(func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = ":1234"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
if !errors.Is(err, expected) {
t.Fatalf("error = %v", err)
}
}
func performRequest(e *echo.Echo, h echo.HandlerFunc, ip string) *httptest.ResponseRecorder {
return performRequestWithPath(e, h, ip, "/")
}
func performRequestWithPath(e *echo.Echo, h echo.HandlerFunc, ip, path string) *httptest.ResponseRecorder {
req := httptest.NewRequest(http.MethodGet, path, nil)
req.RemoteAddr = ip + ":1234"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPath(path)
err := h(c)
if err != nil {
e.HTTPErrorHandler(err, c)
}
return rec
}