Files
vctp2/internal/auth/jwt_test.go
T
nathan ae3e2be89a
continuous-integration/drone/push Build is passing
add auth support
2026-04-17 13:19:08 +10:00

248 lines
7.4 KiB
Go

package auth
import (
"encoding/base64"
"errors"
"strings"
"testing"
"time"
)
func TestNewJWTServiceRejectsBadConfig(t *testing.T) {
_, err := NewJWTService(JWTConfig{
SigningKeyBase64: "!!!",
Issuer: "vctp",
Audience: "vctp-api",
TokenLifespan: time.Hour,
ClockSkew: time.Minute,
})
if err == nil {
t.Fatal("expected invalid base64 signing key to fail")
}
if !errors.Is(err, ErrInvalidJWTConfig) {
t.Fatalf("expected ErrInvalidJWTConfig, got: %v", err)
}
}
func TestIssueAndVerifyTokenRoundTrip(t *testing.T) {
now := time.Unix(1_700_000_000, 0).UTC()
svc := mustJWTService(t)
svc.now = func() time.Time { return now }
token, issuedClaims, err := svc.IssueToken("alice", []string{"admin", " viewer "}, []string{"cn=vctp-admins,dc=example,dc=com"})
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
if token == "" {
t.Fatal("expected non-empty token")
}
if issuedClaims.Subject != "alice" {
t.Fatalf("expected subject alice, got %q", issuedClaims.Subject)
}
if issuedClaims.Issuer != "vctp" {
t.Fatalf("expected issuer vctp, got %q", issuedClaims.Issuer)
}
if issuedClaims.Audience != "vctp-api" {
t.Fatalf("expected audience vctp-api, got %q", issuedClaims.Audience)
}
if issuedClaims.IssuedAt != now.Unix() {
t.Fatalf("unexpected iat: %d", issuedClaims.IssuedAt)
}
if issuedClaims.NotBefore != now.Unix() {
t.Fatalf("unexpected nbf: %d", issuedClaims.NotBefore)
}
if issuedClaims.ExpiresAt != now.Add(2*time.Hour).Unix() {
t.Fatalf("unexpected exp: %d", issuedClaims.ExpiresAt)
}
if issuedClaims.ID == "" {
t.Fatal("expected jti to be populated")
}
verifiedClaims, err := svc.VerifyToken(token)
if err != nil {
t.Fatalf("VerifyToken returned error: %v", err)
}
if verifiedClaims.Subject != issuedClaims.Subject {
t.Fatalf("subject mismatch: got %q want %q", verifiedClaims.Subject, issuedClaims.Subject)
}
if verifiedClaims.ID != issuedClaims.ID {
t.Fatalf("jti mismatch: got %q want %q", verifiedClaims.ID, issuedClaims.ID)
}
}
func TestVerifyTokenRejectsInvalidSignature(t *testing.T) {
svc := mustJWTService(t)
svc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
token, _, err := svc.IssueToken("alice", []string{"admin"}, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
other := mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("a different secret key")))
other.now = svc.now
_, err = other.VerifyToken(token)
if err == nil {
t.Fatal("expected signature mismatch to fail")
}
if !errors.Is(err, ErrInvalidJWTToken) {
t.Fatalf("expected ErrInvalidJWTToken, got: %v", err)
}
}
func TestVerifyTokenRejectsIssuerAndAudienceMismatch(t *testing.T) {
issuerSvc := mustJWTService(t)
issuerSvc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
token, _, err := issuerSvc.IssueToken("alice", nil, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
wrongIssuer, err := NewJWTService(JWTConfig{
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
Issuer: "other-issuer",
Audience: "vctp-api",
TokenLifespan: 2 * time.Hour,
ClockSkew: time.Minute,
})
if err != nil {
t.Fatalf("failed to create verifier with wrong issuer: %v", err)
}
wrongIssuer.now = issuerSvc.now
_, err = wrongIssuer.VerifyToken(token)
if err == nil {
t.Fatal("expected issuer mismatch to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
}
if !strings.Contains(strings.ToLower(err.Error()), "issuer") {
t.Fatalf("expected issuer mismatch error, got: %v", err)
}
wrongAudience, err := NewJWTService(JWTConfig{
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
Issuer: "vctp",
Audience: "other-audience",
TokenLifespan: 2 * time.Hour,
ClockSkew: time.Minute,
})
if err != nil {
t.Fatalf("failed to create verifier with wrong audience: %v", err)
}
wrongAudience.now = issuerSvc.now
_, err = wrongAudience.VerifyToken(token)
if err == nil {
t.Fatal("expected audience mismatch to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
}
if !strings.Contains(strings.ToLower(err.Error()), "audience") {
t.Fatalf("expected audience mismatch error, got: %v", err)
}
}
func TestVerifyTokenRejectsExpiredNotBeforeAndFutureIssuedAt(t *testing.T) {
base := time.Unix(1_700_000_000, 0).UTC()
svc := mustJWTService(t)
svc.now = func() time.Time { return base }
token, claims, err := svc.IssueToken("alice", nil, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
svc.now = func() time.Time { return base.Add(3 * time.Hour) }
_, err = svc.VerifyToken(token)
if !errors.Is(err, ErrExpiredJWTToken) {
t.Fatalf("expected ErrExpiredJWTToken, got: %v", err)
}
notBeforeClaims := claims
notBeforeClaims.NotBefore = base.Add(10 * time.Minute).Unix()
notBeforeClaims.IssuedAt = base.Unix()
notBeforeClaims.ExpiresAt = base.Add(2 * time.Hour).Unix()
notBeforeClaims.ID = "forced-jti-1"
notBeforeToken, err := encodeSignedJWT(notBeforeClaims, svc.signingKey)
if err != nil {
t.Fatalf("failed to create token with future nbf: %v", err)
}
svc.now = func() time.Time { return base }
_, err = svc.VerifyToken(notBeforeToken)
if !errors.Is(err, ErrNotYetValidJWTToken) {
t.Fatalf("expected ErrNotYetValidJWTToken, got: %v", err)
}
futureIatClaims := claims
futureIatClaims.IssuedAt = base.Add(20 * time.Minute).Unix()
futureIatClaims.NotBefore = base.Unix()
futureIatClaims.ExpiresAt = base.Add(3 * time.Hour).Unix()
futureIatClaims.ID = "forced-jti-2"
futureIatToken, err := encodeSignedJWT(futureIatClaims, svc.signingKey)
if err != nil {
t.Fatalf("failed to create token with future iat: %v", err)
}
_, err = svc.VerifyToken(futureIatToken)
if err == nil {
t.Fatal("expected future iat validation to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims for future iat, got: %v", err)
}
}
func TestVerifyTokenRejectsMissingJTI(t *testing.T) {
base := time.Unix(1_700_000_000, 0).UTC()
svc := mustJWTService(t)
svc.now = func() time.Time { return base }
token, claims, err := svc.IssueToken("alice", nil, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
if token == "" {
t.Fatal("expected non-empty token")
}
claims.ID = ""
customToken, err := encodeSignedJWT(claims, svc.signingKey)
if err != nil {
t.Fatalf("failed to create token without jti: %v", err)
}
_, err = svc.VerifyToken(customToken)
if err == nil {
t.Fatal("expected missing jti token to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
}
if !strings.Contains(strings.ToLower(err.Error()), "jti") {
t.Fatalf("expected jti validation error, got: %v", err)
}
}
func mustJWTService(t *testing.T) *JWTService {
t.Helper()
return mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")))
}
func mustJWTServiceWithKey(t *testing.T, keyBase64 string) *JWTService {
t.Helper()
svc, err := NewJWTService(JWTConfig{
SigningKeyBase64: keyBase64,
Issuer: "vctp",
Audience: "vctp-api",
TokenLifespan: 2 * time.Hour,
ClockSkew: time.Minute,
})
if err != nil {
t.Fatalf("failed to create jwt service: %v", err)
}
return svc
}