package models import ( "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/hex" "errors" "fmt" "io" "log" "smt/utils" "strings" "time" ) const nonceSize = 12 // We use the json:"-" field tag to prevent showing these details to the user type Secret struct { SecretId int `db:"SecretId" json:"secretId"` SafeId int `db:"SafeId" json:"safeId"` DeviceName string `db:"DeviceName" json:"deviceName"` DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"` UserName string `db:"UserName" json:"userName"` Secret string `db:"Secret" json:"secret"` LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"` } // SecretRestricted is for when we want to output a Secret but not the protected information type SecretRestricted struct { SecretId int `db:"SecretId" json:"secretId"` SafeId int `db:"SafeId" json:"safeId"` DeviceName string `db:"DeviceName" json:"deviceName"` DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"` UserName string `db:"UserName" json:"userName"` Secret string `db:"Secret" json:"-"` LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"` } // Used for querying all secrets the user has access to // Since there are some ambiguous column names (eg UserName is present in both users and secrets table), the order of fields in this struct matters type UserSecret struct { Secret UserUserId int `db:"UserUserId"` User //Group Permission } // This method allows us to use an interface to avoid adding duplicate entries to a []Secret func (s Secret) GetId() int { return s.SecretId } func (s *Secret) SaveSecret() (*Secret, error) { var err error // Populate timestamp field if not already set if s.LastUpdated.IsZero() { s.LastUpdated = time.Now().UTC() } log.Printf("SaveSecret storing values '%v'\n", s) result, err := db.NamedExec((`INSERT INTO secrets (SafeId, DeviceName, DeviceCategory, UserName, Secret, LastUpdated) VALUES (:SafeId, :DeviceName, :DeviceCategory, :UserName, :Secret, :LastUpdated)`), s) if err != nil { log.Printf("StoreSecret error executing sql record : '%s'\n", err) return s, err } affected, _ := result.RowsAffected() id, _ := result.LastInsertId() s.SecretId = int(id) log.Printf("StoreSecret insert returned result id '%d' affecting %d row(s).\n", id, affected) return s, nil } // SecretsGetAllowed returns all allowed secrets matching the specified parameters in s func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) { var err error var secretResults []UserSecret // Query for group access queryArgs := []interface{}{} query := ` SELECT users.UserId AS UserUserId, permissions.*, secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName FROM users INNER JOIN groups ON users.GroupId = groups.GroupId INNER JOIN permissions ON groups.GroupId = permissions.GroupId INNER JOIN secrets on secrets.SafeId = permissions.SafeId WHERE users.UserId = ? ` queryArgs = append(queryArgs, userId) // Add any other arguments to the query if they were specified if s.SecretId > 0 { query += " AND SecretId = ? " queryArgs = append(queryArgs, s.SecretId) } if s.DeviceName != "" { query += " AND DeviceName LIKE ? " queryArgs = append(queryArgs, s.DeviceName) } if s.DeviceCategory != "" { query += " AND DeviceCategory LIKE ? " queryArgs = append(queryArgs, s.DeviceCategory) } if s.UserName != "" { query += " AND secrets.UserName LIKE ? " queryArgs = append(queryArgs, s.UserName) } // Query for user access query += ` UNION SELECT users.UserId AS UserUserId, permissions.*, secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName FROM users INNER JOIN permissions ON users.UserId = permissions.UserId INNER JOIN safes on permissions.SafeId = safes.SafeId INNER JOIN secrets on secrets.SafeId = safes.SafeId WHERE users.UserId = ?` queryArgs = append(queryArgs, userId) // Add any other arguments to the query if they were specified if s.SecretId > 0 { query += " AND SecretId = ? " queryArgs = append(queryArgs, s.SecretId) } if s.DeviceName != "" { query += " AND DeviceName LIKE ? " queryArgs = append(queryArgs, s.DeviceName) } if s.DeviceCategory != "" { query += " AND DeviceCategory LIKE ? " queryArgs = append(queryArgs, s.DeviceCategory) } if s.UserName != "" { query += " AND secrets.UserName LIKE ? " queryArgs = append(queryArgs, s.UserName) } // Execute the query log.Printf("SecretsGetAllowed query string : '%s'\nArguments:%+v\n", query, queryArgs) rows, err := db.Queryx(query, queryArgs...) if err != nil { log.Printf("SecretsGetAllowed error executing sql record : '%s'\n", err) return secretResults, err } else { log.Printf("SecretsGetAllowed any error '%s'\n", rows.Err()) // parse all the results into a slice for rows.Next() { log.Printf("SecretsGetAllowed processing row\n") var r UserSecret err = rows.StructScan(&r) log.Printf("SecretsGetAllowed performed struct scan\n") if err != nil { log.Printf("SecretsGetAllowed error parsing sql record : '%s'\n", err) return secretResults, err } //log.Printf("r: %v\n", r) log.Printf("SecretsGetAllowed performed err check\n") // work around to get the UserId populated in the User field of the struct r.User.UserId = r.UserUserId // For debugging purposes debugPrint := utils.PrintStructContents(&r, 0) log.Println(debugPrint) log.Printf("SecretsGetAllowed performed debug print\n") // Append the secrets to the query output, don't decrypt the secrets (we didn't SELECT them anyway) //secretResults = append(secretResults, r) // Use generics and the GetID() method on the UserSecret struct // to avoid adding this element to the results // if there is already a secret with the same ID present secretResults = utils.AppendIfNotExists(secretResults, r) log.Printf("SecretsGetAllowed added secret results\n") } log.Printf("SecretsGetAllowed retrieved '%d' results\n", len(secretResults)) } return secretResults, nil } // SecretsGetFromMultipleSafes queries the specified safes for matching secrets func SecretsGetFromMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) { var err error var secretResults []Secret queryArgs := []interface{}{} var query string // Generate placeholders for the IN clause to match multiple SafeId values placeholders := make([]string, len(safeIds)) for i := range safeIds { placeholders[i] = "?" } placeholderStr := strings.Join(placeholders, ",") // Create query with the necessary placeholders query = fmt.Sprintf("SELECT * FROM secrets WHERE SafeId IN (%s) ", placeholderStr) // Add the Safe Ids to the arguments list for _, g := range safeIds { queryArgs = append(queryArgs, g) } // Add any other arguments to the query if they were specified if s.SecretId > 0 { query += " AND SecretId = ? " queryArgs = append(queryArgs, s.SecretId) } if s.DeviceName != "" { query += " AND DeviceName LIKE ? " queryArgs = append(queryArgs, s.DeviceName) } if s.DeviceCategory != "" { query += " AND DeviceCategory LIKE ? " queryArgs = append(queryArgs, s.DeviceCategory) } if s.UserName != "" { query += " AND UserName LIKE ? " queryArgs = append(queryArgs, s.UserName) } // Execute the query log.Printf("SecretsGetMultipleSafes query string :\n'%s'\nQuery Args : %+v\n", query, queryArgs) rows, err := db.Queryx(query, queryArgs...) if err != nil { log.Printf("SecretsGetMultipleSafes error executing sql record : '%s'\n", err) return secretResults, err } else { // parse all the results into a slice for rows.Next() { var r Secret err = rows.StructScan(&r) if err != nil { log.Printf("SecretsGetMultipleSafes error parsing sql record : '%s'\n", err) return secretResults, err } // Decrypt the secret _, err = r.DecryptSecret() if err != nil { log.Printf("SecretsGetMultipleSafes unable to decrypt stored secret : '%s'\n", err) rows.Close() return secretResults, err } else { secretResults = append(secretResults, r) } } log.Printf("SecretsGetMultipleSafes retrieved '%d' results\n", len(secretResults)) } return secretResults, nil } func (s *Secret) UpdateSecret() (*Secret, error) { var err error // Populate timestamp field if not already set if s.LastUpdated.IsZero() { s.LastUpdated = time.Now().UTC() } log.Printf("UpdateSecret storing values '%v'\n", s) if s.SecretId == 0 { err = errors.New("UpdateSecret unable to locate secret with empty secretId field") log.Printf("UpdateSecret error in pre-check : '%s'\n", err) return s, err } result, err := db.NamedExec((`UPDATE secrets SET DeviceName = :DeviceName, DeviceCategory = :DeviceCategory, UserName = :UserName, Secret = :Secret, LastUpdated = :LastUpdated WHERE SecretId = :SecretId`), s) if err != nil { log.Printf("UpdateSecret error executing sql record : '%s'\n", err) return &Secret{}, err } else { affected, _ := result.RowsAffected() id, _ := result.LastInsertId() log.Printf("UpdateSecret insert returned result id '%d' affecting %d row(s).\n", id, affected) } return s, nil } func (s *Secret) DeleteSecret() (*Secret, error) { var err error log.Printf("DeleteSecret deleting record with values '%v'\n", s) if s.SecretId == 0 { err = errors.New("unable to locate secret with empty secretId field") log.Printf("DeleteSecret error in pre-check : '%s'\n", err) return s, err } result, err := db.NamedExec((`DELETE FROM secrets WHERE SecretId = :SecretId`), s) if err != nil { log.Printf("DeleteSecret error executing sql record : '%s'\n", err) return &Secret{}, err } else { affected, _ := result.RowsAffected() id, _ := result.LastInsertId() log.Printf("DeleteSecret delete returned result id '%d' affecting %d row(s).\n", id, affected) } return s, nil } // startCipher does the initial setup of the AES256 GCM mode cipher func startCipher() (cipher.AEAD, error) { key, err := ProvideKey() if err != nil { return nil, err } block, err := aes.NewCipher(key) if err != nil { log.Printf("startCipher NewCipher error '%s'\n", err) return nil, err } aesgcm, err := cipher.NewGCM(block) if err != nil { log.Printf("startCipher NewGCM error '%s'\n", err) return nil, err } return aesgcm, nil } func (s *Secret) EncryptSecret() (*Secret, error) { //keyString := os.Getenv("SECRETS_KEY") //keyString := secretKey // The key argument should be the AES key, either 16 or 32 bytes // to select AES-128 or AES-256. //key := []byte(keyString) /* key, err := ProvideKey() if err != nil { return s, err } */ plaintext := []byte(s.Secret) // TODO : move block and aesgcm generation to separate function since the identical code is used for encrypt and decrypt /* log.Printf("EncryptSecret applying key '%v' of length '%d' to plaintext secret '%s'\n", key, len(key), s.Secret) block, err := aes.NewCipher(key) if err != nil { log.Printf("EncryptSecret NewCipher error '%s'\n", err) return s, err } aesgcm, err := cipher.NewGCM(block) if err != nil { log.Printf("EncryptSecret NewGCM error '%s'\n", err) return s, err } */ aesgcm, err := startCipher() if err != nil { log.Printf("EncryptSecret error commencing GCM cipher '%s'\n", err) return s, err } // Never use more than 2^32 random nonces with a given key because of the risk of a repeat. nonce := make([]byte, nonceSize) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { log.Printf("EncryptSecret nonce generation error '%s'\n", err) return s, err } //log.Printf("EncryptSecret random nonce value is '%x'\n", nonce) ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil) //log.Printf("EncryptSecret generated ciphertext '%x''\n", ciphertext) // Create a new slice to store nonce at the start and then the resulting ciphertext // Nonce is always 12 bytes combinedText := append(nonce, ciphertext...) //log.Printf("EncryptSecret combined secret value is now '%x'\n", combinedText) // Store the value back into the struct ready for database operations s.Secret = hex.EncodeToString(combinedText) return s, nil //return string(ciphertext[:]), nil } func (s *Secret) DecryptSecret() (*Secret, error) { // The key argument should be the AES key, either 16 or 32 bytes // to select AES-128 or AES-256. //keyString := os.Getenv("SECRETS_KEY") //keyString := secretKey //key := []byte(keyString) /* key, err := ProvideKey() if err != nil { return s, err } */ if len(s.Secret) < nonceSize { log.Printf("DecryptSecret ciphertext is too short to decrypt\n") return s, errors.New("ciphertext is too short") } crypted, err := hex.DecodeString(s.Secret) if err != nil { log.Printf("DecryptSecret unable to convert hex encoded string due to error '%s'\n", err) return s, err } //log.Printf("DecryptSecret processing secret '%x'\n", crypted) // The nonce is the first 12 bytes from the ciphertext nonce := crypted[:nonceSize] ciphertext := crypted[nonceSize:] /* log.Printf("DecryptSecret applying key '%v' and nonce '%x' to ciphertext '%x'\n", key, nonce, ciphertext) block, err := aes.NewCipher(key) if err != nil { log.Printf("DecryptSecret NewCipher error '%s'\n", err) return s, err } aesgcm, err := cipher.NewGCM(block) if err != nil { log.Printf("DecryptSecret NewGCM error '%s'\n", err) return s, err } */ aesgcm, err := startCipher() if err != nil { log.Printf("DecryptSecret error commencing GCM cipher '%s'\n", err) return s, err } plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) if err != nil { log.Printf("DecryptSecret Open error '%s'\n", err) return s, err } //log.Printf("DecryptSecret plaintext is '%s'\n", plaintext) s.Secret = string(plaintext) return s, nil }