Browse Source

Merge commit '6ac87dc7dcdf4d3be4a40f7ab19b46ffc1a8cf2f' into HEAD

feature/general-cleanup
Jenkins 3 years ago
parent
commit
c313a758ed
  1. 28
      auth/authorization.go
  2. 40
      auth/authorization_test.go
  3. 4
      main.go

28
auth/authorization.go

@ -0,0 +1,28 @@
package auth
import (
"net/http"
"strings"
)
type AuthenticationMiddleware struct {
secret string
}
func (middleware AuthenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authToken := r.Header.Get("Authorization")
if authToken == "" || strings.HasPrefix(authToken, "Bearer ") == false || authToken[7:] != middleware.secret {
http.Error(w, "Forbidden", http.StatusForbidden)
} else {
next.ServeHTTP(w, r)
}
})
}
func CreateAuthenticationMiddleware(secret string) AuthenticationMiddleware {
return AuthenticationMiddleware{
secret: secret,
}
}

40
auth/authorization_test.go

@ -0,0 +1,40 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/bxcodec/faker/v3"
"github.com/stretchr/testify/assert"
)
func TestAuthorizationMiddleware(t *testing.T) {
token := faker.Word()
middleware := CreateAuthenticationMiddleware(token)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middlewareHandler := middleware.Middleware(handler)
t.Run("AuthorizationMiddleware returns 403 response when authorization header is incorrect", func(t *testing.T) {
request, _ := http.NewRequest("GET", "/", nil)
responseRecorder := httptest.NewRecorder()
middlewareHandler.ServeHTTP(responseRecorder, request)
assert.Equal(t, 403, responseRecorder.Code)
})
t.Run("AuthorizationMiddleware continues when authorization header is correct", func(t *testing.T) {
request, _ := http.NewRequest("GET", "/", nil)
request.Header.Set("Authorization", "Bearer "+token)
responseRecorder := httptest.NewRecorder()
middlewareHandler.ServeHTTP(responseRecorder, request)
assert.Equal(t, 200, responseRecorder.Code)
})
}

4
main.go

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/geplauder/lithium/auth"
"github.com/geplauder/lithium/pipelines" "github.com/geplauder/lithium/pipelines"
"github.com/geplauder/lithium/settings" "github.com/geplauder/lithium/settings"
"github.com/geplauder/lithium/storage" "github.com/geplauder/lithium/storage"
@ -56,7 +57,10 @@ func main() {
pipes := pipelines.LoadPipelines() pipes := pipelines.LoadPipelines()
authMiddleware := auth.CreateAuthenticationMiddleware(settings.Token)
r := mux.NewRouter() r := mux.NewRouter()
r.Use(authMiddleware.Middleware)
r.HandleFunc("/", IndexHandler) r.HandleFunc("/", IndexHandler)
RegisterPipelineRoutes(r, pipes, storageProvider) RegisterPipelineRoutes(r, pipes, storageProvider)

Loading…
Cancel
Save