@@ -0,0 +1,354 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidLDAPConfig = errors.New("invalid ldap config")
|
||||
ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials")
|
||||
ErrLDAPOperationFailed = errors.New("ldap operation failed")
|
||||
)
|
||||
|
||||
type LDAPConfig struct {
|
||||
BindAddress string
|
||||
BaseDN string
|
||||
TrustCertFile string
|
||||
DisableValidation bool
|
||||
Insecure bool
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
type LDAPIdentity struct {
|
||||
Username string
|
||||
UserDN string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
type LDAPAuthenticator struct {
|
||||
bindAddress string
|
||||
baseDN string
|
||||
trustCertFile string
|
||||
disableValidation bool
|
||||
insecure bool
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
|
||||
bindAddress := strings.TrimSpace(cfg.BindAddress)
|
||||
baseDN := strings.TrimSpace(cfg.BaseDN)
|
||||
trustCertFile := strings.TrimSpace(cfg.TrustCertFile)
|
||||
|
||||
if bindAddress == "" {
|
||||
return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig)
|
||||
}
|
||||
if baseDN == "" {
|
||||
return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig)
|
||||
}
|
||||
if _, err := url.ParseRequestURI(bindAddress); err != nil {
|
||||
return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
|
||||
dialTimeout := cfg.DialTimeout
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = 10 * time.Second
|
||||
}
|
||||
|
||||
return &LDAPAuthenticator{
|
||||
bindAddress: bindAddress,
|
||||
baseDN: baseDN,
|
||||
trustCertFile: trustCertFile,
|
||||
disableValidation: cfg.DisableValidation,
|
||||
insecure: cfg.Insecure,
|
||||
dialTimeout: dialTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" || password == "" {
|
||||
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
||||
}
|
||||
if err := ctxErr(ctx); err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
|
||||
conn, err := a.connect()
|
||||
if err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.Bind(username, password); err != nil {
|
||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
||||
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
||||
}
|
||||
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if err := ctxErr(ctx); err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
|
||||
identity := LDAPIdentity{
|
||||
Username: username,
|
||||
UserDN: username,
|
||||
}
|
||||
|
||||
entry, err := a.lookupUserEntry(conn, username)
|
||||
if err != nil {
|
||||
return LDAPIdentity{}, err
|
||||
}
|
||||
if entry != nil {
|
||||
if strings.TrimSpace(entry.DN) != "" {
|
||||
identity.UserDN = entry.DN
|
||||
}
|
||||
if v := firstNonEmpty(
|
||||
entry.GetAttributeValue("uid"),
|
||||
entry.GetAttributeValue("sAMAccountName"),
|
||||
entry.GetAttributeValue("userPrincipalName"),
|
||||
entry.GetAttributeValue("cn"),
|
||||
); v != "" {
|
||||
identity.Username = v
|
||||
}
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
if entry != nil {
|
||||
for _, groupDN := range entry.GetAttributeValues("memberOf") {
|
||||
groupDN = strings.TrimSpace(groupDN)
|
||||
if groupDN == "" {
|
||||
continue
|
||||
}
|
||||
groupSet[groupDN] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
groupEntries, err := conn.Search(ldap.NewSearchRequest(
|
||||
a.baseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
fmt.Sprintf("(|(member=%s)(uniqueMember=%s)(memberUid=%s))",
|
||||
ldap.EscapeFilter(identity.UserDN),
|
||||
ldap.EscapeFilter(identity.UserDN),
|
||||
ldap.EscapeFilter(username),
|
||||
),
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
))
|
||||
if err == nil {
|
||||
for _, e := range groupEntries.Entries {
|
||||
if dn := strings.TrimSpace(e.DN); dn != "" {
|
||||
groupSet[dn] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
identity.Groups = mapKeysSorted(groupSet)
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string {
|
||||
if len(groupDNs) == 0 || len(groupRoleMappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalizedMappings := make(map[string]string, len(groupRoleMappings))
|
||||
for groupDN, role := range groupRoleMappings {
|
||||
groupDN = normalizeDN(groupDN)
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if groupDN == "" || role == "" {
|
||||
continue
|
||||
}
|
||||
normalizedMappings[groupDN] = role
|
||||
}
|
||||
|
||||
roleSet := make(map[string]struct{})
|
||||
for _, groupDN := range groupDNs {
|
||||
if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok {
|
||||
roleSet[role] = struct{}{}
|
||||
}
|
||||
}
|
||||
return mapKeysSorted(roleSet)
|
||||
}
|
||||
|
||||
func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool {
|
||||
requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs)
|
||||
if len(requiredGroupDNs) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(groupDNs) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
required := make(map[string]struct{}, len(requiredGroupDNs))
|
||||
for _, groupDN := range requiredGroupDNs {
|
||||
required[normalizeDN(groupDN)] = struct{}{}
|
||||
}
|
||||
for _, groupDN := range groupDNs {
|
||||
if _, ok := required[normalizeDN(groupDN)]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) {
|
||||
tlsConfig, err := a.buildTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(a.bindAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
|
||||
options := []ldap.DialOpt{
|
||||
ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}),
|
||||
ldap.DialWithTLSConfig(tlsConfig),
|
||||
}
|
||||
conn, err := ldap.DialURL(a.bindAddress, options...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
conn.SetTimeout(a.dialTimeout)
|
||||
|
||||
// For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure.
|
||||
if parsedURL.Scheme == "ldap" && !a.insecure {
|
||||
if err := conn.StartTLS(tlsConfig); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags
|
||||
}
|
||||
|
||||
if a.trustCertFile == "" {
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
caPEM, err := os.ReadFile(a.trustCertFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err)
|
||||
}
|
||||
roots := x509.NewCertPool()
|
||||
if !roots.AppendCertsFromPEM(caPEM) {
|
||||
return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig)
|
||||
}
|
||||
tlsConfig.RootCAs = roots
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string) (*ldap.Entry, error) {
|
||||
if looksLikeDN(username) {
|
||||
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
||||
username,
|
||||
ldap.ScopeBaseObject,
|
||||
ldap.NeverDerefAliases,
|
||||
1,
|
||||
0,
|
||||
false,
|
||||
"(objectClass=*)",
|
||||
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
||||
nil,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: unable to load user entry: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if len(searchRes.Entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return searchRes.Entries[0], nil
|
||||
}
|
||||
|
||||
searchRes, err := conn.Search(ldap.NewSearchRequest(
|
||||
a.baseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
2,
|
||||
0,
|
||||
false,
|
||||
fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))",
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username),
|
||||
),
|
||||
[]string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"},
|
||||
nil,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err)
|
||||
}
|
||||
if len(searchRes.Entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return searchRes.Entries[0], nil
|
||||
}
|
||||
|
||||
func normalizeDN(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func mapKeysSorted[K ~string, V any](m map[K]V) []K {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]K, 0, len(m))
|
||||
for key := range m {
|
||||
out = append(out, key)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
return out[i] < out[j]
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func looksLikeDN(value string) bool {
|
||||
value = strings.TrimSpace(value)
|
||||
return strings.Contains(value, "=") && strings.Contains(value, ",")
|
||||
}
|
||||
|
||||
func ctxErr(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user