package middleware import ( "encoding/base64" "io" "log/slog" "net/http" "net/http/httptest" "testing" "time" "vctp/internal/auth" "vctp/internal/settings" ) func TestRequireAuthRequiredRejectsMissingToken(t *testing.T) { cfg := testAuthSettings(true, authModeRequired) mw := RequireAuth(testLogger(), cfg) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })).ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) } } func TestRequireAuthRequiredAcceptsValidTokenAndInjectsClaims(t *testing.T) { cfg := testAuthSettings(true, authModeRequired) token := mustTokenForConfig(t, cfg, "alice") mw := RequireAuth(testLogger(), cfg) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) req.Header.Set("Authorization", "Bearer "+token) var gotSubject string mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, ok := ClaimsFromContext(r.Context()) if !ok { t.Fatal("expected claims in request context") } gotSubject = claims.Subject w.WriteHeader(http.StatusOK) })).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) } if gotSubject != "alice" { t.Fatalf("expected subject alice, got %q", gotSubject) } } func TestRequireAuthOptionalAllowsNoToken(t *testing.T) { cfg := testAuthSettings(true, authModeOptional) mw := RequireAuth(testLogger(), cfg) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })).ServeHTTP(rr, req) if rr.Code != http.StatusNoContent { t.Fatalf("expected status %d, got %d", http.StatusNoContent, rr.Code) } } func TestRequireAuthOptionalRejectsInvalidProvidedToken(t *testing.T) { cfg := testAuthSettings(true, authModeOptional) mw := RequireAuth(testLogger(), cfg) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) req.Header.Set("Authorization", "Bearer not-a-jwt") mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) })).ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) } } func TestRequireAuthDisabledBypassesMiddleware(t *testing.T) { cfg := testAuthSettings(false, authModeDisabled) mw := RequireAuth(testLogger(), cfg) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusAccepted) })).ServeHTTP(rr, req) if rr.Code != http.StatusAccepted { t.Fatalf("expected status %d, got %d", http.StatusAccepted, rr.Code) } } func TestRequireRoleRejectsMissingAuthContext(t *testing.T) { rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })).ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) } } func TestRequireRoleRejectsInsufficientRole(t *testing.T) { cfg := testAuthSettings(true, authModeRequired) token := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer}) rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) req.Header.Set("Authorization", "Bearer "+token) protected := RequireAuth(testLogger(), cfg)( RequireRole(RoleAdmin)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })), ) protected.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Fatalf("expected status %d, got %d", http.StatusForbidden, rr.Code) } } func TestRequireRoleViewerAllowsViewerAndAdmin(t *testing.T) { cfg := testAuthSettings(true, authModeRequired) viewerToken := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer}) adminToken := mustTokenForConfigWithRoles(t, cfg, "bob", []string{RoleAdmin}) for name, token := range map[string]string{ "viewer": viewerToken, "admin": adminToken, } { t.Run(name, func(t *testing.T) { rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) req.Header.Set("Authorization", "Bearer "+token) protected := RequireAuth(testLogger(), cfg)( RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })), ) protected.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) } }) } } func mustTokenForConfig(t *testing.T, cfg *settings.Settings, subject string) string { t.Helper() return mustTokenForConfigWithRoles(t, cfg, subject, []string{"admin"}) } func mustTokenForConfigWithRoles(t *testing.T, cfg *settings.Settings, subject string, roles []string) string { t.Helper() svc, err := auth.NewJWTService(auth.JWTConfig{ SigningKeyBase64: cfg.Values.Settings.AuthJWTSigningKey, Issuer: cfg.Values.Settings.AuthJWTIssuer, Audience: cfg.Values.Settings.AuthJWTAudience, TokenLifespan: time.Duration(cfg.Values.Settings.AuthTokenLifespanMinutes) * time.Minute, ClockSkew: time.Duration(cfg.Values.Settings.AuthClockSkewSeconds) * time.Second, }) if err != nil { t.Fatalf("failed to create jwt service: %v", err) } token, _, err := svc.IssueToken(subject, roles, nil) if err != nil { t.Fatalf("failed to issue token: %v", err) } return token } func testAuthSettings(enabled bool, mode string) *settings.Settings { cfg := &settings.Settings{Values: &settings.SettingsYML{}} cfg.Values.Settings.AuthEnabled = enabled cfg.Values.Settings.AuthMode = mode cfg.Values.Settings.AuthJWTSigningKey = base64.StdEncoding.EncodeToString([]byte("middleware-test-signing-key")) cfg.Values.Settings.AuthTokenLifespanMinutes = 120 cfg.Values.Settings.AuthJWTIssuer = "vctp" cfg.Values.Settings.AuthJWTAudience = "vctp-api" cfg.Values.Settings.AuthClockSkewSeconds = 60 return cfg } func testLogger() *slog.Logger { return slog.New(slog.NewTextHandler(io.Discard, nil)) }