@@ -0,0 +1,146 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"vctp/internal/auth"
|
||||
"vctp/server/models"
|
||||
)
|
||||
|
||||
const (
|
||||
authLoginFailureMessage = "invalid username or password"
|
||||
authLoginRequestTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type ldapAuthenticator interface {
|
||||
AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (auth.LDAPIdentity, error)
|
||||
}
|
||||
|
||||
type jwtService interface {
|
||||
IssueToken(subject string, roles []string, groups []string) (string, auth.Claims, error)
|
||||
}
|
||||
|
||||
var newLDAPAuthenticator = func(cfg auth.LDAPConfig) (ldapAuthenticator, error) {
|
||||
return auth.NewLDAPAuthenticator(cfg)
|
||||
}
|
||||
|
||||
var newJWTService = func(cfg auth.JWTConfig) (jwtService, error) {
|
||||
return auth.NewJWTService(cfg)
|
||||
}
|
||||
|
||||
// AuthLogin authenticates a user against LDAP and returns a signed JWT.
|
||||
// @Summary Login
|
||||
// @Description Authenticates a username/password against LDAP and returns a signed access token.
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param payload body models.AuthLoginRequest true "Login credentials"
|
||||
// @Success 200 {object} models.AuthLoginResponse "Login success"
|
||||
// @Failure 400 {object} models.ErrorResponse "Invalid request"
|
||||
// @Failure 401 {object} models.ErrorResponse "Invalid credentials"
|
||||
// @Failure 500 {object} models.ErrorResponse "Server error"
|
||||
// @Failure 503 {object} models.ErrorResponse "Authentication disabled"
|
||||
// @Router /api/auth/login [post]
|
||||
func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
if h == nil || h.Settings == nil || h.Settings.Values == nil {
|
||||
writeJSONError(w, http.StatusInternalServerError, "authentication is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
cfg := h.Settings.Values.Settings
|
||||
if !cfg.AuthEnabled {
|
||||
writeJSONError(w, http.StatusServiceUnavailable, "authentication is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
var req models.AuthLoginRequest
|
||||
if err := decodeJSONBody(w, r, &req); err != nil {
|
||||
h.Logger.Error("unable to decode auth login request", "error", err)
|
||||
writeJSONError(w, http.StatusBadRequest, "invalid JSON body")
|
||||
return
|
||||
}
|
||||
username := strings.TrimSpace(req.Username)
|
||||
password := req.Password
|
||||
if username == "" || strings.TrimSpace(password) == "" {
|
||||
writeJSONError(w, http.StatusBadRequest, "username and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
ldapAuth, err := newLDAPAuthenticator(auth.LDAPConfig{
|
||||
BindAddress: cfg.LDAPBindAddress,
|
||||
BaseDN: cfg.LDAPBaseDN,
|
||||
TrustCertFile: cfg.LDAPTrustCertFile,
|
||||
DisableValidation: cfg.LDAPDisableValidation,
|
||||
Insecure: cfg.LDAPInsecure,
|
||||
DialTimeout: authLoginRequestTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
h.Logger.Error("failed to initialize ldap authenticator", "error", err)
|
||||
writeJSONError(w, http.StatusInternalServerError, "authentication service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := withRequestTimeout(r, authLoginRequestTimeout)
|
||||
defer cancel()
|
||||
identity, err := ldapAuth.AuthenticateAndFetchGroups(ctx, username, password)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrLDAPInvalidCredentials) {
|
||||
h.Logger.Warn("auth login rejected", "username", username, "reason", "invalid_credentials")
|
||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
h.Logger.Warn("auth login ldap timeout", "username", username, "error", err)
|
||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||
return
|
||||
}
|
||||
h.Logger.Warn("auth login ldap failure", "username", username, "error", err)
|
||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||
return
|
||||
}
|
||||
|
||||
roles := auth.ResolveRoles(identity.Groups, cfg.AuthGroupRoleMappings)
|
||||
if !auth.HasAnyGroup(identity.Groups, cfg.LDAPGroups) || len(roles) == 0 {
|
||||
h.Logger.Warn("auth login rejected", "username", username, "reason", "group_or_role_denied")
|
||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||
return
|
||||
}
|
||||
|
||||
jwtSvc, err := newJWTService(auth.JWTConfig{
|
||||
SigningKeyBase64: cfg.AuthJWTSigningKey,
|
||||
Issuer: cfg.AuthJWTIssuer,
|
||||
Audience: cfg.AuthJWTAudience,
|
||||
TokenLifespan: time.Duration(cfg.AuthTokenLifespanMinutes) * time.Minute,
|
||||
ClockSkew: time.Duration(cfg.AuthClockSkewSeconds) * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
h.Logger.Error("failed to initialize jwt service", "error", err)
|
||||
writeJSONError(w, http.StatusInternalServerError, "authentication service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
subject := strings.TrimSpace(identity.Username)
|
||||
if subject == "" {
|
||||
subject = username
|
||||
}
|
||||
token, claims, err := jwtSvc.IssueToken(subject, roles, identity.Groups)
|
||||
if err != nil {
|
||||
h.Logger.Error("failed to issue auth token", "username", username, "error", err)
|
||||
writeJSONError(w, http.StatusInternalServerError, "failed to issue access token")
|
||||
return
|
||||
}
|
||||
|
||||
h.Logger.Info("auth login successful", "username", subject, "roles", roles)
|
||||
writeJSON(w, http.StatusOK, models.AuthLoginResponse{
|
||||
AccessToken: token,
|
||||
ExpiresAt: claims.ExpiresAt,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"vctp/internal/auth"
|
||||
"vctp/internal/settings"
|
||||
"vctp/server/models"
|
||||
)
|
||||
|
||||
type stubLDAPAuthenticator struct {
|
||||
identity auth.LDAPIdentity
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubLDAPAuthenticator) AuthenticateAndFetchGroups(_ context.Context, _ string, _ string) (auth.LDAPIdentity, error) {
|
||||
return s.identity, s.err
|
||||
}
|
||||
|
||||
type stubJWTService struct {
|
||||
token string
|
||||
claims auth.Claims
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubJWTService) IssueToken(_ string, _ []string, _ []string) (string, auth.Claims, error) {
|
||||
return s.token, s.claims, s.err
|
||||
}
|
||||
|
||||
func TestAuthLoginAuthDisabled(t *testing.T) {
|
||||
h := &Handler{
|
||||
Logger: newTestLogger(),
|
||||
Settings: &settings.Settings{Values: &settings.SettingsYML{}},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`))
|
||||
rr := httptest.NewRecorder()
|
||||
h.AuthLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusServiceUnavailable, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthLoginInvalidCredentials(t *testing.T) {
|
||||
restoreFactories := swapAuthFactoriesForTest(
|
||||
func(_ auth.LDAPConfig) (ldapAuthenticator, error) {
|
||||
return &stubLDAPAuthenticator{err: auth.ErrLDAPInvalidCredentials}, nil
|
||||
},
|
||||
func(_ auth.JWTConfig) (jwtService, error) {
|
||||
return &stubJWTService{}, nil
|
||||
},
|
||||
)
|
||||
defer restoreFactories()
|
||||
|
||||
h := &Handler{
|
||||
Logger: newTestLogger(),
|
||||
Settings: testAuthEnabledSettings(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`))
|
||||
rr := httptest.NewRecorder()
|
||||
h.AuthLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
||||
}
|
||||
var payload models.ErrorResponse
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if payload.Message != authLoginFailureMessage {
|
||||
t.Fatalf("unexpected error message: %q", payload.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthLoginRejectsUnmappedRoles(t *testing.T) {
|
||||
restoreFactories := swapAuthFactoriesForTest(
|
||||
func(_ auth.LDAPConfig) (ldapAuthenticator, error) {
|
||||
return &stubLDAPAuthenticator{
|
||||
identity: auth.LDAPIdentity{
|
||||
Username: "alice",
|
||||
Groups: []string{"cn=other-group,ou=groups,dc=example,dc=com"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
func(_ auth.JWTConfig) (jwtService, error) {
|
||||
return &stubJWTService{}, nil
|
||||
},
|
||||
)
|
||||
defer restoreFactories()
|
||||
|
||||
h := &Handler{
|
||||
Logger: newTestLogger(),
|
||||
Settings: testAuthEnabledSettings(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`))
|
||||
rr := httptest.NewRecorder()
|
||||
h.AuthLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthLoginSuccess(t *testing.T) {
|
||||
restoreFactories := swapAuthFactoriesForTest(
|
||||
func(_ auth.LDAPConfig) (ldapAuthenticator, error) {
|
||||
return &stubLDAPAuthenticator{
|
||||
identity: auth.LDAPIdentity{
|
||||
Username: "alice",
|
||||
UserDN: "cn=alice,ou=users,dc=example,dc=com",
|
||||
Groups: []string{"cn=vctp-admins,ou=groups,dc=example,dc=com"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
func(_ auth.JWTConfig) (jwtService, error) {
|
||||
return &stubJWTService{
|
||||
token: "issued-token",
|
||||
claims: auth.Claims{
|
||||
ExpiresAt: time.Unix(1_700_000_000, 0).Unix(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
defer restoreFactories()
|
||||
|
||||
h := &Handler{
|
||||
Logger: newTestLogger(),
|
||||
Settings: testAuthEnabledSettings(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`))
|
||||
rr := httptest.NewRecorder()
|
||||
h.AuthLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, rr.Code, rr.Body.String())
|
||||
}
|
||||
var payload models.AuthLoginResponse
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if payload.AccessToken != "issued-token" {
|
||||
t.Fatalf("unexpected token: %q", payload.AccessToken)
|
||||
}
|
||||
if payload.TokenType != "Bearer" {
|
||||
t.Fatalf("unexpected token type: %q", payload.TokenType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthLoginJWTFactoryFailure(t *testing.T) {
|
||||
restoreFactories := swapAuthFactoriesForTest(
|
||||
func(_ auth.LDAPConfig) (ldapAuthenticator, error) {
|
||||
return &stubLDAPAuthenticator{
|
||||
identity: auth.LDAPIdentity{
|
||||
Username: "alice",
|
||||
Groups: []string{"cn=vctp-admins,ou=groups,dc=example,dc=com"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
func(_ auth.JWTConfig) (jwtService, error) {
|
||||
return nil, errors.New("jwt init failed")
|
||||
},
|
||||
)
|
||||
defer restoreFactories()
|
||||
|
||||
h := &Handler{
|
||||
Logger: newTestLogger(),
|
||||
Settings: testAuthEnabledSettings(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`))
|
||||
rr := httptest.NewRecorder()
|
||||
h.AuthLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func testAuthEnabledSettings() *settings.Settings {
|
||||
cfg := &settings.Settings{Values: &settings.SettingsYML{}}
|
||||
cfg.Values.Settings.AuthEnabled = true
|
||||
cfg.Values.Settings.AuthMode = "required"
|
||||
cfg.Values.Settings.AuthJWTSigningKey = base64.StdEncoding.EncodeToString([]byte("test-signing-key"))
|
||||
cfg.Values.Settings.AuthTokenLifespanMinutes = 120
|
||||
cfg.Values.Settings.AuthJWTIssuer = "vctp"
|
||||
cfg.Values.Settings.AuthJWTAudience = "vctp-api"
|
||||
cfg.Values.Settings.AuthClockSkewSeconds = 60
|
||||
cfg.Values.Settings.LDAPBindAddress = "ldaps://ldap.example.com:636"
|
||||
cfg.Values.Settings.LDAPBaseDN = "dc=example,dc=com"
|
||||
cfg.Values.Settings.AuthGroupRoleMappings = map[string]string{
|
||||
"cn=vctp-admins,ou=groups,dc=example,dc=com": "admin",
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func swapAuthFactoriesForTest(
|
||||
ldapFactory func(auth.LDAPConfig) (ldapAuthenticator, error),
|
||||
jwtFactory func(auth.JWTConfig) (jwtService, error),
|
||||
) func() {
|
||||
origLDAPFactory := newLDAPAuthenticator
|
||||
origJWTFactory := newJWTService
|
||||
newLDAPAuthenticator = ldapFactory
|
||||
newJWTService = jwtFactory
|
||||
return func() {
|
||||
newLDAPAuthenticator = origLDAPFactory
|
||||
newJWTService = origJWTFactory
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,13 @@ func TestMutatingHandlersRejectWrongMethod(t *testing.T) {
|
||||
path string
|
||||
call func(*Handler, *httptest.ResponseRecorder, *http.Request)
|
||||
}{
|
||||
{
|
||||
name: "auth login",
|
||||
path: "/api/auth/login",
|
||||
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
|
||||
h.AuthLogin(rr, req)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "snapshot force hourly",
|
||||
path: "/api/snapshots/hourly/force",
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"vctp/internal/auth"
|
||||
"vctp/internal/settings"
|
||||
)
|
||||
|
||||
const (
|
||||
authModeDisabled = "disabled"
|
||||
authModeOptional = "optional"
|
||||
authModeRequired = "required"
|
||||
|
||||
RoleViewer = "viewer"
|
||||
RoleAdmin = "admin"
|
||||
)
|
||||
|
||||
type authClaimsContextKey struct{}
|
||||
|
||||
// ClaimsFromContext returns validated JWT claims injected by RequireAuth.
|
||||
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
|
||||
if ctx == nil {
|
||||
return auth.Claims{}, false
|
||||
}
|
||||
claims, ok := ctx.Value(authClaimsContextKey{}).(auth.Claims)
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// RequireAuth validates Bearer tokens according to settings.auth_mode:
|
||||
// - disabled: auth bypassed
|
||||
// - optional: missing token allowed, provided token must be valid
|
||||
// - required: token required and must be valid
|
||||
func RequireAuth(logger *slog.Logger, cfg *settings.Settings) Handler {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
if cfg == nil || cfg.Values == nil {
|
||||
return defaultHandler
|
||||
}
|
||||
|
||||
values := cfg.Values.Settings
|
||||
mode := strings.ToLower(strings.TrimSpace(values.AuthMode))
|
||||
if mode == "" {
|
||||
mode = authModeDisabled
|
||||
}
|
||||
if !values.AuthEnabled || mode == authModeDisabled {
|
||||
return defaultHandler
|
||||
}
|
||||
|
||||
jwtSvc, err := auth.NewJWTService(auth.JWTConfig{
|
||||
SigningKeyBase64: values.AuthJWTSigningKey,
|
||||
Issuer: values.AuthJWTIssuer,
|
||||
Audience: values.AuthJWTAudience,
|
||||
TokenLifespan: time.Duration(values.AuthTokenLifespanMinutes) * time.Minute,
|
||||
ClockSkew: time.Duration(values.AuthClockSkewSeconds) * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("auth middleware init failed", "error", err)
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONAuthError(w, http.StatusServiceUnavailable, "authentication service unavailable")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, hasHeader, parseOK := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if !hasHeader {
|
||||
if mode == authModeRequired {
|
||||
writeJSONAuthError(w, http.StatusUnauthorized, "missing bearer token")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if !parseOK {
|
||||
writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := jwtSvc.VerifyToken(token)
|
||||
if err != nil {
|
||||
logger.Warn("auth middleware token validation failed", "path", r.URL.Path, "error", err)
|
||||
writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), authClaimsContextKey{}, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole checks JWT claims injected by RequireAuth and enforces role policy.
|
||||
// Returns:
|
||||
// - 401 when no validated auth claims are present
|
||||
// - 403 when claims are present but missing required role(s)
|
||||
func RequireRole(requiredRoles ...string) Handler {
|
||||
normalizedRequired := normalizeRoles(requiredRoles)
|
||||
if len(normalizedRequired) == 0 {
|
||||
return defaultHandler
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := ClaimsFromContext(r.Context())
|
||||
if !ok {
|
||||
writeJSONAuthError(w, http.StatusUnauthorized, "missing authentication context")
|
||||
return
|
||||
}
|
||||
if !hasAnyRequiredRole(claims.Roles, normalizedRequired) {
|
||||
writeJSONAuthError(w, http.StatusForbidden, "insufficient role")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSONAuthError(w http.ResponseWriter, statusCode int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ERROR",
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
func extractBearerToken(headerValue string) (token string, hasHeader bool, ok bool) {
|
||||
headerValue = strings.TrimSpace(headerValue)
|
||||
if headerValue == "" {
|
||||
return "", false, false
|
||||
}
|
||||
parts := strings.Fields(headerValue)
|
||||
if len(parts) != 2 {
|
||||
return "", true, false
|
||||
}
|
||||
if !strings.EqualFold(parts[0], "Bearer") {
|
||||
return "", true, false
|
||||
}
|
||||
token = strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
return "", true, false
|
||||
}
|
||||
return token, true, true
|
||||
}
|
||||
|
||||
func normalizeRoles(roles []string) []string {
|
||||
if len(roles) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(roles))
|
||||
out := make([]string, 0, len(roles))
|
||||
for _, role := range roles {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[role]; ok {
|
||||
continue
|
||||
}
|
||||
seen[role] = struct{}{}
|
||||
out = append(out, role)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hasAnyRequiredRole(userRoles []string, requiredRoles []string) bool {
|
||||
if len(requiredRoles) == 0 {
|
||||
return true
|
||||
}
|
||||
userRoleSet := make(map[string]struct{}, len(userRoles))
|
||||
for _, role := range normalizeRoles(userRoles) {
|
||||
userRoleSet[role] = struct{}{}
|
||||
}
|
||||
if len(userRoleSet) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, requiredRole := range requiredRoles {
|
||||
if hasRoleWithHierarchy(userRoleSet, requiredRole) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasRoleWithHierarchy(userRoleSet map[string]struct{}, requiredRole string) bool {
|
||||
if _, ok := userRoleSet[requiredRole]; ok {
|
||||
return true
|
||||
}
|
||||
// Admin implies viewer access.
|
||||
if requiredRole == RoleViewer {
|
||||
_, ok := userRoleSet[RoleAdmin]
|
||||
return ok
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"vctp/internal/auth"
|
||||
"vctp/internal/settings"
|
||||
)
|
||||
|
||||
func TestRequireAuthRequiredRejectsMissingToken(t *testing.T) {
|
||||
cfg := testAuthSettings(true, authModeRequired)
|
||||
mw := RequireAuth(testLogger(), cfg)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthRequiredAcceptsValidTokenAndInjectsClaims(t *testing.T) {
|
||||
cfg := testAuthSettings(true, authModeRequired)
|
||||
token := mustTokenForConfig(t, cfg, "alice")
|
||||
mw := RequireAuth(testLogger(), cfg)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
var gotSubject string
|
||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := ClaimsFromContext(r.Context())
|
||||
if !ok {
|
||||
t.Fatal("expected claims in request context")
|
||||
}
|
||||
gotSubject = claims.Subject
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
if gotSubject != "alice" {
|
||||
t.Fatalf("expected subject alice, got %q", gotSubject)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthOptionalAllowsNoToken(t *testing.T) {
|
||||
cfg := testAuthSettings(true, authModeOptional)
|
||||
mw := RequireAuth(testLogger(), cfg)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})).ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusNoContent, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthOptionalRejectsInvalidProvidedToken(t *testing.T) {
|
||||
cfg := testAuthSettings(true, authModeOptional)
|
||||
mw := RequireAuth(testLogger(), cfg)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
req.Header.Set("Authorization", "Bearer not-a-jwt")
|
||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})).ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthDisabledBypassesMiddleware(t *testing.T) {
|
||||
cfg := testAuthSettings(false, authModeDisabled)
|
||||
mw := RequireAuth(testLogger(), cfg)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})).ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusAccepted {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusAccepted, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleRejectsMissingAuthContext(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleRejectsInsufficientRole(t *testing.T) {
|
||||
cfg := testAuthSettings(true, authModeRequired)
|
||||
token := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
protected := RequireAuth(testLogger(), cfg)(
|
||||
RequireRole(RoleAdmin)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})),
|
||||
)
|
||||
protected.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusForbidden, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleViewerAllowsViewerAndAdmin(t *testing.T) {
|
||||
cfg := testAuthSettings(true, authModeRequired)
|
||||
viewerToken := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer})
|
||||
adminToken := mustTokenForConfigWithRoles(t, cfg, "bob", []string{RoleAdmin})
|
||||
|
||||
for name, token := range map[string]string{
|
||||
"viewer": viewerToken,
|
||||
"admin": adminToken,
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
protected := RequireAuth(testLogger(), cfg)(
|
||||
RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})),
|
||||
)
|
||||
protected.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustTokenForConfig(t *testing.T, cfg *settings.Settings, subject string) string {
|
||||
t.Helper()
|
||||
return mustTokenForConfigWithRoles(t, cfg, subject, []string{"admin"})
|
||||
}
|
||||
|
||||
func mustTokenForConfigWithRoles(t *testing.T, cfg *settings.Settings, subject string, roles []string) string {
|
||||
t.Helper()
|
||||
svc, err := auth.NewJWTService(auth.JWTConfig{
|
||||
SigningKeyBase64: cfg.Values.Settings.AuthJWTSigningKey,
|
||||
Issuer: cfg.Values.Settings.AuthJWTIssuer,
|
||||
Audience: cfg.Values.Settings.AuthJWTAudience,
|
||||
TokenLifespan: time.Duration(cfg.Values.Settings.AuthTokenLifespanMinutes) * time.Minute,
|
||||
ClockSkew: time.Duration(cfg.Values.Settings.AuthClockSkewSeconds) * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create jwt service: %v", err)
|
||||
}
|
||||
token, _, err := svc.IssueToken(subject, roles, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to issue token: %v", err)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func testAuthSettings(enabled bool, mode string) *settings.Settings {
|
||||
cfg := &settings.Settings{Values: &settings.SettingsYML{}}
|
||||
cfg.Values.Settings.AuthEnabled = enabled
|
||||
cfg.Values.Settings.AuthMode = mode
|
||||
cfg.Values.Settings.AuthJWTSigningKey = base64.StdEncoding.EncodeToString([]byte("middleware-test-signing-key"))
|
||||
cfg.Values.Settings.AuthTokenLifespanMinutes = 120
|
||||
cfg.Values.Settings.AuthJWTIssuer = "vctp"
|
||||
cfg.Values.Settings.AuthJWTAudience = "vctp-api"
|
||||
cfg.Values.Settings.AuthClockSkewSeconds = 60
|
||||
return cfg
|
||||
}
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
}
|
||||
@@ -17,6 +17,19 @@ type ErrorResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// AuthLoginRequest represents login payload for LDAP/JWT authentication.
|
||||
type AuthLoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// AuthLoginResponse represents successful auth login response.
|
||||
type AuthLoginResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// SnapshotMigrationStats mirrors the snapshot registry migration stats payload.
|
||||
type SnapshotMigrationStats struct {
|
||||
HourlyRenamed int `json:"HourlyRenamed"`
|
||||
|
||||
+37
-26
@@ -29,6 +29,14 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
requireAuth := middleware.RequireAuth(logger, settings)
|
||||
withAuthRole := func(next http.HandlerFunc, roles ...string) http.Handler {
|
||||
wrapped := http.Handler(http.HandlerFunc(next))
|
||||
if len(roles) > 0 {
|
||||
wrapped = middleware.RequireRole(roles...)(wrapped)
|
||||
}
|
||||
return requireAuth(wrapped)
|
||||
}
|
||||
|
||||
reportsDir := settings.Values.Settings.ReportsDir
|
||||
if reportsDir == "" {
|
||||
@@ -44,37 +52,38 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st
|
||||
mux.Handle("/favicon-32x32.png", middleware.CacheMiddleware(http.FileServer(http.FS(dist.AssetsDir))))
|
||||
mux.Handle("/reports/", http.StripPrefix("/reports/", http.FileServer(http.Dir(filepath.Clean(reportsDir)))))
|
||||
mux.HandleFunc("/", h.Home)
|
||||
mux.HandleFunc("/api/event/vm/create", h.VmCreateEvent)
|
||||
mux.HandleFunc("/api/event/vm/modify", h.VmModifyEvent)
|
||||
mux.HandleFunc("/api/event/vm/move", h.VmMoveEvent)
|
||||
mux.HandleFunc("/api/event/vm/delete", h.VmDeleteEvent)
|
||||
mux.HandleFunc("/api/import/vm", h.VmImport)
|
||||
mux.Handle("/api/event/vm/create", withAuthRole(h.VmCreateEvent, middleware.RoleAdmin))
|
||||
mux.Handle("/api/event/vm/modify", withAuthRole(h.VmModifyEvent, middleware.RoleAdmin))
|
||||
mux.Handle("/api/event/vm/move", withAuthRole(h.VmMoveEvent, middleware.RoleAdmin))
|
||||
mux.Handle("/api/event/vm/delete", withAuthRole(h.VmDeleteEvent, middleware.RoleAdmin))
|
||||
mux.Handle("/api/import/vm", withAuthRole(h.VmImport, middleware.RoleAdmin))
|
||||
// Use this when we need to manually remove a VM from the database to clean up
|
||||
mux.HandleFunc("/api/inventory/vm/delete", h.VmCleanup)
|
||||
mux.Handle("/api/inventory/vm/delete", withAuthRole(h.VmCleanup, middleware.RoleAdmin))
|
||||
|
||||
// add missing data to VMs
|
||||
mux.HandleFunc("/api/inventory/vm/update", h.VmUpdateDetails)
|
||||
mux.Handle("/api/inventory/vm/update", withAuthRole(h.VmUpdateDetails, middleware.RoleAdmin))
|
||||
|
||||
// Legacy/maintenance endpoints are gated by settings.enable_legacy_api.
|
||||
mux.HandleFunc("/api/cleanup/updates", h.UpdateCleanup)
|
||||
mux.Handle("/api/cleanup/updates", withAuthRole(h.UpdateCleanup, middleware.RoleAdmin))
|
||||
//mux.HandleFunc("/api/cleanup/vcenter", h.VcCleanup)
|
||||
|
||||
mux.HandleFunc("/api/report/inventory", h.InventoryReportDownload)
|
||||
mux.HandleFunc("/api/report/updates", h.UpdateReportDownload)
|
||||
mux.HandleFunc("/api/report/snapshot", h.SnapshotReportDownload)
|
||||
mux.HandleFunc("/api/snapshots/aggregate", h.SnapshotAggregateForce)
|
||||
mux.HandleFunc("/api/snapshots/hourly/force", h.SnapshotForceHourly)
|
||||
mux.HandleFunc("/api/snapshots/migrate", h.SnapshotMigrate)
|
||||
mux.HandleFunc("/api/snapshots/repair", h.SnapshotRepair)
|
||||
mux.HandleFunc("/api/snapshots/repair/all", h.SnapshotRepairSuite)
|
||||
mux.HandleFunc("/api/snapshots/regenerate-hourly-reports", h.SnapshotRegenerateHourlyReports)
|
||||
mux.HandleFunc("/api/diagnostics/daily-creation", h.DailyCreationDiagnostics)
|
||||
mux.Handle("/api/report/inventory", withAuthRole(h.InventoryReportDownload, middleware.RoleViewer))
|
||||
mux.Handle("/api/report/updates", withAuthRole(h.UpdateReportDownload, middleware.RoleViewer))
|
||||
mux.Handle("/api/report/snapshot", withAuthRole(h.SnapshotReportDownload, middleware.RoleViewer))
|
||||
mux.Handle("/api/snapshots/aggregate", withAuthRole(h.SnapshotAggregateForce, middleware.RoleAdmin))
|
||||
mux.Handle("/api/snapshots/hourly/force", withAuthRole(h.SnapshotForceHourly, middleware.RoleAdmin))
|
||||
mux.Handle("/api/snapshots/migrate", withAuthRole(h.SnapshotMigrate, middleware.RoleAdmin))
|
||||
mux.Handle("/api/snapshots/repair", withAuthRole(h.SnapshotRepair, middleware.RoleAdmin))
|
||||
mux.Handle("/api/snapshots/repair/all", withAuthRole(h.SnapshotRepairSuite, middleware.RoleAdmin))
|
||||
mux.Handle("/api/snapshots/regenerate-hourly-reports", withAuthRole(h.SnapshotRegenerateHourlyReports, middleware.RoleAdmin))
|
||||
mux.Handle("/api/diagnostics/daily-creation", withAuthRole(h.DailyCreationDiagnostics, middleware.RoleViewer))
|
||||
mux.HandleFunc("/api/auth/login", h.AuthLogin)
|
||||
mux.HandleFunc("/vm/trace", h.VmTrace)
|
||||
mux.HandleFunc("/vcenters", h.VcenterList)
|
||||
mux.HandleFunc("/vcenters/totals", h.VcenterTotals)
|
||||
mux.HandleFunc("/vcenters/totals/daily", h.VcenterTotalsDaily)
|
||||
mux.HandleFunc("/vcenters/totals/hourly", h.VcenterTotalsHourlyDetailed)
|
||||
mux.HandleFunc("/api/vcenters/cache/rebuild", h.VcenterCacheRebuild)
|
||||
mux.Handle("/api/vcenters/cache/rebuild", withAuthRole(h.VcenterCacheRebuild, middleware.RoleAdmin))
|
||||
mux.HandleFunc("/metrics", h.Metrics)
|
||||
|
||||
mux.HandleFunc("/snapshots/hourly", h.SnapshotHourlyList)
|
||||
@@ -82,7 +91,7 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st
|
||||
mux.HandleFunc("/snapshots/monthly", h.SnapshotMonthlyList)
|
||||
|
||||
// endpoint for encrypting vcenter credential
|
||||
mux.HandleFunc("/api/encrypt", h.EncryptData)
|
||||
mux.Handle("/api/encrypt", withAuthRole(h.EncryptData, middleware.RoleAdmin))
|
||||
|
||||
// serve swagger related components from the embedded fs
|
||||
swaggerSub, err := fs.Sub(swaggerUI, "swagger-ui-dist")
|
||||
@@ -100,12 +109,14 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st
|
||||
w.Write(swaggerSpec)
|
||||
})))
|
||||
|
||||
// Register pprof handlers
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
// Register pprof handlers only when enabled, and gate them behind admin auth.
|
||||
if settings.Values.Settings.EnablePprof {
|
||||
mux.Handle("/debug/pprof/", withAuthRole(pprof.Index, middleware.RoleAdmin))
|
||||
mux.Handle("/debug/pprof/cmdline", withAuthRole(pprof.Cmdline, middleware.RoleAdmin))
|
||||
mux.Handle("/debug/pprof/profile", withAuthRole(pprof.Profile, middleware.RoleAdmin))
|
||||
mux.Handle("/debug/pprof/symbol", withAuthRole(pprof.Symbol, middleware.RoleAdmin))
|
||||
mux.Handle("/debug/pprof/trace", withAuthRole(pprof.Trace, middleware.RoleAdmin))
|
||||
}
|
||||
|
||||
return middleware.NewLoggingMiddleware(logger, mux)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user