From 23e3044f89d5433daca4f238b384598a2ae32d0e Mon Sep 17 00:00:00 2001 From: Nathan Coad Date: Thu, 20 Jul 2023 09:05:48 +1000 Subject: [PATCH] load cert from file rather than embed --- .gitignore | 3 +- main.go | 80 ++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 89547ba..f7a8015 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -log.txt \ No newline at end of file +log.txt +*.pem \ No newline at end of file diff --git a/main.go b/main.go index c292f7e..81bab01 100644 --- a/main.go +++ b/main.go @@ -3,8 +3,12 @@ package main import ( "crypto/x509" "encoding/json" + "encoding/pem" "flag" "fmt" + "io/ioutil" + "os" + "path/filepath" auth "github.com/korylprince/go-ad-auth/v3" ) @@ -71,6 +75,37 @@ lS3ZUQcHCLtUbTw= -----END CERTIFICATE----- ` +func GetFilePath(path string) string { + // Check for empty filename + if len(path) == 0 { + return "" + } + + // check if filename exists + if _, err := os.Stat(path); os.IsNotExist((err)) { + fmt.Printf("File '%s' not found, searching in same directory as binary\n", path) + // if not, check that it exists in the same directory as the currently executing binary + ex, err2 := os.Executable() + if err2 != nil { + //log.Printf("Error determining binary path : '%s'", err) + return "" + } + binaryPath := filepath.Dir(ex) + path = filepath.Join(binaryPath, path) + } + return path +} + +func isFlagPassed(name string) bool { + found := false + flag.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + func main() { var output Output @@ -79,6 +114,7 @@ func main() { baseDN := flag.String("baseDN", "OU=Users,DC=example,DC=com", "Base DN to use when attempting to bind to AD") username := flag.String("username", "user", "Username to use when attempting to bind to AD") password := flag.String("password", "pass", "Password to use when attempting to bind to AD") + certFile := flag.String("cert-file", "rootca.pem", "Filename to load trusted certificate from") flag.Parse() output.Server = *server @@ -93,14 +129,42 @@ func main() { return } - // Add custom certificate to the system cert pool - ok := system.AppendCertsFromPEM([]byte(WSDCCertPem)) - if !ok { - output.AuthSuccess = false - output.Error = "failed to parse WSDC intermediate certificate" - b, _ := json.Marshal(output) - fmt.Println(string(b)) - return + // only try to load certificate from file if the command line argument was specified + if isFlagPassed("cert-file") { + // Try to read the file + cf, err := ioutil.ReadFile(GetFilePath(*certFile)) + if err != nil { + output.AuthSuccess = false + output.Error = err.Error() + b, _ := json.Marshal(output) + fmt.Println(string(b)) + return + } + + // Get the certificate from the file + cpb, _ := pem.Decode(cf) + crt, err := x509.ParseCertificate(cpb.Bytes) + + if err != nil { + output.AuthSuccess = false + output.Error = err.Error() + b, _ := json.Marshal(output) + fmt.Println(string(b)) + return + } + + // Add custom certificate to the system cert pool + system.AddCert(crt) + /* + ok := system.AppendCertsFromPEM(crt) + if !ok { + output.AuthSuccess = false + output.Error = "failed to parse WSDC intermediate certificate" + b, _ := json.Marshal(output) + fmt.Println(string(b)) + return + } + */ } config := &auth.Config{