package auth import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "net" "net/url" "os" "sort" "strings" "time" "github.com/go-ldap/ldap/v3" ) var ( ErrInvalidLDAPConfig = errors.New("invalid ldap config") ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials") ErrLDAPOperationFailed = errors.New("ldap operation failed") ) type LDAPConfig struct { BindAddress string BaseDN string UserBaseDN string GroupBaseDN string TrustCertFile string DisableValidation bool Insecure bool DialTimeout time.Duration } type LDAPIdentity struct { Username string UserDN string Groups []string BindDuration time.Duration UserLookupDuration time.Duration GroupMembershipLookupDuration time.Duration // Diagnostics contains non-sensitive LDAP processing notes useful for debugging auth decisions. Diagnostics []string } type LDAPAuthenticator struct { bindAddress string baseDN string userBaseDN string groupBaseDN string trustCertFile string disableValidation bool insecure bool dialTimeout time.Duration } func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) { bindAddress := strings.TrimSpace(cfg.BindAddress) baseDN := strings.TrimSpace(cfg.BaseDN) userBaseDN := strings.TrimSpace(cfg.UserBaseDN) groupBaseDN := strings.TrimSpace(cfg.GroupBaseDN) trustCertFile := strings.TrimSpace(cfg.TrustCertFile) if bindAddress == "" { return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig) } if baseDN == "" { return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig) } if userBaseDN == "" { userBaseDN = baseDN } if groupBaseDN == "" { groupBaseDN = baseDN } if _, err := url.ParseRequestURI(bindAddress); err != nil { return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err) } dialTimeout := cfg.DialTimeout if dialTimeout <= 0 { dialTimeout = 10 * time.Second } return &LDAPAuthenticator{ bindAddress: bindAddress, baseDN: baseDN, userBaseDN: userBaseDN, groupBaseDN: groupBaseDN, trustCertFile: trustCertFile, disableValidation: cfg.DisableValidation, insecure: cfg.Insecure, dialTimeout: dialTimeout, }, nil } func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) { inputUsername := strings.TrimSpace(username) if inputUsername == "" || password == "" { return LDAPIdentity{}, ErrLDAPInvalidCredentials } if err := ctxErr(ctx); err != nil { return LDAPIdentity{}, err } bindUsername, rewrittenToUPN := normalizeBindUsername(inputUsername, a.baseDN) conn, err := a.connect() if err != nil { return LDAPIdentity{}, err } defer conn.Close() bindStartedAt := time.Now() err = conn.Bind(bindUsername, password) bindDuration := time.Since(bindStartedAt) if err != nil { if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { return LDAPIdentity{}, fmt.Errorf("%w: ldap bind rejected credentials (bind_duration=%s)", ErrLDAPInvalidCredentials, bindDuration) } return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v (bind_duration=%s)", ErrLDAPOperationFailed, err, bindDuration) } if err := ctxErr(ctx); err != nil { return LDAPIdentity{}, err } identity := LDAPIdentity{ Username: inputUsername, UserDN: bindUsername, BindDuration: bindDuration, } identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("bind_duration_ms=%d", bindDuration.Milliseconds())) if rewrittenToUPN { identity.Diagnostics = append(identity.Diagnostics, "bind_username_rewritten_to_upn") } identity.Diagnostics = append(identity.Diagnostics, "user_lookup_base_dn="+a.userBaseDN, "group_lookup_base_dn="+a.groupBaseDN, ) if whoami, err := conn.WhoAmI(nil); err != nil { identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("whoami_failed:%v", err)) } else if boundDN := parseWhoAmIDN(whoami.AuthzID); boundDN != "" { identity.UserDN = boundDN identity.Diagnostics = append(identity.Diagnostics, "whoami_dn_resolved") } else if strings.TrimSpace(whoami.AuthzID) == "" { identity.Diagnostics = append(identity.Diagnostics, "whoami_dn_empty") } else { identity.Diagnostics = append(identity.Diagnostics, "whoami_non_dn_authzid") } userLookupStartedAt := time.Now() entry, lookupStrategy, err := a.lookupUserEntry(conn, inputUsername, identity.UserDN) identity.UserLookupDuration = time.Since(userLookupStartedAt) identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("user_lookup_duration_ms=%d", identity.UserLookupDuration.Milliseconds())) if err != nil { return LDAPIdentity{}, fmt.Errorf("%w: %v (bind_duration=%s user_lookup_duration=%s)", ErrLDAPOperationFailed, err, identity.BindDuration, identity.UserLookupDuration) } if entry != nil { if lookupStrategy == "" { lookupStrategy = "unknown" } identity.Diagnostics = append(identity.Diagnostics, "user_entry_found:"+lookupStrategy) if strings.TrimSpace(entry.DN) != "" { identity.UserDN = entry.DN } if v := firstNonEmpty( entry.GetAttributeValue("uid"), entry.GetAttributeValue("sAMAccountName"), entry.GetAttributeValue("userPrincipalName"), entry.GetAttributeValue("cn"), ); v != "" { identity.Username = v } } else { identity.Diagnostics = append(identity.Diagnostics, "user_entry_not_found") } groupSet := make(map[string]struct{}) groupLookupStartedAt := time.Now() if entry != nil { for _, groupDN := range entry.GetAttributeValues("memberOf") { groupDN = strings.TrimSpace(groupDN) if groupDN == "" { continue } groupSet[groupDN] = struct{}{} } } // Intentionally skip subtree group membership search for now. // Authorization is based only on direct group membership values present in the user entry (memberOf). identity.GroupMembershipLookupDuration = time.Since(groupLookupStartedAt) identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("group_lookup_duration_ms=%d", identity.GroupMembershipLookupDuration.Milliseconds())) identity.Diagnostics = append(identity.Diagnostics, "group_search_skipped_direct_memberof_only") identity.Groups = mapKeysSorted(groupSet) identity.Diagnostics = compactTrimmedStrings(identity.Diagnostics) return identity, nil } func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string { if len(groupDNs) == 0 || len(groupRoleMappings) == 0 { return nil } normalizedMappings := make(map[string]string, len(groupRoleMappings)) for groupDN, role := range groupRoleMappings { groupDN = normalizeDN(groupDN) role = strings.ToLower(strings.TrimSpace(role)) if groupDN == "" || role == "" { continue } normalizedMappings[groupDN] = role } roleSet := make(map[string]struct{}) for _, groupDN := range groupDNs { if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok { roleSet[role] = struct{}{} } } return mapKeysSorted(roleSet) } func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool { requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs) if len(requiredGroupDNs) == 0 { return true } if len(groupDNs) == 0 { return false } required := make(map[string]struct{}, len(requiredGroupDNs)) for _, groupDN := range requiredGroupDNs { required[normalizeDN(groupDN)] = struct{}{} } for _, groupDN := range groupDNs { if _, ok := required[normalizeDN(groupDN)]; ok { return true } } return false } func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) { tlsConfig, err := a.buildTLSConfig() if err != nil { return nil, err } parsedURL, err := url.Parse(a.bindAddress) if err != nil { return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err) } options := []ldap.DialOpt{ ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}), ldap.DialWithTLSConfig(tlsConfig), } conn, err := ldap.DialURL(a.bindAddress, options...) if err != nil { return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err) } conn.SetTimeout(a.dialTimeout) // For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure. if parsedURL.Scheme == "ldap" && !a.insecure { if err := conn.StartTLS(tlsConfig); err != nil { conn.Close() return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err) } } return conn, nil } func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags } if a.trustCertFile == "" { return tlsConfig, nil } caPEM, err := os.ReadFile(a.trustCertFile) if err != nil { return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err) } roots := x509.NewCertPool() if !roots.AppendCertsFromPEM(caPEM) { return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig) } tlsConfig.RootCAs = roots return tlsConfig, nil } func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string, userDNHint string) (*ldap.Entry, string, error) { dnCandidates := make([]string, 0, 2) if looksLikeDN(userDNHint) { dnCandidates = append(dnCandidates, strings.TrimSpace(userDNHint)) } if looksLikeDN(username) { dnCandidates = append(dnCandidates, strings.TrimSpace(username)) } seenDN := make(map[string]struct{}, len(dnCandidates)) for _, dn := range dnCandidates { key := normalizeDN(dn) if key == "" { continue } if _, ok := seenDN[key]; ok { continue } seenDN[key] = struct{}{} searchRes, err := conn.Search(ldap.NewSearchRequest( dn, ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false, "(objectClass=*)", []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, nil, )) if err != nil { if ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) { continue } return nil, "", fmt.Errorf("%w: unable to load user entry by dn: %v", ErrLDAPOperationFailed, err) } if len(searchRes.Entries) > 0 { return searchRes.Entries[0], "dn", nil } } for _, principal := range principalCandidates(username) { if strings.Contains(principal, "@") { entry, err := a.searchUserByAttribute(conn, "userPrincipalName", principal) if err != nil { return nil, "", err } if entry != nil { return entry, "principal_upn", nil } // For UPN principals, avoid fallback attribute probes that are unlikely to match // and can be expensive on large directory trees. continue } entry, err := a.searchUserByAttribute(conn, "sAMAccountName", principal) if err != nil { return nil, "", err } if entry != nil { return entry, "principal_samaccountname", nil } // Keep uid lookup as a fallback for non-AD LDAP directories. entry, err = a.searchUserByAttribute(conn, "uid", principal) if err != nil { return nil, "", err } if entry != nil { return entry, "principal_uid", nil } } return nil, "", nil } func (a *LDAPAuthenticator) searchUserByAttribute(conn *ldap.Conn, attribute string, value string) (*ldap.Entry, error) { attribute = strings.TrimSpace(attribute) value = strings.TrimSpace(value) if attribute == "" || value == "" { return nil, nil } searchRes, err := conn.Search(ldap.NewSearchRequest( a.userBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 2, 0, false, fmt.Sprintf("(%s=%s)", attribute, ldap.EscapeFilter(value)), []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, nil, )) if err != nil { return nil, fmt.Errorf("%w: user lookup failed (%s): %v", ErrLDAPOperationFailed, attribute, err) } if len(searchRes.Entries) == 0 { return nil, nil } return searchRes.Entries[0], nil } func normalizeDN(value string) string { return strings.ToLower(strings.TrimSpace(value)) } func mapKeysSorted[K ~string, V any](m map[K]V) []K { if len(m) == 0 { return nil } out := make([]K, 0, len(m)) for key := range m { out = append(out, key) } sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) return out } func firstNonEmpty(values ...string) string { for _, value := range values { value = strings.TrimSpace(value) if value != "" { return value } } return "" } func looksLikeDN(value string) bool { value = strings.TrimSpace(value) return strings.Contains(value, "=") && strings.Contains(value, ",") } func parseWhoAmIDN(authzID string) string { authzID = strings.TrimSpace(authzID) if authzID == "" { return "" } lower := strings.ToLower(authzID) if strings.HasPrefix(lower, "dn:") { authzID = strings.TrimSpace(authzID[3:]) } if !looksLikeDN(authzID) { return "" } return authzID } func normalizeBindUsername(username string, baseDN string) (string, bool) { username = strings.TrimSpace(username) if username == "" { return "", false } if looksLikeDN(username) || strings.Contains(username, "@") { return username, false } // Convert DOMAIN\user to user before UPN rewrite. if idx := strings.LastIndex(username, `\`); idx >= 0 && idx < len(username)-1 { username = strings.TrimSpace(username[idx+1:]) } domain := upnDomainFromBaseDN(baseDN) if domain == "" { return username, false } if strings.Contains(username, "@") { return username, false } return username + "@" + domain, true } func upnDomainFromBaseDN(baseDN string) string { baseDN = strings.TrimSpace(baseDN) if baseDN == "" { return "" } parts := strings.Split(baseDN, ",") labels := make([]string, 0, len(parts)) for _, part := range parts { part = strings.TrimSpace(part) if len(part) < 3 || !strings.EqualFold(part[:3], "dc=") { continue } label := strings.TrimSpace(part[3:]) if label == "" { continue } labels = append(labels, label) } if len(labels) == 0 { return "" } return strings.Join(labels, ".") } func principalCandidates(username string) []string { username = strings.TrimSpace(username) if username == "" { return nil } seen := make(map[string]struct{}, 4) candidates := make([]string, 0, 4) add := func(value string) { value = strings.TrimSpace(value) if value == "" { return } key := strings.ToLower(value) if _, ok := seen[key]; ok { return } seen[key] = struct{}{} candidates = append(candidates, value) } add(username) if idx := strings.LastIndex(username, `\`); idx >= 0 && idx < len(username)-1 { add(username[idx+1:]) } if idx := strings.Index(username, "@"); idx > 0 { add(username[:idx]) } return candidates } func buildGroupMembershipFilter(userDN string, principals []string) string { clauses := make([]string, 0, 2+len(principals)) userDN = strings.TrimSpace(userDN) if userDN != "" { escapedDN := ldap.EscapeFilter(userDN) clauses = append(clauses, "(member="+escapedDN+")", "(uniqueMember="+escapedDN+")") } for _, principal := range principals { principal = strings.TrimSpace(principal) if principal == "" { continue } clauses = append(clauses, "(memberUid="+ldap.EscapeFilter(principal)+")") } if len(clauses) == 0 { return "(objectClass=group)" } return "(|" + strings.Join(clauses, "") + ")" } func ctxErr(ctx context.Context) error { if ctx == nil { return nil } select { case <-ctx.Done(): return ctx.Err() default: return nil } }