package auth import ( "encoding/base64" "errors" "strings" "testing" "time" ) func TestNewJWTServiceRejectsBadConfig(t *testing.T) { _, err := NewJWTService(JWTConfig{ SigningKeyBase64: "!!!", Issuer: "vctp", Audience: "vctp-api", TokenLifespan: time.Hour, ClockSkew: time.Minute, }) if err == nil { t.Fatal("expected invalid base64 signing key to fail") } if !errors.Is(err, ErrInvalidJWTConfig) { t.Fatalf("expected ErrInvalidJWTConfig, got: %v", err) } } func TestIssueAndVerifyTokenRoundTrip(t *testing.T) { now := time.Unix(1_700_000_000, 0).UTC() svc := mustJWTService(t) svc.now = func() time.Time { return now } token, issuedClaims, err := svc.IssueToken("alice", []string{"admin", " viewer "}, []string{"cn=vctp-admins,dc=example,dc=com"}) if err != nil { t.Fatalf("IssueToken returned error: %v", err) } if token == "" { t.Fatal("expected non-empty token") } if issuedClaims.Subject != "alice" { t.Fatalf("expected subject alice, got %q", issuedClaims.Subject) } if issuedClaims.Issuer != "vctp" { t.Fatalf("expected issuer vctp, got %q", issuedClaims.Issuer) } if issuedClaims.Audience != "vctp-api" { t.Fatalf("expected audience vctp-api, got %q", issuedClaims.Audience) } if issuedClaims.IssuedAt != now.Unix() { t.Fatalf("unexpected iat: %d", issuedClaims.IssuedAt) } if issuedClaims.NotBefore != now.Unix() { t.Fatalf("unexpected nbf: %d", issuedClaims.NotBefore) } if issuedClaims.ExpiresAt != now.Add(2*time.Hour).Unix() { t.Fatalf("unexpected exp: %d", issuedClaims.ExpiresAt) } if issuedClaims.ID == "" { t.Fatal("expected jti to be populated") } if len(issuedClaims.Groups) != 0 { t.Fatalf("expected groups to be omitted from issued claims, got %#v", issuedClaims.Groups) } parts := strings.Split(token, ".") if len(parts) != 3 { t.Fatalf("expected jwt to have 3 parts, got %d", len(parts)) } payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { t.Fatalf("failed to decode jwt payload: %v", err) } if strings.Contains(string(payloadJSON), `"groups"`) { t.Fatalf("expected jwt payload to omit groups claim, got payload: %s", string(payloadJSON)) } verifiedClaims, err := svc.VerifyToken(token) if err != nil { t.Fatalf("VerifyToken returned error: %v", err) } if verifiedClaims.Subject != issuedClaims.Subject { t.Fatalf("subject mismatch: got %q want %q", verifiedClaims.Subject, issuedClaims.Subject) } if verifiedClaims.ID != issuedClaims.ID { t.Fatalf("jti mismatch: got %q want %q", verifiedClaims.ID, issuedClaims.ID) } } func TestVerifyTokenRejectsInvalidSignature(t *testing.T) { svc := mustJWTService(t) svc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() } token, _, err := svc.IssueToken("alice", []string{"admin"}, nil) if err != nil { t.Fatalf("IssueToken returned error: %v", err) } other := mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("a different secret key"))) other.now = svc.now _, err = other.VerifyToken(token) if err == nil { t.Fatal("expected signature mismatch to fail") } if !errors.Is(err, ErrInvalidJWTToken) { t.Fatalf("expected ErrInvalidJWTToken, got: %v", err) } } func TestVerifyTokenRejectsIssuerAndAudienceMismatch(t *testing.T) { issuerSvc := mustJWTService(t) issuerSvc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() } token, _, err := issuerSvc.IssueToken("alice", nil, nil) if err != nil { t.Fatalf("IssueToken returned error: %v", err) } wrongIssuer, err := NewJWTService(JWTConfig{ SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")), Issuer: "other-issuer", Audience: "vctp-api", TokenLifespan: 2 * time.Hour, ClockSkew: time.Minute, }) if err != nil { t.Fatalf("failed to create verifier with wrong issuer: %v", err) } wrongIssuer.now = issuerSvc.now _, err = wrongIssuer.VerifyToken(token) if err == nil { t.Fatal("expected issuer mismatch to fail") } if !errors.Is(err, ErrInvalidJWTClaims) { t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err) } if !strings.Contains(strings.ToLower(err.Error()), "issuer") { t.Fatalf("expected issuer mismatch error, got: %v", err) } wrongAudience, err := NewJWTService(JWTConfig{ SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")), Issuer: "vctp", Audience: "other-audience", TokenLifespan: 2 * time.Hour, ClockSkew: time.Minute, }) if err != nil { t.Fatalf("failed to create verifier with wrong audience: %v", err) } wrongAudience.now = issuerSvc.now _, err = wrongAudience.VerifyToken(token) if err == nil { t.Fatal("expected audience mismatch to fail") } if !errors.Is(err, ErrInvalidJWTClaims) { t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err) } if !strings.Contains(strings.ToLower(err.Error()), "audience") { t.Fatalf("expected audience mismatch error, got: %v", err) } } func TestVerifyTokenRejectsExpiredNotBeforeAndFutureIssuedAt(t *testing.T) { base := time.Unix(1_700_000_000, 0).UTC() svc := mustJWTService(t) svc.now = func() time.Time { return base } token, claims, err := svc.IssueToken("alice", nil, nil) if err != nil { t.Fatalf("IssueToken returned error: %v", err) } svc.now = func() time.Time { return base.Add(3 * time.Hour) } _, err = svc.VerifyToken(token) if !errors.Is(err, ErrExpiredJWTToken) { t.Fatalf("expected ErrExpiredJWTToken, got: %v", err) } notBeforeClaims := claims notBeforeClaims.NotBefore = base.Add(10 * time.Minute).Unix() notBeforeClaims.IssuedAt = base.Unix() notBeforeClaims.ExpiresAt = base.Add(2 * time.Hour).Unix() notBeforeClaims.ID = "forced-jti-1" notBeforeToken, err := encodeSignedJWT(notBeforeClaims, svc.signingKey) if err != nil { t.Fatalf("failed to create token with future nbf: %v", err) } svc.now = func() time.Time { return base } _, err = svc.VerifyToken(notBeforeToken) if !errors.Is(err, ErrNotYetValidJWTToken) { t.Fatalf("expected ErrNotYetValidJWTToken, got: %v", err) } futureIatClaims := claims futureIatClaims.IssuedAt = base.Add(20 * time.Minute).Unix() futureIatClaims.NotBefore = base.Unix() futureIatClaims.ExpiresAt = base.Add(3 * time.Hour).Unix() futureIatClaims.ID = "forced-jti-2" futureIatToken, err := encodeSignedJWT(futureIatClaims, svc.signingKey) if err != nil { t.Fatalf("failed to create token with future iat: %v", err) } _, err = svc.VerifyToken(futureIatToken) if err == nil { t.Fatal("expected future iat validation to fail") } if !errors.Is(err, ErrInvalidJWTClaims) { t.Fatalf("expected ErrInvalidJWTClaims for future iat, got: %v", err) } } func TestVerifyTokenRejectsMissingJTI(t *testing.T) { base := time.Unix(1_700_000_000, 0).UTC() svc := mustJWTService(t) svc.now = func() time.Time { return base } token, claims, err := svc.IssueToken("alice", nil, nil) if err != nil { t.Fatalf("IssueToken returned error: %v", err) } if token == "" { t.Fatal("expected non-empty token") } claims.ID = "" customToken, err := encodeSignedJWT(claims, svc.signingKey) if err != nil { t.Fatalf("failed to create token without jti: %v", err) } _, err = svc.VerifyToken(customToken) if err == nil { t.Fatal("expected missing jti token to fail") } if !errors.Is(err, ErrInvalidJWTClaims) { t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err) } if !strings.Contains(strings.ToLower(err.Error()), "jti") { t.Fatalf("expected jti validation error, got: %v", err) } } func mustJWTService(t *testing.T) *JWTService { t.Helper() return mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key"))) } func mustJWTServiceWithKey(t *testing.T, keyBase64 string) *JWTService { t.Helper() svc, err := NewJWTService(JWTConfig{ SigningKeyBase64: keyBase64, Issuer: "vctp", Audience: "vctp-api", TokenLifespan: 2 * time.Hour, ClockSkew: time.Minute, }) if err != nil { t.Fatalf("failed to create jwt service: %v", err) } return svc }