package auth import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "net" "net/url" "os" "sort" "strings" "time" "github.com/go-ldap/ldap/v3" ) var ( ErrInvalidLDAPConfig = errors.New("invalid ldap config") ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials") ErrLDAPOperationFailed = errors.New("ldap operation failed") ) type LDAPConfig struct { BindAddress string BaseDN string TrustCertFile string DisableValidation bool Insecure bool DialTimeout time.Duration } type LDAPIdentity struct { Username string UserDN string Groups []string } type LDAPAuthenticator struct { bindAddress string baseDN string trustCertFile string disableValidation bool insecure bool dialTimeout time.Duration } func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) { bindAddress := strings.TrimSpace(cfg.BindAddress) baseDN := strings.TrimSpace(cfg.BaseDN) trustCertFile := strings.TrimSpace(cfg.TrustCertFile) if bindAddress == "" { return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig) } if baseDN == "" { return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig) } if _, err := url.ParseRequestURI(bindAddress); err != nil { return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err) } dialTimeout := cfg.DialTimeout if dialTimeout <= 0 { dialTimeout = 10 * time.Second } return &LDAPAuthenticator{ bindAddress: bindAddress, baseDN: baseDN, trustCertFile: trustCertFile, disableValidation: cfg.DisableValidation, insecure: cfg.Insecure, dialTimeout: dialTimeout, }, nil } func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) { username = strings.TrimSpace(username) if username == "" || password == "" { return LDAPIdentity{}, ErrLDAPInvalidCredentials } if err := ctxErr(ctx); err != nil { return LDAPIdentity{}, err } conn, err := a.connect() if err != nil { return LDAPIdentity{}, err } defer conn.Close() if err := conn.Bind(username, password); err != nil { if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { return LDAPIdentity{}, ErrLDAPInvalidCredentials } return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err) } if err := ctxErr(ctx); err != nil { return LDAPIdentity{}, err } identity := LDAPIdentity{ Username: username, UserDN: username, } entry, err := a.lookupUserEntry(conn, username) if err != nil { return LDAPIdentity{}, err } if entry != nil { if strings.TrimSpace(entry.DN) != "" { identity.UserDN = entry.DN } if v := firstNonEmpty( entry.GetAttributeValue("uid"), entry.GetAttributeValue("sAMAccountName"), entry.GetAttributeValue("userPrincipalName"), entry.GetAttributeValue("cn"), ); v != "" { identity.Username = v } } groupSet := make(map[string]struct{}) if entry != nil { for _, groupDN := range entry.GetAttributeValues("memberOf") { groupDN = strings.TrimSpace(groupDN) if groupDN == "" { continue } groupSet[groupDN] = struct{}{} } } groupEntries, err := conn.Search(ldap.NewSearchRequest( a.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(|(member=%s)(uniqueMember=%s)(memberUid=%s))", ldap.EscapeFilter(identity.UserDN), ldap.EscapeFilter(identity.UserDN), ldap.EscapeFilter(username), ), []string{"dn"}, nil, )) if err == nil { for _, e := range groupEntries.Entries { if dn := strings.TrimSpace(e.DN); dn != "" { groupSet[dn] = struct{}{} } } } identity.Groups = mapKeysSorted(groupSet) return identity, nil } func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string { if len(groupDNs) == 0 || len(groupRoleMappings) == 0 { return nil } normalizedMappings := make(map[string]string, len(groupRoleMappings)) for groupDN, role := range groupRoleMappings { groupDN = normalizeDN(groupDN) role = strings.ToLower(strings.TrimSpace(role)) if groupDN == "" || role == "" { continue } normalizedMappings[groupDN] = role } roleSet := make(map[string]struct{}) for _, groupDN := range groupDNs { if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok { roleSet[role] = struct{}{} } } return mapKeysSorted(roleSet) } func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool { requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs) if len(requiredGroupDNs) == 0 { return true } if len(groupDNs) == 0 { return false } required := make(map[string]struct{}, len(requiredGroupDNs)) for _, groupDN := range requiredGroupDNs { required[normalizeDN(groupDN)] = struct{}{} } for _, groupDN := range groupDNs { if _, ok := required[normalizeDN(groupDN)]; ok { return true } } return false } func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) { tlsConfig, err := a.buildTLSConfig() if err != nil { return nil, err } parsedURL, err := url.Parse(a.bindAddress) if err != nil { return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err) } options := []ldap.DialOpt{ ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}), ldap.DialWithTLSConfig(tlsConfig), } conn, err := ldap.DialURL(a.bindAddress, options...) if err != nil { return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err) } conn.SetTimeout(a.dialTimeout) // For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure. if parsedURL.Scheme == "ldap" && !a.insecure { if err := conn.StartTLS(tlsConfig); err != nil { conn.Close() return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err) } } return conn, nil } func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags } if a.trustCertFile == "" { return tlsConfig, nil } caPEM, err := os.ReadFile(a.trustCertFile) if err != nil { return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err) } roots := x509.NewCertPool() if !roots.AppendCertsFromPEM(caPEM) { return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig) } tlsConfig.RootCAs = roots return tlsConfig, nil } func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string) (*ldap.Entry, error) { if looksLikeDN(username) { searchRes, err := conn.Search(ldap.NewSearchRequest( username, ldap.ScopeBaseObject, ldap.NeverDerefAliases, 1, 0, false, "(objectClass=*)", []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, nil, )) if err != nil { return nil, fmt.Errorf("%w: unable to load user entry: %v", ErrLDAPOperationFailed, err) } if len(searchRes.Entries) == 0 { return nil, nil } return searchRes.Entries[0], nil } searchRes, err := conn.Search(ldap.NewSearchRequest( a.baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 2, 0, false, fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))", ldap.EscapeFilter(username), ldap.EscapeFilter(username), ldap.EscapeFilter(username), ldap.EscapeFilter(username), ), []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, nil, )) if err != nil { return nil, fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err) } if len(searchRes.Entries) == 0 { return nil, nil } return searchRes.Entries[0], nil } func normalizeDN(value string) string { return strings.ToLower(strings.TrimSpace(value)) } func mapKeysSorted[K ~string, V any](m map[K]V) []K { if len(m) == 0 { return nil } out := make([]K, 0, len(m)) for key := range m { out = append(out, key) } sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) return out } func firstNonEmpty(values ...string) string { for _, value := range values { value = strings.TrimSpace(value) if value != "" { return value } } return "" } func looksLikeDN(value string) bool { value = strings.TrimSpace(value) return strings.Contains(value, "=") && strings.Contains(value, ",") } func ctxErr(ctx context.Context) error { if ctx == nil { return nil } select { case <-ctx.Done(): return ctx.Err() default: return nil } }