563 lines
15 KiB
Go
563 lines
15 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-ldap/ldap/v3"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidLDAPConfig = errors.New("invalid ldap config")
|
|
ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials")
|
|
ErrLDAPOperationFailed = errors.New("ldap operation failed")
|
|
)
|
|
|
|
type LDAPConfig struct {
|
|
BindAddress string
|
|
BaseDN string
|
|
UserBaseDN string
|
|
GroupBaseDN string
|
|
TrustCertFile string
|
|
DisableValidation bool
|
|
Insecure bool
|
|
DialTimeout time.Duration
|
|
}
|
|
|
|
type LDAPIdentity struct {
|
|
Username string
|
|
UserDN string
|
|
Groups []string
|
|
BindDuration time.Duration
|
|
UserLookupDuration time.Duration
|
|
GroupMembershipLookupDuration time.Duration
|
|
// Diagnostics contains non-sensitive LDAP processing notes useful for debugging auth decisions.
|
|
Diagnostics []string
|
|
}
|
|
|
|
type LDAPAuthenticator struct {
|
|
bindAddress string
|
|
baseDN string
|
|
userBaseDN string
|
|
groupBaseDN string
|
|
trustCertFile string
|
|
disableValidation bool
|
|
insecure bool
|
|
dialTimeout time.Duration
|
|
}
|
|
|
|
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
|
|
bindAddress := strings.TrimSpace(cfg.BindAddress)
|
|
baseDN := strings.TrimSpace(cfg.BaseDN)
|
|
userBaseDN := strings.TrimSpace(cfg.UserBaseDN)
|
|
groupBaseDN := strings.TrimSpace(cfg.GroupBaseDN)
|
|
trustCertFile := strings.TrimSpace(cfg.TrustCertFile)
|
|
|
|
if bindAddress == "" {
|
|
return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig)
|
|
}
|
|
if baseDN == "" {
|
|
return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig)
|
|
}
|
|
if userBaseDN == "" {
|
|
userBaseDN = baseDN
|
|
}
|
|
if groupBaseDN == "" {
|
|
groupBaseDN = baseDN
|
|
}
|
|
if _, err := url.ParseRequestURI(bindAddress); err != nil {
|
|
return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err)
|
|
}
|
|
|
|
dialTimeout := cfg.DialTimeout
|
|
if dialTimeout <= 0 {
|
|
dialTimeout = 10 * time.Second
|
|
}
|
|
|
|
return &LDAPAuthenticator{
|
|
bindAddress: bindAddress,
|
|
baseDN: baseDN,
|
|
userBaseDN: userBaseDN,
|
|
groupBaseDN: groupBaseDN,
|
|
trustCertFile: trustCertFile,
|
|
disableValidation: cfg.DisableValidation,
|
|
insecure: cfg.Insecure,
|
|
dialTimeout: dialTimeout,
|
|
}, nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
|
|
inputUsername := strings.TrimSpace(username)
|
|
if inputUsername == "" || password == "" {
|
|
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
|
}
|
|
if err := ctxErr(ctx); err != nil {
|
|
return LDAPIdentity{}, err
|
|
}
|
|
bindUsername, rewrittenToUPN := normalizeBindUsername(inputUsername, a.baseDN)
|
|
|
|
conn, err := a.connect()
|
|
if err != nil {
|
|
return LDAPIdentity{}, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
bindStartedAt := time.Now()
|
|
err = conn.Bind(bindUsername, password)
|
|
bindDuration := time.Since(bindStartedAt)
|
|
if err != nil {
|
|
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
|
return LDAPIdentity{}, fmt.Errorf("%w: ldap bind rejected credentials (bind_duration=%s)", ErrLDAPInvalidCredentials, bindDuration)
|
|
}
|
|
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v (bind_duration=%s)", ErrLDAPOperationFailed, err, bindDuration)
|
|
}
|
|
if err := ctxErr(ctx); err != nil {
|
|
return LDAPIdentity{}, err
|
|
}
|
|
|
|
identity := LDAPIdentity{
|
|
Username: inputUsername,
|
|
UserDN: bindUsername,
|
|
BindDuration: bindDuration,
|
|
}
|
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("bind_duration_ms=%d", bindDuration.Milliseconds()))
|
|
if rewrittenToUPN {
|
|
identity.Diagnostics = append(identity.Diagnostics, "bind_username_rewritten_to_upn")
|
|
}
|
|
identity.Diagnostics = append(identity.Diagnostics,
|
|
"user_lookup_base_dn="+a.userBaseDN,
|
|
"group_lookup_base_dn="+a.groupBaseDN,
|
|
)
|
|
if whoami, err := conn.WhoAmI(nil); err != nil {
|
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("whoami_failed:%v", err))
|
|
} else if boundDN := parseWhoAmIDN(whoami.AuthzID); boundDN != "" {
|
|
identity.UserDN = boundDN
|
|
identity.Diagnostics = append(identity.Diagnostics, "whoami_dn_resolved")
|
|
} else if strings.TrimSpace(whoami.AuthzID) == "" {
|
|
identity.Diagnostics = append(identity.Diagnostics, "whoami_dn_empty")
|
|
} else {
|
|
identity.Diagnostics = append(identity.Diagnostics, "whoami_non_dn_authzid")
|
|
}
|
|
|
|
userLookupStartedAt := time.Now()
|
|
entry, lookupStrategy, err := a.lookupUserEntry(conn, inputUsername, identity.UserDN)
|
|
identity.UserLookupDuration = time.Since(userLookupStartedAt)
|
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("user_lookup_duration_ms=%d", identity.UserLookupDuration.Milliseconds()))
|
|
if err != nil {
|
|
return LDAPIdentity{}, fmt.Errorf("%w: %v (bind_duration=%s user_lookup_duration=%s)", ErrLDAPOperationFailed, err, identity.BindDuration, identity.UserLookupDuration)
|
|
}
|
|
if entry != nil {
|
|
if lookupStrategy == "" {
|
|
lookupStrategy = "unknown"
|
|
}
|
|
identity.Diagnostics = append(identity.Diagnostics, "user_entry_found:"+lookupStrategy)
|
|
if strings.TrimSpace(entry.DN) != "" {
|
|
identity.UserDN = entry.DN
|
|
}
|
|
if v := firstNonEmpty(
|
|
entry.GetAttributeValue("uid"),
|
|
entry.GetAttributeValue("sAMAccountName"),
|
|
entry.GetAttributeValue("userPrincipalName"),
|
|
entry.GetAttributeValue("cn"),
|
|
); v != "" {
|
|
identity.Username = v
|
|
}
|
|
} else {
|
|
identity.Diagnostics = append(identity.Diagnostics, "user_entry_not_found")
|
|
}
|
|
|
|
groupSet := make(map[string]struct{})
|
|
groupLookupStartedAt := time.Now()
|
|
if entry != nil {
|
|
for _, groupDN := range entry.GetAttributeValues("memberOf") {
|
|
groupDN = strings.TrimSpace(groupDN)
|
|
if groupDN == "" {
|
|
continue
|
|
}
|
|
groupSet[groupDN] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Intentionally skip subtree group membership search for now.
|
|
// Authorization is based only on direct group membership values present in the user entry (memberOf).
|
|
identity.GroupMembershipLookupDuration = time.Since(groupLookupStartedAt)
|
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("group_lookup_duration_ms=%d", identity.GroupMembershipLookupDuration.Milliseconds()))
|
|
identity.Diagnostics = append(identity.Diagnostics, "group_search_skipped_direct_memberof_only")
|
|
|
|
identity.Groups = mapKeysSorted(groupSet)
|
|
identity.Diagnostics = compactTrimmedStrings(identity.Diagnostics)
|
|
return identity, nil
|
|
}
|
|
|
|
func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string {
|
|
if len(groupDNs) == 0 || len(groupRoleMappings) == 0 {
|
|
return nil
|
|
}
|
|
|
|
normalizedMappings := make(map[string]string, len(groupRoleMappings))
|
|
for groupDN, role := range groupRoleMappings {
|
|
groupDN = normalizeDN(groupDN)
|
|
role = strings.ToLower(strings.TrimSpace(role))
|
|
if groupDN == "" || role == "" {
|
|
continue
|
|
}
|
|
normalizedMappings[groupDN] = role
|
|
}
|
|
|
|
roleSet := make(map[string]struct{})
|
|
for _, groupDN := range groupDNs {
|
|
if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok {
|
|
roleSet[role] = struct{}{}
|
|
}
|
|
}
|
|
return mapKeysSorted(roleSet)
|
|
}
|
|
|
|
func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool {
|
|
requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs)
|
|
if len(requiredGroupDNs) == 0 {
|
|
return true
|
|
}
|
|
if len(groupDNs) == 0 {
|
|
return false
|
|
}
|
|
|
|
required := make(map[string]struct{}, len(requiredGroupDNs))
|
|
for _, groupDN := range requiredGroupDNs {
|
|
required[normalizeDN(groupDN)] = struct{}{}
|
|
}
|
|
for _, groupDN := range groupDNs {
|
|
if _, ok := required[normalizeDN(groupDN)]; ok {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
|
|
tlsConfig, err := a.buildTLSConfig()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parsedURL, err := url.Parse(a.bindAddress)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err)
|
|
}
|
|
|
|
options := []ldap.DialOpt{
|
|
ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}),
|
|
ldap.DialWithTLSConfig(tlsConfig),
|
|
}
|
|
conn, err := ldap.DialURL(a.bindAddress, options...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err)
|
|
}
|
|
conn.SetTimeout(a.dialTimeout)
|
|
|
|
// For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure.
|
|
if parsedURL.Scheme == "ldap" && !a.insecure {
|
|
if err := conn.StartTLS(tlsConfig); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err)
|
|
}
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) {
|
|
tlsConfig := &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags
|
|
}
|
|
|
|
if a.trustCertFile == "" {
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
caPEM, err := os.ReadFile(a.trustCertFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err)
|
|
}
|
|
roots := x509.NewCertPool()
|
|
if !roots.AppendCertsFromPEM(caPEM) {
|
|
return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig)
|
|
}
|
|
tlsConfig.RootCAs = roots
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string, userDNHint string) (*ldap.Entry, string, error) {
|
|
dnCandidates := make([]string, 0, 2)
|
|
if looksLikeDN(userDNHint) {
|
|
dnCandidates = append(dnCandidates, strings.TrimSpace(userDNHint))
|
|
}
|
|
if looksLikeDN(username) {
|
|
dnCandidates = append(dnCandidates, strings.TrimSpace(username))
|
|
}
|
|
seenDN := make(map[string]struct{}, len(dnCandidates))
|
|
for _, dn := range dnCandidates {
|
|
key := normalizeDN(dn)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
if _, ok := seenDN[key]; ok {
|
|
continue
|
|
}
|
|
seenDN[key] = struct{}{}
|
|
|
|
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
|
dn,
|
|
ldap.ScopeBaseObject,
|
|
ldap.NeverDerefAliases,
|
|
1,
|
|
0,
|
|
false,
|
|
"(objectClass=*)",
|
|
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
|
nil,
|
|
))
|
|
if err != nil {
|
|
if ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) {
|
|
continue
|
|
}
|
|
return nil, "", fmt.Errorf("%w: unable to load user entry by dn: %v", ErrLDAPOperationFailed, err)
|
|
}
|
|
if len(searchRes.Entries) > 0 {
|
|
return searchRes.Entries[0], "dn", nil
|
|
}
|
|
}
|
|
|
|
for _, principal := range principalCandidates(username) {
|
|
if strings.Contains(principal, "@") {
|
|
entry, err := a.searchUserByAttribute(conn, "userPrincipalName", principal)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
if entry != nil {
|
|
return entry, "principal_upn", nil
|
|
}
|
|
// For UPN principals, avoid fallback attribute probes that are unlikely to match
|
|
// and can be expensive on large directory trees.
|
|
continue
|
|
}
|
|
|
|
entry, err := a.searchUserByAttribute(conn, "sAMAccountName", principal)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
if entry != nil {
|
|
return entry, "principal_samaccountname", nil
|
|
}
|
|
|
|
// Keep uid lookup as a fallback for non-AD LDAP directories.
|
|
entry, err = a.searchUserByAttribute(conn, "uid", principal)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
if entry != nil {
|
|
return entry, "principal_uid", nil
|
|
}
|
|
}
|
|
return nil, "", nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) searchUserByAttribute(conn *ldap.Conn, attribute string, value string) (*ldap.Entry, error) {
|
|
attribute = strings.TrimSpace(attribute)
|
|
value = strings.TrimSpace(value)
|
|
if attribute == "" || value == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
|
a.userBaseDN,
|
|
ldap.ScopeWholeSubtree,
|
|
ldap.NeverDerefAliases,
|
|
2,
|
|
0,
|
|
false,
|
|
fmt.Sprintf("(%s=%s)", attribute, ldap.EscapeFilter(value)),
|
|
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
|
nil,
|
|
))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: user lookup failed (%s): %v", ErrLDAPOperationFailed, attribute, err)
|
|
}
|
|
if len(searchRes.Entries) == 0 {
|
|
return nil, nil
|
|
}
|
|
return searchRes.Entries[0], nil
|
|
}
|
|
|
|
func normalizeDN(value string) string {
|
|
return strings.ToLower(strings.TrimSpace(value))
|
|
}
|
|
|
|
func mapKeysSorted[K ~string, V any](m map[K]V) []K {
|
|
if len(m) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]K, 0, len(m))
|
|
for key := range m {
|
|
out = append(out, key)
|
|
}
|
|
sort.Slice(out, func(i, j int) bool {
|
|
return out[i] < out[j]
|
|
})
|
|
return out
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
value = strings.TrimSpace(value)
|
|
if value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func looksLikeDN(value string) bool {
|
|
value = strings.TrimSpace(value)
|
|
return strings.Contains(value, "=") && strings.Contains(value, ",")
|
|
}
|
|
|
|
func parseWhoAmIDN(authzID string) string {
|
|
authzID = strings.TrimSpace(authzID)
|
|
if authzID == "" {
|
|
return ""
|
|
}
|
|
|
|
lower := strings.ToLower(authzID)
|
|
if strings.HasPrefix(lower, "dn:") {
|
|
authzID = strings.TrimSpace(authzID[3:])
|
|
}
|
|
if !looksLikeDN(authzID) {
|
|
return ""
|
|
}
|
|
return authzID
|
|
}
|
|
|
|
func normalizeBindUsername(username string, baseDN string) (string, bool) {
|
|
username = strings.TrimSpace(username)
|
|
if username == "" {
|
|
return "", false
|
|
}
|
|
if looksLikeDN(username) || strings.Contains(username, "@") {
|
|
return username, false
|
|
}
|
|
|
|
// Convert DOMAIN\user to user before UPN rewrite.
|
|
if idx := strings.LastIndex(username, `\`); idx >= 0 && idx < len(username)-1 {
|
|
username = strings.TrimSpace(username[idx+1:])
|
|
}
|
|
|
|
domain := upnDomainFromBaseDN(baseDN)
|
|
if domain == "" {
|
|
return username, false
|
|
}
|
|
if strings.Contains(username, "@") {
|
|
return username, false
|
|
}
|
|
return username + "@" + domain, true
|
|
}
|
|
|
|
func upnDomainFromBaseDN(baseDN string) string {
|
|
baseDN = strings.TrimSpace(baseDN)
|
|
if baseDN == "" {
|
|
return ""
|
|
}
|
|
|
|
parts := strings.Split(baseDN, ",")
|
|
labels := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if len(part) < 3 || !strings.EqualFold(part[:3], "dc=") {
|
|
continue
|
|
}
|
|
label := strings.TrimSpace(part[3:])
|
|
if label == "" {
|
|
continue
|
|
}
|
|
labels = append(labels, label)
|
|
}
|
|
if len(labels) == 0 {
|
|
return ""
|
|
}
|
|
return strings.Join(labels, ".")
|
|
}
|
|
|
|
func principalCandidates(username string) []string {
|
|
username = strings.TrimSpace(username)
|
|
if username == "" {
|
|
return nil
|
|
}
|
|
|
|
seen := make(map[string]struct{}, 4)
|
|
candidates := make([]string, 0, 4)
|
|
add := func(value string) {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return
|
|
}
|
|
key := strings.ToLower(value)
|
|
if _, ok := seen[key]; ok {
|
|
return
|
|
}
|
|
seen[key] = struct{}{}
|
|
candidates = append(candidates, value)
|
|
}
|
|
|
|
add(username)
|
|
if idx := strings.LastIndex(username, `\`); idx >= 0 && idx < len(username)-1 {
|
|
add(username[idx+1:])
|
|
}
|
|
if idx := strings.Index(username, "@"); idx > 0 {
|
|
add(username[:idx])
|
|
}
|
|
|
|
return candidates
|
|
}
|
|
|
|
func buildGroupMembershipFilter(userDN string, principals []string) string {
|
|
clauses := make([]string, 0, 2+len(principals))
|
|
userDN = strings.TrimSpace(userDN)
|
|
if userDN != "" {
|
|
escapedDN := ldap.EscapeFilter(userDN)
|
|
clauses = append(clauses, "(member="+escapedDN+")", "(uniqueMember="+escapedDN+")")
|
|
}
|
|
for _, principal := range principals {
|
|
principal = strings.TrimSpace(principal)
|
|
if principal == "" {
|
|
continue
|
|
}
|
|
clauses = append(clauses, "(memberUid="+ldap.EscapeFilter(principal)+")")
|
|
}
|
|
if len(clauses) == 0 {
|
|
return "(objectClass=group)"
|
|
}
|
|
return "(|" + strings.Join(clauses, "") + ")"
|
|
}
|
|
|
|
func ctxErr(ctx context.Context) error {
|
|
if ctx == nil {
|
|
return nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
return nil
|
|
}
|
|
}
|