213 lines
6.1 KiB
Go
213 lines
6.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
"vctp/internal/auth"
|
|
"vctp/internal/settings"
|
|
"vctp/server/audit"
|
|
)
|
|
|
|
const (
|
|
authModeDisabled = "disabled"
|
|
authModeOptional = "optional"
|
|
authModeRequired = "required"
|
|
|
|
RoleViewer = "viewer"
|
|
RoleAdmin = "admin"
|
|
)
|
|
|
|
type authClaimsContextKey struct{}
|
|
|
|
// ClaimsFromContext returns validated JWT claims injected by RequireAuth.
|
|
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
|
|
if ctx == nil {
|
|
return auth.Claims{}, false
|
|
}
|
|
claims, ok := ctx.Value(authClaimsContextKey{}).(auth.Claims)
|
|
return claims, ok
|
|
}
|
|
|
|
// RequireAuth validates Bearer tokens according to settings.auth_mode:
|
|
// - disabled: auth bypassed
|
|
// - optional: missing token allowed, provided token must be valid
|
|
// - required: token required and must be valid
|
|
func RequireAuth(logger *slog.Logger, cfg *settings.Settings) Handler {
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
if cfg == nil || cfg.Values == nil {
|
|
return defaultHandler
|
|
}
|
|
|
|
values := cfg.Values.Settings
|
|
mode := strings.ToLower(strings.TrimSpace(values.AuthMode))
|
|
if mode == "" {
|
|
mode = authModeDisabled
|
|
}
|
|
if !values.AuthEnabled || mode == authModeDisabled {
|
|
return defaultHandler
|
|
}
|
|
|
|
jwtSvc, err := auth.NewJWTService(auth.JWTConfig{
|
|
SigningKeyBase64: values.AuthJWTSigningKey,
|
|
Issuer: values.AuthJWTIssuer,
|
|
Audience: values.AuthJWTAudience,
|
|
TokenLifespan: time.Duration(values.AuthTokenLifespanMinutes) * time.Minute,
|
|
ClockSkew: time.Duration(values.AuthClockSkewSeconds) * time.Second,
|
|
})
|
|
if err != nil {
|
|
logger.Error("auth middleware init failed", "error", err)
|
|
audit.LogAuthEvent(logger, nil, "auth_middleware_init", "error", "reason", "jwt_service_init_failed", "error", err)
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
writeJSONAuthError(w, http.StatusServiceUnavailable, "authentication service unavailable")
|
|
})
|
|
}
|
|
}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
token, hasHeader, parseOK := extractBearerToken(r.Header.Get("Authorization"))
|
|
if !hasHeader {
|
|
if mode == authModeRequired {
|
|
audit.LogAuthEvent(logger, r, "token_validation", "deny", "reason", "missing_bearer_token", "auth_mode", mode)
|
|
writeJSONAuthError(w, http.StatusUnauthorized, "missing bearer token")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
if !parseOK {
|
|
audit.LogAuthEvent(logger, r, "token_validation", "deny", "reason", "invalid_bearer_header", "auth_mode", mode)
|
|
writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token")
|
|
return
|
|
}
|
|
|
|
claims, err := jwtSvc.VerifyToken(token)
|
|
if err != nil {
|
|
audit.LogAuthEvent(logger, r, "token_validation", "deny", "reason", "invalid_token", "error", err)
|
|
writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token")
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), authClaimsContextKey{}, claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequireRole checks JWT claims injected by RequireAuth and enforces role policy.
|
|
// Returns:
|
|
// - 401 when no validated auth claims are present
|
|
// - 403 when claims are present but missing required role(s)
|
|
func RequireRole(requiredRoles ...string) Handler {
|
|
normalizedRequired := normalizeRoles(requiredRoles)
|
|
if len(normalizedRequired) == 0 {
|
|
return defaultHandler
|
|
}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := ClaimsFromContext(r.Context())
|
|
if !ok {
|
|
audit.LogAuthEvent(nil, r, "role_authorization", "deny", "reason", "missing_auth_context", "required_roles", normalizedRequired)
|
|
writeJSONAuthError(w, http.StatusUnauthorized, "missing authentication context")
|
|
return
|
|
}
|
|
if !hasAnyRequiredRole(claims.Roles, normalizedRequired) {
|
|
audit.LogAuthEvent(nil, r, "role_authorization", "deny", "reason", "insufficient_role", "required_roles", normalizedRequired, "user_roles", normalizeRoles(claims.Roles), "subject", claims.Subject)
|
|
writeJSONAuthError(w, http.StatusForbidden, "insufficient role")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func writeJSONAuthError(w http.ResponseWriter, statusCode int, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
|
"status": "ERROR",
|
|
"message": message,
|
|
})
|
|
}
|
|
|
|
func extractBearerToken(headerValue string) (token string, hasHeader bool, ok bool) {
|
|
headerValue = strings.TrimSpace(headerValue)
|
|
if headerValue == "" {
|
|
return "", false, false
|
|
}
|
|
parts := strings.Fields(headerValue)
|
|
if len(parts) != 2 {
|
|
return "", true, false
|
|
}
|
|
if !strings.EqualFold(parts[0], "Bearer") {
|
|
return "", true, false
|
|
}
|
|
token = strings.TrimSpace(parts[1])
|
|
if token == "" {
|
|
return "", true, false
|
|
}
|
|
return token, true, true
|
|
}
|
|
|
|
func normalizeRoles(roles []string) []string {
|
|
if len(roles) == 0 {
|
|
return nil
|
|
}
|
|
seen := make(map[string]struct{}, len(roles))
|
|
out := make([]string, 0, len(roles))
|
|
for _, role := range roles {
|
|
role = strings.ToLower(strings.TrimSpace(role))
|
|
if role == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[role]; ok {
|
|
continue
|
|
}
|
|
seen[role] = struct{}{}
|
|
out = append(out, role)
|
|
}
|
|
if len(out) == 0 {
|
|
return nil
|
|
}
|
|
return out
|
|
}
|
|
|
|
func hasAnyRequiredRole(userRoles []string, requiredRoles []string) bool {
|
|
if len(requiredRoles) == 0 {
|
|
return true
|
|
}
|
|
userRoleSet := make(map[string]struct{}, len(userRoles))
|
|
for _, role := range normalizeRoles(userRoles) {
|
|
userRoleSet[role] = struct{}{}
|
|
}
|
|
if len(userRoleSet) == 0 {
|
|
return false
|
|
}
|
|
for _, requiredRole := range requiredRoles {
|
|
if hasRoleWithHierarchy(userRoleSet, requiredRole) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func hasRoleWithHierarchy(userRoleSet map[string]struct{}, requiredRole string) bool {
|
|
if _, ok := userRoleSet[requiredRole]; ok {
|
|
return true
|
|
}
|
|
// Admin implies viewer access.
|
|
if requiredRole == RoleViewer {
|
|
_, ok := userRoleSet[RoleAdmin]
|
|
return ok
|
|
}
|
|
return false
|
|
}
|