diff --git a/models/ldap.go b/models/ldap.go index 22c1e30..4da21bf 100644 --- a/models/ldap.go +++ b/models/ldap.go @@ -236,7 +236,7 @@ func LdapGetGroupMembership(username string, password string) ([]string, error) defer ldaps.Close() // try an authenticated bind to AD to verify credentials - log.Printf("GetLdapGroupMembership Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password)) + log.Printf("LdapGetGroupMembership Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password)) err = ldaps.Bind(username, password) if err != nil { if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials { @@ -244,17 +244,17 @@ func LdapGetGroupMembership(username string, password string) ([]string, error) log.Print(errString) return nil, errors.New(errString) } else { - errString := fmt.Sprintf("GetLdapGroupMembership error binding to LDAP with supplied credentials : '%s'\n", err) + errString := fmt.Sprintf("LdapGetGroupMembership error binding to LDAP with supplied credentials : '%s'\n", err) log.Print(errString) return nil, errors.New(errString) } } else { - log.Printf("GetLdapGroupMembership successfully bound to LDAP\n") + log.Printf("LdapGetGroupMembership successfully bound to LDAP\n") } groups, err := GetGroupsOfUser(username, LdapBaseDn, ldaps) if err != nil { - errString := fmt.Sprintf("GetLdapGroupMembership group search error : '%s'\n", err) + errString := fmt.Sprintf("LdapGetGroupMembership group search error : '%s'\n", err) log.Print(errString) return nil, errors.New(errString) } @@ -373,3 +373,12 @@ func GetLdapUserDn(username string, baseDN string, conn *ldap.Conn) (string, err } } } + +// Returns the user portion of a UPN formatted username +func GetUserFromUPN(email string) string { + parts := strings.Split(email, "@") + if len(parts) > 0 { + return parts[0] + } + return "" +} diff --git a/models/user.go b/models/user.go index 769e814..e5cbad4 100644 --- a/models/user.go +++ b/models/user.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "smt/utils/token" + "strings" "time" "golang.org/x/crypto/bcrypt" @@ -123,18 +124,32 @@ func LoginCheck(username string, password string) (string, error) { // Query database for matching user object // Use IFNULL to handle situation where a user might not be a member of a group + // Join on groups table so we can get the value in LdapGroup column - // TODO join on groups table so we can get the value in LdapGroup column - - err = db.QueryRowx(` + // if username is UPN format then get just the user portion + if strings.Contains(username, "@") { + plainUser := GetUserFromUPN(username) + // check for original username or plainUser + err = db.QueryRowx(` + SELECT users.UserId, IFNULL(users.GroupId, 0) GroupId, UserName, Password, LdapUser, users.Admin, groups.LdapGroup FROM Users + INNER JOIN groups ON users.GroupId = groups.GroupId + WHERE Username=? OR Username=?`, username, plainUser).StructScan(&u) + } else { + err = db.QueryRowx(` SELECT users.UserId, IFNULL(users.GroupId, 0) GroupId, UserName, Password, LdapUser, users.Admin, groups.LdapGroup FROM Users INNER JOIN groups ON users.GroupId = groups.GroupId WHERE Username=?`, username).StructScan(&u) + } if err != nil { if err == sql.ErrNoRows { + log.Printf("LoginCheck found no users matching username '%s'\n", username) + + // TODO - if username contains UPN style login then try extracting just the username and doing a query on that + // check LDAP if enabled if LdapEnabled { + log.Printf("LoginCheck initiating ldap lookup for username '%s'\n", username) ldapUser, err := UserLdapNewLoginCheck(username, password) if err != nil { errString := fmt.Sprintf("LoginCheck error checking LDAP for user : '%s'\n", err)