From c8f065adf00ba11b4f724d350b0887f510120195 Mon Sep 17 00:00:00 2001 From: Fabian Vowie Date: Sun, 23 Jan 2022 17:10:16 +0100 Subject: [PATCH] Add allowed burst for requests to rate limiting middleware --- main.go | 2 +- middlewares/ratelimiter.go | 5 ++-- middlewares/ratelimiter_test.go | 46 ++++++++++++++++++++++----------- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/main.go b/main.go index 9d3c095..281005a 100644 --- a/main.go +++ b/main.go @@ -61,7 +61,7 @@ func main() { pipes := pipelines.LoadPipelines() authMiddleware := middlewares.CreateAuthenticationMiddleware(appSettings.Token) - rateLimiterMiddleware, err := middlewares.CreateRateLimiterMiddleware(appSettings.RequestsPerMinute) + rateLimiterMiddleware, err := middlewares.CreateRateLimiterMiddleware(appSettings.RequestsPerMinute, 0) if err != nil { panic(err) } diff --git a/middlewares/ratelimiter.go b/middlewares/ratelimiter.go index 49fe1d8..febe7fb 100644 --- a/middlewares/ratelimiter.go +++ b/middlewares/ratelimiter.go @@ -15,7 +15,7 @@ func (middleware RateLimiterMiddleware) Middleware(next http.Handler) http.Handl return middleware.rateLimiter.RateLimit(next) } -func CreateRateLimiterMiddleware(requestsPerMinute int) (*RateLimiterMiddleware, error) { +func CreateRateLimiterMiddleware(requestsPerMinute int, allowedBurst int) (*RateLimiterMiddleware, error) { store, err := memstore.New(65536) if err != nil { @@ -23,7 +23,8 @@ func CreateRateLimiterMiddleware(requestsPerMinute int) (*RateLimiterMiddleware, } quota := throttled.RateQuota{ - MaxRate: throttled.PerMin(requestsPerMinute), + MaxRate: throttled.PerMin(requestsPerMinute), + MaxBurst: allowedBurst, } rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) diff --git a/middlewares/ratelimiter_test.go b/middlewares/ratelimiter_test.go index ce946c0..808f236 100644 --- a/middlewares/ratelimiter_test.go +++ b/middlewares/ratelimiter_test.go @@ -8,41 +8,57 @@ import ( "github.com/stretchr/testify/assert" ) +func ExecuteRequest(middlewareHandler http.Handler) int { + request, _ := http.NewRequest("GET", "/", nil) + responseRecorder := httptest.NewRecorder() + + middlewareHandler.ServeHTTP(responseRecorder, request) + + return responseRecorder.Code +} + func TestRateLimiterMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) t.Run("AuthorizationMiddleware returns 200 response when rate limit is not hit", func(t *testing.T) { - middleware, err := CreateRateLimiterMiddleware(1) + middleware, err := CreateRateLimiterMiddleware(1, 0) assert.Nil(t, err) middlewareHandler := middleware.Middleware(handler) - request, _ := http.NewRequest("GET", "/", nil) - responseRecorder := httptest.NewRecorder() + assert.Equal(t, 200, ExecuteRequest(middlewareHandler)) + }) - middlewareHandler.ServeHTTP(responseRecorder, request) + t.Run("AuthorizationMiddleware returns 429 response when rate limit is hit", func(t *testing.T) { + middleware, err := CreateRateLimiterMiddleware(1, 0) + assert.Nil(t, err) - assert.Equal(t, 200, responseRecorder.Code) + middlewareHandler := middleware.Middleware(handler) + + assert.Equal(t, 200, ExecuteRequest(middlewareHandler)) + assert.Equal(t, 429, ExecuteRequest(middlewareHandler)) }) - t.Run("AuthorizationMiddleware returns 429 response when rate limit is hit", func(t *testing.T) { - middleware, err := CreateRateLimiterMiddleware(1) + t.Run("AuthorizationMiddleware returns 200 response when rate limit with burst is not hit", func(t *testing.T) { + middleware, err := CreateRateLimiterMiddleware(1, 1) assert.Nil(t, err) middlewareHandler := middleware.Middleware(handler) - request, _ := http.NewRequest("GET", "/", nil) - responseRecorder := httptest.NewRecorder() + assert.Equal(t, 200, ExecuteRequest(middlewareHandler)) + assert.Equal(t, 200, ExecuteRequest(middlewareHandler)) + }) - middlewareHandler.ServeHTTP(responseRecorder, request) - assert.Equal(t, 200, responseRecorder.Code) + t.Run("AuthorizationMiddleware returns 429 response when rate limit with burst is hit", func(t *testing.T) { + middleware, err := CreateRateLimiterMiddleware(1, 1) + assert.Nil(t, err) - request, _ = http.NewRequest("GET", "/", nil) - responseRecorder = httptest.NewRecorder() + middlewareHandler := middleware.Middleware(handler) - middlewareHandler.ServeHTTP(responseRecorder, request) - assert.Equal(t, 429, responseRecorder.Code) + assert.Equal(t, 200, ExecuteRequest(middlewareHandler)) + assert.Equal(t, 200, ExecuteRequest(middlewareHandler)) + assert.Equal(t, 429, ExecuteRequest(middlewareHandler)) }) }