add auth support
continuous-integration/drone/push Build is passing

This commit is contained in:
2026-04-17 13:19:08 +10:00
parent 9a561f3b07
commit ae3e2be89a
22 changed files with 2479 additions and 40 deletions
+206
View File
@@ -0,0 +1,206 @@
package middleware
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strings"
"time"
"vctp/internal/auth"
"vctp/internal/settings"
)
const (
authModeDisabled = "disabled"
authModeOptional = "optional"
authModeRequired = "required"
RoleViewer = "viewer"
RoleAdmin = "admin"
)
type authClaimsContextKey struct{}
// ClaimsFromContext returns validated JWT claims injected by RequireAuth.
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
if ctx == nil {
return auth.Claims{}, false
}
claims, ok := ctx.Value(authClaimsContextKey{}).(auth.Claims)
return claims, ok
}
// RequireAuth validates Bearer tokens according to settings.auth_mode:
// - disabled: auth bypassed
// - optional: missing token allowed, provided token must be valid
// - required: token required and must be valid
func RequireAuth(logger *slog.Logger, cfg *settings.Settings) Handler {
if logger == nil {
logger = slog.Default()
}
if cfg == nil || cfg.Values == nil {
return defaultHandler
}
values := cfg.Values.Settings
mode := strings.ToLower(strings.TrimSpace(values.AuthMode))
if mode == "" {
mode = authModeDisabled
}
if !values.AuthEnabled || mode == authModeDisabled {
return defaultHandler
}
jwtSvc, err := auth.NewJWTService(auth.JWTConfig{
SigningKeyBase64: values.AuthJWTSigningKey,
Issuer: values.AuthJWTIssuer,
Audience: values.AuthJWTAudience,
TokenLifespan: time.Duration(values.AuthTokenLifespanMinutes) * time.Minute,
ClockSkew: time.Duration(values.AuthClockSkewSeconds) * time.Second,
})
if err != nil {
logger.Error("auth middleware init failed", "error", err)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeJSONAuthError(w, http.StatusServiceUnavailable, "authentication service unavailable")
})
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, hasHeader, parseOK := extractBearerToken(r.Header.Get("Authorization"))
if !hasHeader {
if mode == authModeRequired {
writeJSONAuthError(w, http.StatusUnauthorized, "missing bearer token")
return
}
next.ServeHTTP(w, r)
return
}
if !parseOK {
writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token")
return
}
claims, err := jwtSvc.VerifyToken(token)
if err != nil {
logger.Warn("auth middleware token validation failed", "path", r.URL.Path, "error", err)
writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token")
return
}
ctx := context.WithValue(r.Context(), authClaimsContextKey{}, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequireRole checks JWT claims injected by RequireAuth and enforces role policy.
// Returns:
// - 401 when no validated auth claims are present
// - 403 when claims are present but missing required role(s)
func RequireRole(requiredRoles ...string) Handler {
normalizedRequired := normalizeRoles(requiredRoles)
if len(normalizedRequired) == 0 {
return defaultHandler
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := ClaimsFromContext(r.Context())
if !ok {
writeJSONAuthError(w, http.StatusUnauthorized, "missing authentication context")
return
}
if !hasAnyRequiredRole(claims.Roles, normalizedRequired) {
writeJSONAuthError(w, http.StatusForbidden, "insufficient role")
return
}
next.ServeHTTP(w, r)
})
}
}
func writeJSONAuthError(w http.ResponseWriter, statusCode int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "ERROR",
"message": message,
})
}
func extractBearerToken(headerValue string) (token string, hasHeader bool, ok bool) {
headerValue = strings.TrimSpace(headerValue)
if headerValue == "" {
return "", false, false
}
parts := strings.Fields(headerValue)
if len(parts) != 2 {
return "", true, false
}
if !strings.EqualFold(parts[0], "Bearer") {
return "", true, false
}
token = strings.TrimSpace(parts[1])
if token == "" {
return "", true, false
}
return token, true, true
}
func normalizeRoles(roles []string) []string {
if len(roles) == 0 {
return nil
}
seen := make(map[string]struct{}, len(roles))
out := make([]string, 0, len(roles))
for _, role := range roles {
role = strings.ToLower(strings.TrimSpace(role))
if role == "" {
continue
}
if _, ok := seen[role]; ok {
continue
}
seen[role] = struct{}{}
out = append(out, role)
}
if len(out) == 0 {
return nil
}
return out
}
func hasAnyRequiredRole(userRoles []string, requiredRoles []string) bool {
if len(requiredRoles) == 0 {
return true
}
userRoleSet := make(map[string]struct{}, len(userRoles))
for _, role := range normalizeRoles(userRoles) {
userRoleSet[role] = struct{}{}
}
if len(userRoleSet) == 0 {
return false
}
for _, requiredRole := range requiredRoles {
if hasRoleWithHierarchy(userRoleSet, requiredRole) {
return true
}
}
return false
}
func hasRoleWithHierarchy(userRoleSet map[string]struct{}, requiredRole string) bool {
if _, ok := userRoleSet[requiredRole]; ok {
return true
}
// Admin implies viewer access.
if requiredRole == RoleViewer {
_, ok := userRoleSet[RoleAdmin]
return ok
}
return false
}
+201
View File
@@ -0,0 +1,201 @@
package middleware
import (
"encoding/base64"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"vctp/internal/auth"
"vctp/internal/settings"
)
func TestRequireAuthRequiredRejectsMissingToken(t *testing.T) {
cfg := testAuthSettings(true, authModeRequired)
mw := RequireAuth(testLogger(), cfg)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
}
}
func TestRequireAuthRequiredAcceptsValidTokenAndInjectsClaims(t *testing.T) {
cfg := testAuthSettings(true, authModeRequired)
token := mustTokenForConfig(t, cfg, "alice")
mw := RequireAuth(testLogger(), cfg)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
req.Header.Set("Authorization", "Bearer "+token)
var gotSubject string
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := ClaimsFromContext(r.Context())
if !ok {
t.Fatal("expected claims in request context")
}
gotSubject = claims.Subject
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
}
if gotSubject != "alice" {
t.Fatalf("expected subject alice, got %q", gotSubject)
}
}
func TestRequireAuthOptionalAllowsNoToken(t *testing.T) {
cfg := testAuthSettings(true, authModeOptional)
mw := RequireAuth(testLogger(), cfg)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})).ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("expected status %d, got %d", http.StatusNoContent, rr.Code)
}
}
func TestRequireAuthOptionalRejectsInvalidProvidedToken(t *testing.T) {
cfg := testAuthSettings(true, authModeOptional)
mw := RequireAuth(testLogger(), cfg)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
req.Header.Set("Authorization", "Bearer not-a-jwt")
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})).ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
}
}
func TestRequireAuthDisabledBypassesMiddleware(t *testing.T) {
cfg := testAuthSettings(false, authModeDisabled)
mw := RequireAuth(testLogger(), cfg)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
})).ServeHTTP(rr, req)
if rr.Code != http.StatusAccepted {
t.Fatalf("expected status %d, got %d", http.StatusAccepted, rr.Code)
}
}
func TestRequireRoleRejectsMissingAuthContext(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
}
}
func TestRequireRoleRejectsInsufficientRole(t *testing.T) {
cfg := testAuthSettings(true, authModeRequired)
token := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer})
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
req.Header.Set("Authorization", "Bearer "+token)
protected := RequireAuth(testLogger(), cfg)(
RequireRole(RoleAdmin)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})),
)
protected.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Fatalf("expected status %d, got %d", http.StatusForbidden, rr.Code)
}
}
func TestRequireRoleViewerAllowsViewerAndAdmin(t *testing.T) {
cfg := testAuthSettings(true, authModeRequired)
viewerToken := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer})
adminToken := mustTokenForConfigWithRoles(t, cfg, "bob", []string{RoleAdmin})
for name, token := range map[string]string{
"viewer": viewerToken,
"admin": adminToken,
} {
t.Run(name, func(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil)
req.Header.Set("Authorization", "Bearer "+token)
protected := RequireAuth(testLogger(), cfg)(
RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})),
)
protected.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
}
})
}
}
func mustTokenForConfig(t *testing.T, cfg *settings.Settings, subject string) string {
t.Helper()
return mustTokenForConfigWithRoles(t, cfg, subject, []string{"admin"})
}
func mustTokenForConfigWithRoles(t *testing.T, cfg *settings.Settings, subject string, roles []string) string {
t.Helper()
svc, err := auth.NewJWTService(auth.JWTConfig{
SigningKeyBase64: cfg.Values.Settings.AuthJWTSigningKey,
Issuer: cfg.Values.Settings.AuthJWTIssuer,
Audience: cfg.Values.Settings.AuthJWTAudience,
TokenLifespan: time.Duration(cfg.Values.Settings.AuthTokenLifespanMinutes) * time.Minute,
ClockSkew: time.Duration(cfg.Values.Settings.AuthClockSkewSeconds) * time.Second,
})
if err != nil {
t.Fatalf("failed to create jwt service: %v", err)
}
token, _, err := svc.IssueToken(subject, roles, nil)
if err != nil {
t.Fatalf("failed to issue token: %v", err)
}
return token
}
func testAuthSettings(enabled bool, mode string) *settings.Settings {
cfg := &settings.Settings{Values: &settings.SettingsYML{}}
cfg.Values.Settings.AuthEnabled = enabled
cfg.Values.Settings.AuthMode = mode
cfg.Values.Settings.AuthJWTSigningKey = base64.StdEncoding.EncodeToString([]byte("middleware-test-signing-key"))
cfg.Values.Settings.AuthTokenLifespanMinutes = 120
cfg.Values.Settings.AuthJWTIssuer = "vctp"
cfg.Values.Settings.AuthJWTAudience = "vctp-api"
cfg.Values.Settings.AuthClockSkewSeconds = 60
return cfg
}
func testLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(io.Discard, nil))
}