package middleware import ( "context" "encoding/json" "log/slog" "net/http" "strings" "time" "vctp/internal/auth" "vctp/internal/settings" ) const ( authModeDisabled = "disabled" authModeOptional = "optional" authModeRequired = "required" RoleViewer = "viewer" RoleAdmin = "admin" ) type authClaimsContextKey struct{} // ClaimsFromContext returns validated JWT claims injected by RequireAuth. func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) { if ctx == nil { return auth.Claims{}, false } claims, ok := ctx.Value(authClaimsContextKey{}).(auth.Claims) return claims, ok } // RequireAuth validates Bearer tokens according to settings.auth_mode: // - disabled: auth bypassed // - optional: missing token allowed, provided token must be valid // - required: token required and must be valid func RequireAuth(logger *slog.Logger, cfg *settings.Settings) Handler { if logger == nil { logger = slog.Default() } if cfg == nil || cfg.Values == nil { return defaultHandler } values := cfg.Values.Settings mode := strings.ToLower(strings.TrimSpace(values.AuthMode)) if mode == "" { mode = authModeDisabled } if !values.AuthEnabled || mode == authModeDisabled { return defaultHandler } jwtSvc, err := auth.NewJWTService(auth.JWTConfig{ SigningKeyBase64: values.AuthJWTSigningKey, Issuer: values.AuthJWTIssuer, Audience: values.AuthJWTAudience, TokenLifespan: time.Duration(values.AuthTokenLifespanMinutes) * time.Minute, ClockSkew: time.Duration(values.AuthClockSkewSeconds) * time.Second, }) if err != nil { logger.Error("auth middleware init failed", "error", err) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeJSONAuthError(w, http.StatusServiceUnavailable, "authentication service unavailable") }) } } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token, hasHeader, parseOK := extractBearerToken(r.Header.Get("Authorization")) if !hasHeader { if mode == authModeRequired { writeJSONAuthError(w, http.StatusUnauthorized, "missing bearer token") return } next.ServeHTTP(w, r) return } if !parseOK { writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token") return } claims, err := jwtSvc.VerifyToken(token) if err != nil { logger.Warn("auth middleware token validation failed", "path", r.URL.Path, "error", err) writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token") return } ctx := context.WithValue(r.Context(), authClaimsContextKey{}, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // RequireRole checks JWT claims injected by RequireAuth and enforces role policy. // Returns: // - 401 when no validated auth claims are present // - 403 when claims are present but missing required role(s) func RequireRole(requiredRoles ...string) Handler { normalizedRequired := normalizeRoles(requiredRoles) if len(normalizedRequired) == 0 { return defaultHandler } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, ok := ClaimsFromContext(r.Context()) if !ok { writeJSONAuthError(w, http.StatusUnauthorized, "missing authentication context") return } if !hasAnyRequiredRole(claims.Roles, normalizedRequired) { writeJSONAuthError(w, http.StatusForbidden, "insufficient role") return } next.ServeHTTP(w, r) }) } } func writeJSONAuthError(w http.ResponseWriter, statusCode int, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) _ = json.NewEncoder(w).Encode(map[string]string{ "status": "ERROR", "message": message, }) } func extractBearerToken(headerValue string) (token string, hasHeader bool, ok bool) { headerValue = strings.TrimSpace(headerValue) if headerValue == "" { return "", false, false } parts := strings.Fields(headerValue) if len(parts) != 2 { return "", true, false } if !strings.EqualFold(parts[0], "Bearer") { return "", true, false } token = strings.TrimSpace(parts[1]) if token == "" { return "", true, false } return token, true, true } func normalizeRoles(roles []string) []string { if len(roles) == 0 { return nil } seen := make(map[string]struct{}, len(roles)) out := make([]string, 0, len(roles)) for _, role := range roles { role = strings.ToLower(strings.TrimSpace(role)) if role == "" { continue } if _, ok := seen[role]; ok { continue } seen[role] = struct{}{} out = append(out, role) } if len(out) == 0 { return nil } return out } func hasAnyRequiredRole(userRoles []string, requiredRoles []string) bool { if len(requiredRoles) == 0 { return true } userRoleSet := make(map[string]struct{}, len(userRoles)) for _, role := range normalizeRoles(userRoles) { userRoleSet[role] = struct{}{} } if len(userRoleSet) == 0 { return false } for _, requiredRole := range requiredRoles { if hasRoleWithHierarchy(userRoleSet, requiredRole) { return true } } return false } func hasRoleWithHierarchy(userRoleSet map[string]struct{}, requiredRole string) bool { if _, ok := userRoleSet[requiredRole]; ok { return true } // Admin implies viewer access. if requiredRole == RoleViewer { _, ok := userRoleSet[RoleAdmin] return ok } return false }