package auth import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "net" "net/url" "os" "sort" "strings" "time" "github.com/go-ldap/ldap/v3" ) var ( ErrInvalidLDAPConfig = errors.New("invalid ldap config") ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials") ErrLDAPOperationFailed = errors.New("ldap operation failed") ) type LDAPConfig struct { BindAddress string BaseDN string TrustCertFile string DisableValidation bool Insecure bool DialTimeout time.Duration } type LDAPIdentity struct { Username string UserDN string Groups []string // Diagnostics contains non-sensitive LDAP processing notes useful for debugging auth decisions. Diagnostics []string } type LDAPAuthenticator struct { bindAddress string baseDN string trustCertFile string disableValidation bool insecure bool dialTimeout time.Duration } func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) { bindAddress := strings.TrimSpace(cfg.BindAddress) baseDN := strings.TrimSpace(cfg.BaseDN) trustCertFile := strings.TrimSpace(cfg.TrustCertFile) if bindAddress == "" { return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig) } if baseDN == "" { return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig) } if _, err := url.ParseRequestURI(bindAddress); err != nil { return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err) } dialTimeout := cfg.DialTimeout if dialTimeout <= 0 { dialTimeout = 10 * time.Second } return &LDAPAuthenticator{ bindAddress: bindAddress, baseDN: baseDN, trustCertFile: trustCertFile, disableValidation: cfg.DisableValidation, insecure: cfg.Insecure, dialTimeout: dialTimeout, }, nil } func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) { username = strings.TrimSpace(username) if username == "" || password == "" { return LDAPIdentity{}, ErrLDAPInvalidCredentials } if err := ctxErr(ctx); err != nil { return LDAPIdentity{}, err } conn, err := a.connect() if err != nil { return LDAPIdentity{}, err } defer conn.Close() if err := conn.Bind(username, password); err != nil { if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { return LDAPIdentity{}, fmt.Errorf("%w: ldap bind rejected credentials", ErrLDAPInvalidCredentials) } return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err) } if err := ctxErr(ctx); err != nil { return LDAPIdentity{}, err } identity := LDAPIdentity{ Username: username, UserDN: username, } if whoami, err := conn.WhoAmI(nil); err != nil { identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("whoami_failed:%v", err)) } else if boundDN := parseWhoAmIDN(whoami.AuthzID); boundDN != "" { identity.UserDN = boundDN identity.Diagnostics = append(identity.Diagnostics, "whoami_dn_resolved") } else if strings.TrimSpace(whoami.AuthzID) == "" { identity.Diagnostics = append(identity.Diagnostics, "whoami_dn_empty") } else { identity.Diagnostics = append(identity.Diagnostics, "whoami_non_dn_authzid") } entry, lookupStrategy, err := a.lookupUserEntry(conn, username, identity.UserDN) if err != nil { return LDAPIdentity{}, err } if entry != nil { if lookupStrategy == "" { lookupStrategy = "unknown" } identity.Diagnostics = append(identity.Diagnostics, "user_entry_found:"+lookupStrategy) if strings.TrimSpace(entry.DN) != "" { identity.UserDN = entry.DN } if v := firstNonEmpty( entry.GetAttributeValue("uid"), entry.GetAttributeValue("sAMAccountName"), entry.GetAttributeValue("userPrincipalName"), entry.GetAttributeValue("cn"), ); v != "" { identity.Username = v } } else { identity.Diagnostics = append(identity.Diagnostics, "user_entry_not_found") } groupSet := make(map[string]struct{}) if entry != nil { for _, groupDN := range entry.GetAttributeValues("memberOf") { groupDN = strings.TrimSpace(groupDN) if groupDN == "" { continue } groupSet[groupDN] = struct{}{} } } groupFilter := buildGroupMembershipFilter(identity.UserDN, principalCandidates(username)) groupEntries, err := conn.Search(ldap.NewSearchRequest( a.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, groupFilter, []string{"dn"}, nil, )) if err == nil { for _, e := range groupEntries.Entries { if dn := strings.TrimSpace(e.DN); dn != "" { groupSet[dn] = struct{}{} } } if len(groupEntries.Entries) == 0 { identity.Diagnostics = append(identity.Diagnostics, "group_search_returned_no_entries") } } else { identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("group_search_failed:%v", err)) } identity.Groups = mapKeysSorted(groupSet) identity.Diagnostics = compactTrimmedStrings(identity.Diagnostics) return identity, nil } func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string { if len(groupDNs) == 0 || len(groupRoleMappings) == 0 { return nil } normalizedMappings := make(map[string]string, len(groupRoleMappings)) for groupDN, role := range groupRoleMappings { groupDN = normalizeDN(groupDN) role = strings.ToLower(strings.TrimSpace(role)) if groupDN == "" || role == "" { continue } normalizedMappings[groupDN] = role } roleSet := make(map[string]struct{}) for _, groupDN := range groupDNs { if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok { roleSet[role] = struct{}{} } } return mapKeysSorted(roleSet) } func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool { requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs) if len(requiredGroupDNs) == 0 { return true } if len(groupDNs) == 0 { return false } required := make(map[string]struct{}, len(requiredGroupDNs)) for _, groupDN := range requiredGroupDNs { required[normalizeDN(groupDN)] = struct{}{} } for _, groupDN := range groupDNs { if _, ok := required[normalizeDN(groupDN)]; ok { return true } } return false } func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) { tlsConfig, err := a.buildTLSConfig() if err != nil { return nil, err } parsedURL, err := url.Parse(a.bindAddress) if err != nil { return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err) } options := []ldap.DialOpt{ ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}), ldap.DialWithTLSConfig(tlsConfig), } conn, err := ldap.DialURL(a.bindAddress, options...) if err != nil { return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err) } conn.SetTimeout(a.dialTimeout) // For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure. if parsedURL.Scheme == "ldap" && !a.insecure { if err := conn.StartTLS(tlsConfig); err != nil { conn.Close() return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err) } } return conn, nil } func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags } if a.trustCertFile == "" { return tlsConfig, nil } caPEM, err := os.ReadFile(a.trustCertFile) if err != nil { return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err) } roots := x509.NewCertPool() if !roots.AppendCertsFromPEM(caPEM) { return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig) } tlsConfig.RootCAs = roots return tlsConfig, nil } func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string, userDNHint string) (*ldap.Entry, string, error) { dnCandidates := make([]string, 0, 2) if looksLikeDN(userDNHint) { dnCandidates = append(dnCandidates, strings.TrimSpace(userDNHint)) } if looksLikeDN(username) { dnCandidates = append(dnCandidates, strings.TrimSpace(username)) } seenDN := make(map[string]struct{}, len(dnCandidates)) for _, dn := range dnCandidates { key := normalizeDN(dn) if key == "" { continue } if _, ok := seenDN[key]; ok { continue } seenDN[key] = struct{}{} searchRes, err := conn.Search(ldap.NewSearchRequest( dn, ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false, "(objectClass=*)", []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, nil, )) if err != nil { if ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) { continue } return nil, "", fmt.Errorf("%w: unable to load user entry by dn: %v", ErrLDAPOperationFailed, err) } if len(searchRes.Entries) > 0 { return searchRes.Entries[0], "dn", nil } } for _, principal := range principalCandidates(username) { searchRes, err := conn.Search(ldap.NewSearchRequest( a.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 2, 0, false, fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))", ldap.EscapeFilter(principal), ldap.EscapeFilter(principal), ldap.EscapeFilter(principal), ldap.EscapeFilter(principal), ), []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, nil, )) if err != nil { return nil, "", fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err) } if len(searchRes.Entries) > 0 { return searchRes.Entries[0], "principal", nil } } return nil, "", nil } func normalizeDN(value string) string { return strings.ToLower(strings.TrimSpace(value)) } func mapKeysSorted[K ~string, V any](m map[K]V) []K { if len(m) == 0 { return nil } out := make([]K, 0, len(m)) for key := range m { out = append(out, key) } sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) return out } func firstNonEmpty(values ...string) string { for _, value := range values { value = strings.TrimSpace(value) if value != "" { return value } } return "" } func looksLikeDN(value string) bool { value = strings.TrimSpace(value) return strings.Contains(value, "=") && strings.Contains(value, ",") } func parseWhoAmIDN(authzID string) string { authzID = strings.TrimSpace(authzID) if authzID == "" { return "" } lower := strings.ToLower(authzID) if strings.HasPrefix(lower, "dn:") { authzID = strings.TrimSpace(authzID[3:]) } if !looksLikeDN(authzID) { return "" } return authzID } func principalCandidates(username string) []string { username = strings.TrimSpace(username) if username == "" { return nil } seen := make(map[string]struct{}, 4) candidates := make([]string, 0, 4) add := func(value string) { value = strings.TrimSpace(value) if value == "" { return } key := strings.ToLower(value) if _, ok := seen[key]; ok { return } seen[key] = struct{}{} candidates = append(candidates, value) } add(username) if idx := strings.LastIndex(username, `\`); idx >= 0 && idx < len(username)-1 { add(username[idx+1:]) } if idx := strings.Index(username, "@"); idx > 0 { add(username[:idx]) } return candidates } func buildGroupMembershipFilter(userDN string, principals []string) string { clauses := make([]string, 0, 2+len(principals)) userDN = strings.TrimSpace(userDN) if userDN != "" { escapedDN := ldap.EscapeFilter(userDN) clauses = append(clauses, "(member="+escapedDN+")", "(uniqueMember="+escapedDN+")") } for _, principal := range principals { principal = strings.TrimSpace(principal) if principal == "" { continue } clauses = append(clauses, "(memberUid="+ldap.EscapeFilter(principal)+")") } if len(clauses) == 0 { return "(objectClass=group)" } return "(|" + strings.Join(clauses, "") + ")" } func ctxErr(ctx context.Context) error { if ctx == nil { return nil } select { case <-ctx.Done(): return ctx.Err() default: return nil } }