package ratelimit import ( "errors" "net/http" "net/http/httptest" "testing" "time" "github.com/labstack/echo/v4" ) func TestAllowsRequestsWithinWindow(t *testing.T) { e := echo.New() now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) mw := WithConfig(Config{ Limit: 2, Window: time.Minute, SetHeaders: true, now: func() time.Time { return now }, }) handler := mw(func(c echo.Context) error { return c.String(http.StatusOK, "ok") }) rec1 := performRequest(e, handler, "203.0.113.1") if rec1.Code != http.StatusOK { t.Fatalf("first request status = %d", rec1.Code) } if got := rec1.Header().Get("X-RateLimit-Remaining"); got != "1" { t.Fatalf("first remaining = %q", got) } rec2 := performRequest(e, handler, "203.0.113.1") if rec2.Code != http.StatusOK { t.Fatalf("second request status = %d", rec2.Code) } if got := rec2.Header().Get("X-RateLimit-Remaining"); got != "0" { t.Fatalf("second remaining = %q", got) } } func TestHeadersEnabledByDefault(t *testing.T) { e := echo.New() now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) mw := WithConfig(Config{ Limit: 1, Window: time.Minute, now: func() time.Time { return now }, }) handler := mw(func(c echo.Context) error { return c.NoContent(http.StatusOK) }) rec := performRequest(e, handler, "203.0.113.20") if got := rec.Header().Get("X-RateLimit-Limit"); got != "1" { t.Fatalf("limit header = %q", got) } } func TestDeniesRequestsOverLimit(t *testing.T) { e := echo.New() now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) mw := WithConfig(Config{ Limit: 1, Window: 30 * time.Second, SetHeaders: true, now: func() time.Time { return now }, }) handler := mw(func(c echo.Context) error { return c.NoContent(http.StatusOK) }) _ = performRequest(e, handler, "203.0.113.10") rec := performRequest(e, handler, "203.0.113.10") if rec.Code != http.StatusTooManyRequests { t.Fatalf("status = %d", rec.Code) } if got := rec.Header().Get("Retry-After"); got != "30" { t.Fatalf("retry-after = %q", got) } } func TestWindowResets(t *testing.T) { e := echo.New() now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) mw := WithConfig(Config{ Limit: 1, Window: 10 * time.Second, now: func() time.Time { return now }, }) handler := mw(func(c echo.Context) error { return c.NoContent(http.StatusOK) }) rec1 := performRequest(e, handler, "198.51.100.1") if rec1.Code != http.StatusOK { t.Fatalf("first status = %d", rec1.Code) } rec2 := performRequest(e, handler, "198.51.100.1") if rec2.Code != http.StatusTooManyRequests { t.Fatalf("second status = %d", rec2.Code) } now = now.Add(11 * time.Second) rec3 := performRequest(e, handler, "198.51.100.1") if rec3.Code != http.StatusOK { t.Fatalf("third status = %d", rec3.Code) } } func TestSkipperBypassesLimiter(t *testing.T) { e := echo.New() mw := WithConfig(Config{ Limit: 1, Window: time.Minute, Skipper: func(c echo.Context) bool { return c.Path() == "/health" }, }) handler := mw(func(c echo.Context) error { return c.NoContent(http.StatusOK) }) rec1 := performRequestWithPath(e, handler, "203.0.113.5", "/health") rec2 := performRequestWithPath(e, handler, "203.0.113.5", "/health") if rec1.Code != http.StatusOK || rec2.Code != http.StatusOK { t.Fatalf("skipped requests should both pass: %d %d", rec1.Code, rec2.Code) } } func TestRealIPErrorsPropagate(t *testing.T) { e := echo.New() expected := errors.New("boom") mw := WithConfig(Config{ Limit: 1, Window: time.Minute, ErrorHandler: func(_ echo.Context, err error) error { if err == nil { t.Fatal("expected error") } return expected }, }) handler := mw(func(c echo.Context) error { return c.NoContent(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = ":1234" rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) if !errors.Is(err, expected) { t.Fatalf("error = %v", err) } } func performRequest(e *echo.Echo, h echo.HandlerFunc, ip string) *httptest.ResponseRecorder { return performRequestWithPath(e, h, ip, "/") } func performRequestWithPath(e *echo.Echo, h echo.HandlerFunc, ip, path string) *httptest.ResponseRecorder { req := httptest.NewRequest(http.MethodGet, path, nil) req.RemoteAddr = ip + ":1234" rec := httptest.NewRecorder() c := e.NewContext(req, rec) c.SetPath(path) err := h(c) if err != nil { e.HTTPErrorHandler(err, c) } return rec }