@@ -0,0 +1,292 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
jwtAlgHS256 = "HS256"
|
||||
jwtTyp = "JWT"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidJWTConfig = errors.New("invalid jwt config")
|
||||
ErrInvalidJWTToken = errors.New("invalid jwt token")
|
||||
ErrInvalidJWTClaims = errors.New("invalid jwt claims")
|
||||
ErrExpiredJWTToken = errors.New("jwt token expired")
|
||||
ErrNotYetValidJWTToken = errors.New("jwt token is not yet valid")
|
||||
)
|
||||
|
||||
type JWTConfig struct {
|
||||
SigningKeyBase64 string
|
||||
Issuer string
|
||||
Audience string
|
||||
TokenLifespan time.Duration
|
||||
ClockSkew time.Duration
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Subject string `json:"sub"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
Issuer string `json:"iss"`
|
||||
Audience string `json:"aud"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
NotBefore int64 `json:"nbf"`
|
||||
ID string `json:"jti"`
|
||||
}
|
||||
|
||||
type JWTService struct {
|
||||
signingKey []byte
|
||||
issuer string
|
||||
audience string
|
||||
tokenLifespan time.Duration
|
||||
clockSkew time.Duration
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type jwtHeader struct {
|
||||
Algorithm string `json:"alg"`
|
||||
Type string `json:"typ"`
|
||||
}
|
||||
|
||||
func NewJWTService(cfg JWTConfig) (*JWTService, error) {
|
||||
issuer := strings.TrimSpace(cfg.Issuer)
|
||||
audience := strings.TrimSpace(cfg.Audience)
|
||||
if issuer == "" {
|
||||
return nil, fmt.Errorf("%w: issuer is required", ErrInvalidJWTConfig)
|
||||
}
|
||||
if audience == "" {
|
||||
return nil, fmt.Errorf("%w: audience is required", ErrInvalidJWTConfig)
|
||||
}
|
||||
if cfg.TokenLifespan <= 0 {
|
||||
return nil, fmt.Errorf("%w: token lifespan must be greater than zero", ErrInvalidJWTConfig)
|
||||
}
|
||||
if cfg.ClockSkew < 0 {
|
||||
return nil, fmt.Errorf("%w: clock skew cannot be negative", ErrInvalidJWTConfig)
|
||||
}
|
||||
|
||||
signingKey, err := decodeBase64Key(strings.TrimSpace(cfg.SigningKeyBase64))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: signing key must be valid base64", ErrInvalidJWTConfig)
|
||||
}
|
||||
if len(signingKey) == 0 {
|
||||
return nil, fmt.Errorf("%w: signing key cannot be empty", ErrInvalidJWTConfig)
|
||||
}
|
||||
|
||||
return &JWTService{
|
||||
signingKey: signingKey,
|
||||
issuer: issuer,
|
||||
audience: audience,
|
||||
tokenLifespan: cfg.TokenLifespan,
|
||||
clockSkew: cfg.ClockSkew,
|
||||
now: time.Now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *JWTService) IssueToken(subject string, roles []string, groups []string) (string, Claims, error) {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return "", Claims{}, fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
claims := Claims{
|
||||
Subject: subject,
|
||||
Roles: compactTrimmedStrings(roles),
|
||||
Groups: compactTrimmedStrings(groups),
|
||||
Issuer: s.issuer,
|
||||
Audience: s.audience,
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(s.tokenLifespan).Unix(),
|
||||
NotBefore: now.Unix(),
|
||||
ID: newTokenID(),
|
||||
}
|
||||
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
|
||||
return "", Claims{}, err
|
||||
}
|
||||
|
||||
token, err := encodeSignedJWT(claims, s.signingKey)
|
||||
if err != nil {
|
||||
return "", Claims{}, err
|
||||
}
|
||||
return token, claims, nil
|
||||
}
|
||||
|
||||
func (s *JWTService) VerifyToken(token string) (Claims, error) {
|
||||
header, claims, signingInput, signature, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return Claims{}, err
|
||||
}
|
||||
if header.Algorithm != jwtAlgHS256 {
|
||||
return Claims{}, fmt.Errorf("%w: unsupported algorithm", ErrInvalidJWTToken)
|
||||
}
|
||||
if header.Type != "" && header.Type != jwtTyp {
|
||||
return Claims{}, fmt.Errorf("%w: invalid token type", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
expected := signPayload(signingInput, s.signingKey)
|
||||
if !hmac.Equal(signature, expected) {
|
||||
return Claims{}, fmt.Errorf("%w: signature mismatch", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
|
||||
return Claims{}, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func validateClaims(claims Claims, now time.Time, expectedIssuer string, expectedAudience string, clockSkew time.Duration) error {
|
||||
if strings.TrimSpace(claims.Subject) == "" {
|
||||
return fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if strings.TrimSpace(claims.ID) == "" {
|
||||
return fmt.Errorf("%w: jti is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.Issuer != expectedIssuer {
|
||||
return fmt.Errorf("%w: issuer mismatch", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.Audience != expectedAudience {
|
||||
return fmt.Errorf("%w: audience mismatch", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.IssuedAt <= 0 {
|
||||
return fmt.Errorf("%w: iat is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore <= 0 {
|
||||
return fmt.Errorf("%w: nbf is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.ExpiresAt <= 0 {
|
||||
return fmt.Errorf("%w: exp is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.ExpiresAt <= claims.IssuedAt {
|
||||
return fmt.Errorf("%w: exp must be greater than iat", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore > claims.ExpiresAt {
|
||||
return fmt.Errorf("%w: nbf cannot be greater than exp", ErrInvalidJWTClaims)
|
||||
}
|
||||
|
||||
unixNow := now.Unix()
|
||||
skewSeconds := int64(clockSkew / time.Second)
|
||||
if claims.IssuedAt > unixNow+skewSeconds {
|
||||
return fmt.Errorf("%w: iat is in the future", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore > unixNow+skewSeconds {
|
||||
return ErrNotYetValidJWTToken
|
||||
}
|
||||
if unixNow > claims.ExpiresAt+skewSeconds {
|
||||
return ErrExpiredJWTToken
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeSignedJWT(claims Claims, signingKey []byte) (string, error) {
|
||||
headerJSON, err := json.Marshal(jwtHeader{Algorithm: jwtAlgHS256, Type: jwtTyp})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal jwt header: %w", err)
|
||||
}
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal jwt claims: %w", err)
|
||||
}
|
||||
|
||||
headerPart := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
payloadPart := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signingInput := headerPart + "." + payloadPart
|
||||
signature := signPayload(signingInput, signingKey)
|
||||
signaturePart := base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return signingInput + "." + signaturePart, nil
|
||||
}
|
||||
|
||||
func parseJWT(token string) (jwtHeader, Claims, string, []byte, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
|
||||
}
|
||||
if parts[0] == "" || parts[1] == "" || parts[2] == "" {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid payload encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid signature encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
var header jwtHeader
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header json", ErrInvalidJWTToken)
|
||||
}
|
||||
var claims Claims
|
||||
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid claims json", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
return header, claims, parts[0] + "." + parts[1], signature, nil
|
||||
}
|
||||
|
||||
func signPayload(payload string, signingKey []byte) []byte {
|
||||
mac := hmac.New(sha256.New, signingKey)
|
||||
mac.Write([]byte(payload))
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func newTokenID() string {
|
||||
raw := make([]byte, 16)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return fmt.Sprintf("fallback-%d", time.Now().UTC().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(raw)
|
||||
}
|
||||
|
||||
func decodeBase64Key(value string) ([]byte, error) {
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
for _, encoding := range encodings {
|
||||
decoded, err := encoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("invalid base64 encoding")
|
||||
}
|
||||
|
||||
func compactTrimmedStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewJWTServiceRejectsBadConfig(t *testing.T) {
|
||||
_, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: "!!!",
|
||||
Issuer: "vctp",
|
||||
Audience: "vctp-api",
|
||||
TokenLifespan: time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid base64 signing key to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTConfig) {
|
||||
t.Fatalf("expected ErrInvalidJWTConfig, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssueAndVerifyTokenRoundTrip(t *testing.T) {
|
||||
now := time.Unix(1_700_000_000, 0).UTC()
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return now }
|
||||
|
||||
token, issuedClaims, err := svc.IssueToken("alice", []string{"admin", " viewer "}, []string{"cn=vctp-admins,dc=example,dc=com"})
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
if issuedClaims.Subject != "alice" {
|
||||
t.Fatalf("expected subject alice, got %q", issuedClaims.Subject)
|
||||
}
|
||||
if issuedClaims.Issuer != "vctp" {
|
||||
t.Fatalf("expected issuer vctp, got %q", issuedClaims.Issuer)
|
||||
}
|
||||
if issuedClaims.Audience != "vctp-api" {
|
||||
t.Fatalf("expected audience vctp-api, got %q", issuedClaims.Audience)
|
||||
}
|
||||
if issuedClaims.IssuedAt != now.Unix() {
|
||||
t.Fatalf("unexpected iat: %d", issuedClaims.IssuedAt)
|
||||
}
|
||||
if issuedClaims.NotBefore != now.Unix() {
|
||||
t.Fatalf("unexpected nbf: %d", issuedClaims.NotBefore)
|
||||
}
|
||||
if issuedClaims.ExpiresAt != now.Add(2*time.Hour).Unix() {
|
||||
t.Fatalf("unexpected exp: %d", issuedClaims.ExpiresAt)
|
||||
}
|
||||
if issuedClaims.ID == "" {
|
||||
t.Fatal("expected jti to be populated")
|
||||
}
|
||||
|
||||
verifiedClaims, err := svc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyToken returned error: %v", err)
|
||||
}
|
||||
if verifiedClaims.Subject != issuedClaims.Subject {
|
||||
t.Fatalf("subject mismatch: got %q want %q", verifiedClaims.Subject, issuedClaims.Subject)
|
||||
}
|
||||
if verifiedClaims.ID != issuedClaims.ID {
|
||||
t.Fatalf("jti mismatch: got %q want %q", verifiedClaims.ID, issuedClaims.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsInvalidSignature(t *testing.T) {
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
|
||||
|
||||
token, _, err := svc.IssueToken("alice", []string{"admin"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
|
||||
other := mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("a different secret key")))
|
||||
other.now = svc.now
|
||||
|
||||
_, err = other.VerifyToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected signature mismatch to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTToken) {
|
||||
t.Fatalf("expected ErrInvalidJWTToken, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsIssuerAndAudienceMismatch(t *testing.T) {
|
||||
issuerSvc := mustJWTService(t)
|
||||
issuerSvc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
|
||||
token, _, err := issuerSvc.IssueToken("alice", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
|
||||
wrongIssuer, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
|
||||
Issuer: "other-issuer",
|
||||
Audience: "vctp-api",
|
||||
TokenLifespan: 2 * time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create verifier with wrong issuer: %v", err)
|
||||
}
|
||||
wrongIssuer.now = issuerSvc.now
|
||||
|
||||
_, err = wrongIssuer.VerifyToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected issuer mismatch to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "issuer") {
|
||||
t.Fatalf("expected issuer mismatch error, got: %v", err)
|
||||
}
|
||||
|
||||
wrongAudience, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
|
||||
Issuer: "vctp",
|
||||
Audience: "other-audience",
|
||||
TokenLifespan: 2 * time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create verifier with wrong audience: %v", err)
|
||||
}
|
||||
wrongAudience.now = issuerSvc.now
|
||||
|
||||
_, err = wrongAudience.VerifyToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("expected audience mismatch to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "audience") {
|
||||
t.Fatalf("expected audience mismatch error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsExpiredNotBeforeAndFutureIssuedAt(t *testing.T) {
|
||||
base := time.Unix(1_700_000_000, 0).UTC()
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return base }
|
||||
|
||||
token, claims, err := svc.IssueToken("alice", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
|
||||
svc.now = func() time.Time { return base.Add(3 * time.Hour) }
|
||||
_, err = svc.VerifyToken(token)
|
||||
if !errors.Is(err, ErrExpiredJWTToken) {
|
||||
t.Fatalf("expected ErrExpiredJWTToken, got: %v", err)
|
||||
}
|
||||
|
||||
notBeforeClaims := claims
|
||||
notBeforeClaims.NotBefore = base.Add(10 * time.Minute).Unix()
|
||||
notBeforeClaims.IssuedAt = base.Unix()
|
||||
notBeforeClaims.ExpiresAt = base.Add(2 * time.Hour).Unix()
|
||||
notBeforeClaims.ID = "forced-jti-1"
|
||||
notBeforeToken, err := encodeSignedJWT(notBeforeClaims, svc.signingKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create token with future nbf: %v", err)
|
||||
}
|
||||
svc.now = func() time.Time { return base }
|
||||
_, err = svc.VerifyToken(notBeforeToken)
|
||||
if !errors.Is(err, ErrNotYetValidJWTToken) {
|
||||
t.Fatalf("expected ErrNotYetValidJWTToken, got: %v", err)
|
||||
}
|
||||
|
||||
futureIatClaims := claims
|
||||
futureIatClaims.IssuedAt = base.Add(20 * time.Minute).Unix()
|
||||
futureIatClaims.NotBefore = base.Unix()
|
||||
futureIatClaims.ExpiresAt = base.Add(3 * time.Hour).Unix()
|
||||
futureIatClaims.ID = "forced-jti-2"
|
||||
futureIatToken, err := encodeSignedJWT(futureIatClaims, svc.signingKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create token with future iat: %v", err)
|
||||
}
|
||||
_, err = svc.VerifyToken(futureIatToken)
|
||||
if err == nil {
|
||||
t.Fatal("expected future iat validation to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims for future iat, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTokenRejectsMissingJTI(t *testing.T) {
|
||||
base := time.Unix(1_700_000_000, 0).UTC()
|
||||
svc := mustJWTService(t)
|
||||
svc.now = func() time.Time { return base }
|
||||
|
||||
token, claims, err := svc.IssueToken("alice", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken returned error: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
|
||||
claims.ID = ""
|
||||
customToken, err := encodeSignedJWT(claims, svc.signingKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create token without jti: %v", err)
|
||||
}
|
||||
|
||||
_, err = svc.VerifyToken(customToken)
|
||||
if err == nil {
|
||||
t.Fatal("expected missing jti token to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidJWTClaims) {
|
||||
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "jti") {
|
||||
t.Fatalf("expected jti validation error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustJWTService(t *testing.T) *JWTService {
|
||||
t.Helper()
|
||||
return mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")))
|
||||
}
|
||||
|
||||
func mustJWTServiceWithKey(t *testing.T, keyBase64 string) *JWTService {
|
||||
t.Helper()
|
||||
svc, err := NewJWTService(JWTConfig{
|
||||
SigningKeyBase64: keyBase64,
|
||||
Issuer: "vctp",
|
||||
Audience: "vctp-api",
|
||||
TokenLifespan: 2 * time.Hour,
|
||||
ClockSkew: time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create jwt service: %v", err)
|
||||
}
|
||||
return svc
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidLDAPConfig = errors.New("invalid ldap config")
|
||||
ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials")
|
||||
ErrLDAPOperationFailed = errors.New("ldap operation failed")
|
||||
)
|
||||
|
||||
type LDAPConfig struct {
|
||||
BindAddress string
|
||||
BaseDN string
|
||||
TrustCertFile string
|
||||
DisableValidation bool
|
||||
Insecure bool
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
type LDAPIdentity struct {
|
||||
Username string
|
||||
UserDN string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
type LDAPAuthenticator struct {
|
||||
bindAddress string
|
||||
baseDN string
|
||||
trustCertFile string
|
||||
disableValidation bool
|
||||
insecure bool
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
|
||||
bindAddress := strings.TrimSpace(cfg.BindAddress)
|
||||
baseDN := strings.TrimSpace(cfg.BaseDN)
|
||||
trustCertFile := strings.TrimSpace(cfg.TrustCertFile)
|
||||
|
||||
if bindAddress == "" {
|
||||
return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig)
|
||||
}
|
||||
if baseDN == "" {
|
||||
return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig)
|
||||
}
|
||||
if _, err := url.ParseRequestURI(bindAddress); err != nil {
|
||||
return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
|
||||
dialTimeout := cfg.DialTimeout
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = 10 * time.Second
|
||||
}
|
||||
|
||||
return &LDAPAuthenticator{
|
||||
bindAddress: bindAddress,
|
||||
baseDN: baseDN,
|
||||
trustCertFile: trustCertFile,
|
||||
disableValidation: cfg.DisableValidation,
|
||||
insecure: cfg.Insecure,
|
||||
dialTimeout: dialTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" || password == "" {
|
||||
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
||||
}
|
||||
if err := ctxErr(ctx); err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
|
||||
conn, err := a.connect()
|
||||
if err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.Bind(username, password); err != nil {
|
||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
||||
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
||||
}
|
||||
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if err := ctxErr(ctx); err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
|
||||
identity := LDAPIdentity{
|
||||
Username: username,
|
||||
UserDN: username,
|
||||
}
|
||||
|
||||
entry, err := a.lookupUserEntry(conn, username)
|
||||
if err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
if entry != nil {
|
||||
if strings.TrimSpace(entry.DN) != "" {
|
||||
identity.UserDN = entry.DN
|
||||
}
|
||||
if v := firstNonEmpty(
|
||||
entry.GetAttributeValue("uid"),
|
||||
entry.GetAttributeValue("sAMAccountName"),
|
||||
entry.GetAttributeValue("userPrincipalName"),
|
||||
entry.GetAttributeValue("cn"),
|
||||
); v != "" {
|
||||
identity.Username = v
|
||||
}
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
if entry != nil {
|
||||
for _, groupDN := range entry.GetAttributeValues("memberOf") {
|
||||
groupDN = strings.TrimSpace(groupDN)
|
||||
if groupDN == "" {
|
||||
continue
|
||||
}
|
||||
groupSet[groupDN] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
groupEntries, err := conn.Search(ldap.NewSearchRequest(
|
||||
a.baseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
fmt.Sprintf("(|(member=%s)(uniqueMember=%s)(memberUid=%s))",
|
||||
ldap.EscapeFilter(identity.UserDN),
|
||||
ldap.EscapeFilter(identity.UserDN),
|
||||
ldap.EscapeFilter(username),
|
||||
),
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
))
|
||||
if err == nil {
|
||||
for _, e := range groupEntries.Entries {
|
||||
if dn := strings.TrimSpace(e.DN); dn != "" {
|
||||
groupSet[dn] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
identity.Groups = mapKeysSorted(groupSet)
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string {
|
||||
if len(groupDNs) == 0 || len(groupRoleMappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalizedMappings := make(map[string]string, len(groupRoleMappings))
|
||||
for groupDN, role := range groupRoleMappings {
|
||||
groupDN = normalizeDN(groupDN)
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if groupDN == "" || role == "" {
|
||||
continue
|
||||
}
|
||||
normalizedMappings[groupDN] = role
|
||||
}
|
||||
|
||||
roleSet := make(map[string]struct{})
|
||||
for _, groupDN := range groupDNs {
|
||||
if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok {
|
||||
roleSet[role] = struct{}{}
|
||||
}
|
||||
}
|
||||
return mapKeysSorted(roleSet)
|
||||
}
|
||||
|
||||
func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool {
|
||||
requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs)
|
||||
if len(requiredGroupDNs) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(groupDNs) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
required := make(map[string]struct{}, len(requiredGroupDNs))
|
||||
for _, groupDN := range requiredGroupDNs {
|
||||
required[normalizeDN(groupDN)] = struct{}{}
|
||||
}
|
||||
for _, groupDN := range groupDNs {
|
||||
if _, ok := required[normalizeDN(groupDN)]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
|
||||
tlsConfig, err := a.buildTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(a.bindAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
|
||||
options := []ldap.DialOpt{
|
||||
ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}),
|
||||
ldap.DialWithTLSConfig(tlsConfig),
|
||||
}
|
||||
conn, err := ldap.DialURL(a.bindAddress, options...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
conn.SetTimeout(a.dialTimeout)
|
||||
|
||||
// For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure.
|
||||
if parsedURL.Scheme == "ldap" && !a.insecure {
|
||||
if err := conn.StartTLS(tlsConfig); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags
|
||||
}
|
||||
|
||||
if a.trustCertFile == "" {
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
caPEM, err := os.ReadFile(a.trustCertFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
roots := x509.NewCertPool()
|
||||
if !roots.AppendCertsFromPEM(caPEM) {
|
||||
return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig)
|
||||
}
|
||||
tlsConfig.RootCAs = roots
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string) (*ldap.Entry, error) {
|
||||
if looksLikeDN(username) {
|
||||
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
||||
username,
|
||||
ldap.ScopeBaseObject,
|
||||
ldap.NeverDerefAliases,
|
||||
1,
|
||||
0,
|
||||
false,
|
||||
"(objectClass=*)",
|
||||
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
||||
nil,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: unable to load user entry: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if len(searchRes.Entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return searchRes.Entries[0], nil
|
||||
}
|
||||
|
||||
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
||||
a.baseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
2,
|
||||
0,
|
||||
false,
|
||||
fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))",
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
),
|
||||
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
||||
nil,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if len(searchRes.Entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return searchRes.Entries[0], nil
|
||||
}
|
||||
|
||||
func normalizeDN(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func mapKeysSorted[K ~string, V any](m map[K]V) []K {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]K, 0, len(m))
|
||||
for key := range m {
|
||||
out = append(out, key)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
return out[i] < out[j]
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func looksLikeDN(value string) bool {
|
||||
value = strings.TrimSpace(value)
|
||||
return strings.Contains(value, "=") && strings.Contains(value, ",")
|
||||
}
|
||||
|
||||
func ctxErr(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveRoles(t *testing.T) {
|
||||
roles := ResolveRoles(
|
||||
[]string{
|
||||
"cn=vctp-admins,ou=groups,dc=example,dc=com",
|
||||
" CN=VCTP-VIEWERS,OU=GROUPS,DC=EXAMPLE,DC=COM ",
|
||||
},
|
||||
map[string]string{
|
||||
"cn=vctp-admins,ou=groups,dc=example,dc=com": "admin",
|
||||
"cn=vctp-viewers,ou=groups,dc=example,dc=com": "viewer",
|
||||
},
|
||||
)
|
||||
|
||||
if len(roles) != 2 {
|
||||
t.Fatalf("expected 2 roles, got %d (%#v)", len(roles), roles)
|
||||
}
|
||||
if roles[0] != "admin" || roles[1] != "viewer" {
|
||||
t.Fatalf("unexpected resolved roles: %#v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasAnyGroup(t *testing.T) {
|
||||
groups := []string{
|
||||
"cn=vctp-admins,ou=groups,dc=example,dc=com",
|
||||
}
|
||||
|
||||
if !HasAnyGroup(groups, []string{" cn=vctp-admins,ou=groups,dc=example,dc=com "}) {
|
||||
t.Fatal("expected group intersection to match")
|
||||
}
|
||||
if HasAnyGroup(groups, []string{"cn=vctp-operators,ou=groups,dc=example,dc=com"}) {
|
||||
t.Fatal("expected no intersection")
|
||||
}
|
||||
if !HasAnyGroup(groups, nil) {
|
||||
t.Fatal("expected empty required groups to allow")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user