package auth import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "strings" "time" ) const ( jwtAlgHS256 = "HS256" jwtTyp = "JWT" ) var ( ErrInvalidJWTConfig = errors.New("invalid jwt config") ErrInvalidJWTToken = errors.New("invalid jwt token") ErrInvalidJWTClaims = errors.New("invalid jwt claims") ErrExpiredJWTToken = errors.New("jwt token expired") ErrNotYetValidJWTToken = errors.New("jwt token is not yet valid") ) type JWTConfig struct { SigningKeyBase64 string Issuer string Audience string TokenLifespan time.Duration ClockSkew time.Duration } type Claims struct { Subject string `json:"sub"` Roles []string `json:"roles,omitempty"` Groups []string `json:"groups,omitempty"` Issuer string `json:"iss"` Audience string `json:"aud"` IssuedAt int64 `json:"iat"` ExpiresAt int64 `json:"exp"` NotBefore int64 `json:"nbf"` ID string `json:"jti"` } type JWTService struct { signingKey []byte issuer string audience string tokenLifespan time.Duration clockSkew time.Duration now func() time.Time } type jwtHeader struct { Algorithm string `json:"alg"` Type string `json:"typ"` } func NewJWTService(cfg JWTConfig) (*JWTService, error) { issuer := strings.TrimSpace(cfg.Issuer) audience := strings.TrimSpace(cfg.Audience) if issuer == "" { return nil, fmt.Errorf("%w: issuer is required", ErrInvalidJWTConfig) } if audience == "" { return nil, fmt.Errorf("%w: audience is required", ErrInvalidJWTConfig) } if cfg.TokenLifespan <= 0 { return nil, fmt.Errorf("%w: token lifespan must be greater than zero", ErrInvalidJWTConfig) } if cfg.ClockSkew < 0 { return nil, fmt.Errorf("%w: clock skew cannot be negative", ErrInvalidJWTConfig) } signingKey, err := decodeBase64Key(strings.TrimSpace(cfg.SigningKeyBase64)) if err != nil { return nil, fmt.Errorf("%w: signing key must be valid base64", ErrInvalidJWTConfig) } if len(signingKey) == 0 { return nil, fmt.Errorf("%w: signing key cannot be empty", ErrInvalidJWTConfig) } return &JWTService{ signingKey: signingKey, issuer: issuer, audience: audience, tokenLifespan: cfg.TokenLifespan, clockSkew: cfg.ClockSkew, now: time.Now, }, nil } func (s *JWTService) IssueToken(subject string, roles []string, groups []string) (string, Claims, error) { subject = strings.TrimSpace(subject) if subject == "" { return "", Claims{}, fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims) } now := s.now().UTC() claims := Claims{ Subject: subject, Roles: compactTrimmedStrings(roles), // Intentionally omit LDAP groups from JWTs; role claims are sufficient for authorization. Groups: nil, Issuer: s.issuer, Audience: s.audience, IssuedAt: now.Unix(), ExpiresAt: now.Add(s.tokenLifespan).Unix(), NotBefore: now.Unix(), ID: newTokenID(), } if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil { return "", Claims{}, err } token, err := encodeSignedJWT(claims, s.signingKey) if err != nil { return "", Claims{}, err } return token, claims, nil } func (s *JWTService) VerifyToken(token string) (Claims, error) { header, claims, signingInput, signature, err := parseJWT(token) if err != nil { return Claims{}, err } if header.Algorithm != jwtAlgHS256 { return Claims{}, fmt.Errorf("%w: unsupported algorithm", ErrInvalidJWTToken) } if header.Type != "" && header.Type != jwtTyp { return Claims{}, fmt.Errorf("%w: invalid token type", ErrInvalidJWTToken) } expected := signPayload(signingInput, s.signingKey) if !hmac.Equal(signature, expected) { return Claims{}, fmt.Errorf("%w: signature mismatch", ErrInvalidJWTToken) } now := s.now().UTC() if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil { return Claims{}, err } return claims, nil } func validateClaims(claims Claims, now time.Time, expectedIssuer string, expectedAudience string, clockSkew time.Duration) error { if strings.TrimSpace(claims.Subject) == "" { return fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims) } if strings.TrimSpace(claims.ID) == "" { return fmt.Errorf("%w: jti is required", ErrInvalidJWTClaims) } if claims.Issuer != expectedIssuer { return fmt.Errorf("%w: issuer mismatch", ErrInvalidJWTClaims) } if claims.Audience != expectedAudience { return fmt.Errorf("%w: audience mismatch", ErrInvalidJWTClaims) } if claims.IssuedAt <= 0 { return fmt.Errorf("%w: iat is required", ErrInvalidJWTClaims) } if claims.NotBefore <= 0 { return fmt.Errorf("%w: nbf is required", ErrInvalidJWTClaims) } if claims.ExpiresAt <= 0 { return fmt.Errorf("%w: exp is required", ErrInvalidJWTClaims) } if claims.ExpiresAt <= claims.IssuedAt { return fmt.Errorf("%w: exp must be greater than iat", ErrInvalidJWTClaims) } if claims.NotBefore > claims.ExpiresAt { return fmt.Errorf("%w: nbf cannot be greater than exp", ErrInvalidJWTClaims) } unixNow := now.Unix() skewSeconds := int64(clockSkew / time.Second) if claims.IssuedAt > unixNow+skewSeconds { return fmt.Errorf("%w: iat is in the future", ErrInvalidJWTClaims) } if claims.NotBefore > unixNow+skewSeconds { return ErrNotYetValidJWTToken } if unixNow > claims.ExpiresAt+skewSeconds { return ErrExpiredJWTToken } return nil } func encodeSignedJWT(claims Claims, signingKey []byte) (string, error) { headerJSON, err := json.Marshal(jwtHeader{Algorithm: jwtAlgHS256, Type: jwtTyp}) if err != nil { return "", fmt.Errorf("marshal jwt header: %w", err) } claimsJSON, err := json.Marshal(claims) if err != nil { return "", fmt.Errorf("marshal jwt claims: %w", err) } headerPart := base64.RawURLEncoding.EncodeToString(headerJSON) payloadPart := base64.RawURLEncoding.EncodeToString(claimsJSON) signingInput := headerPart + "." + payloadPart signature := signPayload(signingInput, signingKey) signaturePart := base64.RawURLEncoding.EncodeToString(signature) return signingInput + "." + signaturePart, nil } func parseJWT(token string) (jwtHeader, Claims, string, []byte, error) { parts := strings.Split(token, ".") if len(parts) != 3 { return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken) } if parts[0] == "" || parts[1] == "" || parts[2] == "" { return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken) } headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) if err != nil { return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header encoding", ErrInvalidJWTToken) } payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid payload encoding", ErrInvalidJWTToken) } signature, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid signature encoding", ErrInvalidJWTToken) } var header jwtHeader if err := json.Unmarshal(headerBytes, &header); err != nil { return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header json", ErrInvalidJWTToken) } var claims Claims if err := json.Unmarshal(payloadBytes, &claims); err != nil { return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid claims json", ErrInvalidJWTToken) } return header, claims, parts[0] + "." + parts[1], signature, nil } func signPayload(payload string, signingKey []byte) []byte { mac := hmac.New(sha256.New, signingKey) mac.Write([]byte(payload)) return mac.Sum(nil) } func newTokenID() string { raw := make([]byte, 16) if _, err := rand.Read(raw); err != nil { return fmt.Sprintf("fallback-%d", time.Now().UTC().UnixNano()) } return hex.EncodeToString(raw) } func decodeBase64Key(value string) ([]byte, error) { encodings := []*base64.Encoding{ base64.StdEncoding, base64.RawStdEncoding, base64.URLEncoding, base64.RawURLEncoding, } for _, encoding := range encodings { decoded, err := encoding.DecodeString(value) if err == nil { return decoded, nil } } return nil, errors.New("invalid base64 encoding") } func compactTrimmedStrings(values []string) []string { if len(values) == 0 { return nil } out := make([]string, 0, len(values)) for _, value := range values { trimmed := strings.TrimSpace(value) if trimmed == "" { continue } out = append(out, trimmed) } if len(out) == 0 { return nil } return out }