add auth support
continuous-integration/drone/push Build is passing

This commit is contained in:
2026-04-17 13:19:08 +10:00
parent 9a561f3b07
commit ae3e2be89a
22 changed files with 2479 additions and 40 deletions
+292
View File
@@ -0,0 +1,292 @@
package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
const (
jwtAlgHS256 = "HS256"
jwtTyp = "JWT"
)
var (
ErrInvalidJWTConfig = errors.New("invalid jwt config")
ErrInvalidJWTToken = errors.New("invalid jwt token")
ErrInvalidJWTClaims = errors.New("invalid jwt claims")
ErrExpiredJWTToken = errors.New("jwt token expired")
ErrNotYetValidJWTToken = errors.New("jwt token is not yet valid")
)
type JWTConfig struct {
SigningKeyBase64 string
Issuer string
Audience string
TokenLifespan time.Duration
ClockSkew time.Duration
}
type Claims struct {
Subject string `json:"sub"`
Roles []string `json:"roles,omitempty"`
Groups []string `json:"groups,omitempty"`
Issuer string `json:"iss"`
Audience string `json:"aud"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
NotBefore int64 `json:"nbf"`
ID string `json:"jti"`
}
type JWTService struct {
signingKey []byte
issuer string
audience string
tokenLifespan time.Duration
clockSkew time.Duration
now func() time.Time
}
type jwtHeader struct {
Algorithm string `json:"alg"`
Type string `json:"typ"`
}
func NewJWTService(cfg JWTConfig) (*JWTService, error) {
issuer := strings.TrimSpace(cfg.Issuer)
audience := strings.TrimSpace(cfg.Audience)
if issuer == "" {
return nil, fmt.Errorf("%w: issuer is required", ErrInvalidJWTConfig)
}
if audience == "" {
return nil, fmt.Errorf("%w: audience is required", ErrInvalidJWTConfig)
}
if cfg.TokenLifespan <= 0 {
return nil, fmt.Errorf("%w: token lifespan must be greater than zero", ErrInvalidJWTConfig)
}
if cfg.ClockSkew < 0 {
return nil, fmt.Errorf("%w: clock skew cannot be negative", ErrInvalidJWTConfig)
}
signingKey, err := decodeBase64Key(strings.TrimSpace(cfg.SigningKeyBase64))
if err != nil {
return nil, fmt.Errorf("%w: signing key must be valid base64", ErrInvalidJWTConfig)
}
if len(signingKey) == 0 {
return nil, fmt.Errorf("%w: signing key cannot be empty", ErrInvalidJWTConfig)
}
return &JWTService{
signingKey: signingKey,
issuer: issuer,
audience: audience,
tokenLifespan: cfg.TokenLifespan,
clockSkew: cfg.ClockSkew,
now: time.Now,
}, nil
}
func (s *JWTService) IssueToken(subject string, roles []string, groups []string) (string, Claims, error) {
subject = strings.TrimSpace(subject)
if subject == "" {
return "", Claims{}, fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
}
now := s.now().UTC()
claims := Claims{
Subject: subject,
Roles: compactTrimmedStrings(roles),
Groups: compactTrimmedStrings(groups),
Issuer: s.issuer,
Audience: s.audience,
IssuedAt: now.Unix(),
ExpiresAt: now.Add(s.tokenLifespan).Unix(),
NotBefore: now.Unix(),
ID: newTokenID(),
}
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
return "", Claims{}, err
}
token, err := encodeSignedJWT(claims, s.signingKey)
if err != nil {
return "", Claims{}, err
}
return token, claims, nil
}
func (s *JWTService) VerifyToken(token string) (Claims, error) {
header, claims, signingInput, signature, err := parseJWT(token)
if err != nil {
return Claims{}, err
}
if header.Algorithm != jwtAlgHS256 {
return Claims{}, fmt.Errorf("%w: unsupported algorithm", ErrInvalidJWTToken)
}
if header.Type != "" && header.Type != jwtTyp {
return Claims{}, fmt.Errorf("%w: invalid token type", ErrInvalidJWTToken)
}
expected := signPayload(signingInput, s.signingKey)
if !hmac.Equal(signature, expected) {
return Claims{}, fmt.Errorf("%w: signature mismatch", ErrInvalidJWTToken)
}
now := s.now().UTC()
if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil {
return Claims{}, err
}
return claims, nil
}
func validateClaims(claims Claims, now time.Time, expectedIssuer string, expectedAudience string, clockSkew time.Duration) error {
if strings.TrimSpace(claims.Subject) == "" {
return fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims)
}
if strings.TrimSpace(claims.ID) == "" {
return fmt.Errorf("%w: jti is required", ErrInvalidJWTClaims)
}
if claims.Issuer != expectedIssuer {
return fmt.Errorf("%w: issuer mismatch", ErrInvalidJWTClaims)
}
if claims.Audience != expectedAudience {
return fmt.Errorf("%w: audience mismatch", ErrInvalidJWTClaims)
}
if claims.IssuedAt <= 0 {
return fmt.Errorf("%w: iat is required", ErrInvalidJWTClaims)
}
if claims.NotBefore <= 0 {
return fmt.Errorf("%w: nbf is required", ErrInvalidJWTClaims)
}
if claims.ExpiresAt <= 0 {
return fmt.Errorf("%w: exp is required", ErrInvalidJWTClaims)
}
if claims.ExpiresAt <= claims.IssuedAt {
return fmt.Errorf("%w: exp must be greater than iat", ErrInvalidJWTClaims)
}
if claims.NotBefore > claims.ExpiresAt {
return fmt.Errorf("%w: nbf cannot be greater than exp", ErrInvalidJWTClaims)
}
unixNow := now.Unix()
skewSeconds := int64(clockSkew / time.Second)
if claims.IssuedAt > unixNow+skewSeconds {
return fmt.Errorf("%w: iat is in the future", ErrInvalidJWTClaims)
}
if claims.NotBefore > unixNow+skewSeconds {
return ErrNotYetValidJWTToken
}
if unixNow > claims.ExpiresAt+skewSeconds {
return ErrExpiredJWTToken
}
return nil
}
func encodeSignedJWT(claims Claims, signingKey []byte) (string, error) {
headerJSON, err := json.Marshal(jwtHeader{Algorithm: jwtAlgHS256, Type: jwtTyp})
if err != nil {
return "", fmt.Errorf("marshal jwt header: %w", err)
}
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal jwt claims: %w", err)
}
headerPart := base64.RawURLEncoding.EncodeToString(headerJSON)
payloadPart := base64.RawURLEncoding.EncodeToString(claimsJSON)
signingInput := headerPart + "." + payloadPart
signature := signPayload(signingInput, signingKey)
signaturePart := base64.RawURLEncoding.EncodeToString(signature)
return signingInput + "." + signaturePart, nil
}
func parseJWT(token string) (jwtHeader, Claims, string, []byte, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
}
if parts[0] == "" || parts[1] == "" || parts[2] == "" {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken)
}
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header encoding", ErrInvalidJWTToken)
}
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid payload encoding", ErrInvalidJWTToken)
}
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid signature encoding", ErrInvalidJWTToken)
}
var header jwtHeader
if err := json.Unmarshal(headerBytes, &header); err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header json", ErrInvalidJWTToken)
}
var claims Claims
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid claims json", ErrInvalidJWTToken)
}
return header, claims, parts[0] + "." + parts[1], signature, nil
}
func signPayload(payload string, signingKey []byte) []byte {
mac := hmac.New(sha256.New, signingKey)
mac.Write([]byte(payload))
return mac.Sum(nil)
}
func newTokenID() string {
raw := make([]byte, 16)
if _, err := rand.Read(raw); err != nil {
return fmt.Sprintf("fallback-%d", time.Now().UTC().UnixNano())
}
return hex.EncodeToString(raw)
}
func decodeBase64Key(value string) ([]byte, error) {
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
for _, encoding := range encodings {
decoded, err := encoding.DecodeString(value)
if err == nil {
return decoded, nil
}
}
return nil, errors.New("invalid base64 encoding")
}
func compactTrimmedStrings(values []string) []string {
if len(values) == 0 {
return nil
}
out := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
out = append(out, trimmed)
}
if len(out) == 0 {
return nil
}
return out
}
+247
View File
@@ -0,0 +1,247 @@
package auth
import (
"encoding/base64"
"errors"
"strings"
"testing"
"time"
)
func TestNewJWTServiceRejectsBadConfig(t *testing.T) {
_, err := NewJWTService(JWTConfig{
SigningKeyBase64: "!!!",
Issuer: "vctp",
Audience: "vctp-api",
TokenLifespan: time.Hour,
ClockSkew: time.Minute,
})
if err == nil {
t.Fatal("expected invalid base64 signing key to fail")
}
if !errors.Is(err, ErrInvalidJWTConfig) {
t.Fatalf("expected ErrInvalidJWTConfig, got: %v", err)
}
}
func TestIssueAndVerifyTokenRoundTrip(t *testing.T) {
now := time.Unix(1_700_000_000, 0).UTC()
svc := mustJWTService(t)
svc.now = func() time.Time { return now }
token, issuedClaims, err := svc.IssueToken("alice", []string{"admin", " viewer "}, []string{"cn=vctp-admins,dc=example,dc=com"})
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
if token == "" {
t.Fatal("expected non-empty token")
}
if issuedClaims.Subject != "alice" {
t.Fatalf("expected subject alice, got %q", issuedClaims.Subject)
}
if issuedClaims.Issuer != "vctp" {
t.Fatalf("expected issuer vctp, got %q", issuedClaims.Issuer)
}
if issuedClaims.Audience != "vctp-api" {
t.Fatalf("expected audience vctp-api, got %q", issuedClaims.Audience)
}
if issuedClaims.IssuedAt != now.Unix() {
t.Fatalf("unexpected iat: %d", issuedClaims.IssuedAt)
}
if issuedClaims.NotBefore != now.Unix() {
t.Fatalf("unexpected nbf: %d", issuedClaims.NotBefore)
}
if issuedClaims.ExpiresAt != now.Add(2*time.Hour).Unix() {
t.Fatalf("unexpected exp: %d", issuedClaims.ExpiresAt)
}
if issuedClaims.ID == "" {
t.Fatal("expected jti to be populated")
}
verifiedClaims, err := svc.VerifyToken(token)
if err != nil {
t.Fatalf("VerifyToken returned error: %v", err)
}
if verifiedClaims.Subject != issuedClaims.Subject {
t.Fatalf("subject mismatch: got %q want %q", verifiedClaims.Subject, issuedClaims.Subject)
}
if verifiedClaims.ID != issuedClaims.ID {
t.Fatalf("jti mismatch: got %q want %q", verifiedClaims.ID, issuedClaims.ID)
}
}
func TestVerifyTokenRejectsInvalidSignature(t *testing.T) {
svc := mustJWTService(t)
svc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
token, _, err := svc.IssueToken("alice", []string{"admin"}, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
other := mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("a different secret key")))
other.now = svc.now
_, err = other.VerifyToken(token)
if err == nil {
t.Fatal("expected signature mismatch to fail")
}
if !errors.Is(err, ErrInvalidJWTToken) {
t.Fatalf("expected ErrInvalidJWTToken, got: %v", err)
}
}
func TestVerifyTokenRejectsIssuerAndAudienceMismatch(t *testing.T) {
issuerSvc := mustJWTService(t)
issuerSvc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }
token, _, err := issuerSvc.IssueToken("alice", nil, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
wrongIssuer, err := NewJWTService(JWTConfig{
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
Issuer: "other-issuer",
Audience: "vctp-api",
TokenLifespan: 2 * time.Hour,
ClockSkew: time.Minute,
})
if err != nil {
t.Fatalf("failed to create verifier with wrong issuer: %v", err)
}
wrongIssuer.now = issuerSvc.now
_, err = wrongIssuer.VerifyToken(token)
if err == nil {
t.Fatal("expected issuer mismatch to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
}
if !strings.Contains(strings.ToLower(err.Error()), "issuer") {
t.Fatalf("expected issuer mismatch error, got: %v", err)
}
wrongAudience, err := NewJWTService(JWTConfig{
SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")),
Issuer: "vctp",
Audience: "other-audience",
TokenLifespan: 2 * time.Hour,
ClockSkew: time.Minute,
})
if err != nil {
t.Fatalf("failed to create verifier with wrong audience: %v", err)
}
wrongAudience.now = issuerSvc.now
_, err = wrongAudience.VerifyToken(token)
if err == nil {
t.Fatal("expected audience mismatch to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
}
if !strings.Contains(strings.ToLower(err.Error()), "audience") {
t.Fatalf("expected audience mismatch error, got: %v", err)
}
}
func TestVerifyTokenRejectsExpiredNotBeforeAndFutureIssuedAt(t *testing.T) {
base := time.Unix(1_700_000_000, 0).UTC()
svc := mustJWTService(t)
svc.now = func() time.Time { return base }
token, claims, err := svc.IssueToken("alice", nil, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
svc.now = func() time.Time { return base.Add(3 * time.Hour) }
_, err = svc.VerifyToken(token)
if !errors.Is(err, ErrExpiredJWTToken) {
t.Fatalf("expected ErrExpiredJWTToken, got: %v", err)
}
notBeforeClaims := claims
notBeforeClaims.NotBefore = base.Add(10 * time.Minute).Unix()
notBeforeClaims.IssuedAt = base.Unix()
notBeforeClaims.ExpiresAt = base.Add(2 * time.Hour).Unix()
notBeforeClaims.ID = "forced-jti-1"
notBeforeToken, err := encodeSignedJWT(notBeforeClaims, svc.signingKey)
if err != nil {
t.Fatalf("failed to create token with future nbf: %v", err)
}
svc.now = func() time.Time { return base }
_, err = svc.VerifyToken(notBeforeToken)
if !errors.Is(err, ErrNotYetValidJWTToken) {
t.Fatalf("expected ErrNotYetValidJWTToken, got: %v", err)
}
futureIatClaims := claims
futureIatClaims.IssuedAt = base.Add(20 * time.Minute).Unix()
futureIatClaims.NotBefore = base.Unix()
futureIatClaims.ExpiresAt = base.Add(3 * time.Hour).Unix()
futureIatClaims.ID = "forced-jti-2"
futureIatToken, err := encodeSignedJWT(futureIatClaims, svc.signingKey)
if err != nil {
t.Fatalf("failed to create token with future iat: %v", err)
}
_, err = svc.VerifyToken(futureIatToken)
if err == nil {
t.Fatal("expected future iat validation to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims for future iat, got: %v", err)
}
}
func TestVerifyTokenRejectsMissingJTI(t *testing.T) {
base := time.Unix(1_700_000_000, 0).UTC()
svc := mustJWTService(t)
svc.now = func() time.Time { return base }
token, claims, err := svc.IssueToken("alice", nil, nil)
if err != nil {
t.Fatalf("IssueToken returned error: %v", err)
}
if token == "" {
t.Fatal("expected non-empty token")
}
claims.ID = ""
customToken, err := encodeSignedJWT(claims, svc.signingKey)
if err != nil {
t.Fatalf("failed to create token without jti: %v", err)
}
_, err = svc.VerifyToken(customToken)
if err == nil {
t.Fatal("expected missing jti token to fail")
}
if !errors.Is(err, ErrInvalidJWTClaims) {
t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err)
}
if !strings.Contains(strings.ToLower(err.Error()), "jti") {
t.Fatalf("expected jti validation error, got: %v", err)
}
}
func mustJWTService(t *testing.T) *JWTService {
t.Helper()
return mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")))
}
func mustJWTServiceWithKey(t *testing.T, keyBase64 string) *JWTService {
t.Helper()
svc, err := NewJWTService(JWTConfig{
SigningKeyBase64: keyBase64,
Issuer: "vctp",
Audience: "vctp-api",
TokenLifespan: 2 * time.Hour,
ClockSkew: time.Minute,
})
if err != nil {
t.Fatalf("failed to create jwt service: %v", err)
}
return svc
}
+354
View File
@@ -0,0 +1,354 @@
package auth
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
"os"
"sort"
"strings"
"time"
"github.com/go-ldap/ldap/v3"
)
var (
ErrInvalidLDAPConfig = errors.New("invalid ldap config")
ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials")
ErrLDAPOperationFailed = errors.New("ldap operation failed")
)
type LDAPConfig struct {
BindAddress string
BaseDN string
TrustCertFile string
DisableValidation bool
Insecure bool
DialTimeout time.Duration
}
type LDAPIdentity struct {
Username string
UserDN string
Groups []string
}
type LDAPAuthenticator struct {
bindAddress string
baseDN string
trustCertFile string
disableValidation bool
insecure bool
dialTimeout time.Duration
}
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
bindAddress := strings.TrimSpace(cfg.BindAddress)
baseDN := strings.TrimSpace(cfg.BaseDN)
trustCertFile := strings.TrimSpace(cfg.TrustCertFile)
if bindAddress == "" {
return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig)
}
if baseDN == "" {
return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig)
}
if _, err := url.ParseRequestURI(bindAddress); err != nil {
return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err)
}
dialTimeout := cfg.DialTimeout
if dialTimeout <= 0 {
dialTimeout = 10 * time.Second
}
return &LDAPAuthenticator{
bindAddress: bindAddress,
baseDN: baseDN,
trustCertFile: trustCertFile,
disableValidation: cfg.DisableValidation,
insecure: cfg.Insecure,
dialTimeout: dialTimeout,
}, nil
}
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
username = strings.TrimSpace(username)
if username == "" || password == "" {
return LDAPIdentity{}, ErrLDAPInvalidCredentials
}
if err := ctxErr(ctx); err != nil {
return LDAPIdentity{}, err
}
conn, err := a.connect()
if err != nil {
return LDAPIdentity{}, err
}
defer conn.Close()
if err := conn.Bind(username, password); err != nil {
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
return LDAPIdentity{}, ErrLDAPInvalidCredentials
}
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err)
}
if err := ctxErr(ctx); err != nil {
return LDAPIdentity{}, err
}
identity := LDAPIdentity{
Username: username,
UserDN: username,
}
entry, err := a.lookupUserEntry(conn, username)
if err != nil {
return LDAPIdentity{}, err
}
if entry != nil {
if strings.TrimSpace(entry.DN) != "" {
identity.UserDN = entry.DN
}
if v := firstNonEmpty(
entry.GetAttributeValue("uid"),
entry.GetAttributeValue("sAMAccountName"),
entry.GetAttributeValue("userPrincipalName"),
entry.GetAttributeValue("cn"),
); v != "" {
identity.Username = v
}
}
groupSet := make(map[string]struct{})
if entry != nil {
for _, groupDN := range entry.GetAttributeValues("memberOf") {
groupDN = strings.TrimSpace(groupDN)
if groupDN == "" {
continue
}
groupSet[groupDN] = struct{}{}
}
}
groupEntries, err := conn.Search(ldap.NewSearchRequest(
a.baseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
0,
0,
false,
fmt.Sprintf("(|(member=%s)(uniqueMember=%s)(memberUid=%s))",
ldap.EscapeFilter(identity.UserDN),
ldap.EscapeFilter(identity.UserDN),
ldap.EscapeFilter(username),
),
[]string{"dn"},
nil,
))
if err == nil {
for _, e := range groupEntries.Entries {
if dn := strings.TrimSpace(e.DN); dn != "" {
groupSet[dn] = struct{}{}
}
}
}
identity.Groups = mapKeysSorted(groupSet)
return identity, nil
}
func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string {
if len(groupDNs) == 0 || len(groupRoleMappings) == 0 {
return nil
}
normalizedMappings := make(map[string]string, len(groupRoleMappings))
for groupDN, role := range groupRoleMappings {
groupDN = normalizeDN(groupDN)
role = strings.ToLower(strings.TrimSpace(role))
if groupDN == "" || role == "" {
continue
}
normalizedMappings[groupDN] = role
}
roleSet := make(map[string]struct{})
for _, groupDN := range groupDNs {
if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok {
roleSet[role] = struct{}{}
}
}
return mapKeysSorted(roleSet)
}
func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool {
requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs)
if len(requiredGroupDNs) == 0 {
return true
}
if len(groupDNs) == 0 {
return false
}
required := make(map[string]struct{}, len(requiredGroupDNs))
for _, groupDN := range requiredGroupDNs {
required[normalizeDN(groupDN)] = struct{}{}
}
for _, groupDN := range groupDNs {
if _, ok := required[normalizeDN(groupDN)]; ok {
return true
}
}
return false
}
func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
tlsConfig, err := a.buildTLSConfig()
if err != nil {
return nil, err
}
parsedURL, err := url.Parse(a.bindAddress)
if err != nil {
return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err)
}
options := []ldap.DialOpt{
ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}),
ldap.DialWithTLSConfig(tlsConfig),
}
conn, err := ldap.DialURL(a.bindAddress, options...)
if err != nil {
return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err)
}
conn.SetTimeout(a.dialTimeout)
// For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure.
if parsedURL.Scheme == "ldap" && !a.insecure {
if err := conn.StartTLS(tlsConfig); err != nil {
conn.Close()
return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err)
}
}
return conn, nil
}
func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags
}
if a.trustCertFile == "" {
return tlsConfig, nil
}
caPEM, err := os.ReadFile(a.trustCertFile)
if err != nil {
return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err)
}
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(caPEM) {
return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig)
}
tlsConfig.RootCAs = roots
return tlsConfig, nil
}
func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string) (*ldap.Entry, error) {
if looksLikeDN(username) {
searchRes, err := conn.Search(ldap.NewSearchRequest(
username,
ldap.ScopeBaseObject,
ldap.NeverDerefAliases,
1,
0,
false,
"(objectClass=*)",
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
nil,
))
if err != nil {
return nil, fmt.Errorf("%w: unable to load user entry: %v", ErrLDAPOperationFailed, err)
}
if len(searchRes.Entries) == 0 {
return nil, nil
}
return searchRes.Entries[0], nil
}
searchRes, err := conn.Search(ldap.NewSearchRequest(
a.baseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
2,
0,
false,
fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))",
ldap.EscapeFilter(username),
ldap.EscapeFilter(username),
ldap.EscapeFilter(username),
ldap.EscapeFilter(username),
),
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
nil,
))
if err != nil {
return nil, fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err)
}
if len(searchRes.Entries) == 0 {
return nil, nil
}
return searchRes.Entries[0], nil
}
func normalizeDN(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func mapKeysSorted[K ~string, V any](m map[K]V) []K {
if len(m) == 0 {
return nil
}
out := make([]K, 0, len(m))
for key := range m {
out = append(out, key)
}
sort.Slice(out, func(i, j int) bool {
return out[i] < out[j]
})
return out
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
value = strings.TrimSpace(value)
if value != "" {
return value
}
}
return ""
}
func looksLikeDN(value string) bool {
value = strings.TrimSpace(value)
return strings.Contains(value, "=") && strings.Contains(value, ",")
}
func ctxErr(ctx context.Context) error {
if ctx == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
+39
View File
@@ -0,0 +1,39 @@
package auth
import "testing"
func TestResolveRoles(t *testing.T) {
roles := ResolveRoles(
[]string{
"cn=vctp-admins,ou=groups,dc=example,dc=com",
" CN=VCTP-VIEWERS,OU=GROUPS,DC=EXAMPLE,DC=COM ",
},
map[string]string{
"cn=vctp-admins,ou=groups,dc=example,dc=com": "admin",
"cn=vctp-viewers,ou=groups,dc=example,dc=com": "viewer",
},
)
if len(roles) != 2 {
t.Fatalf("expected 2 roles, got %d (%#v)", len(roles), roles)
}
if roles[0] != "admin" || roles[1] != "viewer" {
t.Fatalf("unexpected resolved roles: %#v", roles)
}
}
func TestHasAnyGroup(t *testing.T) {
groups := []string{
"cn=vctp-admins,ou=groups,dc=example,dc=com",
}
if !HasAnyGroup(groups, []string{" cn=vctp-admins,ou=groups,dc=example,dc=com "}) {
t.Fatal("expected group intersection to match")
}
if HasAnyGroup(groups, []string{"cn=vctp-operators,ou=groups,dc=example,dc=com"}) {
t.Fatal("expected no intersection")
}
if !HasAnyGroup(groups, nil) {
t.Fatal("expected empty required groups to allow")
}
}