@@ -0,0 +1,292 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
jwtAlgHS256 = "HS256"
|
||||
jwtTyp = "JWT"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidJWTConfig = errors.New("invalid jwt config")
|
||||
ErrInvalidJWTToken = errors.New("invalid jwt token")
|
||||
ErrInvalidJWTClaims = errors.New("invalid jwt claims")
|
||||
ErrExpiredJWTToken = errors.New("jwt token expired")
|
||||
ErrNotYetValidJWTToken = errors.New("jwt token is not yet valid")
|
||||
)
|
||||
|
||||
type JWTConfig struct {
|
||||
SigningKeyBase64 string
|
||||
Issuer string
|
||||
Audience string
|
||||
TokenLifespan time.Duration
|
||||
ClockSkew time.Duration
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Subject string `json:"sub"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
Issuer string `json:"iss"`
|
||||
Audience string `json:"aud"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
NotBefore int64 `json:"nbf"`
|
||||
ID string `json:"jti"`
|
||||
}
|
||||
|
||||
type JWTService struct {
|
||||
signingKey []byte
|
||||
issuer string
|
||||
audience string
|
||||
tokenLifespan time.Duration
|
||||
clockSkew time.Duration
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type jwtHeader struct {
|
||||
Algorithm string `json:"alg"`
|
||||
Type string `json:"typ"`
|
||||
}
|
||||
|
||||
func NewJWTService(cfg JWTConfig) (*JWTService, error) {
|
||||
issuer := strings.TrimSpace(cfg.Issuer)
|
||||
audience := strings.TrimSpace(cfg.Audience)
|
||||
if issuer == "" {
|
||||
return nil, fmt.Errorf("%w: issuer is required", ErrInvalidJWTConfig)
|
||||
}
|
||||
if audience == "" {
|
||||
return nil, fmt.Errorf("%w: audience is required", ErrInvalidJWTConfig)
|
||||
}
|
||||
if cfg.TokenLifespan <= 0 {
|
||||
return nil, fmt.Errorf("%w: token lifespan must be greater than zero", ErrInvalidJWTConfig)
|
||||
}
|
||||
if cfg.ClockSkew < 0 {
|
||||
return nil, fmt.Errorf("%w: clock skew cannot be negative", ErrInvalidJWTConfig)
|
||||
}
|
||||
|
||||
signingKey, err := decodeBase64Key(strings.TrimSpace(cfg.SigningKeyBase64))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: signing key must be valid base64", ErrInvalidJWTConfig)
|
||||
}
|
||||
if len(signingKey) == 0 {
|
||||
return nil, fmt.Errorf("%w: signing key cannot be empty", ErrInvalidJWTConfig)
|
||||
}
|
||||
|
||||
return &JWTService{
|
||||
signingKey: signingKey,
|
||||
issuer: issuer,
|
||||
audience: audience,
|
||||
tokenLifespan: cfg.TokenLifespan,
|
||||
clockSkew: cfg.ClockSkew,
|
||||
now: time.Now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *JWTService) IssueToken(subject string, roles []string, groups []string) (string, Claims, error) {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return "", Claims{}, fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
claims := Claims{
|
||||
Subject: subject,
|
||||
Roles: compactTrimmedStrings(roles),
|
||||
Groups: compactTrimmedStrings(groups),
|
||||
Issuer: s.issuer,
|
||||
Audience: s.audience,
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(s.tokenLifespan).Unix(),
|
||||
NotBefore: now.Unix(),
|
||||
ID: newTokenID(),
|
||||
}
|
||||
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
|
||||
return "", Claims{}, err
|
||||
}
|
||||
|
||||
token, err := encodeSignedJWT(claims, s.signingKey)
|
||||
if err != nil {
|
||||
return "", Claims{}, err
|
||||
}
|
||||
return token, claims, nil
|
||||
}
|
||||
|
||||
func (s *JWTService) VerifyToken(token string) (Claims, error) {
|
||||
header, claims, signingInput, signature, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return Claims{}, err
|
||||
}
|
||||
if header.Algorithm != jwtAlgHS256 {
|
||||
return Claims{}, fmt.Errorf("%w: unsupported algorithm", ErrInvalidJWTToken)
|
||||
}
|
||||
if header.Type != "" && header.Type != jwtTyp {
|
||||
return Claims{}, fmt.Errorf("%w: invalid token type", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
expected := signPayload(signingInput, s.signingKey)
|
||||
if !hmac.Equal(signature, expected) {
|
||||
return Claims{}, fmt.Errorf("%w: signature mismatch", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
|
||||
return Claims{}, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func validateClaims(claims Claims, now time.Time, expectedIssuer string, expectedAudience string, clockSkew time.Duration) error {
|
||||
if strings.TrimSpace(claims.Subject) == "" {
|
||||
return fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if strings.TrimSpace(claims.ID) == "" {
|
||||
return fmt.Errorf("%w: jti is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.Issuer != expectedIssuer {
|
||||
return fmt.Errorf("%w: issuer mismatch", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.Audience != expectedAudience {
|
||||
return fmt.Errorf("%w: audience mismatch", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.IssuedAt <= 0 {
|
||||
return fmt.Errorf("%w: iat is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore <= 0 {
|
||||
return fmt.Errorf("%w: nbf is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.ExpiresAt <= 0 {
|
||||
return fmt.Errorf("%w: exp is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.ExpiresAt <= claims.IssuedAt {
|
||||
return fmt.Errorf("%w: exp must be greater than iat", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore > claims.ExpiresAt {
|
||||
return fmt.Errorf("%w: nbf cannot be greater than exp", ErrInvalidJWTClaims)
|
||||
}
|
||||
|
||||
unixNow := now.Unix()
|
||||
skewSeconds := int64(clockSkew / time.Second)
|
||||
if claims.IssuedAt > unixNow+skewSeconds {
|
||||
return fmt.Errorf("%w: iat is in the future", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore > unixNow+skewSeconds {
|
||||
return ErrNotYetValidJWTToken
|
||||
}
|
||||
if unixNow > claims.ExpiresAt+skewSeconds {
|
||||
return ErrExpiredJWTToken
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeSignedJWT(claims Claims, signingKey []byte) (string, error) {
|
||||
headerJSON, err := json.Marshal(jwtHeader{Algorithm: jwtAlgHS256, Type: jwtTyp})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal jwt header: %w", err)
|
||||
}
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal jwt claims: %w", err)
|
||||
}
|
||||
|
||||
headerPart := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
payloadPart := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signingInput := headerPart + "." + payloadPart
|
||||
signature := signPayload(signingInput, signingKey)
|
||||
signaturePart := base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return signingInput + "." + signaturePart, nil
|
||||
}
|
||||
|
||||
func parseJWT(token string) (jwtHeader, Claims, string, []byte, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
|
||||
}
|
||||
if parts[0] == "" || parts[1] == "" || parts[2] == "" {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid payload encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid signature encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
var header jwtHeader
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header json", ErrInvalidJWTToken)
|
||||
}
|
||||
var claims Claims
|
||||
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid claims json", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
return header, claims, parts[0] + "." + parts[1], signature, nil
|
||||
}
|
||||
|
||||
func signPayload(payload string, signingKey []byte) []byte {
|
||||
mac := hmac.New(sha256.New, signingKey)
|
||||
mac.Write([]byte(payload))
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func newTokenID() string {
|
||||
raw := make([]byte, 16)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return fmt.Sprintf("fallback-%d", time.Now().UTC().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(raw)
|
||||
}
|
||||
|
||||
func decodeBase64Key(value string) ([]byte, error) {
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
for _, encoding := range encodings {
|
||||
decoded, err := encoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("invalid base64 encoding")
|
||||
}
|
||||
|
||||
func compactTrimmedStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewJWTServiceRejectsBadConfig(t *testing.T) {
|
||||
_, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: "!!!",
|
||||
Issuer: "vctp",
|
||||
Audience: "vctp-api",
|
||||
TokenLifespan: time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid base64 signing key to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTConfig) {
|
||||
t.Fatalf("expected ErrInvalidJWTConfig, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssueAndVerifyTokenRoundTrip(t *testing.T) {
|
||||
now := time.Unix(1_700_000_000, 0).UTC()
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return now }
|
||||
|
||||
token, issuedClaims, err := svc.IssueToken("alice", []string{"admin", " viewer "}, []string{"cn=vctp-admins,dc=example,dc=com"})
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
if issuedClaims.Subject != "alice" {
|
||||
t.Fatalf("expected subject alice, got %q", issuedClaims.Subject)
|
||||
}
|
||||
if issuedClaims.Issuer != "vctp" {
|
||||
t.Fatalf("expected issuer vctp, got %q", issuedClaims.Issuer)
|
||||
}
|
||||
if issuedClaims.Audience != "vctp-api" {
|
||||
t.Fatalf("expected audience vctp-api, got %q", issuedClaims.Audience)
|
||||
}
|
||||
if issuedClaims.IssuedAt != now.Unix() {
|
||||
t.Fatalf("unexpected iat: %d", issuedClaims.IssuedAt)
|
||||
}
|
||||
if issuedClaims.NotBefore != now.Unix() {
|
||||
t.Fatalf("unexpected nbf: %d", issuedClaims.NotBefore)
|
||||
}
|
||||
if issuedClaims.ExpiresAt != now.Add(2*time.Hour).Unix() {
|
||||
t.Fatalf("unexpected exp: %d", issuedClaims.ExpiresAt)
|
||||
}
|
||||
if issuedClaims.ID == "" {
|
||||
t.Fatal("expected jti to be populated")
|
||||
}
|
||||
|
||||
verifiedClaims, err := svc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyToken returned error: %v", err)
|
||||
}
|
||||
if verifiedClaims.Subject != issuedClaims.Subject {
|
||||
t.Fatalf("subject mismatch: got %q want %q", verifiedClaims.Subject, issuedClaims.Subject)
|
||||
}
|
||||
if verifiedClaims.ID != issuedClaims.ID {
|
||||
t.Fatalf("jti mismatch: got %q want %q", verifiedClaims.ID, issuedClaims.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsInvalidSignature(t *testing.T) {
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
|
||||
|
||||
token, _, err := svc.IssueToken("alice", []string{"admin"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
|
||||
other := mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("a different secret key")))
|
||||
other.now = svc.now
|
||||
|
||||
_, err = other.VerifyToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected signature mismatch to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTToken) {
|
||||
t.Fatalf("expected ErrInvalidJWTToken, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsIssuerAndAudienceMismatch(t *testing.T) {
|
||||
issuerSvc := mustJWTService(t)
|
||||
issuerSvc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
|
||||
token, _, err := issuerSvc.IssueToken("alice", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
|
||||
wrongIssuer, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
|
||||
Issuer: "other-issuer",
|
||||
Audience: "vctp-api",
|
||||
TokenLifespan: 2 * time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create verifier with wrong issuer: %v", err)
|
||||
}
|
||||
wrongIssuer.now = issuerSvc.now
|
||||
|
||||
_, err = wrongIssuer.VerifyToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected issuer mismatch to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "issuer") {
|
||||
t.Fatalf("expected issuer mismatch error, got: %v", err)
|
||||
}
|
||||
|
||||
wrongAudience, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
|
||||
Issuer: "vctp",
|
||||
Audience: "other-audience",
|
||||
TokenLifespan: 2 * time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create verifier with wrong audience: %v", err)
|
||||
}
|
||||
wrongAudience.now = issuerSvc.now
|
||||
|
||||
_, err = wrongAudience.VerifyToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected audience mismatch to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "audience") {
|
||||
t.Fatalf("expected audience mismatch error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsExpiredNotBeforeAndFutureIssuedAt(t *testing.T) {
|
||||
base := time.Unix(1_700_000_000, 0).UTC()
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return base }
|
||||
|
||||
token, claims, err := svc.IssueToken("alice", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
|
||||
svc.now = func() time.Time { return base.Add(3 * time.Hour) }
|
||||
_, err = svc.VerifyToken(token)
|
||||
if !errors.Is(err, ErrExpiredJWTToken) {
|
||||
t.Fatalf("expected ErrExpiredJWTToken, got: %v", err)
|
||||
}
|
||||
|
||||
notBeforeClaims := claims
|
||||
notBeforeClaims.NotBefore = base.Add(10 * time.Minute).Unix()
|
||||
notBeforeClaims.IssuedAt = base.Unix()
|
||||
notBeforeClaims.ExpiresAt = base.Add(2 * time.Hour).Unix()
|
||||
notBeforeClaims.ID = "forced-jti-1"
|
||||
notBeforeToken, err := encodeSignedJWT(notBeforeClaims, svc.signingKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create token with future nbf: %v", err)
|
||||
}
|
||||
svc.now = func() time.Time { return base }
|
||||
_, err = svc.VerifyToken(notBeforeToken)
|
||||
if !errors.Is(err, ErrNotYetValidJWTToken) {
|
||||
t.Fatalf("expected ErrNotYetValidJWTToken, got: %v", err)
|
||||
}
|
||||
|
||||
futureIatClaims := claims
|
||||
futureIatClaims.IssuedAt = base.Add(20 * time.Minute).Unix()
|
||||
futureIatClaims.NotBefore = base.Unix()
|
||||
futureIatClaims.ExpiresAt = base.Add(3 * time.Hour).Unix()
|
||||
futureIatClaims.ID = "forced-jti-2"
|
||||
futureIatToken, err := encodeSignedJWT(futureIatClaims, svc.signingKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create token with future iat: %v", err)
|
||||
}
|
||||
_, err = svc.VerifyToken(futureIatToken)
|
||||
if err == nil {
|
||||
t.Fatal("expected future iat validation to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims for future iat, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsMissingJTI(t *testing.T) {
|
||||
base := time.Unix(1_700_000_000, 0).UTC()
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return base }
|
||||
|
||||
token, claims, err := svc.IssueToken("alice", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
|
||||
claims.ID = ""
|
||||
customToken, err := encodeSignedJWT(claims, svc.signingKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create token without jti: %v", err)
|
||||
}
|
||||
|
||||
_, err = svc.VerifyToken(customToken)
|
||||
if err == nil {
|
||||
t.Fatal("expected missing jti token to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "jti") {
|
||||
t.Fatalf("expected jti validation error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustJWTService(t *testing.T) *JWTService {
|
||||
t.Helper()
|
||||
return mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")))
|
||||
}
|
||||
|
||||
func mustJWTServiceWithKey(t *testing.T, keyBase64 string) *JWTService {
|
||||
t.Helper()
|
||||
svc, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: keyBase64,
|
||||
Issuer: "vctp",
|
||||
Audience: "vctp-api",
|
||||
TokenLifespan: 2 * time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create jwt service: %v", err)
|
||||
}
|
||||
return svc
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidLDAPConfig = errors.New("invalid ldap config")
|
||||
ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials")
|
||||
ErrLDAPOperationFailed = errors.New("ldap operation failed")
|
||||
)
|
||||
|
||||
type LDAPConfig struct {
|
||||
BindAddress string
|
||||
BaseDN string
|
||||
TrustCertFile string
|
||||
DisableValidation bool
|
||||
Insecure bool
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
type LDAPIdentity struct {
|
||||
Username string
|
||||
UserDN string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
type LDAPAuthenticator struct {
|
||||
bindAddress string
|
||||
baseDN string
|
||||
trustCertFile string
|
||||
disableValidation bool
|
||||
insecure bool
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
|
||||
bindAddress := strings.TrimSpace(cfg.BindAddress)
|
||||
baseDN := strings.TrimSpace(cfg.BaseDN)
|
||||
trustCertFile := strings.TrimSpace(cfg.TrustCertFile)
|
||||
|
||||
if bindAddress == "" {
|
||||
return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig)
|
||||
}
|
||||
if baseDN == "" {
|
||||
return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig)
|
||||
}
|
||||
if _, err := url.ParseRequestURI(bindAddress); err != nil {
|
||||
return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
|
||||
dialTimeout := cfg.DialTimeout
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = 10 * time.Second
|
||||
}
|
||||
|
||||
return &LDAPAuthenticator{
|
||||
bindAddress: bindAddress,
|
||||
baseDN: baseDN,
|
||||
trustCertFile: trustCertFile,
|
||||
disableValidation: cfg.DisableValidation,
|
||||
insecure: cfg.Insecure,
|
||||
dialTimeout: dialTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" || password == "" {
|
||||
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
||||
}
|
||||
if err := ctxErr(ctx); err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
|
||||
conn, err := a.connect()
|
||||
if err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.Bind(username, password); err != nil {
|
||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
||||
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
||||
}
|
||||
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if err := ctxErr(ctx); err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
|
||||
identity := LDAPIdentity{
|
||||
Username: username,
|
||||
UserDN: username,
|
||||
}
|
||||
|
||||
entry, err := a.lookupUserEntry(conn, username)
|
||||
if err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
if entry != nil {
|
||||
if strings.TrimSpace(entry.DN) != "" {
|
||||
identity.UserDN = entry.DN
|
||||
}
|
||||
if v := firstNonEmpty(
|
||||
entry.GetAttributeValue("uid"),
|
||||
entry.GetAttributeValue("sAMAccountName"),
|
||||
entry.GetAttributeValue("userPrincipalName"),
|
||||
entry.GetAttributeValue("cn"),
|
||||
); v != "" {
|
||||
identity.Username = v
|
||||
}
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
if entry != nil {
|
||||
for _, groupDN := range entry.GetAttributeValues("memberOf") {
|
||||
groupDN = strings.TrimSpace(groupDN)
|
||||
if groupDN == "" {
|
||||
continue
|
||||
}
|
||||
groupSet[groupDN] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
groupEntries, err := conn.Search(ldap.NewSearchRequest(
|
||||
a.baseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
fmt.Sprintf("(|(member=%s)(uniqueMember=%s)(memberUid=%s))",
|
||||
ldap.EscapeFilter(identity.UserDN),
|
||||
ldap.EscapeFilter(identity.UserDN),
|
||||
ldap.EscapeFilter(username),
|
||||
),
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
))
|
||||
if err == nil {
|
||||
for _, e := range groupEntries.Entries {
|
||||
if dn := strings.TrimSpace(e.DN); dn != "" {
|
||||
groupSet[dn] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
identity.Groups = mapKeysSorted(groupSet)
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string {
|
||||
if len(groupDNs) == 0 || len(groupRoleMappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalizedMappings := make(map[string]string, len(groupRoleMappings))
|
||||
for groupDN, role := range groupRoleMappings {
|
||||
groupDN = normalizeDN(groupDN)
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if groupDN == "" || role == "" {
|
||||
continue
|
||||
}
|
||||
normalizedMappings[groupDN] = role
|
||||
}
|
||||
|
||||
roleSet := make(map[string]struct{})
|
||||
for _, groupDN := range groupDNs {
|
||||
if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok {
|
||||
roleSet[role] = struct{}{}
|
||||
}
|
||||
}
|
||||
return mapKeysSorted(roleSet)
|
||||
}
|
||||
|
||||
func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool {
|
||||
requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs)
|
||||
if len(requiredGroupDNs) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(groupDNs) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
required := make(map[string]struct{}, len(requiredGroupDNs))
|
||||
for _, groupDN := range requiredGroupDNs {
|
||||
required[normalizeDN(groupDN)] = struct{}{}
|
||||
}
|
||||
for _, groupDN := range groupDNs {
|
||||
if _, ok := required[normalizeDN(groupDN)]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
|
||||
tlsConfig, err := a.buildTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(a.bindAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
|
||||
options := []ldap.DialOpt{
|
||||
ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}),
|
||||
ldap.DialWithTLSConfig(tlsConfig),
|
||||
}
|
||||
conn, err := ldap.DialURL(a.bindAddress, options...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
conn.SetTimeout(a.dialTimeout)
|
||||
|
||||
// For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure.
|
||||
if parsedURL.Scheme == "ldap" && !a.insecure {
|
||||
if err := conn.StartTLS(tlsConfig); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags
|
||||
}
|
||||
|
||||
if a.trustCertFile == "" {
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
caPEM, err := os.ReadFile(a.trustCertFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
roots := x509.NewCertPool()
|
||||
if !roots.AppendCertsFromPEM(caPEM) {
|
||||
return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig)
|
||||
}
|
||||
tlsConfig.RootCAs = roots
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string) (*ldap.Entry, error) {
|
||||
if looksLikeDN(username) {
|
||||
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
||||
username,
|
||||
ldap.ScopeBaseObject,
|
||||
ldap.NeverDerefAliases,
|
||||
1,
|
||||
0,
|
||||
false,
|
||||
"(objectClass=*)",
|
||||
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
||||
nil,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: unable to load user entry: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if len(searchRes.Entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return searchRes.Entries[0], nil
|
||||
}
|
||||
|
||||
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
||||
a.baseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
2,
|
||||
0,
|
||||
false,
|
||||
fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))",
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
),
|
||||
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
||||
nil,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if len(searchRes.Entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return searchRes.Entries[0], nil
|
||||
}
|
||||
|
||||
func normalizeDN(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func mapKeysSorted[K ~string, V any](m map[K]V) []K {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]K, 0, len(m))
|
||||
for key := range m {
|
||||
out = append(out, key)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
return out[i] < out[j]
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func looksLikeDN(value string) bool {
|
||||
value = strings.TrimSpace(value)
|
||||
return strings.Contains(value, "=") && strings.Contains(value, ",")
|
||||
}
|
||||
|
||||
func ctxErr(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveRoles(t *testing.T) {
|
||||
roles := ResolveRoles(
|
||||
[]string{
|
||||
"cn=vctp-admins,ou=groups,dc=example,dc=com",
|
||||
" CN=VCTP-VIEWERS,OU=GROUPS,DC=EXAMPLE,DC=COM ",
|
||||
},
|
||||
map[string]string{
|
||||
"cn=vctp-admins,ou=groups,dc=example,dc=com": "admin",
|
||||
"cn=vctp-viewers,ou=groups,dc=example,dc=com": "viewer",
|
||||
},
|
||||
)
|
||||
|
||||
if len(roles) != 2 {
|
||||
t.Fatalf("expected 2 roles, got %d (%#v)", len(roles), roles)
|
||||
}
|
||||
if roles[0] != "admin" || roles[1] != "viewer" {
|
||||
t.Fatalf("unexpected resolved roles: %#v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasAnyGroup(t *testing.T) {
|
||||
groups := []string{
|
||||
"cn=vctp-admins,ou=groups,dc=example,dc=com",
|
||||
}
|
||||
|
||||
if !HasAnyGroup(groups, []string{" cn=vctp-admins,ou=groups,dc=example,dc=com "}) {
|
||||
t.Fatal("expected group intersection to match")
|
||||
}
|
||||
if HasAnyGroup(groups, []string{"cn=vctp-operators,ou=groups,dc=example,dc=com"}) {
|
||||
t.Fatal("expected no intersection")
|
||||
}
|
||||
if !HasAnyGroup(groups, nil) {
|
||||
t.Fatal("expected empty required groups to allow")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -18,6 +19,20 @@ var (
|
||||
postgresKVPasswordPattern = regexp.MustCompile(`(?i)(\bpassword\s*=\s*)(?:'[^']*'|"[^"]*"|[^\s]+)`)
|
||||
)
|
||||
|
||||
const (
|
||||
authModeDisabled = "disabled"
|
||||
authModeOptional = "optional"
|
||||
authModeRequired = "required"
|
||||
|
||||
authRoleAdmin = "admin"
|
||||
authRoleViewer = "viewer"
|
||||
|
||||
defaultAuthTokenLifespanMinutes = 120
|
||||
defaultAuthJWTIssuer = "vctp"
|
||||
defaultAuthJWTAudience = "vctp-api"
|
||||
defaultAuthClockSkewSeconds = 60
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
SettingsPath string
|
||||
Logger *slog.Logger
|
||||
@@ -50,6 +65,21 @@ type SettingsYML struct {
|
||||
VcenterPassword string `yaml:"vcenter_password"`
|
||||
VcenterInsecure bool `yaml:"vcenter_insecure"`
|
||||
EnableLegacyAPI bool `yaml:"enable_legacy_api"`
|
||||
AuthEnabled bool `yaml:"auth_enabled"`
|
||||
AuthMode string `yaml:"auth_mode"`
|
||||
AuthJWTSigningKey string `yaml:"auth_jwt_signing_key"`
|
||||
AuthTokenLifespanMinutes int `yaml:"auth_token_lifespan_minutes"`
|
||||
AuthJWTIssuer string `yaml:"auth_jwt_issuer"`
|
||||
AuthJWTAudience string `yaml:"auth_jwt_audience"`
|
||||
AuthClockSkewSeconds int `yaml:"auth_clock_skew_seconds"`
|
||||
AuthGroupRoleMappings map[string]string `yaml:"auth_group_role_mappings"`
|
||||
LDAPGroups []string `yaml:"ldap_groups"`
|
||||
LDAPBindAddress string `yaml:"ldap_bind_address"`
|
||||
LDAPBaseDN string `yaml:"ldap_base_dn"`
|
||||
LDAPTrustCertFile string `yaml:"ldap_trust_cert_file"`
|
||||
LDAPDisableValidation bool `yaml:"ldap_disable_validation"`
|
||||
LDAPInsecure bool `yaml:"ldap_insecure"`
|
||||
EnablePprof bool `yaml:"enable_pprof"`
|
||||
VcenterEventPollingSeconds int `yaml:"vcenter_event_polling_seconds"`
|
||||
VcenterInventoryPollingSeconds int `yaml:"vcenter_inventory_polling_seconds"`
|
||||
VcenterInventorySnapshotSeconds int `yaml:"vcenter_inventory_snapshot_seconds"`
|
||||
@@ -112,6 +142,9 @@ func (s *Settings) ReadYMLSettings() error {
|
||||
if err := d.Decode(&settings); err != nil {
|
||||
return fmt.Errorf("unable to decode settings file : '%s'", err)
|
||||
}
|
||||
if err := applyDefaultsAndValidateSettings(&settings); err != nil {
|
||||
return fmt.Errorf("invalid settings file: %w", err)
|
||||
}
|
||||
|
||||
// Avoid logging sensitive fields (e.g., credentials).
|
||||
redacted := settings
|
||||
@@ -119,6 +152,9 @@ func (s *Settings) ReadYMLSettings() error {
|
||||
if redacted.Settings.EncryptionKey != "" {
|
||||
redacted.Settings.EncryptionKey = "REDACTED"
|
||||
}
|
||||
if redacted.Settings.AuthJWTSigningKey != "" {
|
||||
redacted.Settings.AuthJWTSigningKey = "REDACTED"
|
||||
}
|
||||
if redacted.Settings.DatabaseURL != "" {
|
||||
redacted.Settings.DatabaseURL = redactDatabaseURL(redacted.Settings.DatabaseURL)
|
||||
}
|
||||
@@ -189,3 +225,140 @@ func secureSettingsFileMode(mode os.FileMode) os.FileMode {
|
||||
secured |= 0o600
|
||||
return secured
|
||||
}
|
||||
|
||||
func applyDefaultsAndValidateSettings(cfg *SettingsYML) error {
|
||||
if cfg == nil {
|
||||
return errors.New("settings config is nil")
|
||||
}
|
||||
s := &cfg.Settings
|
||||
|
||||
s.AuthMode = strings.ToLower(strings.TrimSpace(s.AuthMode))
|
||||
if s.AuthMode == "" {
|
||||
s.AuthMode = authModeDisabled
|
||||
}
|
||||
if s.AuthTokenLifespanMinutes == 0 {
|
||||
s.AuthTokenLifespanMinutes = defaultAuthTokenLifespanMinutes
|
||||
}
|
||||
s.AuthJWTIssuer = strings.TrimSpace(s.AuthJWTIssuer)
|
||||
if s.AuthJWTIssuer == "" {
|
||||
s.AuthJWTIssuer = defaultAuthJWTIssuer
|
||||
}
|
||||
s.AuthJWTAudience = strings.TrimSpace(s.AuthJWTAudience)
|
||||
if s.AuthJWTAudience == "" {
|
||||
s.AuthJWTAudience = defaultAuthJWTAudience
|
||||
}
|
||||
if s.AuthClockSkewSeconds == 0 {
|
||||
s.AuthClockSkewSeconds = defaultAuthClockSkewSeconds
|
||||
}
|
||||
s.AuthJWTSigningKey = strings.TrimSpace(s.AuthJWTSigningKey)
|
||||
s.LDAPBindAddress = strings.TrimSpace(s.LDAPBindAddress)
|
||||
s.LDAPBaseDN = strings.TrimSpace(s.LDAPBaseDN)
|
||||
s.LDAPTrustCertFile = strings.TrimSpace(s.LDAPTrustCertFile)
|
||||
s.LDAPGroups = compactTrimmedStrings(s.LDAPGroups)
|
||||
|
||||
if !isValidAuthMode(s.AuthMode) {
|
||||
return fmt.Errorf("settings.auth_mode must be one of %q, %q, %q", authModeDisabled, authModeOptional, authModeRequired)
|
||||
}
|
||||
if s.AuthTokenLifespanMinutes <= 0 {
|
||||
return errors.New("settings.auth_token_lifespan_minutes must be greater than 0")
|
||||
}
|
||||
if s.AuthClockSkewSeconds < 0 {
|
||||
return errors.New("settings.auth_clock_skew_seconds must be >= 0")
|
||||
}
|
||||
|
||||
if len(s.AuthGroupRoleMappings) > 0 {
|
||||
normalized := make(map[string]string, len(s.AuthGroupRoleMappings))
|
||||
for groupDN, role := range s.AuthGroupRoleMappings {
|
||||
groupDN = strings.TrimSpace(groupDN)
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if groupDN == "" {
|
||||
return errors.New("settings.auth_group_role_mappings contains an empty group DN key")
|
||||
}
|
||||
if !isValidAuthRole(role) {
|
||||
return fmt.Errorf("settings.auth_group_role_mappings[%q] has unsupported role %q", groupDN, role)
|
||||
}
|
||||
normalized[groupDN] = role
|
||||
}
|
||||
s.AuthGroupRoleMappings = normalized
|
||||
}
|
||||
|
||||
if !s.AuthEnabled {
|
||||
return nil
|
||||
}
|
||||
if s.AuthMode == authModeDisabled {
|
||||
return errors.New("settings.auth_mode must be optional or required when settings.auth_enabled=true")
|
||||
}
|
||||
if s.AuthJWTSigningKey == "" {
|
||||
return errors.New("settings.auth_jwt_signing_key is required when settings.auth_enabled=true")
|
||||
}
|
||||
decodedKey, err := decodeBase64(s.AuthJWTSigningKey)
|
||||
if err != nil {
|
||||
return errors.New("settings.auth_jwt_signing_key must be valid base64")
|
||||
}
|
||||
if len(decodedKey) == 0 {
|
||||
return errors.New("settings.auth_jwt_signing_key cannot decode to an empty value")
|
||||
}
|
||||
if s.LDAPBindAddress == "" {
|
||||
return errors.New("settings.ldap_bind_address is required when settings.auth_enabled=true")
|
||||
}
|
||||
if s.LDAPBaseDN == "" {
|
||||
return errors.New("settings.ldap_base_dn is required when settings.auth_enabled=true")
|
||||
}
|
||||
if len(s.AuthGroupRoleMappings) == 0 {
|
||||
return errors.New("settings.auth_group_role_mappings must define at least one mapping when settings.auth_enabled=true")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isValidAuthMode(mode string) bool {
|
||||
switch mode {
|
||||
case authModeDisabled, authModeOptional, authModeRequired:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isValidAuthRole(role string) bool {
|
||||
switch role {
|
||||
case authRoleAdmin, authRoleViewer:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func decodeBase64(value string) ([]byte, error) {
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
for _, encoding := range encodings {
|
||||
decoded, err := encoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("invalid base64 encoding")
|
||||
}
|
||||
|
||||
func compactTrimmedStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package settings
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedactDatabaseURL_PostgresURI(t *testing.T) {
|
||||
input := "postgres://vctp_user:Secr3tP%40ss@db-host:5432/vctp?sslmode=disable"
|
||||
@@ -27,3 +34,29 @@ func TestRedactDatabaseURL_UnchangedWhenNoPassword(t *testing.T) {
|
||||
t.Fatalf("expected input to remain unchanged\nwant: %s\ngot: %s", input, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadYMLSettingsRedactsAuthJWTSigningKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
settingsPath := filepath.Join(tmpDir, "vctp.yml")
|
||||
content := `settings:
|
||||
auth_jwt_signing_key: "c2VjcmV0"
|
||||
`
|
||||
if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("failed to write settings file: %v", err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&output, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
s := New(logger, settingsPath)
|
||||
if err := s.ReadYMLSettings(); err != nil {
|
||||
t.Fatalf("expected settings to load, got error: %v", err)
|
||||
}
|
||||
|
||||
logged := output.String()
|
||||
if strings.Contains(logged, "c2VjcmV0") {
|
||||
t.Fatalf("expected auth_jwt_signing_key to be redacted in logs, got log output: %s", logged)
|
||||
}
|
||||
if !strings.Contains(logged, "REDACTED") {
|
||||
t.Fatalf("expected redacted marker in logs, got log output: %s", logged)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,134 @@ func TestReadYMLSettingsRejectsUnknownField(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadYMLSettingsAppliesAuthDefaults(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
settingsPath := filepath.Join(tmpDir, "vctp.yml")
|
||||
content := `settings:
|
||||
log_level: "info"
|
||||
`
|
||||
if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("failed to write settings file: %v", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
s := New(logger, settingsPath)
|
||||
if err := s.ReadYMLSettings(); err != nil {
|
||||
t.Fatalf("expected settings to load, got error: %v", err)
|
||||
}
|
||||
|
||||
got := s.Values.Settings
|
||||
if got.AuthMode != authModeDisabled {
|
||||
t.Fatalf("expected default auth_mode=%q, got %q", authModeDisabled, got.AuthMode)
|
||||
}
|
||||
if got.AuthTokenLifespanMinutes != defaultAuthTokenLifespanMinutes {
|
||||
t.Fatalf("expected default auth_token_lifespan_minutes=%d, got %d", defaultAuthTokenLifespanMinutes, got.AuthTokenLifespanMinutes)
|
||||
}
|
||||
if got.AuthJWTIssuer != defaultAuthJWTIssuer {
|
||||
t.Fatalf("expected default auth_jwt_issuer=%q, got %q", defaultAuthJWTIssuer, got.AuthJWTIssuer)
|
||||
}
|
||||
if got.AuthJWTAudience != defaultAuthJWTAudience {
|
||||
t.Fatalf("expected default auth_jwt_audience=%q, got %q", defaultAuthJWTAudience, got.AuthJWTAudience)
|
||||
}
|
||||
if got.AuthClockSkewSeconds != defaultAuthClockSkewSeconds {
|
||||
t.Fatalf("expected default auth_clock_skew_seconds=%d, got %d", defaultAuthClockSkewSeconds, got.AuthClockSkewSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadYMLSettingsRejectsInvalidAuthMode(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
settingsPath := filepath.Join(tmpDir, "vctp.yml")
|
||||
content := `settings:
|
||||
auth_mode: "sometimes"
|
||||
`
|
||||
if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("failed to write settings file: %v", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
s := New(logger, settingsPath)
|
||||
err := s.ReadYMLSettings()
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid auth_mode to fail")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "auth_mode") {
|
||||
t.Fatalf("expected error to mention auth_mode, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadYMLSettingsRejectsAuthEnabledWithoutSigningKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
settingsPath := filepath.Join(tmpDir, "vctp.yml")
|
||||
content := `settings:
|
||||
auth_enabled: true
|
||||
auth_mode: "required"
|
||||
ldap_bind_address: "ldaps://ldap.example.com:636"
|
||||
ldap_base_dn: "dc=example,dc=com"
|
||||
auth_group_role_mappings:
|
||||
"cn=vctp-admin,ou=groups,dc=example,dc=com": "admin"
|
||||
`
|
||||
if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("failed to write settings file: %v", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
s := New(logger, settingsPath)
|
||||
err := s.ReadYMLSettings()
|
||||
if err == nil {
|
||||
t.Fatal("expected auth_enabled=true without signing key to fail")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "auth_jwt_signing_key") {
|
||||
t.Fatalf("expected error to mention auth_jwt_signing_key, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadYMLSettingsAcceptsValidAuthConfigAndNormalizesMappings(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
settingsPath := filepath.Join(tmpDir, "vctp.yml")
|
||||
content := `settings:
|
||||
auth_enabled: true
|
||||
auth_mode: "REQUIRED"
|
||||
auth_jwt_signing_key: "c2VjcmV0"
|
||||
auth_token_lifespan_minutes: 90
|
||||
auth_jwt_issuer: " custom-issuer "
|
||||
auth_jwt_audience: " custom-audience "
|
||||
auth_clock_skew_seconds: 15
|
||||
ldap_bind_address: "ldaps://ldap.example.com:636"
|
||||
ldap_base_dn: "dc=example,dc=com"
|
||||
ldap_groups:
|
||||
- " cn=vctp-viewers,ou=groups,dc=example,dc=com "
|
||||
auth_group_role_mappings:
|
||||
" cn=vctp-admins,ou=groups,dc=example,dc=com ": " ADMIN "
|
||||
"cn=vctp-viewers,ou=groups,dc=example,dc=com": "viewer"
|
||||
`
|
||||
if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil {
|
||||
t.Fatalf("failed to write settings file: %v", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
s := New(logger, settingsPath)
|
||||
if err := s.ReadYMLSettings(); err != nil {
|
||||
t.Fatalf("expected valid auth config, got error: %v", err)
|
||||
}
|
||||
|
||||
got := s.Values.Settings
|
||||
if got.AuthMode != authModeRequired {
|
||||
t.Fatalf("expected normalized auth_mode=%q, got %q", authModeRequired, got.AuthMode)
|
||||
}
|
||||
if got.AuthJWTIssuer != "custom-issuer" {
|
||||
t.Fatalf("expected trimmed auth_jwt_issuer, got %q", got.AuthJWTIssuer)
|
||||
}
|
||||
if got.AuthJWTAudience != "custom-audience" {
|
||||
t.Fatalf("expected trimmed auth_jwt_audience, got %q", got.AuthJWTAudience)
|
||||
}
|
||||
if len(got.LDAPGroups) != 1 || got.LDAPGroups[0] != "cn=vctp-viewers,ou=groups,dc=example,dc=com" {
|
||||
t.Fatalf("expected ldap_groups to be compacted+trimmed, got %#v", got.LDAPGroups)
|
||||
}
|
||||
if got.AuthGroupRoleMappings["cn=vctp-admins,ou=groups,dc=example,dc=com"] != authRoleAdmin {
|
||||
t.Fatalf("expected admin mapping to normalize role to %q, got %#v", authRoleAdmin, got.AuthGroupRoleMappings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureSettingsFileMode(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user