package models import ( "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" "log" "os" "path/filepath" "strconv" "strings" "github.com/go-ldap/ldap/v3" ) // Code relating to AD integration type LdapConfig struct { LdapBindAddress string LdapBaseDn string LdapCertFile string } var systemCA *x509.CertPool // var ldaps *ldap.Conn var LdapServer string var CertLoaded bool var LdapEnabled bool var LdapInsecure bool = false var LdapBaseDn string var DefaultDomainSuffix string func GetFilePath(path string) string { // Check for empty filename if len(path) == 0 { return "" } // check if filename exists if _, err := os.Stat(path); os.IsNotExist((err)) { log.Printf("File '%s' not found, searching in same directory as binary\n", path) // if not, check that it exists in the same directory as the currently executing binary ex, err2 := os.Executable() if err2 != nil { //log.Printf("Error determining binary path : '%s'", err) return "" } binaryPath := filepath.Dir(ex) path = filepath.Join(binaryPath, path) } return path } // DomainSuffixFromNamingContext will convert DC=example,DC=com to example.com func DomainSuffixFromNamingContext(input string) string { tokens := strings.Split(input, ",") var args []string for _, token := range tokens { parts := strings.Split(token, "=") if len(parts) == 2 && parts[0] == "DC" { args = append(args, parts[1]) } } return strings.Join(args, ".") } func CheckUsername(username string) string { if strings.ContainsAny(username, "/@") { // Username contains forward slash or at symbol return username } // Append suffix to the username log.Printf("CheckUsername appending default domain suffix '%s'\n", DefaultDomainSuffix) return username + "@" + DefaultDomainSuffix } func loadLdapCert() { var err error // Get a copy of the system defined CA's systemCA, err = x509.SystemCertPool() if err != nil { log.Printf("LoadLdapCert error getting system certificate pool : '%s'\n", err) return } // only try to load certificate from file if the command line argument was specified ldapCertFile := os.Getenv("LDAP_TRUST_CERT_FILE") if ldapCertFile == "" { log.Printf("LoadLdapCert no certificate specified\n") return } else { // Try to read the file cf, err := os.ReadFile(GetFilePath(ldapCertFile)) if err != nil { log.Printf("LoadLdapCert error opening LDAP certificate file '%s' : '%s'\n", ldapCertFile, err) return } // Get the certificate from the file cpb, _ := pem.Decode(cf) crt, err := x509.ParseCertificate(cpb.Bytes) //log.Printf("Loaded certificate with subject %s\n", crt.Subject) if err != nil { log.Printf("LoadLdapCert error processing LDAP certificate file '%s' : '%s'\n", ldapCertFile, err) return } // Add custom certificate to the system cert pool systemCA.AddCert(crt) CertLoaded = true } } func LdapSetup() bool { var err error // Load LDAP certificate if necessary loadLdapCert() LdapServer = os.Getenv("LDAP_BIND_ADDRESS") if LdapServer == "" { log.Printf("VerifyLdapCreds no LDAP bind address supplied\n") return false } else { LdapEnabled = true } LdapBaseDn = os.Getenv("LDAP_BASE_DN") if LdapBaseDn == "" { log.Printf("VerifyLdapCreds no LDAP base DN supplied\n") return false } insecure := os.Getenv("LDAP_INSECURE_VALIDATION") if insecure != "" { LdapInsecure, err = strconv.ParseBool(insecure) if err != nil { log.Printf("LdapSetup could not convert environment variable LDAP_INSECURE_VALIDATION with value of '%s'\n", insecure) } } // Set up TLS to use our custom certificate authority passed in cli argument tlsConfig := &tls.Config{ RootCAs: systemCA, InsecureSkipVerify: true, } // Add port if not specified in .env file if !(strings.HasSuffix(LdapServer, ":636")) { LdapServer = fmt.Sprintf("%s:636", LdapServer) log.Printf("VerifyLdapCreds updated ldapServer string '%s'\n", LdapServer) } // try connecting to AD via TLS and our custom certificate authority ldaps, err := ldap.DialTLS("tcp", LdapServer, tlsConfig) if err != nil { log.Printf("VerifyLdapCreds error connecting to LDAP bind address '%s' : '%s'\n", LdapServer, err) return false } LdapEnabled = true namingContext := LookupNamingContext(ldaps) if namingContext != "" { DefaultDomainSuffix = DomainSuffixFromNamingContext(namingContext) } ldaps.Close() return true } // LdapConnect sets up the connection to LDAP to be used by other functions func ldapConnect() *ldap.Conn { // Set up TLS to use our custom certificate authority passed in cli argument tlsConfig := &tls.Config{ RootCAs: systemCA, InsecureSkipVerify: LdapInsecure, } log.Printf("ldapConnect initiating connection\n") ldaps, err := ldap.DialTLS("tcp", LdapServer, tlsConfig) if err != nil { log.Printf("VerifyLdapCreds error connecting to LDAP server '%s' : '%s'\n", LdapServer, err) return nil } log.Printf("ldapConnect connection succeeded\n") return ldaps } func LookupNamingContext(ldaps *ldap.Conn) string { // Retrieve the defaultNamingContext searchRequest := ldap.NewSearchRequest( "", ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{"defaultNamingContext"}, nil, ) searchResult, err := ldaps.Search(searchRequest) if err != nil { log.Printf("LookupNamingContext unable to perform unauthenticated search : '%s'\n", err) return "" } if len(searchResult.Entries) != 1 { log.Printf("LookupNamingContext unable to retrieve defaultNamingContext\n") return "" } defaultNamingContext := searchResult.Entries[0].GetAttributeValue("defaultNamingContext") if defaultNamingContext == "" { log.Printf("LookupNamingContext defaultNamingContext attribute not found\n") return "" } log.Printf("Default Naming Context: '%s'\n", defaultNamingContext) return defaultNamingContext } // LdapGetGroupMembership returns a list of distinguishedNames for groups that a user is a member of func LdapGetGroupMembership(username string, password string) ([]string, error) { var err error username = CheckUsername(username) ldaps := ldapConnect() defer ldaps.Close() // try an authenticated bind to AD to verify credentials log.Printf("LdapGetGroupMembership Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password)) err = ldaps.Bind(username, password) if err != nil { if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials { errString := "invalid user credentials" log.Print(errString) return nil, errors.New(errString) } else { errString := fmt.Sprintf("LdapGetGroupMembership error binding to LDAP with supplied credentials : '%s'\n", err) log.Print(errString) return nil, errors.New(errString) } } else { log.Printf("LdapGetGroupMembership successfully bound to LDAP\n") } groups, err := GetGroupsOfUser(username, LdapBaseDn, ldaps) if err != nil { errString := fmt.Sprintf("LdapGetGroupMembership group search error : '%s'\n", err) log.Print(errString) return nil, errors.New(errString) } return groups, nil } // VerifyLdapCreds validates that we can bind successfully to LDAP with the supplied credentials func VerifyLdapCreds(username string, password string) error { var err error username = CheckUsername(username) ldaps := ldapConnect() defer ldaps.Close() // try an authenticated bind to AD to verify credentials log.Printf("Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password)) err = ldaps.Bind(username, password) if err != nil { if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials { errString := "invalid user credentials" log.Print(errString) return errors.New(errString) } else { errString := fmt.Sprintf("VerifyLdapCreds error binding to LDAP with supplied credentials : '%s'\n", err) log.Print(errString) return errors.New(errString) } } else { log.Printf("VerifyLdapCreds successfully bound to LDAP\n") } return nil } // GetGroupsOfUser returns the group for a user. // Taken from https://github.com/jtblin/go-ldap-client/issues/13#issuecomment-456090979 func GetGroupsOfUser(username string, baseDN string, conn *ldap.Conn) ([]string, error) { var sAMAccountName string var groups []string if strings.Contains(username, "@") { s := strings.Split(username, "@") sAMAccountName = s[0] } else if strings.Contains(username, "\\") { s := strings.Split(username, "\\") sAMAccountName = s[len(s)-1] } else { sAMAccountName = username } // Get the users DN // Search for the given username searchRequest := ldap.NewSearchRequest( baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(sAMAccountName=%s)", ldap.EscapeFilter(sAMAccountName)), []string{}, nil, ) log.Printf("searchRequest: %v\n", searchRequest) sr, err := conn.Search(searchRequest) if err != nil { return nil, err } if len(sr.Entries) != 1 { return nil, fmt.Errorf("user '%s' does not exist", sAMAccountName) } else { // Get the groups of the first result groups = sr.Entries[0].GetAttributeValues("memberOf") } return groups, nil } func GetLdapUserDn(username string, baseDN string, conn *ldap.Conn) (string, error) { var sAMAccountName string if strings.Contains(username, "@") { s := strings.Split(username, "@") sAMAccountName = s[0] } else if strings.Contains(username, "\\") { s := strings.Split(username, "\\") sAMAccountName = s[len(s)-1] } else { sAMAccountName = username } // Search for the user's distinguishedName searchRequest := ldap.NewSearchRequest( baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(sAMAccountName=%s)", sAMAccountName), []string{"distinguishedName"}, nil, ) searchResult, err := conn.Search(searchRequest) if err != nil { log.Fatal(err) } if len(searchResult.Entries) == 0 { return "", fmt.Errorf("user '%s' does not exist", sAMAccountName) } else { // Retrieve the distinguishedName of the user distinguishedName := searchResult.Entries[0].GetAttributeValue("distinguishedName") if distinguishedName != "" { log.Printf("GetLdapUserDn located user's distinguishedName : '%s'\n", distinguishedName) return distinguishedName, nil } else { return "", fmt.Errorf("could not find distinguishedName for user '%s'", sAMAccountName) } } } // Returns the user portion of a UPN formatted username func GetUserFromUPN(email string) string { parts := strings.Split(email, "@") if len(parts) > 0 { return parts[0] } return "" }