Files
smt/models/ldap.go
Nathan Coad ffa8778d2b
All checks were successful
continuous-integration/drone/push Build is passing
more work on ldap integration
2024-01-04 20:45:34 +11:00

249 lines
6.2 KiB
Go

package models
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/go-ldap/ldap"
)
// Code relating to AD integration
type LdapConfig struct {
LdapBindAddress string
LdapBaseDn string
LdapCertFile string
}
var systemCA *x509.CertPool
var ldaps *ldap.Conn
var CertLoaded bool
var LdapEnabled bool
var LdapBaseDn string
func GetFilePath(path string) string {
// Check for empty filename
if len(path) == 0 {
return ""
}
// check if filename exists
if _, err := os.Stat(path); os.IsNotExist((err)) {
log.Printf("File '%s' not found, searching in same directory as binary\n", path)
// if not, check that it exists in the same directory as the currently executing binary
ex, err2 := os.Executable()
if err2 != nil {
//log.Printf("Error determining binary path : '%s'", err)
return ""
}
binaryPath := filepath.Dir(ex)
path = filepath.Join(binaryPath, path)
}
return path
}
func loadLdapCert() {
var err error
// Get a copy of the system defined CA's
systemCA, err = x509.SystemCertPool()
if err != nil {
log.Printf("LoadLdapCert error getting system certificate pool : '%s'\n", err)
return
}
// only try to load certificate from file if the command line argument was specified
ldapCertFile := os.Getenv("LDAP_TRUST_CERT_FILE")
if ldapCertFile == "" {
log.Printf("LoadLdapCert no certificate specified\n")
return
} else {
// Try to read the file
cf, err := os.ReadFile(GetFilePath(ldapCertFile))
if err != nil {
log.Printf("LoadLdapCert error opening LDAP certificate file '%s' : '%s'\n", ldapCertFile, err)
return
}
// Get the certificate from the file
cpb, _ := pem.Decode(cf)
crt, err := x509.ParseCertificate(cpb.Bytes)
//fmt.Printf("Loaded certificate with subject %s\n", crt.Subject)
if err != nil {
log.Printf("LoadLdapCert error processing LDAP certificate file '%s' : '%s'\n", ldapCertFile, err)
return
}
// Add custom certificate to the system cert pool
systemCA.AddCert(crt)
CertLoaded = true
}
}
func LdapSetup() bool {
var err error
// Load LDAP certificate if necessary
loadLdapCert()
ldapServer := os.Getenv("LDAP_BIND_ADDRESS")
if ldapServer == "" {
log.Printf("VerifyLdapCreds no LDAP bind address supplied\n")
return false
} else {
LdapEnabled = true
}
LdapBaseDn = os.Getenv("LDAP_BASE_DN")
if LdapBaseDn == "" {
log.Printf("VerifyLdapCreds no LDAP base DN supplied\n")
return false
}
// Set up TLS to use our custom certificate authority passed in cli argument
tlsConfig := &tls.Config{
RootCAs: systemCA,
InsecureSkipVerify: true,
}
// Add port if not specified in .env file
if !(strings.HasSuffix(ldapServer, ":636")) {
ldapServer = fmt.Sprintf("%s:636", ldapServer)
log.Printf("VerifyLdapCreds updated ldapServer string '%s'\n", ldapServer)
}
// try connecting to AD via TLS and our custom certificate authority
ldaps, err = ldap.DialTLS("tcp", ldapServer, tlsConfig)
if err != nil {
log.Printf("VerifyLdapCreds error connecting to LDAP bind address '%s' : '%s'\n", ldapServer, err)
return false
}
defer ldaps.Close()
LdapEnabled = true
LookupNamingContext()
return true
}
func LookupNamingContext() {
// Retrieve the defaultNamingContext
searchRequest := ldap.NewSearchRequest(
"",
ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false,
"(objectClass=*)",
[]string{"defaultNamingContext"},
nil,
)
searchResult, err := ldaps.Search(searchRequest)
if err != nil {
log.Printf("LookupNamingContext unable to perform unauthenticated search : '%s'\n", err)
return
}
if len(searchResult.Entries) != 1 {
log.Printf("LookupNamingContext unable to retrieve defaultNamingContext\n")
return
}
defaultNamingContext := searchResult.Entries[0].GetAttributeValue("defaultNamingContext")
if defaultNamingContext == "" {
log.Printf("LookupNamingContext defaultNamingContext attribute not found\n")
return
}
log.Printf("Default Naming Context: '%s'\n", defaultNamingContext)
}
func VerifyLdapCreds(username string, password string) bool {
var err error
// try an authenticated bind to AD to verify credentials
log.Printf("Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password))
err = ldaps.Bind(username, password)
if err != nil {
if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials {
log.Printf("VerifyLdapCreds user credentials are incorrect : '%s'\n", err)
return false
} else {
log.Printf("VerifyLdapCreds error binding to LDAP with supplied credentials : '%s'\n", err)
return false
}
} else {
log.Printf("VerifyLdapCreds successfully bound to LDAP\n")
}
log.Printf("Attempting LDAP search request from base DN '%s'\n", LdapBaseDn)
searchReq := ldap.NewSearchRequest(
LdapBaseDn,
ldap.ScopeWholeSubtree, // you can also use ldap.ScopeWholeSubtree
ldap.NeverDerefAliases,
0,
0,
false,
"(objectClass=*)",
[]string{},
nil,
)
result, err := ldaps.Search(searchReq)
if err != nil {
log.Printf("VerifyLdapCreds search error : '%s'\n", err)
return false
}
log.Printf("result: %v\n", result)
return true
}
// GetGroupsOfUser returns the group for a user.
// Taken from https://github.com/jtblin/go-ldap-client/issues/13#issuecomment-456090979
func GetGroupsOfUser(username string, baseDN string, conn *ldap.Conn) ([]string, error) {
var samAccountName string
var groups []string
if strings.Contains(username, "@") {
s := strings.Split(username, "@")
samAccountName = s[0]
} else if strings.Contains(username, "\\") {
s := strings.Split(username, "\\")
samAccountName = s[len(s)-1]
} else {
samAccountName = username
}
// Get the users DN
// Search for the given username
searchRequest := ldap.NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(CN=%s)", ldap.EscapeFilter(samAccountName)),
[]string{},
nil,
)
sr, err := conn.Search(searchRequest)
if err != nil {
return nil, err
}
if len(sr.Entries) != 1 {
return nil, fmt.Errorf("user '%s' does not exist", samAccountName)
} else {
// Get the groups of the first result
groups = sr.Entries[0].GetAttributeValues("memberOf")
}
return groups, nil
}