198
ratelimit_test.go
Normal file
198
ratelimit_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func TestAllowsRequestsWithinWindow(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 2,
|
||||
Window: time.Minute,
|
||||
SetHeaders: true,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
rec1 := performRequest(e, handler, "203.0.113.1")
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first request status = %d", rec1.Code)
|
||||
}
|
||||
if got := rec1.Header().Get("X-RateLimit-Remaining"); got != "1" {
|
||||
t.Fatalf("first remaining = %q", got)
|
||||
}
|
||||
|
||||
rec2 := performRequest(e, handler, "203.0.113.1")
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("second request status = %d", rec2.Code)
|
||||
}
|
||||
if got := rec2.Header().Get("X-RateLimit-Remaining"); got != "0" {
|
||||
t.Fatalf("second remaining = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadersEnabledByDefault(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: time.Minute,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
rec := performRequest(e, handler, "203.0.113.20")
|
||||
if got := rec.Header().Get("X-RateLimit-Limit"); got != "1" {
|
||||
t.Fatalf("limit header = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeniesRequestsOverLimit(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: 30 * time.Second,
|
||||
SetHeaders: true,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
_ = performRequest(e, handler, "203.0.113.10")
|
||||
rec := performRequest(e, handler, "203.0.113.10")
|
||||
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("status = %d", rec.Code)
|
||||
}
|
||||
if got := rec.Header().Get("Retry-After"); got != "30" {
|
||||
t.Fatalf("retry-after = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowResets(t *testing.T) {
|
||||
e := echo.New()
|
||||
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: 10 * time.Second,
|
||||
now: func() time.Time {
|
||||
return now
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
rec1 := performRequest(e, handler, "198.51.100.1")
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d", rec1.Code)
|
||||
}
|
||||
|
||||
rec2 := performRequest(e, handler, "198.51.100.1")
|
||||
if rec2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second status = %d", rec2.Code)
|
||||
}
|
||||
|
||||
now = now.Add(11 * time.Second)
|
||||
rec3 := performRequest(e, handler, "198.51.100.1")
|
||||
if rec3.Code != http.StatusOK {
|
||||
t.Fatalf("third status = %d", rec3.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipperBypassesLimiter(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: time.Minute,
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return c.Path() == "/health"
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
rec1 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
|
||||
rec2 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
|
||||
if rec1.Code != http.StatusOK || rec2.Code != http.StatusOK {
|
||||
t.Fatalf("skipped requests should both pass: %d %d", rec1.Code, rec2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPErrorsPropagate(t *testing.T) {
|
||||
e := echo.New()
|
||||
expected := errors.New("boom")
|
||||
|
||||
mw := WithConfig(Config{
|
||||
Limit: 1,
|
||||
Window: time.Minute,
|
||||
ErrorHandler: func(_ echo.Context, err error) error {
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
return expected
|
||||
},
|
||||
})
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = ":1234"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func performRequest(e *echo.Echo, h echo.HandlerFunc, ip string) *httptest.ResponseRecorder {
|
||||
return performRequestWithPath(e, h, ip, "/")
|
||||
}
|
||||
|
||||
func performRequestWithPath(e *echo.Echo, h echo.HandlerFunc, ip, path string) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
req.RemoteAddr = ip + ":1234"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetPath(path)
|
||||
|
||||
err := h(c)
|
||||
if err != nil {
|
||||
e.HTTPErrorHandler(err, c)
|
||||
}
|
||||
|
||||
return rec
|
||||
}
|
||||
Reference in New Issue
Block a user