Files
Nathan Coad fb7e9bdca4
continuous-integration/drone/push Build is passing
dont include groups in JWT
2026-04-21 14:54:19 +10:00

294 lines
8.4 KiB
Go

package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
const (
jwtAlgHS256 = "HS256"
jwtTyp = "JWT"
)
var (
ErrInvalidJWTConfig = errors.New("invalid jwt config")
ErrInvalidJWTToken = errors.New("invalid jwt token")
ErrInvalidJWTClaims = errors.New("invalid jwt claims")
ErrExpiredJWTToken = errors.New("jwt token expired")
ErrNotYetValidJWTToken = errors.New("jwt token is not yet valid")
)
type JWTConfig struct {
SigningKeyBase64 string
Issuer string
Audience string
TokenLifespan time.Duration
ClockSkew time.Duration
}
type Claims struct {
Subject string `json:"sub"`
Roles []string `json:"roles,omitempty"`
Groups []string `json:"groups,omitempty"`
Issuer string `json:"iss"`
Audience string `json:"aud"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
NotBefore int64 `json:"nbf"`
ID string `json:"jti"`
}
type JWTService struct {
signingKey []byte
issuer string
audience string
tokenLifespan time.Duration
clockSkew time.Duration
now func() time.Time
}
type jwtHeader struct {
Algorithm string `json:"alg"`
Type string `json:"typ"`
}
func NewJWTService(cfg JWTConfig) (*JWTService, error) {
issuer := strings.TrimSpace(cfg.Issuer)
audience := strings.TrimSpace(cfg.Audience)
if issuer == "" {
return nil, fmt.Errorf("%w: issuer is required", ErrInvalidJWTConfig)
}
if audience == "" {
return nil, fmt.Errorf("%w: audience is required", ErrInvalidJWTConfig)
}
if cfg.TokenLifespan <= 0 {
return nil, fmt.Errorf("%w: token lifespan must be greater than zero", ErrInvalidJWTConfig)
}
if cfg.ClockSkew < 0 {
return nil, fmt.Errorf("%w: clock skew cannot be negative", ErrInvalidJWTConfig)
}
signingKey, err := decodeBase64Key(strings.TrimSpace(cfg.SigningKeyBase64))
if err != nil {
return nil, fmt.Errorf("%w: signing key must be valid base64", ErrInvalidJWTConfig)
}
if len(signingKey) == 0 {
return nil, fmt.Errorf("%w: signing key cannot be empty", ErrInvalidJWTConfig)
}
return &JWTService{
signingKey: signingKey,
issuer: issuer,
audience: audience,
tokenLifespan: cfg.TokenLifespan,
clockSkew: cfg.ClockSkew,
now: time.Now,
}, nil
}
func (s *JWTService) IssueToken(subject string, roles []string, groups []string) (string, Claims, error) {
subject = strings.TrimSpace(subject)
if subject == "" {
return "", Claims{}, fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
}
now := s.now().UTC()
claims := Claims{
Subject: subject,
Roles: compactTrimmedStrings(roles),
// Intentionally omit LDAP groups from JWTs; role claims are sufficient for authorization.
Groups: nil,
Issuer: s.issuer,
Audience: s.audience,
IssuedAt: now.Unix(),
ExpiresAt: now.Add(s.tokenLifespan).Unix(),
NotBefore: now.Unix(),
ID: newTokenID(),
}
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
return "", Claims{}, err
}
token, err := encodeSignedJWT(claims, s.signingKey)
if err != nil {
return "", Claims{}, err
}
return token, claims, nil
}
func (s *JWTService) VerifyToken(token string) (Claims, error) {
header, claims, signingInput, signature, err := parseJWT(token)
if err != nil {
return Claims{}, err
}
if header.Algorithm != jwtAlgHS256 {
return Claims{}, fmt.Errorf("%w: unsupported algorithm", ErrInvalidJWTToken)
}
if header.Type != "" && header.Type != jwtTyp {
return Claims{}, fmt.Errorf("%w: invalid token type", ErrInvalidJWTToken)
}
expected := signPayload(signingInput, s.signingKey)
if !hmac.Equal(signature, expected) {
return Claims{}, fmt.Errorf("%w: signature mismatch", ErrInvalidJWTToken)
}
now := s.now().UTC()
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
return Claims{}, err
}
return claims, nil
}
func validateClaims(claims Claims, now time.Time, expectedIssuer string, expectedAudience string, clockSkew time.Duration) error {
if strings.TrimSpace(claims.Subject) == "" {
return fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
}
if strings.TrimSpace(claims.ID) == "" {
return fmt.Errorf("%w: jti is required", ErrInvalidJWTClaims)
}
if claims.Issuer != expectedIssuer {
return fmt.Errorf("%w: issuer mismatch", ErrInvalidJWTClaims)
}
if claims.Audience != expectedAudience {
return fmt.Errorf("%w: audience mismatch", ErrInvalidJWTClaims)
}
if claims.IssuedAt <= 0 {
return fmt.Errorf("%w: iat is required", ErrInvalidJWTClaims)
}
if claims.NotBefore <= 0 {
return fmt.Errorf("%w: nbf is required", ErrInvalidJWTClaims)
}
if claims.ExpiresAt <= 0 {
return fmt.Errorf("%w: exp is required", ErrInvalidJWTClaims)
}
if claims.ExpiresAt <= claims.IssuedAt {
return fmt.Errorf("%w: exp must be greater than iat", ErrInvalidJWTClaims)
}
if claims.NotBefore > claims.ExpiresAt {
return fmt.Errorf("%w: nbf cannot be greater than exp", ErrInvalidJWTClaims)
}
unixNow := now.Unix()
skewSeconds := int64(clockSkew / time.Second)
if claims.IssuedAt > unixNow+skewSeconds {
return fmt.Errorf("%w: iat is in the future", ErrInvalidJWTClaims)
}
if claims.NotBefore > unixNow+skewSeconds {
return ErrNotYetValidJWTToken
}
if unixNow > claims.ExpiresAt+skewSeconds {
return ErrExpiredJWTToken
}
return nil
}
func encodeSignedJWT(claims Claims, signingKey []byte) (string, error) {
headerJSON, err := json.Marshal(jwtHeader{Algorithm: jwtAlgHS256, Type: jwtTyp})
if err != nil {
return "", fmt.Errorf("marshal jwt header: %w", err)
}
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal jwt claims: %w", err)
}
headerPart := base64.RawURLEncoding.EncodeToString(headerJSON)
payloadPart := base64.RawURLEncoding.EncodeToString(claimsJSON)
signingInput := headerPart + "." + payloadPart
signature := signPayload(signingInput, signingKey)
signaturePart := base64.RawURLEncoding.EncodeToString(signature)
return signingInput + "." + signaturePart, nil
}
func parseJWT(token string) (jwtHeader, Claims, string, []byte, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
}
if parts[0] == "" || parts[1] == "" || parts[2] == "" {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
}
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header encoding", ErrInvalidJWTToken)
}
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid payload encoding", ErrInvalidJWTToken)
}
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid signature encoding", ErrInvalidJWTToken)
}
var header jwtHeader
if err := json.Unmarshal(headerBytes, &header); err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header json", ErrInvalidJWTToken)
}
var claims Claims
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid claims json", ErrInvalidJWTToken)
}
return header, claims, parts[0] + "." + parts[1], signature, nil
}
func signPayload(payload string, signingKey []byte) []byte {
mac := hmac.New(sha256.New, signingKey)
mac.Write([]byte(payload))
return mac.Sum(nil)
}
func newTokenID() string {
raw := make([]byte, 16)
if _, err := rand.Read(raw); err != nil {
return fmt.Sprintf("fallback-%d", time.Now().UTC().UnixNano())
}
return hex.EncodeToString(raw)
}
func decodeBase64Key(value string) ([]byte, error) {
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
for _, encoding := range encodings {
decoded, err := encoding.DecodeString(value)
if err == nil {
return decoded, nil
}
}
return nil, errors.New("invalid base64 encoding")
}
func compactTrimmedStrings(values []string) []string {
if len(values) == 0 {
return nil
}
out := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
out = append(out, trimmed)
}
if len(out) == 0 {
return nil
}
return out
}