From 25510c63e5443a4017c64b2734efffe18787e485 Mon Sep 17 00:00:00 2001 From: Nathan Coad Date: Tue, 9 Jan 2024 22:34:41 +1100 Subject: [PATCH] test storesecret update --- controllers/retrieve_secrets.go | 16 ++++--- controllers/store_secrets.go | 14 ++++-- models/secret.go | 75 +++------------------------------ models/user.go | 45 +++++++++++--------- 4 files changed, 50 insertions(+), 100 deletions(-) diff --git a/controllers/retrieve_secrets.go b/controllers/retrieve_secrets.go index cb20e73..7192916 100644 --- a/controllers/retrieve_secrets.go +++ b/controllers/retrieve_secrets.go @@ -162,7 +162,7 @@ func retrieveSpecifiedSecret(s *models.Secret, c *gin.Context) { return } */ - + var UserId int var results []models.Secret /* user_id, err := token.ExtractTokenID(c) @@ -171,10 +171,16 @@ func retrieveSpecifiedSecret(s *models.Secret, c *gin.Context) { return } */ - user_id := c.GetInt("user-id") + // Get userId that we stored in the context earlier + if val, ok := c.Get("user-id"); !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "error determining user"}) + return + } else { + UserId = val.(int) + } // Work out which safe to query for this user if the safe was not specified - safeList, err := models.UserGetSafesAllowed(int(user_id)) + safeList, err := models.UserGetSafesAllowed(int(UserId)) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "error determining user safes"}) @@ -189,7 +195,7 @@ func retrieveSpecifiedSecret(s *models.Secret, c *gin.Context) { return } else if len(safeList) == 1 { s.SafeId = safeList[0].SafeId - results, err = models.SecretsGetMultipleSafes(s, []int{s.SafeId}) + results, err = models.SecretsGetFromMultipleSafes(s, []int{s.SafeId}) } else { // Create a list of all the safes this user can access var safeIds []int @@ -197,7 +203,7 @@ func retrieveSpecifiedSecret(s *models.Secret, c *gin.Context) { safeIds = append(safeIds, safe.SafeId) } - results, err = models.SecretsGetMultipleSafes(s, safeIds) + results, err = models.SecretsGetFromMultipleSafes(s, safeIds) } if err != nil { diff --git a/controllers/store_secrets.go b/controllers/store_secrets.go index 12bdb68..c30d3d2 100644 --- a/controllers/store_secrets.go +++ b/controllers/store_secrets.go @@ -31,8 +31,8 @@ type SecretInput struct { SecretValue string `json:"secretValue"` } -func FindSafeId(UserId int, input SecretInput) (int, error) { - +// CheckSafeAllowed returns the SafeId of an allowed safe containing the secret specified by SafeId or SafeName +func CheckSafeAllowed(UserId int, input SecretInput) (int, error) { // Check which safes a user is allowed to access allowedSafes, err := models.UserGetSafesAllowed(UserId) if err != nil { @@ -122,7 +122,9 @@ func StoreSecret(c *gin.Context) { //log.Printf("user_id: %v\n", user_id) } - safeId, err := FindSafeId(UserId, input) + // TODO replace FindSafeId with models.SecretsGetAllowed() + + safeId, err := CheckSafeAllowed(UserId, input) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -130,7 +132,11 @@ func StoreSecret(c *gin.Context) { s.SafeId = safeId // If this secret already exists in the database then generate an error - checkExists, err := models.GetSecrets(&s, false) + //checkExists, err := models.GetSecrets(&s, false) + checkExists, err := models.SecretsGetFromMultipleSafes(&s, []int{safeId}) + + // TODO replace GetSecrets with SecretsGetFromMultipleSafes + if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return diff --git a/models/secret.go b/models/secret.go index d642d03..a502f1e 100644 --- a/models/secret.go +++ b/models/secret.go @@ -11,8 +11,6 @@ import ( "log" "smt/utils" "strings" - - "github.com/jmoiron/sqlx" ) const nonceSize = 12 @@ -56,6 +54,7 @@ func (s *Secret) SaveSecret() (*Secret, error) { return s, nil } +// SecretsGetAllowed returns all allowed secrets matching the specified parameters in s func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) { var err error var secretResults []UserSecret @@ -150,74 +149,8 @@ func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) { return secretResults, nil } -/* -func SecretsSearchAllSafes(s *Secret) ([]Secret, error) { - var err error - var secretResults []Secret - - args := []interface{}{} - query := "SELECT * FROM secrets WHERE 1=1 " - - // Make sure at least one parameter was specified - if s.DeviceName == "" && s.DeviceCategory == "" && s.UserName == "" { - err = errors.New("no search parameters specified") - log.Println(err) - return secretResults, err - } - - // Add any other arguments to the query if they were specified - if s.DeviceName != "" { - query += " AND DeviceName LIKE ? " - args = append(args, s.DeviceName) - } - - if s.DeviceCategory != "" { - query += " AND DeviceCategory LIKE ? " - args = append(args, s.DeviceCategory) - } - - if s.UserName != "" { - query += " AND UserName LIKE ? " - args = append(args, s.UserName) - } - - // Execute the query - log.Printf("SecretsSearchAllSafes query string : '%s'\n%+v\n", query, args) - rows, err := db.Queryx(query, args...) - - if err != nil { - log.Printf("SecretsSearchAllSafes error executing sql record : '%s'\n", err) - return secretResults, err - } else { - // parse all the results into a slice - for rows.Next() { - var r Secret - err = rows.StructScan(&r) - if err != nil { - log.Printf("SecretsSearchAllSafes error parsing sql record : '%s'\n", err) - return secretResults, err - } - - // Decrypt the secret - _, err = r.DecryptSecret() - if err != nil { - //log.Printf("GetSecret unable to decrypt stored secret '%v' : '%s'\n", r.Secret, err) - log.Printf("SecretsSearchAllSafes unable to decrypt stored secret : '%s'\n", err) - return secretResults, err - } else { - secretResults = append(secretResults, r) - } - - } - log.Printf("SecretsSearchAllSafes retrieved '%d' results\n", len(secretResults)) - } - - return secretResults, nil -} -*/ - -// SecretsGetMultipleSafes queries the specified safes for matching secrets -func SecretsGetMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) { +// SecretsGetFromMultipleSafes queries the specified safes for matching secrets +func SecretsGetFromMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) { var err error var secretResults []Secret @@ -302,6 +235,7 @@ func SecretsGetMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) { return secretResults, nil } +/* // Returns all matching secrets, up to caller to determine how to deal with multiple results func GetSecrets(s *Secret, adminRole bool) ([]Secret, error) { var err error @@ -388,6 +322,7 @@ func GetSecrets(s *Secret, adminRole bool) ([]Secret, error) { return secretResults, nil } +*/ func (s *Secret) UpdateSecret() (*Secret, error) { diff --git a/models/user.go b/models/user.go index ec7b011..4640797 100644 --- a/models/user.go +++ b/models/user.go @@ -373,12 +373,14 @@ func UserGetSafesAllowed(userId int) ([]UserSafe, error) { defer rows.Close() // Get columns from rows for debugging - columns, err := rows.Columns() - if err != nil { - log.Printf("UserGetSafesAllowed error getting column listing : '%s'\n", err) - return results, err - } - log.Printf("columns: %v\n", columns) + /* + columns, err := rows.Columns() + if err != nil { + log.Printf("UserGetSafesAllowed error getting column listing : '%s'\n", err) + return results, err + } + log.Printf("columns: %v\n", columns) + */ // parse all the results into a slice for rows.Next() { @@ -394,23 +396,24 @@ func UserGetSafesAllowed(userId int) ([]UserSafe, error) { results = append(results, us) - // Create a map to store column names and values - rowValues := make(map[string]interface{}) + /* + // Create a map to store column names and values + rowValues := make(map[string]interface{}) - // Scan each row into the map - err := rows.MapScan(rowValues) - if err != nil { - log.Println(err) - continue - } - - // Print the raw row record - log.Println("-----------") - for _, column := range columns { - log.Printf("%s: %v\n", column, rowValues[column]) - } - log.Println("-----------") + // Scan each row into the map + err := rows.MapScan(rowValues) + if err != nil { + log.Println(err) + continue + } + // Print the raw row record + log.Println("-----------") + for _, column := range columns { + log.Printf("%s: %v\n", column, rowValues[column]) + } + log.Println("-----------") + */ } log.Printf("UserGetSafesAllowed retrieved '%d' results\n", len(results)) }