diff --git a/main.go b/main.go index 38891d2..b50422f 100644 --- a/main.go +++ b/main.go @@ -131,7 +131,7 @@ func UploadHandler(w http.ResponseWriter, r *http.Request, pipes []pipelines.IPi logrus.Info("Pipeline routes registered successfully") } -func RegisterRoutes(r *mux.Router, pipelines []pipelines.IPipeline, storageProvider storage.IStorageProvider) { +func RegisterRoutes(r *mux.Router, appSettings settings.Settings, pipelines []pipelines.IPipeline, storageProvider storage.IStorageProvider) { index := r.Methods(http.MethodGet).Subrouter() index.HandleFunc("/", IndexHandler) @@ -151,6 +151,13 @@ func RegisterRoutes(r *mux.Router, pipelines []pipelines.IPipeline, storageProvi w.WriteHeader(404) }) + + if appSettings.Authentication.Enabled { + authMiddleware := middlewares.CreateAuthenticationMiddleware(appSettings.Authentication.Token) + + upload.Use(authMiddleware.Middleware) + pipeline.Use(authMiddleware.Middleware) + } } func main() { @@ -183,12 +190,6 @@ func main() { r := mux.NewRouter() - if appSettings.Authentication.Enabled { - authMiddleware := middlewares.CreateAuthenticationMiddleware(appSettings.Authentication.Token) - - r.Use(authMiddleware.Middleware) - } - if appSettings.RateLimiter.Enabled { rateLimiterMiddleware, err := middlewares.CreateRateLimiterMiddleware(appSettings.RateLimiter.RequestsPerMinute, appSettings.RateLimiter.AllowedBurst) if err != nil { @@ -198,9 +199,7 @@ func main() { r.Use(rateLimiterMiddleware.Middleware) } - r.HandleFunc("/", IndexHandler) - - RegisterRoutes(r, pipes, storageProvider) + RegisterRoutes(r, appSettings, pipes, storageProvider) logrus.Info("Lithium started, listening for requests...") diff --git a/main_test.go b/main_test.go index bcc73f4..e1c8a74 100644 --- a/main_test.go +++ b/main_test.go @@ -10,8 +10,10 @@ import ( "github.com/bxcodec/faker/v3" "github.com/geplauder/lithium/pipelines" + "github.com/geplauder/lithium/settings" "github.com/geplauder/lithium/storage" "github.com/gorilla/mux" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" ) @@ -35,8 +37,9 @@ func TestEndpointRoute(t *testing.T) { t.Run("Registered pipelines are valid routes", func(t *testing.T) { router := mux.NewRouter() fs := storage.GetMemoryStorageProvider() + appSettings, _ := settings.LoadSettings(afero.NewMemMapFs()) - RegisterRoutes(router, []pipelines.IPipeline{data}, fs) + RegisterRoutes(router, appSettings, []pipelines.IPipeline{data}, fs) request, _ := http.NewRequest("GET", "/pipelines/"+data.Slug, nil) responseRecorder := httptest.NewRecorder() @@ -64,8 +67,9 @@ func TestUploadRoute(t *testing.T) { t.Run("Test uploads missing multipart boundary", func(t *testing.T) { router := mux.NewRouter() fs := storage.GetMemoryStorageProvider() + appSettings, _ := settings.LoadSettings(afero.NewMemMapFs()) - RegisterRoutes(router, []pipelines.IPipeline{pipelines.Pipeline{ + RegisterRoutes(router, appSettings, []pipelines.IPipeline{pipelines.Pipeline{ Name: "", Slug: "", Type: 0, @@ -92,8 +96,9 @@ func TestUploadRoute(t *testing.T) { t.Run("Test uploads missing multipart boundary", func(t *testing.T) { router := mux.NewRouter() fs := storage.GetMemoryStorageProvider() + appSettings, _ := settings.LoadSettings(afero.NewMemMapFs()) - RegisterRoutes(router, []pipelines.IPipeline{pipelines.Pipeline{ + RegisterRoutes(router, appSettings, []pipelines.IPipeline{pipelines.Pipeline{ Name: "", Slug: "", Type: 0, diff --git a/settings/settings.go b/settings/settings.go index 81cd6e4..5f2c447 100644 --- a/settings/settings.go +++ b/settings/settings.go @@ -63,7 +63,7 @@ func LoadSettings(fileSystem afero.Fs) (Settings, error) { defaultSettings := Settings{ Endpoint: "127.0.0.1:8000", Authentication: AuthenticationSettings{ - Enabled: true, + Enabled: false, Token: "changeme", }, RateLimiter: RateLimiterSettings{