Browse Source

Add allowed burst for requests to rate limiting middleware

feature/add-rate-limiting
Fabian Vowie 3 years ago
parent
commit
069a5ffe53
No known key found for this signature in database GPG Key ID: C27317C33B27C410
  1. 2
      main.go
  2. 5
      middlewares/ratelimiter.go
  3. 46
      middlewares/ratelimiter_test.go

2
main.go

@ -61,7 +61,7 @@ func main() {
pipes := pipelines.LoadPipelines() pipes := pipelines.LoadPipelines()
authMiddleware := middlewares.CreateAuthenticationMiddleware(appSettings.Token) authMiddleware := middlewares.CreateAuthenticationMiddleware(appSettings.Token)
rateLimiterMiddleware, err := middlewares.CreateRateLimiterMiddleware(appSettings.RequestsPerMinute)
rateLimiterMiddleware, err := middlewares.CreateRateLimiterMiddleware(appSettings.RequestsPerMinute, 0)
if err != nil { if err != nil {
panic(err) panic(err)
} }

5
middlewares/ratelimiter.go

@ -15,7 +15,7 @@ func (middleware RateLimiterMiddleware) Middleware(next http.Handler) http.Handl
return middleware.rateLimiter.RateLimit(next) return middleware.rateLimiter.RateLimit(next)
} }
func CreateRateLimiterMiddleware(requestsPerMinute int) (*RateLimiterMiddleware, error) {
func CreateRateLimiterMiddleware(requestsPerMinute int, allowedBurst int) (*RateLimiterMiddleware, error) {
store, err := memstore.New(65536) store, err := memstore.New(65536)
if err != nil { if err != nil {
@ -23,7 +23,8 @@ func CreateRateLimiterMiddleware(requestsPerMinute int) (*RateLimiterMiddleware,
} }
quota := throttled.RateQuota{ quota := throttled.RateQuota{
MaxRate: throttled.PerMin(requestsPerMinute),
MaxRate: throttled.PerMin(requestsPerMinute),
MaxBurst: allowedBurst,
} }
rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) rateLimiter, err := throttled.NewGCRARateLimiter(store, quota)

46
middlewares/ratelimiter_test.go

@ -8,41 +8,57 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func ExecuteRequest(middlewareHandler http.Handler) int {
request, _ := http.NewRequest("GET", "/", nil)
responseRecorder := httptest.NewRecorder()
middlewareHandler.ServeHTTP(responseRecorder, request)
return responseRecorder.Code
}
func TestRateLimiterMiddleware(t *testing.T) { func TestRateLimiterMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
t.Run("AuthorizationMiddleware returns 200 response when rate limit is not hit", func(t *testing.T) { t.Run("AuthorizationMiddleware returns 200 response when rate limit is not hit", func(t *testing.T) {
middleware, err := CreateRateLimiterMiddleware(1)
middleware, err := CreateRateLimiterMiddleware(1, 0)
assert.Nil(t, err) assert.Nil(t, err)
middlewareHandler := middleware.Middleware(handler) middlewareHandler := middleware.Middleware(handler)
request, _ := http.NewRequest("GET", "/", nil)
responseRecorder := httptest.NewRecorder()
assert.Equal(t, 200, ExecuteRequest(middlewareHandler))
})
middlewareHandler.ServeHTTP(responseRecorder, request)
t.Run("AuthorizationMiddleware returns 429 response when rate limit is hit", func(t *testing.T) {
middleware, err := CreateRateLimiterMiddleware(1, 0)
assert.Nil(t, err)
assert.Equal(t, 200, responseRecorder.Code)
middlewareHandler := middleware.Middleware(handler)
assert.Equal(t, 200, ExecuteRequest(middlewareHandler))
assert.Equal(t, 429, ExecuteRequest(middlewareHandler))
}) })
t.Run("AuthorizationMiddleware returns 429 response when rate limit is hit", func(t *testing.T) {
middleware, err := CreateRateLimiterMiddleware(1)
t.Run("AuthorizationMiddleware returns 200 response when rate limit with burst is not hit", func(t *testing.T) {
middleware, err := CreateRateLimiterMiddleware(1, 1)
assert.Nil(t, err) assert.Nil(t, err)
middlewareHandler := middleware.Middleware(handler) middlewareHandler := middleware.Middleware(handler)
request, _ := http.NewRequest("GET", "/", nil)
responseRecorder := httptest.NewRecorder()
assert.Equal(t, 200, ExecuteRequest(middlewareHandler))
assert.Equal(t, 200, ExecuteRequest(middlewareHandler))
})
middlewareHandler.ServeHTTP(responseRecorder, request)
assert.Equal(t, 200, responseRecorder.Code)
t.Run("AuthorizationMiddleware returns 429 response when rate limit with burst is hit", func(t *testing.T) {
middleware, err := CreateRateLimiterMiddleware(1, 1)
assert.Nil(t, err)
request, _ = http.NewRequest("GET", "/", nil)
responseRecorder = httptest.NewRecorder()
middlewareHandler := middleware.Middleware(handler)
middlewareHandler.ServeHTTP(responseRecorder, request)
assert.Equal(t, 429, responseRecorder.Code)
assert.Equal(t, 200, ExecuteRequest(middlewareHandler))
assert.Equal(t, 200, ExecuteRequest(middlewareHandler))
assert.Equal(t, 429, ExecuteRequest(middlewareHandler))
}) })
} }
Loading…
Cancel
Save