199 lines
4.4 KiB
Go
199 lines
4.4 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
)
|
|
|
|
func TestAllowsRequestsWithinWindow(t *testing.T) {
|
|
e := echo.New()
|
|
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
|
|
|
mw := WithConfig(Config{
|
|
Limit: 2,
|
|
Window: time.Minute,
|
|
SetHeaders: true,
|
|
now: func() time.Time {
|
|
return now
|
|
},
|
|
})
|
|
|
|
handler := mw(func(c echo.Context) error {
|
|
return c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
rec1 := performRequest(e, handler, "203.0.113.1")
|
|
if rec1.Code != http.StatusOK {
|
|
t.Fatalf("first request status = %d", rec1.Code)
|
|
}
|
|
if got := rec1.Header().Get("X-RateLimit-Remaining"); got != "1" {
|
|
t.Fatalf("first remaining = %q", got)
|
|
}
|
|
|
|
rec2 := performRequest(e, handler, "203.0.113.1")
|
|
if rec2.Code != http.StatusOK {
|
|
t.Fatalf("second request status = %d", rec2.Code)
|
|
}
|
|
if got := rec2.Header().Get("X-RateLimit-Remaining"); got != "0" {
|
|
t.Fatalf("second remaining = %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHeadersEnabledByDefault(t *testing.T) {
|
|
e := echo.New()
|
|
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
|
|
|
mw := WithConfig(Config{
|
|
Limit: 1,
|
|
Window: time.Minute,
|
|
now: func() time.Time {
|
|
return now
|
|
},
|
|
})
|
|
|
|
handler := mw(func(c echo.Context) error {
|
|
return c.NoContent(http.StatusOK)
|
|
})
|
|
|
|
rec := performRequest(e, handler, "203.0.113.20")
|
|
if got := rec.Header().Get("X-RateLimit-Limit"); got != "1" {
|
|
t.Fatalf("limit header = %q", got)
|
|
}
|
|
}
|
|
|
|
func TestDeniesRequestsOverLimit(t *testing.T) {
|
|
e := echo.New()
|
|
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
|
|
|
mw := WithConfig(Config{
|
|
Limit: 1,
|
|
Window: 30 * time.Second,
|
|
SetHeaders: true,
|
|
now: func() time.Time {
|
|
return now
|
|
},
|
|
})
|
|
|
|
handler := mw(func(c echo.Context) error {
|
|
return c.NoContent(http.StatusOK)
|
|
})
|
|
|
|
_ = performRequest(e, handler, "203.0.113.10")
|
|
rec := performRequest(e, handler, "203.0.113.10")
|
|
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("status = %d", rec.Code)
|
|
}
|
|
if got := rec.Header().Get("Retry-After"); got != "30" {
|
|
t.Fatalf("retry-after = %q", got)
|
|
}
|
|
}
|
|
|
|
func TestWindowResets(t *testing.T) {
|
|
e := echo.New()
|
|
now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC)
|
|
|
|
mw := WithConfig(Config{
|
|
Limit: 1,
|
|
Window: 10 * time.Second,
|
|
now: func() time.Time {
|
|
return now
|
|
},
|
|
})
|
|
|
|
handler := mw(func(c echo.Context) error {
|
|
return c.NoContent(http.StatusOK)
|
|
})
|
|
|
|
rec1 := performRequest(e, handler, "198.51.100.1")
|
|
if rec1.Code != http.StatusOK {
|
|
t.Fatalf("first status = %d", rec1.Code)
|
|
}
|
|
|
|
rec2 := performRequest(e, handler, "198.51.100.1")
|
|
if rec2.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("second status = %d", rec2.Code)
|
|
}
|
|
|
|
now = now.Add(11 * time.Second)
|
|
rec3 := performRequest(e, handler, "198.51.100.1")
|
|
if rec3.Code != http.StatusOK {
|
|
t.Fatalf("third status = %d", rec3.Code)
|
|
}
|
|
}
|
|
|
|
func TestSkipperBypassesLimiter(t *testing.T) {
|
|
e := echo.New()
|
|
mw := WithConfig(Config{
|
|
Limit: 1,
|
|
Window: time.Minute,
|
|
Skipper: func(c echo.Context) bool {
|
|
return c.Path() == "/health"
|
|
},
|
|
})
|
|
|
|
handler := mw(func(c echo.Context) error {
|
|
return c.NoContent(http.StatusOK)
|
|
})
|
|
|
|
rec1 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
|
|
rec2 := performRequestWithPath(e, handler, "203.0.113.5", "/health")
|
|
if rec1.Code != http.StatusOK || rec2.Code != http.StatusOK {
|
|
t.Fatalf("skipped requests should both pass: %d %d", rec1.Code, rec2.Code)
|
|
}
|
|
}
|
|
|
|
func TestRealIPErrorsPropagate(t *testing.T) {
|
|
e := echo.New()
|
|
expected := errors.New("boom")
|
|
|
|
mw := WithConfig(Config{
|
|
Limit: 1,
|
|
Window: time.Minute,
|
|
ErrorHandler: func(_ echo.Context, err error) error {
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
return expected
|
|
},
|
|
})
|
|
|
|
handler := mw(func(c echo.Context) error {
|
|
return c.NoContent(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = ":1234"
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := handler(c)
|
|
if !errors.Is(err, expected) {
|
|
t.Fatalf("error = %v", err)
|
|
}
|
|
}
|
|
|
|
func performRequest(e *echo.Echo, h echo.HandlerFunc, ip string) *httptest.ResponseRecorder {
|
|
return performRequestWithPath(e, h, ip, "/")
|
|
}
|
|
|
|
func performRequestWithPath(e *echo.Echo, h echo.HandlerFunc, ip, path string) *httptest.ResponseRecorder {
|
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
|
req.RemoteAddr = ip + ":1234"
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
c.SetPath(path)
|
|
|
|
err := h(c)
|
|
if err != nil {
|
|
e.HTTPErrorHandler(err, c)
|
|
}
|
|
|
|
return rec
|
|
}
|