Files
vctp2/internal/auth/ldap.go
T
Nathan Coad 361ba7719b
continuous-integration/drone/push Build is passing
more auth logging
2026-04-21 10:35:10 +10:00

366 lines
9.1 KiB
Go

package auth
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
"os"
"sort"
"strings"
"time"
"github.com/go-ldap/ldap/v3"
)
var (
ErrInvalidLDAPConfig = errors.New("invalid ldap config")
ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials")
ErrLDAPOperationFailed = errors.New("ldap operation failed")
)
type LDAPConfig struct {
BindAddress string
BaseDN string
TrustCertFile string
DisableValidation bool
Insecure bool
DialTimeout time.Duration
}
type LDAPIdentity struct {
Username string
UserDN string
Groups []string
// Diagnostics contains non-sensitive LDAP processing notes useful for debugging auth decisions.
Diagnostics []string
}
type LDAPAuthenticator struct {
bindAddress string
baseDN string
trustCertFile string
disableValidation bool
insecure bool
dialTimeout time.Duration
}
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
bindAddress := strings.TrimSpace(cfg.BindAddress)
baseDN := strings.TrimSpace(cfg.BaseDN)
trustCertFile := strings.TrimSpace(cfg.TrustCertFile)
if bindAddress == "" {
return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig)
}
if baseDN == "" {
return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig)
}
if _, err := url.ParseRequestURI(bindAddress); err != nil {
return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err)
}
dialTimeout := cfg.DialTimeout
if dialTimeout <= 0 {
dialTimeout = 10 * time.Second
}
return &LDAPAuthenticator{
bindAddress: bindAddress,
baseDN: baseDN,
trustCertFile: trustCertFile,
disableValidation: cfg.DisableValidation,
insecure: cfg.Insecure,
dialTimeout: dialTimeout,
}, nil
}
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
username = strings.TrimSpace(username)
if username == "" || password == "" {
return LDAPIdentity{}, ErrLDAPInvalidCredentials
}
if err := ctxErr(ctx); err != nil {
return LDAPIdentity{}, err
}
conn, err := a.connect()
if err != nil {
return LDAPIdentity{}, err
}
defer conn.Close()
if err := conn.Bind(username, password); err != nil {
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
return LDAPIdentity{}, fmt.Errorf("%w: ldap bind rejected credentials", ErrLDAPInvalidCredentials)
}
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err)
}
if err := ctxErr(ctx); err != nil {
return LDAPIdentity{}, err
}
identity := LDAPIdentity{
Username: username,
UserDN: username,
}
entry, err := a.lookupUserEntry(conn, username)
if err != nil {
return LDAPIdentity{}, err
}
if entry != nil {
identity.Diagnostics = append(identity.Diagnostics, "user_entry_found")
if strings.TrimSpace(entry.DN) != "" {
identity.UserDN = entry.DN
}
if v := firstNonEmpty(
entry.GetAttributeValue("uid"),
entry.GetAttributeValue("sAMAccountName"),
entry.GetAttributeValue("userPrincipalName"),
entry.GetAttributeValue("cn"),
); v != "" {
identity.Username = v
}
} else {
identity.Diagnostics = append(identity.Diagnostics, "user_entry_not_found")
}
groupSet := make(map[string]struct{})
if entry != nil {
for _, groupDN := range entry.GetAttributeValues("memberOf") {
groupDN = strings.TrimSpace(groupDN)
if groupDN == "" {
continue
}
groupSet[groupDN] = struct{}{}
}
}
groupEntries, err := conn.Search(ldap.NewSearchRequest(
a.baseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
0,
0,
false,
fmt.Sprintf("(|(member=%s)(uniqueMember=%s)(memberUid=%s))",
ldap.EscapeFilter(identity.UserDN),
ldap.EscapeFilter(identity.UserDN),
ldap.EscapeFilter(username),
),
[]string{"dn"},
nil,
))
if err == nil {
for _, e := range groupEntries.Entries {
if dn := strings.TrimSpace(e.DN); dn != "" {
groupSet[dn] = struct{}{}
}
}
if len(groupEntries.Entries) == 0 {
identity.Diagnostics = append(identity.Diagnostics, "group_search_returned_no_entries")
}
} else {
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("group_search_failed:%v", err))
}
identity.Groups = mapKeysSorted(groupSet)
identity.Diagnostics = compactTrimmedStrings(identity.Diagnostics)
return identity, nil
}
func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string {
if len(groupDNs) == 0 || len(groupRoleMappings) == 0 {
return nil
}
normalizedMappings := make(map[string]string, len(groupRoleMappings))
for groupDN, role := range groupRoleMappings {
groupDN = normalizeDN(groupDN)
role = strings.ToLower(strings.TrimSpace(role))
if groupDN == "" || role == "" {
continue
}
normalizedMappings[groupDN] = role
}
roleSet := make(map[string]struct{})
for _, groupDN := range groupDNs {
if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok {
roleSet[role] = struct{}{}
}
}
return mapKeysSorted(roleSet)
}
func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool {
requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs)
if len(requiredGroupDNs) == 0 {
return true
}
if len(groupDNs) == 0 {
return false
}
required := make(map[string]struct{}, len(requiredGroupDNs))
for _, groupDN := range requiredGroupDNs {
required[normalizeDN(groupDN)] = struct{}{}
}
for _, groupDN := range groupDNs {
if _, ok := required[normalizeDN(groupDN)]; ok {
return true
}
}
return false
}
func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
tlsConfig, err := a.buildTLSConfig()
if err != nil {
return nil, err
}
parsedURL, err := url.Parse(a.bindAddress)
if err != nil {
return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err)
}
options := []ldap.DialOpt{
ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}),
ldap.DialWithTLSConfig(tlsConfig),
}
conn, err := ldap.DialURL(a.bindAddress, options...)
if err != nil {
return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err)
}
conn.SetTimeout(a.dialTimeout)
// For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure.
if parsedURL.Scheme == "ldap" && !a.insecure {
if err := conn.StartTLS(tlsConfig); err != nil {
conn.Close()
return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err)
}
}
return conn, nil
}
func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags
}
if a.trustCertFile == "" {
return tlsConfig, nil
}
caPEM, err := os.ReadFile(a.trustCertFile)
if err != nil {
return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err)
}
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(caPEM) {
return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig)
}
tlsConfig.RootCAs = roots
return tlsConfig, nil
}
func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string) (*ldap.Entry, error) {
if looksLikeDN(username) {
searchRes, err := conn.Search(ldap.NewSearchRequest(
username,
ldap.ScopeBaseObject,
ldap.NeverDerefAliases,
1,
0,
false,
"(objectClass=*)",
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
nil,
))
if err != nil {
return nil, fmt.Errorf("%w: unable to load user entry: %v", ErrLDAPOperationFailed, err)
}
if len(searchRes.Entries) == 0 {
return nil, nil
}
return searchRes.Entries[0], nil
}
searchRes, err := conn.Search(ldap.NewSearchRequest(
a.baseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
2,
0,
false,
fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))",
ldap.EscapeFilter(username),
ldap.EscapeFilter(username),
ldap.EscapeFilter(username),
ldap.EscapeFilter(username),
),
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
nil,
))
if err != nil {
return nil, fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err)
}
if len(searchRes.Entries) == 0 {
return nil, nil
}
return searchRes.Entries[0], nil
}
func normalizeDN(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func mapKeysSorted[K ~string, V any](m map[K]V) []K {
if len(m) == 0 {
return nil
}
out := make([]K, 0, len(m))
for key := range m {
out = append(out, key)
}
sort.Slice(out, func(i, j int) bool {
return out[i] < out[j]
})
return out
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
value = strings.TrimSpace(value)
if value != "" {
return value
}
}
return ""
}
func looksLikeDN(value string) bool {
value = strings.TrimSpace(value)
return strings.Contains(value, "=") && strings.Contains(value, ",")
}
func ctxErr(ctx context.Context) error {
if ctx == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}