package models import ( "crypto/tls" "crypto/x509" "encoding/pem" "fmt" "log" "os" "path/filepath" "strings" "github.com/go-ldap/ldap" ) // Code relating to AD integration type LdapConfig struct { LdapBindAddress string LdapBaseDn string LdapCertFile string } var systemCA *x509.CertPool var ldaps *ldap.Conn var CertLoaded bool var LdapEnabled bool var LdapBaseDn string var DefaultDomainSuffix string func GetFilePath(path string) string { // Check for empty filename if len(path) == 0 { return "" } // check if filename exists if _, err := os.Stat(path); os.IsNotExist((err)) { log.Printf("File '%s' not found, searching in same directory as binary\n", path) // if not, check that it exists in the same directory as the currently executing binary ex, err2 := os.Executable() if err2 != nil { //log.Printf("Error determining binary path : '%s'", err) return "" } binaryPath := filepath.Dir(ex) path = filepath.Join(binaryPath, path) } return path } // DomainSuffixFromNamingContext will convert DC=example,DC=com to example.com func DomainSuffixFromNamingContext(input string) string { tokens := strings.Split(input, ",") var args []string for _, token := range tokens { parts := strings.Split(token, "=") if len(parts) == 2 && parts[0] == "DC" { args = append(args, parts[1]) } } return strings.Join(args, ".") } func CheckUsername(username string) string { if strings.ContainsAny(username, "/@") { // Username contains forward slash or at symbol return username } // Append suffix to the username log.Printf("CheckUsername appending default domain suffix '%s'\n", DefaultDomainSuffix) return username + "@" + DefaultDomainSuffix } func loadLdapCert() { var err error // Get a copy of the system defined CA's systemCA, err = x509.SystemCertPool() if err != nil { log.Printf("LoadLdapCert error getting system certificate pool : '%s'\n", err) return } // only try to load certificate from file if the command line argument was specified ldapCertFile := os.Getenv("LDAP_TRUST_CERT_FILE") if ldapCertFile == "" { log.Printf("LoadLdapCert no certificate specified\n") return } else { // Try to read the file cf, err := os.ReadFile(GetFilePath(ldapCertFile)) if err != nil { log.Printf("LoadLdapCert error opening LDAP certificate file '%s' : '%s'\n", ldapCertFile, err) return } // Get the certificate from the file cpb, _ := pem.Decode(cf) crt, err := x509.ParseCertificate(cpb.Bytes) //fmt.Printf("Loaded certificate with subject %s\n", crt.Subject) if err != nil { log.Printf("LoadLdapCert error processing LDAP certificate file '%s' : '%s'\n", ldapCertFile, err) return } // Add custom certificate to the system cert pool systemCA.AddCert(crt) CertLoaded = true } } func LdapSetup() bool { var err error // Load LDAP certificate if necessary loadLdapCert() ldapServer := os.Getenv("LDAP_BIND_ADDRESS") if ldapServer == "" { log.Printf("VerifyLdapCreds no LDAP bind address supplied\n") return false } else { LdapEnabled = true } LdapBaseDn = os.Getenv("LDAP_BASE_DN") if LdapBaseDn == "" { log.Printf("VerifyLdapCreds no LDAP base DN supplied\n") return false } // Set up TLS to use our custom certificate authority passed in cli argument tlsConfig := &tls.Config{ RootCAs: systemCA, InsecureSkipVerify: true, } // Add port if not specified in .env file if !(strings.HasSuffix(ldapServer, ":636")) { ldapServer = fmt.Sprintf("%s:636", ldapServer) log.Printf("VerifyLdapCreds updated ldapServer string '%s'\n", ldapServer) } // try connecting to AD via TLS and our custom certificate authority ldaps, err = ldap.DialTLS("tcp", ldapServer, tlsConfig) if err != nil { log.Printf("VerifyLdapCreds error connecting to LDAP bind address '%s' : '%s'\n", ldapServer, err) return false } //defer ldaps.Close() LdapEnabled = true namingContext := LookupNamingContext() if namingContext != "" { DefaultDomainSuffix = DomainSuffixFromNamingContext(namingContext) } return true } func LookupNamingContext() string { // Retrieve the defaultNamingContext searchRequest := ldap.NewSearchRequest( "", ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{"defaultNamingContext"}, nil, ) searchResult, err := ldaps.Search(searchRequest) if err != nil { log.Printf("LookupNamingContext unable to perform unauthenticated search : '%s'\n", err) return "" } if len(searchResult.Entries) != 1 { log.Printf("LookupNamingContext unable to retrieve defaultNamingContext\n") return "" } defaultNamingContext := searchResult.Entries[0].GetAttributeValue("defaultNamingContext") if defaultNamingContext == "" { log.Printf("LookupNamingContext defaultNamingContext attribute not found\n") return "" } log.Printf("Default Naming Context: '%s'\n", defaultNamingContext) return defaultNamingContext } func VerifyLdapCreds(username string, password string) bool { var err error username = CheckUsername(username) // try an authenticated bind to AD to verify credentials log.Printf("Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password)) err = ldaps.Bind(username, password) if err != nil { if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials { log.Printf("VerifyLdapCreds user credentials are incorrect : '%s'\n", err) return false } else { log.Printf("VerifyLdapCreds error binding to LDAP with supplied credentials : '%s'\n", err) return false } } else { log.Printf("VerifyLdapCreds successfully bound to LDAP\n") } /* log.Printf("Attempting LDAP search request from base DN '%s'\n", LdapBaseDn) searchReq := ldap.NewSearchRequest( LdapBaseDn, ldap.ScopeWholeSubtree, // you can also use ldap.ScopeWholeSubtree ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{}, nil, ) result, err := ldaps.Search(searchReq) if err != nil { log.Printf("VerifyLdapCreds search error : '%s'\n", err) return false } log.Printf("result: %v\n", result) */ groups, err := GetGroupsOfUser(username, LdapBaseDn, ldaps) if err != nil { log.Printf("VerifyLdapCreds group search error : '%s'\n", err) return false } fmt.Printf("groups: %v\n", groups) return true } // GetGroupsOfUser returns the group for a user. // Taken from https://github.com/jtblin/go-ldap-client/issues/13#issuecomment-456090979 func GetGroupsOfUser(username string, baseDN string, conn *ldap.Conn) ([]string, error) { var samAccountName string var groups []string if strings.Contains(username, "@") { s := strings.Split(username, "@") samAccountName = s[0] } else if strings.Contains(username, "\\") { s := strings.Split(username, "\\") samAccountName = s[len(s)-1] } else { samAccountName = username } // Get the users DN // Search for the given username searchRequest := ldap.NewSearchRequest( baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(sAMAccountName=%s)", ldap.EscapeFilter(samAccountName)), []string{}, nil, ) fmt.Printf("searchRequest: %v\n", searchRequest) sr, err := conn.Search(searchRequest) if err != nil { return nil, err } if len(sr.Entries) != 1 { return nil, fmt.Errorf("user '%s' does not exist", samAccountName) } else { // Get the groups of the first result groups = sr.Entries[0].GetAttributeValues("memberOf") } return groups, nil }