202 lines
6.4 KiB
Go
202 lines
6.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
"vctp/internal/auth"
|
|
"vctp/internal/settings"
|
|
)
|
|
|
|
func TestRequireAuthRequiredRejectsMissingToken(t *testing.T) {
|
|
cfg := testAuthSettings(true, authModeRequired)
|
|
mw := RequireAuth(testLogger(), cfg)
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthRequiredAcceptsValidTokenAndInjectsClaims(t *testing.T) {
|
|
cfg := testAuthSettings(true, authModeRequired)
|
|
token := mustTokenForConfig(t, cfg, "alice")
|
|
mw := RequireAuth(testLogger(), cfg)
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
var gotSubject string
|
|
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := ClaimsFromContext(r.Context())
|
|
if !ok {
|
|
t.Fatal("expected claims in request context")
|
|
}
|
|
gotSubject = claims.Subject
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
|
|
}
|
|
if gotSubject != "alice" {
|
|
t.Fatalf("expected subject alice, got %q", gotSubject)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthOptionalAllowsNoToken(t *testing.T) {
|
|
cfg := testAuthSettings(true, authModeOptional)
|
|
mw := RequireAuth(testLogger(), cfg)
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
})).ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusNoContent {
|
|
t.Fatalf("expected status %d, got %d", http.StatusNoContent, rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthOptionalRejectsInvalidProvidedToken(t *testing.T) {
|
|
cfg := testAuthSettings(true, authModeOptional)
|
|
mw := RequireAuth(testLogger(), cfg)
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
req.Header.Set("Authorization", "Bearer not-a-jwt")
|
|
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
})).ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthDisabledBypassesMiddleware(t *testing.T) {
|
|
cfg := testAuthSettings(false, authModeDisabled)
|
|
mw := RequireAuth(testLogger(), cfg)
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusAccepted)
|
|
})).ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusAccepted {
|
|
t.Fatalf("expected status %d, got %d", http.StatusAccepted, rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireRoleRejectsMissingAuthContext(t *testing.T) {
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireRoleRejectsInsufficientRole(t *testing.T) {
|
|
cfg := testAuthSettings(true, authModeRequired)
|
|
token := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer})
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
protected := RequireAuth(testLogger(), cfg)(
|
|
RequireRole(RoleAdmin)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})),
|
|
)
|
|
protected.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireRoleViewerAllowsViewerAndAdmin(t *testing.T) {
|
|
cfg := testAuthSettings(true, authModeRequired)
|
|
viewerToken := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer})
|
|
adminToken := mustTokenForConfigWithRoles(t, cfg, "bob", []string{RoleAdmin})
|
|
|
|
for name, token := range map[string]string{
|
|
"viewer": viewerToken,
|
|
"admin": adminToken,
|
|
} {
|
|
t.Run(name, func(t *testing.T) {
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
protected := RequireAuth(testLogger(), cfg)(
|
|
RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})),
|
|
)
|
|
protected.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func mustTokenForConfig(t *testing.T, cfg *settings.Settings, subject string) string {
|
|
t.Helper()
|
|
return mustTokenForConfigWithRoles(t, cfg, subject, []string{"admin"})
|
|
}
|
|
|
|
func mustTokenForConfigWithRoles(t *testing.T, cfg *settings.Settings, subject string, roles []string) string {
|
|
t.Helper()
|
|
svc, err := auth.NewJWTService(auth.JWTConfig{
|
|
SigningKeyBase64: cfg.Values.Settings.AuthJWTSigningKey,
|
|
Issuer: cfg.Values.Settings.AuthJWTIssuer,
|
|
Audience: cfg.Values.Settings.AuthJWTAudience,
|
|
TokenLifespan: time.Duration(cfg.Values.Settings.AuthTokenLifespanMinutes) * time.Minute,
|
|
ClockSkew: time.Duration(cfg.Values.Settings.AuthClockSkewSeconds) * time.Second,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to create jwt service: %v", err)
|
|
}
|
|
token, _, err := svc.IssueToken(subject, roles, nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to issue token: %v", err)
|
|
}
|
|
return token
|
|
}
|
|
|
|
func testAuthSettings(enabled bool, mode string) *settings.Settings {
|
|
cfg := &settings.Settings{Values: &settings.SettingsYML{}}
|
|
cfg.Values.Settings.AuthEnabled = enabled
|
|
cfg.Values.Settings.AuthMode = mode
|
|
cfg.Values.Settings.AuthJWTSigningKey = base64.StdEncoding.EncodeToString([]byte("middleware-test-signing-key"))
|
|
cfg.Values.Settings.AuthTokenLifespanMinutes = 120
|
|
cfg.Values.Settings.AuthJWTIssuer = "vctp"
|
|
cfg.Values.Settings.AuthJWTAudience = "vctp-api"
|
|
cfg.Values.Settings.AuthClockSkewSeconds = 60
|
|
return cfg
|
|
}
|
|
|
|
func testLogger() *slog.Logger {
|
|
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
}
|