diff --git a/main.go b/main.go index 40c9ee1..62aecfc 100644 --- a/main.go +++ b/main.go @@ -147,12 +147,10 @@ func main() { models.ReceiveKey(keyString) } - // Load certificate for LDAP connectivy - models.LoadLdapCert() - + // If LDAP configuration is defined, prepare connection ldapServer := os.Getenv("LDAP_BIND_ADDRESS") if ldapServer != "" { - models.LdapEnabled = true + models.LdapSetup() } // Create context that listens for the interrupt signal from the OS. diff --git a/models/ldap.go b/models/ldap.go index bd40df7..f47fc01 100644 --- a/models/ldap.go +++ b/models/ldap.go @@ -22,8 +22,10 @@ type LdapConfig struct { } var systemCA *x509.CertPool +var ldaps *ldap.Conn var CertLoaded bool var LdapEnabled bool +var LdapBaseDn string func GetFilePath(path string) string { // Check for empty filename @@ -46,7 +48,7 @@ func GetFilePath(path string) string { return path } -func LoadLdapCert() { +func loadLdapCert() { var err error // Get a copy of the system defined CA's systemCA, err = x509.SystemCertPool() @@ -85,9 +87,12 @@ func LoadLdapCert() { } } -func VerifyLdapCreds(username string, password string) bool { - var ldaps *ldap.Conn +func LdapSetup() bool { var err error + + // Load LDAP certificate if necessary + loadLdapCert() + ldapServer := os.Getenv("LDAP_BIND_ADDRESS") if ldapServer == "" { log.Printf("VerifyLdapCreds no LDAP bind address supplied\n") @@ -96,8 +101,8 @@ func VerifyLdapCreds(username string, password string) bool { LdapEnabled = true } - ldapBaseDn := os.Getenv("LDAP_BASE_DN") - if ldapBaseDn == "" { + LdapBaseDn = os.Getenv("LDAP_BASE_DN") + if LdapBaseDn == "" { log.Printf("VerifyLdapCreds no LDAP base DN supplied\n") return false } @@ -123,10 +128,48 @@ func VerifyLdapCreds(username string, password string) bool { defer ldaps.Close() - //ldaps.Debug = true + LdapEnabled = true - // try to bind to AD - log.Printf("Attempting LDAP bind with user '%s' and password '%s'\n", username, password) + LookupNamingContext() + + return true +} + +func LookupNamingContext() { + // Retrieve the defaultNamingContext + searchRequest := ldap.NewSearchRequest( + "", + ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, + "(objectClass=*)", + []string{"defaultNamingContext"}, + nil, + ) + + searchResult, err := ldaps.Search(searchRequest) + if err != nil { + log.Printf("LookupNamingContext unable to perform unauthenticated search : '%s'\n", err) + return + } + + if len(searchResult.Entries) != 1 { + log.Printf("LookupNamingContext unable to retrieve defaultNamingContext\n") + return + } + + defaultNamingContext := searchResult.Entries[0].GetAttributeValue("defaultNamingContext") + if defaultNamingContext == "" { + log.Printf("LookupNamingContext defaultNamingContext attribute not found\n") + return + } + + log.Printf("Default Naming Context: '%s'\n", defaultNamingContext) +} + +func VerifyLdapCreds(username string, password string) bool { + var err error + + // try an authenticated bind to AD to verify credentials + log.Printf("Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password)) err = ldaps.Bind(username, password) if err != nil { if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials { @@ -140,9 +183,9 @@ func VerifyLdapCreds(username string, password string) bool { log.Printf("VerifyLdapCreds successfully bound to LDAP\n") } - log.Printf("Attempting LDAP search request from base DN '%s'\n", ldapBaseDn) + log.Printf("Attempting LDAP search request from base DN '%s'\n", LdapBaseDn) searchReq := ldap.NewSearchRequest( - ldapBaseDn, + LdapBaseDn, ldap.ScopeWholeSubtree, // you can also use ldap.ScopeWholeSubtree ldap.NeverDerefAliases, 0, @@ -162,3 +205,44 @@ func VerifyLdapCreds(username string, password string) bool { return true } + +// GetGroupsOfUser returns the group for a user. +// Taken from https://github.com/jtblin/go-ldap-client/issues/13#issuecomment-456090979 +func GetGroupsOfUser(username string, baseDN string, conn *ldap.Conn) ([]string, error) { + var samAccountName string + var groups []string + + if strings.Contains(username, "@") { + s := strings.Split(username, "@") + samAccountName = s[0] + } else if strings.Contains(username, "\\") { + s := strings.Split(username, "\\") + samAccountName = s[len(s)-1] + } else { + samAccountName = username + } + + // Get the users DN + // Search for the given username + searchRequest := ldap.NewSearchRequest( + baseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(CN=%s)", ldap.EscapeFilter(samAccountName)), + []string{}, + nil, + ) + + sr, err := conn.Search(searchRequest) + if err != nil { + return nil, err + } + + if len(sr.Entries) != 1 { + return nil, fmt.Errorf("user '%s' does not exist", samAccountName) + } else { + // Get the groups of the first result + groups = sr.Entries[0].GetAttributeValues("memberOf") + } + + return groups, nil +}