@@ -0,0 +1,292 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
jwtAlgHS256 = "HS256"
|
||||
jwtTyp = "JWT"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidJWTConfig = errors.New("invalid jwt config")
|
||||
ErrInvalidJWTToken = errors.New("invalid jwt token")
|
||||
ErrInvalidJWTClaims = errors.New("invalid jwt claims")
|
||||
ErrExpiredJWTToken = errors.New("jwt token expired")
|
||||
ErrNotYetValidJWTToken = errors.New("jwt token is not yet valid")
|
||||
)
|
||||
|
||||
type JWTConfig struct {
|
||||
SigningKeyBase64 string
|
||||
Issuer string
|
||||
Audience string
|
||||
TokenLifespan time.Duration
|
||||
ClockSkew time.Duration
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Subject string `json:"sub"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
Issuer string `json:"iss"`
|
||||
Audience string `json:"aud"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
NotBefore int64 `json:"nbf"`
|
||||
ID string `json:"jti"`
|
||||
}
|
||||
|
||||
type JWTService struct {
|
||||
signingKey []byte
|
||||
issuer string
|
||||
audience string
|
||||
tokenLifespan time.Duration
|
||||
clockSkew time.Duration
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type jwtHeader struct {
|
||||
Algorithm string `json:"alg"`
|
||||
Type string `json:"typ"`
|
||||
}
|
||||
|
||||
func NewJWTService(cfg JWTConfig) (*JWTService, error) {
|
||||
issuer := strings.TrimSpace(cfg.Issuer)
|
||||
audience := strings.TrimSpace(cfg.Audience)
|
||||
if issuer == "" {
|
||||
return nil, fmt.Errorf("%w: issuer is required", ErrInvalidJWTConfig)
|
||||
}
|
||||
if audience == "" {
|
||||
return nil, fmt.Errorf("%w: audience is required", ErrInvalidJWTConfig)
|
||||
}
|
||||
if cfg.TokenLifespan <= 0 {
|
||||
return nil, fmt.Errorf("%w: token lifespan must be greater than zero", ErrInvalidJWTConfig)
|
||||
}
|
||||
if cfg.ClockSkew < 0 {
|
||||
return nil, fmt.Errorf("%w: clock skew cannot be negative", ErrInvalidJWTConfig)
|
||||
}
|
||||
|
||||
signingKey, err := decodeBase64Key(strings.TrimSpace(cfg.SigningKeyBase64))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: signing key must be valid base64", ErrInvalidJWTConfig)
|
||||
}
|
||||
if len(signingKey) == 0 {
|
||||
return nil, fmt.Errorf("%w: signing key cannot be empty", ErrInvalidJWTConfig)
|
||||
}
|
||||
|
||||
return &JWTService{
|
||||
signingKey: signingKey,
|
||||
issuer: issuer,
|
||||
audience: audience,
|
||||
tokenLifespan: cfg.TokenLifespan,
|
||||
clockSkew: cfg.ClockSkew,
|
||||
now: time.Now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *JWTService) IssueToken(subject string, roles []string, groups []string) (string, Claims, error) {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return "", Claims{}, fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
claims := Claims{
|
||||
Subject: subject,
|
||||
Roles: compactTrimmedStrings(roles),
|
||||
Groups: compactTrimmedStrings(groups),
|
||||
Issuer: s.issuer,
|
||||
Audience: s.audience,
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(s.tokenLifespan).Unix(),
|
||||
NotBefore: now.Unix(),
|
||||
ID: newTokenID(),
|
||||
}
|
||||
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
|
||||
return "", Claims{}, err
|
||||
}
|
||||
|
||||
token, err := encodeSignedJWT(claims, s.signingKey)
|
||||
if err != nil {
|
||||
return "", Claims{}, err
|
||||
}
|
||||
return token, claims, nil
|
||||
}
|
||||
|
||||
func (s *JWTService) VerifyToken(token string) (Claims, error) {
|
||||
header, claims, signingInput, signature, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return Claims{}, err
|
||||
}
|
||||
if header.Algorithm != jwtAlgHS256 {
|
||||
return Claims{}, fmt.Errorf("%w: unsupported algorithm", ErrInvalidJWTToken)
|
||||
}
|
||||
if header.Type != "" && header.Type != jwtTyp {
|
||||
return Claims{}, fmt.Errorf("%w: invalid token type", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
expected := signPayload(signingInput, s.signingKey)
|
||||
if !hmac.Equal(signature, expected) {
|
||||
return Claims{}, fmt.Errorf("%w: signature mismatch", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
|
||||
return Claims{}, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func validateClaims(claims Claims, now time.Time, expectedIssuer string, expectedAudience string, clockSkew time.Duration) error {
|
||||
if strings.TrimSpace(claims.Subject) == "" {
|
||||
return fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if strings.TrimSpace(claims.ID) == "" {
|
||||
return fmt.Errorf("%w: jti is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.Issuer != expectedIssuer {
|
||||
return fmt.Errorf("%w: issuer mismatch", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.Audience != expectedAudience {
|
||||
return fmt.Errorf("%w: audience mismatch", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.IssuedAt <= 0 {
|
||||
return fmt.Errorf("%w: iat is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore <= 0 {
|
||||
return fmt.Errorf("%w: nbf is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.ExpiresAt <= 0 {
|
||||
return fmt.Errorf("%w: exp is required", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.ExpiresAt <= claims.IssuedAt {
|
||||
return fmt.Errorf("%w: exp must be greater than iat", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore > claims.ExpiresAt {
|
||||
return fmt.Errorf("%w: nbf cannot be greater than exp", ErrInvalidJWTClaims)
|
||||
}
|
||||
|
||||
unixNow := now.Unix()
|
||||
skewSeconds := int64(clockSkew / time.Second)
|
||||
if claims.IssuedAt > unixNow+skewSeconds {
|
||||
return fmt.Errorf("%w: iat is in the future", ErrInvalidJWTClaims)
|
||||
}
|
||||
if claims.NotBefore > unixNow+skewSeconds {
|
||||
return ErrNotYetValidJWTToken
|
||||
}
|
||||
if unixNow > claims.ExpiresAt+skewSeconds {
|
||||
return ErrExpiredJWTToken
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeSignedJWT(claims Claims, signingKey []byte) (string, error) {
|
||||
headerJSON, err := json.Marshal(jwtHeader{Algorithm: jwtAlgHS256, Type: jwtTyp})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal jwt header: %w", err)
|
||||
}
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal jwt claims: %w", err)
|
||||
}
|
||||
|
||||
headerPart := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
payloadPart := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signingInput := headerPart + "." + payloadPart
|
||||
signature := signPayload(signingInput, signingKey)
|
||||
signaturePart := base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return signingInput + "." + signaturePart, nil
|
||||
}
|
||||
|
||||
func parseJWT(token string) (jwtHeader, Claims, string, []byte, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
|
||||
}
|
||||
if parts[0] == "" || parts[1] == "" || parts[2] == "" {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid payload encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid signature encoding", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
var header jwtHeader
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header json", ErrInvalidJWTToken)
|
||||
}
|
||||
var claims Claims
|
||||
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
||||
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid claims json", ErrInvalidJWTToken)
|
||||
}
|
||||
|
||||
return header, claims, parts[0] + "." + parts[1], signature, nil
|
||||
}
|
||||
|
||||
func signPayload(payload string, signingKey []byte) []byte {
|
||||
mac := hmac.New(sha256.New, signingKey)
|
||||
mac.Write([]byte(payload))
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func newTokenID() string {
|
||||
raw := make([]byte, 16)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return fmt.Sprintf("fallback-%d", time.Now().UTC().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(raw)
|
||||
}
|
||||
|
||||
func decodeBase64Key(value string) ([]byte, error) {
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
for _, encoding := range encodings {
|
||||
decoded, err := encoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("invalid base64 encoding")
|
||||
}
|
||||
|
||||
func compactTrimmedStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user