diff --git a/main.go b/main.go index 75474ba..53f9b88 100644 --- a/main.go +++ b/main.go @@ -5,10 +5,12 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "flag" "fmt" "os" "path/filepath" + "strings" "github.com/go-ldap/ldap/v3" ) @@ -19,6 +21,7 @@ type Output struct { Error string CertLoaded bool Results string + Groups string } func GetFilePath(path string) string { @@ -52,6 +55,51 @@ func isFlagPassed(name string) bool { return found } +// GetGroupsOfUser returns the group for a user. +// Taken from https://github.com/jtblin/go-ldap-client/issues/13#issuecomment-456090979 +func GetGroupsOfUser(username string, baseDN string, conn *ldap.Conn) ([]string, error) { + + // Get the users DN + // Search for the given username + searchRequest := ldap.NewSearchRequest( + baseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(uid=%s)", username), + []string{"dn"}, + nil, + ) + + sr, err := conn.Search(searchRequest) + if err != nil { + return nil, err + } + + if len(sr.Entries) != 1 { + return nil, errors.New("user does not exist") + } + + userdn := sr.Entries[0].DN + + searchRequest = ldap.NewSearchRequest( + baseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(memberUid=%s)", userdn), + []string{"cn"}, // can it be something else than "cn"? + nil, + ) + sr, err = conn.Search(searchRequest) + if err != nil { + return nil, err + } + + groups := []string{} + for _, entry := range sr.Entries { + groups = append(groups, entry.GetAttributeValue("cn")) + } + + return groups, nil +} + func main() { var output Output @@ -173,6 +221,15 @@ func main() { } else { output.AuthSuccess = true output.Results = fmt.Sprintf("Search result count: %d; %s", len(result.Entries), result.Entries[0].DN) + + // Since we have a successful connection, try getting group membership + groups, err := GetGroupsOfUser(*username, *baseDN, ldaps) + if err != nil { + output.Results = err.Error() + } else { + output.Groups = strings.Join(groups[:], ",") + } + b, _ := json.Marshal(output) fmt.Println(string(b)) return