All checks were successful
continuous-integration/drone/push Build is passing
290 lines
7.4 KiB
Go
290 lines
7.4 KiB
Go
package models
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/go-ldap/ldap"
|
|
)
|
|
|
|
// Code relating to AD integration
|
|
|
|
type LdapConfig struct {
|
|
LdapBindAddress string
|
|
LdapBaseDn string
|
|
LdapCertFile string
|
|
}
|
|
|
|
var systemCA *x509.CertPool
|
|
var ldaps *ldap.Conn
|
|
var CertLoaded bool
|
|
var LdapEnabled bool
|
|
var LdapBaseDn string
|
|
var DefaultDomainSuffix string
|
|
|
|
func GetFilePath(path string) string {
|
|
// Check for empty filename
|
|
if len(path) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// check if filename exists
|
|
if _, err := os.Stat(path); os.IsNotExist((err)) {
|
|
log.Printf("File '%s' not found, searching in same directory as binary\n", path)
|
|
// if not, check that it exists in the same directory as the currently executing binary
|
|
ex, err2 := os.Executable()
|
|
if err2 != nil {
|
|
//log.Printf("Error determining binary path : '%s'", err)
|
|
return ""
|
|
}
|
|
binaryPath := filepath.Dir(ex)
|
|
path = filepath.Join(binaryPath, path)
|
|
}
|
|
return path
|
|
}
|
|
|
|
// DomainSuffixFromNamingContext will convert DC=example,DC=com to example.com
|
|
func DomainSuffixFromNamingContext(input string) string {
|
|
tokens := strings.Split(input, ",")
|
|
var args []string
|
|
for _, token := range tokens {
|
|
parts := strings.Split(token, "=")
|
|
if len(parts) == 2 && parts[0] == "DC" {
|
|
args = append(args, parts[1])
|
|
}
|
|
}
|
|
return strings.Join(args, ".")
|
|
}
|
|
|
|
func CheckUsername(username string) string {
|
|
if strings.ContainsAny(username, "/@") {
|
|
// Username contains forward slash or at symbol
|
|
return username
|
|
}
|
|
|
|
// Append suffix to the username
|
|
log.Printf("CheckUsername appending default domain suffix '%s'\n", DefaultDomainSuffix)
|
|
return username + "@" + DefaultDomainSuffix
|
|
}
|
|
|
|
func loadLdapCert() {
|
|
var err error
|
|
// Get a copy of the system defined CA's
|
|
systemCA, err = x509.SystemCertPool()
|
|
if err != nil {
|
|
log.Printf("LoadLdapCert error getting system certificate pool : '%s'\n", err)
|
|
return
|
|
}
|
|
|
|
// only try to load certificate from file if the command line argument was specified
|
|
ldapCertFile := os.Getenv("LDAP_TRUST_CERT_FILE")
|
|
if ldapCertFile == "" {
|
|
log.Printf("LoadLdapCert no certificate specified\n")
|
|
return
|
|
} else {
|
|
// Try to read the file
|
|
cf, err := os.ReadFile(GetFilePath(ldapCertFile))
|
|
if err != nil {
|
|
log.Printf("LoadLdapCert error opening LDAP certificate file '%s' : '%s'\n", ldapCertFile, err)
|
|
return
|
|
}
|
|
|
|
// Get the certificate from the file
|
|
cpb, _ := pem.Decode(cf)
|
|
crt, err := x509.ParseCertificate(cpb.Bytes)
|
|
//fmt.Printf("Loaded certificate with subject %s\n", crt.Subject)
|
|
|
|
if err != nil {
|
|
log.Printf("LoadLdapCert error processing LDAP certificate file '%s' : '%s'\n", ldapCertFile, err)
|
|
return
|
|
}
|
|
|
|
// Add custom certificate to the system cert pool
|
|
systemCA.AddCert(crt)
|
|
|
|
CertLoaded = true
|
|
}
|
|
}
|
|
|
|
func LdapSetup() bool {
|
|
var err error
|
|
|
|
// Load LDAP certificate if necessary
|
|
loadLdapCert()
|
|
|
|
ldapServer := os.Getenv("LDAP_BIND_ADDRESS")
|
|
if ldapServer == "" {
|
|
log.Printf("VerifyLdapCreds no LDAP bind address supplied\n")
|
|
return false
|
|
} else {
|
|
LdapEnabled = true
|
|
}
|
|
|
|
LdapBaseDn = os.Getenv("LDAP_BASE_DN")
|
|
if LdapBaseDn == "" {
|
|
log.Printf("VerifyLdapCreds no LDAP base DN supplied\n")
|
|
return false
|
|
}
|
|
|
|
// Set up TLS to use our custom certificate authority passed in cli argument
|
|
tlsConfig := &tls.Config{
|
|
RootCAs: systemCA,
|
|
InsecureSkipVerify: true,
|
|
}
|
|
|
|
// Add port if not specified in .env file
|
|
if !(strings.HasSuffix(ldapServer, ":636")) {
|
|
ldapServer = fmt.Sprintf("%s:636", ldapServer)
|
|
log.Printf("VerifyLdapCreds updated ldapServer string '%s'\n", ldapServer)
|
|
}
|
|
|
|
// try connecting to AD via TLS and our custom certificate authority
|
|
ldaps, err = ldap.DialTLS("tcp", ldapServer, tlsConfig)
|
|
if err != nil {
|
|
log.Printf("VerifyLdapCreds error connecting to LDAP bind address '%s' : '%s'\n", ldapServer, err)
|
|
return false
|
|
}
|
|
|
|
//defer ldaps.Close()
|
|
|
|
LdapEnabled = true
|
|
|
|
namingContext := LookupNamingContext()
|
|
if namingContext != "" {
|
|
DefaultDomainSuffix = DomainSuffixFromNamingContext(namingContext)
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func LookupNamingContext() string {
|
|
// Retrieve the defaultNamingContext
|
|
searchRequest := ldap.NewSearchRequest(
|
|
"",
|
|
ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false,
|
|
"(objectClass=*)",
|
|
[]string{"defaultNamingContext"},
|
|
nil,
|
|
)
|
|
|
|
searchResult, err := ldaps.Search(searchRequest)
|
|
if err != nil {
|
|
log.Printf("LookupNamingContext unable to perform unauthenticated search : '%s'\n", err)
|
|
return ""
|
|
}
|
|
|
|
if len(searchResult.Entries) != 1 {
|
|
log.Printf("LookupNamingContext unable to retrieve defaultNamingContext\n")
|
|
return ""
|
|
}
|
|
|
|
defaultNamingContext := searchResult.Entries[0].GetAttributeValue("defaultNamingContext")
|
|
if defaultNamingContext == "" {
|
|
log.Printf("LookupNamingContext defaultNamingContext attribute not found\n")
|
|
return ""
|
|
}
|
|
|
|
log.Printf("Default Naming Context: '%s'\n", defaultNamingContext)
|
|
return defaultNamingContext
|
|
}
|
|
|
|
func VerifyLdapCreds(username string, password string) bool {
|
|
var err error
|
|
username = CheckUsername(username)
|
|
|
|
// try an authenticated bind to AD to verify credentials
|
|
log.Printf("Attempting LDAP bind with user '%s' and password length '%d'\n", username, len(password))
|
|
err = ldaps.Bind(username, password)
|
|
if err != nil {
|
|
if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials {
|
|
log.Printf("VerifyLdapCreds user credentials are incorrect : '%s'\n", err)
|
|
return false
|
|
} else {
|
|
log.Printf("VerifyLdapCreds error binding to LDAP with supplied credentials : '%s'\n", err)
|
|
return false
|
|
}
|
|
} else {
|
|
log.Printf("VerifyLdapCreds successfully bound to LDAP\n")
|
|
}
|
|
|
|
/*
|
|
log.Printf("Attempting LDAP search request from base DN '%s'\n", LdapBaseDn)
|
|
searchReq := ldap.NewSearchRequest(
|
|
LdapBaseDn,
|
|
ldap.ScopeWholeSubtree, // you can also use ldap.ScopeWholeSubtree
|
|
ldap.NeverDerefAliases,
|
|
0,
|
|
0,
|
|
false,
|
|
"(objectClass=*)",
|
|
[]string{},
|
|
nil,
|
|
)
|
|
result, err := ldaps.Search(searchReq)
|
|
if err != nil {
|
|
log.Printf("VerifyLdapCreds search error : '%s'\n", err)
|
|
return false
|
|
}
|
|
|
|
log.Printf("result: %v\n", result)
|
|
*/
|
|
|
|
groups, err := GetGroupsOfUser(username, LdapBaseDn, ldaps)
|
|
if err != nil {
|
|
log.Printf("VerifyLdapCreds group search error : '%s'\n", err)
|
|
return false
|
|
}
|
|
fmt.Printf("groups: %v\n", groups)
|
|
|
|
return true
|
|
}
|
|
|
|
// GetGroupsOfUser returns the group for a user.
|
|
// Taken from https://github.com/jtblin/go-ldap-client/issues/13#issuecomment-456090979
|
|
func GetGroupsOfUser(username string, baseDN string, conn *ldap.Conn) ([]string, error) {
|
|
var samAccountName string
|
|
var groups []string
|
|
|
|
if strings.Contains(username, "@") {
|
|
s := strings.Split(username, "@")
|
|
samAccountName = s[0]
|
|
} else if strings.Contains(username, "\\") {
|
|
s := strings.Split(username, "\\")
|
|
samAccountName = s[len(s)-1]
|
|
} else {
|
|
samAccountName = username
|
|
}
|
|
|
|
// Get the users DN
|
|
// Search for the given username
|
|
searchRequest := ldap.NewSearchRequest(
|
|
baseDN,
|
|
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
|
fmt.Sprintf("(sAMAccountName=%s)", ldap.EscapeFilter(samAccountName)),
|
|
[]string{},
|
|
nil,
|
|
)
|
|
|
|
fmt.Printf("searchRequest: %v\n", searchRequest)
|
|
|
|
sr, err := conn.Search(searchRequest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(sr.Entries) != 1 {
|
|
return nil, fmt.Errorf("user '%s' does not exist", samAccountName)
|
|
} else {
|
|
// Get the groups of the first result
|
|
groups = sr.Entries[0].GetAttributeValues("memberOf")
|
|
}
|
|
|
|
return groups, nil
|
|
}
|