diff --git a/models/secret.go b/models/secret.go index b902f15..f9dfb12 100644 --- a/models/secret.go +++ b/models/secret.go @@ -35,6 +35,11 @@ type UserSecret struct { Permission } +// This method allows us to use an interface to avoid adding duplicate entries to a []Secret +func (s Secret) GetId() int { + return s.SecretId +} + func (s *Secret) SaveSecret() (*Secret, error) { var err error @@ -141,7 +146,12 @@ func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) { log.Println(debugPrint) // Append the secrets to the query output, don't decrypt the secrets (we didn't SELECT them anyway) - secretResults = append(secretResults, r) + //secretResults = append(secretResults, r) + + // Use generics and the GetID() method on the UserSecret struct + // to avoid adding this element to the results + // if there is already a secret with the same ID present + secretResults = utils.AppendIfNotExists(secretResults, r) } log.Printf("SecretsGetAllowed retrieved '%d' results\n", len(secretResults)) } diff --git a/utils/certOperations.go b/utils/certOperations.go new file mode 100644 index 0000000..7455992 --- /dev/null +++ b/utils/certOperations.go @@ -0,0 +1,139 @@ +package utils + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "log" + "math/big" + "net" + "os" + "path/filepath" + "time" +) + +func GenerateCerts(tlsCert string, tlsKey string) { + // @see https://shaneutt.com/blog/golang-ca-and-signed-cert-go/ + // @see https://golang.org/src/crypto/tls/generate_cert.go + validFrom := "" + validFor := 365 * 24 * time.Hour + isCA := true + + // Get the hostname + hostname, err := os.Hostname() + if err != nil { + panic(err) + } + + // Check that the directory exists + relativePath := filepath.Dir(tlsCert) + log.Printf("GenerateCerts relative path for file creation is '%s'\n", relativePath) + _, err = os.Stat(relativePath) + if os.IsNotExist(err) { + log.Printf("Certificate path does not exist, creating %s before generating certificate\n", relativePath) + os.MkdirAll(relativePath, os.ModePerm) + } + + // Generate a private key + priv, err := rsa.GenerateKey(rand.Reader, rsaBits) + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + + var notBefore time.Time + if len(validFrom) == 0 { + notBefore = time.Now() + } else { + notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom) + if err != nil { + log.Fatalf("Failed to parse creation date: %v", err) + } + } + + notAfter := notBefore.Add(validFor) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + log.Fatalf("Failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"DTMS"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + template.DNSNames = append(template.DNSNames, hostname) + + // Add in all the non-local IPs + ifaces, err := net.Interfaces() + + if err != nil { + log.Printf("Error enumerating interfaces: %v\n", err) + } + + for _, i := range ifaces { + addrs, err := i.Addrs() + if err != nil { + log.Printf("Oops: %v\n", err) + } + + for _, address := range addrs { + // check the address type and if it is not a loopback then add it to the list + if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + template.IPAddresses = append(template.IPAddresses, ipnet.IP) + } + } + } + } + + if isCA { + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + log.Fatalf("Failed to create certificate: %v", err) + } + + certOut, err := os.Create(tlsCert) + if err != nil { + log.Fatalf("Failed to open %s for writing: %v", tlsCert, err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + log.Fatalf("Failed to write data to %s: %v", tlsCert, err) + } + if err := certOut.Close(); err != nil { + log.Fatalf("Error closing %s: %v", tlsCert, err) + } + log.Printf("wrote %s\n", tlsCert) + + keyOut, err := os.OpenFile(tlsKey, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + log.Fatalf("Failed to open %s for writing: %v", tlsKey, err) + return + } + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + log.Fatalf("Unable to marshal private key: %v", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + log.Fatalf("Failed to write data to %s: %v", tlsKey, err) + } + if err := keyOut.Close(); err != nil { + log.Fatalf("Error closing %s: %v", tlsKey, err) + } + log.Printf("wrote %s\n", tlsKey) +} diff --git a/utils/structOperations.go b/utils/structOperations.go new file mode 100644 index 0000000..af78977 --- /dev/null +++ b/utils/structOperations.go @@ -0,0 +1,55 @@ +package utils + +import ( + "fmt" + "reflect" + "strings" +) + +func PrintStructContents(s interface{}, indentLevel int) string { + var result strings.Builder + + val := reflect.ValueOf(s) + + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + typ := val.Type() + + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + + indent := strings.Repeat("\t", indentLevel) + result.WriteString(fmt.Sprintf("%s%s: ", indent, fieldType.Name)) + + switch field.Kind() { + case reflect.Struct: + result.WriteString("\n") + result.WriteString(PrintStructContents(field.Interface(), indentLevel+1)) + default: + result.WriteString(fmt.Sprintf("%v\n", field.Interface())) + } + } + + return result.String() +} + +type Identifiable interface { + GetId() int +} + +// AppendIfNotExists requires a struct to implement the GetId() function +// Then we can use this function to avoid creating duplicate entries in the slice +func AppendIfNotExists[T Identifiable](slice []T, element T) []T { + for _, existingElement := range slice { + if existingElement.GetId() == element.GetId() { + // Element with the same Id already exists, don't append + return slice + } + } + + // Element with the same Id does not exist, append the new element + return append(slice, element) +} diff --git a/utils/utils.go b/utils/utils.go index 4c42345..9703e2f 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,20 +1,10 @@ package utils import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" "log" - "math/big" "net" "os" "path/filepath" - "reflect" - "strings" - "time" ) const rsaBits = 4096 @@ -62,157 +52,3 @@ func FileExists(filename string) bool { } return !info.IsDir() } - -func GenerateCerts(tlsCert string, tlsKey string) { - // @see https://shaneutt.com/blog/golang-ca-and-signed-cert-go/ - // @see https://golang.org/src/crypto/tls/generate_cert.go - validFrom := "" - validFor := 365 * 24 * time.Hour - isCA := true - - // Get the hostname - hostname, err := os.Hostname() - if err != nil { - panic(err) - } - - // Check that the directory exists - relativePath := filepath.Dir(tlsCert) - log.Printf("GenerateCerts relative path for file creation is '%s'\n", relativePath) - _, err = os.Stat(relativePath) - if os.IsNotExist(err) { - log.Printf("Certificate path does not exist, creating %s before generating certificate\n", relativePath) - os.MkdirAll(relativePath, os.ModePerm) - } - - // Generate a private key - priv, err := rsa.GenerateKey(rand.Reader, rsaBits) - if err != nil { - log.Fatalf("Failed to generate private key: %v", err) - } - - var notBefore time.Time - if len(validFrom) == 0 { - notBefore = time.Now() - } else { - notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom) - if err != nil { - log.Fatalf("Failed to parse creation date: %v", err) - } - } - - notAfter := notBefore.Add(validFor) - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - log.Fatalf("Failed to generate serial number: %v", err) - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"DTMS"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - template.DNSNames = append(template.DNSNames, hostname) - - // Add in all the non-local IPs - ifaces, err := net.Interfaces() - - if err != nil { - log.Printf("Error enumerating interfaces: %v\n", err) - } - - for _, i := range ifaces { - addrs, err := i.Addrs() - if err != nil { - log.Printf("Oops: %v\n", err) - } - - for _, address := range addrs { - // check the address type and if it is not a loopback then add it to the list - if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - template.IPAddresses = append(template.IPAddresses, ipnet.IP) - } - } - } - } - - if isCA { - template.IsCA = true - template.KeyUsage |= x509.KeyUsageCertSign - } - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - log.Fatalf("Failed to create certificate: %v", err) - } - - certOut, err := os.Create(tlsCert) - if err != nil { - log.Fatalf("Failed to open %s for writing: %v", tlsCert, err) - } - if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - log.Fatalf("Failed to write data to %s: %v", tlsCert, err) - } - if err := certOut.Close(); err != nil { - log.Fatalf("Error closing %s: %v", tlsCert, err) - } - log.Printf("wrote %s\n", tlsCert) - - keyOut, err := os.OpenFile(tlsKey, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - log.Fatalf("Failed to open %s for writing: %v", tlsKey, err) - return - } - privBytes, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - log.Fatalf("Unable to marshal private key: %v", err) - } - if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { - log.Fatalf("Failed to write data to %s: %v", tlsKey, err) - } - if err := keyOut.Close(); err != nil { - log.Fatalf("Error closing %s: %v", tlsKey, err) - } - log.Printf("wrote %s\n", tlsKey) -} - -func PrintStructContents(s interface{}, indentLevel int) string { - var result strings.Builder - - val := reflect.ValueOf(s) - - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - typ := val.Type() - - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - fieldType := typ.Field(i) - - indent := strings.Repeat("\t", indentLevel) - result.WriteString(fmt.Sprintf("%s%s: ", indent, fieldType.Name)) - - switch field.Kind() { - case reflect.Struct: - result.WriteString("\n") - result.WriteString(PrintStructContents(field.Interface(), indentLevel+1)) - default: - result.WriteString(fmt.Sprintf("%v\n", field.Interface())) - } - } - - return result.String() -}