+82
-36
@@ -32,9 +32,12 @@ type LDAPConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LDAPIdentity struct {
|
type LDAPIdentity struct {
|
||||||
Username string
|
Username string
|
||||||
UserDN string
|
UserDN string
|
||||||
Groups []string
|
Groups []string
|
||||||
|
BindDuration time.Duration
|
||||||
|
UserLookupDuration time.Duration
|
||||||
|
GroupMembershipLookupDuration time.Duration
|
||||||
// Diagnostics contains non-sensitive LDAP processing notes useful for debugging auth decisions.
|
// Diagnostics contains non-sensitive LDAP processing notes useful for debugging auth decisions.
|
||||||
Diagnostics []string
|
Diagnostics []string
|
||||||
}
|
}
|
||||||
@@ -79,13 +82,14 @@ func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
|
func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) {
|
||||||
username = strings.TrimSpace(username)
|
inputUsername := strings.TrimSpace(username)
|
||||||
if username == "" || password == "" {
|
if inputUsername == "" || password == "" {
|
||||||
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
return LDAPIdentity{}, ErrLDAPInvalidCredentials
|
||||||
}
|
}
|
||||||
if err := ctxErr(ctx); err != nil {
|
if err := ctxErr(ctx); err != nil {
|
||||||
return LDAPIdentity{}, err
|
return LDAPIdentity{}, err
|
||||||
}
|
}
|
||||||
|
bindUsername, rewrittenToUPN := normalizeBindUsername(inputUsername, a.baseDN)
|
||||||
|
|
||||||
conn, err := a.connect()
|
conn, err := a.connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -93,19 +97,27 @@ func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, user
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
if err := conn.Bind(username, password); err != nil {
|
bindStartedAt := time.Now()
|
||||||
|
err = conn.Bind(bindUsername, password)
|
||||||
|
bindDuration := time.Since(bindStartedAt)
|
||||||
|
if err != nil {
|
||||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
||||||
return LDAPIdentity{}, fmt.Errorf("%w: ldap bind rejected credentials", ErrLDAPInvalidCredentials)
|
return LDAPIdentity{}, fmt.Errorf("%w: ldap bind rejected credentials (bind_duration=%s)", ErrLDAPInvalidCredentials, bindDuration)
|
||||||
}
|
}
|
||||||
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err)
|
return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v (bind_duration=%s)", ErrLDAPOperationFailed, err, bindDuration)
|
||||||
}
|
}
|
||||||
if err := ctxErr(ctx); err != nil {
|
if err := ctxErr(ctx); err != nil {
|
||||||
return LDAPIdentity{}, err
|
return LDAPIdentity{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
identity := LDAPIdentity{
|
identity := LDAPIdentity{
|
||||||
Username: username,
|
Username: inputUsername,
|
||||||
UserDN: username,
|
UserDN: bindUsername,
|
||||||
|
BindDuration: bindDuration,
|
||||||
|
}
|
||||||
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("bind_duration_ms=%d", bindDuration.Milliseconds()))
|
||||||
|
if rewrittenToUPN {
|
||||||
|
identity.Diagnostics = append(identity.Diagnostics, "bind_username_rewritten_to_upn")
|
||||||
}
|
}
|
||||||
if whoami, err := conn.WhoAmI(nil); err != nil {
|
if whoami, err := conn.WhoAmI(nil); err != nil {
|
||||||
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("whoami_failed:%v", err))
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("whoami_failed:%v", err))
|
||||||
@@ -118,9 +130,12 @@ func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, user
|
|||||||
identity.Diagnostics = append(identity.Diagnostics, "whoami_non_dn_authzid")
|
identity.Diagnostics = append(identity.Diagnostics, "whoami_non_dn_authzid")
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, lookupStrategy, err := a.lookupUserEntry(conn, username, identity.UserDN)
|
userLookupStartedAt := time.Now()
|
||||||
|
entry, lookupStrategy, err := a.lookupUserEntry(conn, bindUsername, identity.UserDN)
|
||||||
|
identity.UserLookupDuration = time.Since(userLookupStartedAt)
|
||||||
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("user_lookup_duration_ms=%d", identity.UserLookupDuration.Milliseconds()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return LDAPIdentity{}, err
|
return LDAPIdentity{}, fmt.Errorf("%w: %v (bind_duration=%s user_lookup_duration=%s)", ErrLDAPOperationFailed, err, identity.BindDuration, identity.UserLookupDuration)
|
||||||
}
|
}
|
||||||
if entry != nil {
|
if entry != nil {
|
||||||
if lookupStrategy == "" {
|
if lookupStrategy == "" {
|
||||||
@@ -143,6 +158,7 @@ func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
groupSet := make(map[string]struct{})
|
groupSet := make(map[string]struct{})
|
||||||
|
groupLookupStartedAt := time.Now()
|
||||||
if entry != nil {
|
if entry != nil {
|
||||||
for _, groupDN := range entry.GetAttributeValues("memberOf") {
|
for _, groupDN := range entry.GetAttributeValues("memberOf") {
|
||||||
groupDN = strings.TrimSpace(groupDN)
|
groupDN = strings.TrimSpace(groupDN)
|
||||||
@@ -153,30 +169,11 @@ func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, user
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
groupFilter := buildGroupMembershipFilter(identity.UserDN, principalCandidates(username))
|
// Intentionally skip subtree group membership search for now.
|
||||||
groupEntries, err := conn.Search(ldap.NewSearchRequest(
|
// Authorization is based only on direct group membership values present in the user entry (memberOf).
|
||||||
a.baseDN,
|
identity.GroupMembershipLookupDuration = time.Since(groupLookupStartedAt)
|
||||||
ldap.ScopeWholeSubtree,
|
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("group_lookup_duration_ms=%d", identity.GroupMembershipLookupDuration.Milliseconds()))
|
||||||
ldap.NeverDerefAliases,
|
identity.Diagnostics = append(identity.Diagnostics, "group_search_skipped_direct_memberof_only")
|
||||||
0,
|
|
||||||
0,
|
|
||||||
false,
|
|
||||||
groupFilter,
|
|
||||||
[]string{"dn"},
|
|
||||||
nil,
|
|
||||||
))
|
|
||||||
if err == nil {
|
|
||||||
for _, e := range groupEntries.Entries {
|
|
||||||
if dn := strings.TrimSpace(e.DN); dn != "" {
|
|
||||||
groupSet[dn] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(groupEntries.Entries) == 0 {
|
|
||||||
identity.Diagnostics = append(identity.Diagnostics, "group_search_returned_no_entries")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
identity.Diagnostics = append(identity.Diagnostics, fmt.Sprintf("group_search_failed:%v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
identity.Groups = mapKeysSorted(groupSet)
|
identity.Groups = mapKeysSorted(groupSet)
|
||||||
identity.Diagnostics = compactTrimmedStrings(identity.Diagnostics)
|
identity.Diagnostics = compactTrimmedStrings(identity.Diagnostics)
|
||||||
@@ -399,6 +396,55 @@ func parseWhoAmIDN(authzID string) string {
|
|||||||
return authzID
|
return authzID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeBindUsername(username string, baseDN string) (string, bool) {
|
||||||
|
username = strings.TrimSpace(username)
|
||||||
|
if username == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if looksLikeDN(username) || strings.Contains(username, "@") {
|
||||||
|
return username, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert DOMAIN\user to user before UPN rewrite.
|
||||||
|
if idx := strings.LastIndex(username, `\`); idx >= 0 && idx < len(username)-1 {
|
||||||
|
username = strings.TrimSpace(username[idx+1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
domain := upnDomainFromBaseDN(baseDN)
|
||||||
|
if domain == "" {
|
||||||
|
return username, false
|
||||||
|
}
|
||||||
|
if strings.Contains(username, "@") {
|
||||||
|
return username, false
|
||||||
|
}
|
||||||
|
return username + "@" + domain, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func upnDomainFromBaseDN(baseDN string) string {
|
||||||
|
baseDN = strings.TrimSpace(baseDN)
|
||||||
|
if baseDN == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(baseDN, ",")
|
||||||
|
labels := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if len(part) < 3 || !strings.EqualFold(part[:3], "dc=") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
label := strings.TrimSpace(part[3:])
|
||||||
|
if label == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
labels = append(labels, label)
|
||||||
|
}
|
||||||
|
if len(labels) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.Join(labels, ".")
|
||||||
|
}
|
||||||
|
|
||||||
func principalCandidates(username string) []string {
|
func principalCandidates(username string) []string {
|
||||||
username = strings.TrimSpace(username)
|
username = strings.TrimSpace(username)
|
||||||
if username == "" {
|
if username == "" {
|
||||||
|
|||||||
@@ -124,3 +124,87 @@ func TestParseWhoAmIDN(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUPNDomainFromBaseDN(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
baseDN string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "standard dc chain",
|
||||||
|
baseDN: "dc=corpau,dc=wbcau,dc=westpac,dc=com,dc=au",
|
||||||
|
want: "corpau.wbcau.westpac.com.au",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed dn parts",
|
||||||
|
baseDN: "ou=Users,dc=example,dc=com",
|
||||||
|
want: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no dc parts",
|
||||||
|
baseDN: "ou=Users,ou=Org",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := upnDomainFromBaseDN(tc.baseDN)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("unexpected upn domain from base dn: got=%q want=%q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeBindUsername(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
username string
|
||||||
|
baseDN string
|
||||||
|
wantUser string
|
||||||
|
wantRewrite bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "plain sam rewritten",
|
||||||
|
username: "L075239",
|
||||||
|
baseDN: "dc=corpau,dc=wbcau,dc=westpac,dc=com,dc=au",
|
||||||
|
wantUser: "L075239@corpau.wbcau.westpac.com.au",
|
||||||
|
wantRewrite: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain user rewritten",
|
||||||
|
username: `CORPAU\L075239`,
|
||||||
|
baseDN: "dc=corpau,dc=wbcau,dc=westpac,dc=com,dc=au",
|
||||||
|
wantUser: "L075239@corpau.wbcau.westpac.com.au",
|
||||||
|
wantRewrite: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "upn unchanged",
|
||||||
|
username: "L075239@corpau.wbcau.westpac.com.au",
|
||||||
|
baseDN: "dc=corpau,dc=wbcau,dc=westpac,dc=com,dc=au",
|
||||||
|
wantUser: "L075239@corpau.wbcau.westpac.com.au",
|
||||||
|
wantRewrite: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dn unchanged",
|
||||||
|
username: "CN=User,OU=Users,DC=corpau,DC=wbcau,DC=westpac,DC=com,DC=au",
|
||||||
|
baseDN: "dc=corpau,dc=wbcau,dc=westpac,dc=com,dc=au",
|
||||||
|
wantUser: "CN=User,OU=Users,DC=corpau,DC=wbcau,DC=westpac,DC=com,DC=au",
|
||||||
|
wantRewrite: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
gotUser, gotRewrite := normalizeBindUsername(tc.username, tc.baseDN)
|
||||||
|
if gotUser != tc.wantUser {
|
||||||
|
t.Fatalf("unexpected normalized bind username: got=%q want=%q", gotUser, tc.wantUser)
|
||||||
|
}
|
||||||
|
if gotRewrite != tc.wantRewrite {
|
||||||
|
t.Fatalf("unexpected rewrite flag: got=%v want=%v", gotRewrite, tc.wantRewrite)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -109,7 +109,9 @@ func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
ctx, cancel := withRequestTimeout(r, authLoginRequestTimeout)
|
ctx, cancel := withRequestTimeout(r, authLoginRequestTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
ldapAuthStartedAt := time.Now()
|
||||||
identity, err := ldapAuth.AuthenticateAndFetchGroups(ctx, username, password)
|
identity, err := ldapAuth.AuthenticateAndFetchGroups(ctx, username, password)
|
||||||
|
ldapAuthDuration := time.Since(ldapAuthStartedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, auth.ErrLDAPInvalidCredentials) {
|
if errors.Is(err, auth.ErrLDAPInvalidCredentials) {
|
||||||
audit.LogAuthEvent(h.Logger, r, "login", "deny",
|
audit.LogAuthEvent(h.Logger, r, "login", "deny",
|
||||||
@@ -117,6 +119,7 @@ func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
"username", username,
|
"username", username,
|
||||||
"ldap_bind_address", cfg.LDAPBindAddress,
|
"ldap_bind_address", cfg.LDAPBindAddress,
|
||||||
"ldap_base_dn", cfg.LDAPBaseDN,
|
"ldap_base_dn", cfg.LDAPBaseDN,
|
||||||
|
"ldap_auth_total_duration_ms", ldapAuthDuration.Milliseconds(),
|
||||||
"error", err,
|
"error", err,
|
||||||
)
|
)
|
||||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||||
@@ -129,6 +132,7 @@ func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
"ldap_bind_address", cfg.LDAPBindAddress,
|
"ldap_bind_address", cfg.LDAPBindAddress,
|
||||||
"ldap_base_dn", cfg.LDAPBaseDN,
|
"ldap_base_dn", cfg.LDAPBaseDN,
|
||||||
"timeout_seconds", authLoginRequestTimeout.Seconds(),
|
"timeout_seconds", authLoginRequestTimeout.Seconds(),
|
||||||
|
"ldap_auth_total_duration_ms", ldapAuthDuration.Milliseconds(),
|
||||||
"error", err,
|
"error", err,
|
||||||
)
|
)
|
||||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||||
@@ -139,6 +143,7 @@ func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
"username", username,
|
"username", username,
|
||||||
"ldap_bind_address", cfg.LDAPBindAddress,
|
"ldap_bind_address", cfg.LDAPBindAddress,
|
||||||
"ldap_base_dn", cfg.LDAPBaseDN,
|
"ldap_base_dn", cfg.LDAPBaseDN,
|
||||||
|
"ldap_auth_total_duration_ms", ldapAuthDuration.Milliseconds(),
|
||||||
"error", err,
|
"error", err,
|
||||||
)
|
)
|
||||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||||
@@ -151,6 +156,10 @@ func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
"ldap_user_dn", identity.UserDN,
|
"ldap_user_dn", identity.UserDN,
|
||||||
"ldap_group_count", len(identity.Groups),
|
"ldap_group_count", len(identity.Groups),
|
||||||
"ldap_groups", limitStrings(identity.Groups, maxDebugLogListItems),
|
"ldap_groups", limitStrings(identity.Groups, maxDebugLogListItems),
|
||||||
|
"ldap_auth_total_duration_ms", ldapAuthDuration.Milliseconds(),
|
||||||
|
"ldap_bind_duration_ms", identity.BindDuration.Milliseconds(),
|
||||||
|
"ldap_user_lookup_duration_ms", identity.UserLookupDuration.Milliseconds(),
|
||||||
|
"ldap_group_lookup_duration_ms", identity.GroupMembershipLookupDuration.Milliseconds(),
|
||||||
"ldap_diagnostics", limitStrings(identity.Diagnostics, maxDebugLogListItems),
|
"ldap_diagnostics", limitStrings(identity.Diagnostics, maxDebugLogListItems),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -163,6 +172,10 @@ func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
"required_groups", limitStrings(cfg.LDAPGroups, maxDebugLogListItems),
|
"required_groups", limitStrings(cfg.LDAPGroups, maxDebugLogListItems),
|
||||||
"user_groups", limitStrings(identity.Groups, maxDebugLogListItems),
|
"user_groups", limitStrings(identity.Groups, maxDebugLogListItems),
|
||||||
"resolved_roles", roles,
|
"resolved_roles", roles,
|
||||||
|
"ldap_auth_total_duration_ms", ldapAuthDuration.Milliseconds(),
|
||||||
|
"ldap_bind_duration_ms", identity.BindDuration.Milliseconds(),
|
||||||
|
"ldap_user_lookup_duration_ms", identity.UserLookupDuration.Milliseconds(),
|
||||||
|
"ldap_group_lookup_duration_ms", identity.GroupMembershipLookupDuration.Milliseconds(),
|
||||||
"auth_group_role_mapping_keys", limitStrings(sortedStringMapKeys(cfg.AuthGroupRoleMappings), maxDebugLogListItems),
|
"auth_group_role_mapping_keys", limitStrings(sortedStringMapKeys(cfg.AuthGroupRoleMappings), maxDebugLogListItems),
|
||||||
)
|
)
|
||||||
if !hasRequiredGroup || len(roles) == 0 {
|
if !hasRequiredGroup || len(roles) == 0 {
|
||||||
@@ -174,6 +187,10 @@ func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
"required_groups", limitStrings(cfg.LDAPGroups, maxDebugLogListItems),
|
"required_groups", limitStrings(cfg.LDAPGroups, maxDebugLogListItems),
|
||||||
"user_groups", limitStrings(identity.Groups, maxDebugLogListItems),
|
"user_groups", limitStrings(identity.Groups, maxDebugLogListItems),
|
||||||
"resolved_roles", roles,
|
"resolved_roles", roles,
|
||||||
|
"ldap_auth_total_duration_ms", ldapAuthDuration.Milliseconds(),
|
||||||
|
"ldap_bind_duration_ms", identity.BindDuration.Milliseconds(),
|
||||||
|
"ldap_user_lookup_duration_ms", identity.UserLookupDuration.Milliseconds(),
|
||||||
|
"ldap_group_lookup_duration_ms", identity.GroupMembershipLookupDuration.Milliseconds(),
|
||||||
"ldap_diagnostics", limitStrings(identity.Diagnostics, maxDebugLogListItems),
|
"ldap_diagnostics", limitStrings(identity.Diagnostics, maxDebugLogListItems),
|
||||||
)
|
)
|
||||||
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage)
|
||||||
|
|||||||
Reference in New Issue
Block a user