Files
smt/models/ldap.go
Nathan Coad ee822b5c9d
All checks were successful
continuous-integration/drone/push Build is passing
if username in UPN format for login try searching both user and full UPN string
2024-04-02 16:55:11 +11:00

385 lines
10 KiB
Go

package models
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/go-ldap/ldap"
)
// Code relating to AD integration
type LdapConfig struct {
LdapBindAddress string
LdapBaseDn string
LdapCertFile string
}
var systemCA *x509.CertPool
// var ldaps *ldap.Conn
var LdapServer string
var CertLoaded bool
var LdapEnabled bool
var LdapInsecure bool = false
var LdapBaseDn string
var DefaultDomainSuffix string
func GetFilePath(path string) string {
// Check for empty filename
if len(path) == 0 {
return ""
}
// check if filename exists
if _, err := os.Stat(path); os.IsNotExist((err)) {
log.Printf("File '%s' not found, searching in same directory as binary\n", path)
// if not, check that it exists in the same directory as the currently executing binary
ex, err2 := os.Executable()
if err2 != nil {
//log.Printf("Error determining binary path : '%s'", err)
return ""
}
binaryPath := filepath.Dir(ex)
path = filepath.Join(binaryPath, path)
}
return path
}
// DomainSuffixFromNamingContext will convert DC=example,DC=com to example.com
func DomainSuffixFromNamingContext(input string) string {
tokens := strings.Split(input, ",")
var args []string
for _, token := range tokens {
parts := strings.Split(token, "=")
if len(parts) == 2 && parts[0] == "DC" {
args = append(args, parts[1])
}
}
return strings.Join(args, ".")
}
func CheckUsername(username string) string {
if strings.ContainsAny(username, "/@") {
// Username contains forward slash or at symbol
return username
}
// Append suffix to the username
log.Printf("CheckUsername appending default domain suffix '%s'\n", DefaultDomainSuffix)
return username + "@" + DefaultDomainSuffix
}
func loadLdapCert() {
var err error
// Get a copy of the system defined CA's
systemCA, err = x509.SystemCertPool()
if err != nil {
log.Printf("LoadLdapCert error getting system certificate pool : '%s'\n", err)
return
}
// only try to load certificate from file if the command line argument was specified
ldapCertFile := os.Getenv("LDAP_TRUST_CERT_FILE")
if ldapCertFile == "" {
log.Printf("LoadLdapCert no certificate specified\n")
return
} else {
// Try to read the file
cf, err := os.ReadFile(GetFilePath(ldapCertFile))
if err != nil {
log.Printf("LoadLdapCert error opening LDAP certificate file '%s' : '%s'\n", ldapCertFile, err)
return
}
// Get the certificate from the file
cpb, _ := pem.Decode(cf)
crt, err := x509.ParseCertificate(cpb.Bytes)
//log.Printf("Loaded certificate with subject %s\n", crt.Subject)
if err != nil {
log.Printf("LoadLdapCert error processing LDAP certificate file '%s' : '%s'\n", ldapCertFile, err)
return
}
// Add custom certificate to the system cert pool
systemCA.AddCert(crt)
CertLoaded = true
}
}
func LdapSetup() bool {
var err error
// Load LDAP certificate if necessary
loadLdapCert()
LdapServer = os.Getenv("LDAP_BIND_ADDRESS")
if LdapServer == "" {
log.Printf("VerifyLdapCreds no LDAP bind address supplied\n")
return false
} else {
LdapEnabled = true
}
LdapBaseDn = os.Getenv("LDAP_BASE_DN")
if LdapBaseDn == "" {
log.Printf("VerifyLdapCreds no LDAP base DN supplied\n")
return false
}
insecure := os.Getenv("LDAP_INSECURE_VALIDATION")
if insecure != "" {
LdapInsecure, err = strconv.ParseBool(insecure)
if err != nil {
log.Printf("LdapSetup could not convert environment variable LDAP_INSECURE_VALIDATION with value of '%s'\n", insecure)
}
}
// Set up TLS to use our custom certificate authority passed in cli argument
tlsConfig := &tls.Config{
RootCAs: systemCA,
InsecureSkipVerify: true,
}
// Add port if not specified in .env file
if !(strings.HasSuffix(LdapServer, ":636")) {
LdapServer = fmt.Sprintf("%s:636", LdapServer)
log.Printf("VerifyLdapCreds updated ldapServer string '%s'\n", LdapServer)
}
// try connecting to AD via TLS and our custom certificate authority
ldaps, err := ldap.DialTLS("tcp", LdapServer, tlsConfig)
if err != nil {
log.Printf("VerifyLdapCreds error connecting to LDAP bind address '%s' : '%s'\n", LdapServer, err)
return false
}
LdapEnabled = true
namingContext := LookupNamingContext(ldaps)
if namingContext != "" {
DefaultDomainSuffix = DomainSuffixFromNamingContext(namingContext)
}
ldaps.Close()
return true
}
// LdapConnect sets up the connection to LDAP to be used by other functions
func ldapConnect() *ldap.Conn {
// Set up TLS to use our custom certificate authority passed in cli argument
tlsConfig := &tls.Config{
RootCAs: systemCA,
InsecureSkipVerify: LdapInsecure,
}
log.Printf("ldapConnect initiating connection\n")
ldaps, err := ldap.DialTLS("tcp", LdapServer, tlsConfig)
if err != nil {
log.Printf("VerifyLdapCreds error connecting to LDAP server '%s' : '%s'\n", LdapServer, err)
return nil
}
log.Printf("ldapConnect connection succeeded\n")
return ldaps
}
func LookupNamingContext(ldaps *ldap.Conn) string {
// Retrieve the defaultNamingContext
searchRequest := ldap.NewSearchRequest(
"",
ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false,
"(objectClass=*)",
[]string{"defaultNamingContext"},
nil,
)
searchResult, err := ldaps.Search(searchRequest)
if err != nil {
log.Printf("LookupNamingContext unable to perform unauthenticated search : '%s'\n", err)
return ""
}
if len(searchResult.Entries) != 1 {
log.Printf("LookupNamingContext unable to retrieve defaultNamingContext\n")
return ""
}
defaultNamingContext := searchResult.Entries[0].GetAttributeValue("defaultNamingContext")
if defaultNamingContext == "" {
log.Printf("LookupNamingContext defaultNamingContext attribute not found\n")
return ""
}
log.Printf("Default Naming Context: '%s'\n", defaultNamingContext)
return defaultNamingContext
}
// LdapGetGroupMembership returns a list of distinguishedNames for groups that a user is a member of
func LdapGetGroupMembership(username string, password string) ([]string, error) {
var err error
username = CheckUsername(username)
ldaps := ldapConnect()
defer ldaps.Close()
// try an authenticated bind to AD to verify credentials
log.Printf("LdapGetGroupMembership Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password))
err = ldaps.Bind(username, password)
if err != nil {
if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials {
errString := "invalid user credentials"
log.Print(errString)
return nil, errors.New(errString)
} else {
errString := fmt.Sprintf("LdapGetGroupMembership error binding to LDAP with supplied credentials : '%s'\n", err)
log.Print(errString)
return nil, errors.New(errString)
}
} else {
log.Printf("LdapGetGroupMembership successfully bound to LDAP\n")
}
groups, err := GetGroupsOfUser(username, LdapBaseDn, ldaps)
if err != nil {
errString := fmt.Sprintf("LdapGetGroupMembership group search error : '%s'\n", err)
log.Print(errString)
return nil, errors.New(errString)
}
return groups, nil
}
// VerifyLdapCreds validates that we can bind successfully to LDAP with the supplied credentials
func VerifyLdapCreds(username string, password string) error {
var err error
username = CheckUsername(username)
ldaps := ldapConnect()
defer ldaps.Close()
// try an authenticated bind to AD to verify credentials
log.Printf("Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password))
err = ldaps.Bind(username, password)
if err != nil {
if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials {
errString := "invalid user credentials"
log.Print(errString)
return errors.New(errString)
} else {
errString := fmt.Sprintf("VerifyLdapCreds error binding to LDAP with supplied credentials : '%s'\n", err)
log.Print(errString)
return errors.New(errString)
}
} else {
log.Printf("VerifyLdapCreds successfully bound to LDAP\n")
}
return nil
}
// GetGroupsOfUser returns the group for a user.
// Taken from https://github.com/jtblin/go-ldap-client/issues/13#issuecomment-456090979
func GetGroupsOfUser(username string, baseDN string, conn *ldap.Conn) ([]string, error) {
var sAMAccountName string
var groups []string
if strings.Contains(username, "@") {
s := strings.Split(username, "@")
sAMAccountName = s[0]
} else if strings.Contains(username, "\\") {
s := strings.Split(username, "\\")
sAMAccountName = s[len(s)-1]
} else {
sAMAccountName = username
}
// Get the users DN
// Search for the given username
searchRequest := ldap.NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(sAMAccountName=%s)", ldap.EscapeFilter(sAMAccountName)),
[]string{},
nil,
)
log.Printf("searchRequest: %v\n", searchRequest)
sr, err := conn.Search(searchRequest)
if err != nil {
return nil, err
}
if len(sr.Entries) != 1 {
return nil, fmt.Errorf("user '%s' does not exist", sAMAccountName)
} else {
// Get the groups of the first result
groups = sr.Entries[0].GetAttributeValues("memberOf")
}
return groups, nil
}
func GetLdapUserDn(username string, baseDN string, conn *ldap.Conn) (string, error) {
var sAMAccountName string
if strings.Contains(username, "@") {
s := strings.Split(username, "@")
sAMAccountName = s[0]
} else if strings.Contains(username, "\\") {
s := strings.Split(username, "\\")
sAMAccountName = s[len(s)-1]
} else {
sAMAccountName = username
}
// Search for the user's distinguishedName
searchRequest := ldap.NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(sAMAccountName=%s)", sAMAccountName),
[]string{"distinguishedName"},
nil,
)
searchResult, err := conn.Search(searchRequest)
if err != nil {
log.Fatal(err)
}
if len(searchResult.Entries) == 0 {
return "", fmt.Errorf("user '%s' does not exist", sAMAccountName)
} else {
// Retrieve the distinguishedName of the user
distinguishedName := searchResult.Entries[0].GetAttributeValue("distinguishedName")
if distinguishedName != "" {
log.Printf("GetLdapUserDn located user's distinguishedName : '%s'\n", distinguishedName)
return distinguishedName, nil
} else {
return "", fmt.Errorf("could not find distinguishedName for user '%s'", sAMAccountName)
}
}
}
// Returns the user portion of a UPN formatted username
func GetUserFromUPN(email string) string {
parts := strings.Split(email, "@")
if len(parts) > 0 {
return parts[0]
}
return ""
}