diff --git a/models/db.go b/models/db.go index dafb885..d3a800c 100644 --- a/models/db.go +++ b/models/db.go @@ -1,478 +1,491 @@ package models import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/hex" "errors" "fmt" - "io" "log" + "os" + "reflect" + "smt/utils" - "strings" - "time" + + "github.com/jmoiron/sqlx" + "golang.org/x/crypto/bcrypt" + _ "modernc.org/sqlite" ) -const nonceSize = 12 +var db *sqlx.DB -// We use the json:"-" field tag to prevent showing these details to the user -type Secret struct { - SecretId int `db:"SecretId" json:"secretId"` - SafeId int `db:"SafeId" json:"safeId"` - DeviceName string `db:"DeviceName" json:"deviceName"` - DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"` - UserName string `db:"UserName" json:"userName"` - Secret string `db:"Secret" json:"secret"` - LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"` -} +const ( + sqlFile = "smt.db" +) -// SecretRestricted is for when we want to output a Secret but not the protected information -type SecretRestricted struct { - SecretId int `db:"SecretId" json:"secretId"` - SafeId int `db:"SafeId" json:"safeId"` - DeviceName string `db:"DeviceName" json:"deviceName"` - DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"` - UserName string `db:"UserName" json:"userName"` - Secret string `db:"Secret" json:"-"` - LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"` -} +const createUsers string = ` + CREATE TABLE IF NOT EXISTS users ( + UserId INTEGER PRIMARY KEY AUTOINCREMENT, + GroupId INTEGER, + UserName VARCHAR, + Password VARCHAR, + Admin BOOLEAN DEFAULT 0, + LdapUser BOOLEAN DEFAULT 0, + LastLogin datetime DEFAULT (datetime('1970-01-01 00:00:00')), + ); +` -// Used for querying all secrets the user has access to -// Since there are some ambiguous column names (eg UserName is present in both users and secrets table), the order of fields in this struct matters -type UserSecret struct { - Secret - UserUserId int `db:"UserUserId"` - User - //Group - Permission -} +const createSafes string = ` + CREATE TABLE IF NOT EXISTS safes ( + SafeId INTEGER PRIMARY KEY AUTOINCREMENT, + SafeName VARCHAR + ); +` -// This method allows us to use an interface to avoid adding duplicate entries to a []Secret -func (s Secret) GetId() int { - return s.SecretId -} +const createGroups string = ` + CREATE TABLE IF NOT EXISTS groups ( + GroupId INTEGER PRIMARY KEY AUTOINCREMENT, + GroupName VARCHAR, + LdapGroup BOOLEAN DEFAULT 0, + LdapDn VARCHAR DEFAULT '', + Admin BOOLEAN DEFAULT 0 + ); +` -func (s *Secret) SaveSecret() (*Secret, error) { +const createPermissions = ` + CREATE TABLE IF NOT EXISTS permissions ( + PermissionId INTEGER PRIMARY KEY AUTOINCREMENT, + Description VARCHAR DEFAULT '', + ReadOnly BOOLEAN DEFAULT 0, + SafeId INTEGER, + UserId INTEGER DEFAULT 0, + GroupId INTEGER DEFAULT 0, + FOREIGN KEY (SafeId) REFERENCES safes(SafeId) + ); +` + +const createSecrets string = ` + CREATE TABLE IF NOT EXISTS secrets ( + SecretId INTEGER PRIMARY KEY AUTOINCREMENT, + SafeId INTEGER, + DeviceName VARCHAR, + DeviceCategory VARCHAR, + UserName VARCHAR, + Secret VARCHAR, + LastUpdated datetime DEFAULT (datetime('1970-01-01 00:00:00')), + FOREIGN KEY (SafeId) REFERENCES safes(SafeId) + ); +` + +const createSchema string = ` + CREATE TABLE IF NOT EXISTS schema ( + Version INTEGER + ); +` + +const createAudit string = ` + CREATE TABLE IF NOT EXISTS audit ( + AuditId INTEGER PRIMARY KEY AUTOINCREMENT, + UserId INTEGER DEFAULT 0, + SecretId INTEGER DEFAULT 0, + EventText VARCHAR, + EventTime datetime + ); +` + +// Establish connection to sqlite database +func ConnectDatabase() { var err error - // Populate timestamp field if not already set - if s.LastUpdated.IsZero() { - s.LastUpdated = time.Now().UTC() - } - - log.Printf("SaveSecret storing values '%v'\n", s) - result, err := db.NamedExec((`INSERT INTO secrets (SafeId, DeviceName, DeviceCategory, UserName, Secret, LastUpdated) VALUES (:SafeId, :DeviceName, :DeviceCategory, :UserName, :Secret, :LastUpdated)`), s) + // Try using sqlite as our database + sqlPath := utils.GetFilePath(sqlFile) + db, err = sqlx.Open("sqlite", sqlPath) if err != nil { - log.Printf("StoreSecret error executing sql record : '%s'\n", err) - return s, err - } - - affected, _ := result.RowsAffected() - id, _ := result.LastInsertId() - s.SecretId = int(id) - log.Printf("StoreSecret insert returned result id '%d' affecting %d row(s).\n", id, affected) - - return s, nil -} - -// SecretsGetAllowed returns all allowed secrets matching the specified parameters in s -func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) { - var err error - var secretResults []UserSecret - - // Query for group access - queryArgs := []interface{}{} - query := ` - SELECT users.UserId AS UserUserId, permissions.*, - secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName - FROM users - INNER JOIN groups ON users.GroupId = groups.GroupId - INNER JOIN permissions ON groups.GroupId = permissions.GroupId - INNER JOIN secrets on secrets.SafeId = permissions.SafeId - WHERE users.UserId = ? ` - queryArgs = append(queryArgs, userId) - - // Add any other arguments to the query if they were specified - if s.SecretId > 0 { - query += " AND SecretId = ? " - queryArgs = append(queryArgs, s.SecretId) - } - - if s.DeviceName != "" { - query += " AND DeviceName LIKE ? " - queryArgs = append(queryArgs, s.DeviceName) - } - - if s.DeviceCategory != "" { - query += " AND DeviceCategory LIKE ? " - queryArgs = append(queryArgs, s.DeviceCategory) - } - - if s.UserName != "" { - query += " AND secrets.UserName LIKE ? " - queryArgs = append(queryArgs, s.UserName) - } - - // Query for user access - query += ` - UNION - SELECT users.UserId AS UserUserId, permissions.*, - secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName - FROM users - INNER JOIN permissions ON users.UserId = permissions.UserId - INNER JOIN safes on permissions.SafeId = safes.SafeId - INNER JOIN secrets on secrets.SafeId = safes.SafeId - WHERE users.UserId = ?` - queryArgs = append(queryArgs, userId) - - // Add any other arguments to the query if they were specified - if s.SecretId > 0 { - query += " AND SecretId = ? " - queryArgs = append(queryArgs, s.SecretId) - } - if s.DeviceName != "" { - query += " AND DeviceName LIKE ? " - queryArgs = append(queryArgs, s.DeviceName) - } - - if s.DeviceCategory != "" { - query += " AND DeviceCategory LIKE ? " - queryArgs = append(queryArgs, s.DeviceCategory) - } - - if s.UserName != "" { - query += " AND secrets.UserName LIKE ? " - queryArgs = append(queryArgs, s.UserName) - } - - // Execute the query - log.Printf("SecretsGetAllowed query string : '%s'\nArguments:%+v\n", query, queryArgs) - rows, err := db.Queryx(query, queryArgs...) - - if err != nil { - log.Printf("SecretsGetAllowed error executing sql record : '%s'\n", err) - return secretResults, err + log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err) + os.Exit(1) } else { - //log.Printf("SecretsGetAllowed any error '%s'\n", rows.Err()) - // parse all the results into a slice - for rows.Next() { - //log.Printf("SecretsGetAllowed processing row\n") - var r UserSecret - err = rows.StructScan(&r) - //log.Printf("SecretsGetAllowed performed struct scan\n") - if err != nil { - log.Printf("SecretsGetAllowed error parsing sql record : '%s'\n", err) - return secretResults, err - } - //log.Printf("r: %v\n", r) + log.Printf("Connected to sqlite database file '%s'\n", sqlPath) + } - //log.Printf("SecretsGetAllowed performed err check\n") + //sqlx.NameMapper = func(s string) string { return s } - // work around to get the UserId populated in the User field of the struct - r.User.UserId = r.UserUserId + // Make sure our tables exist + CreateTables() - // For debugging purposes - debugPrint := utils.PrintStructContents(&r, 0) - log.Println(debugPrint) + //defer db.Close() +} - //log.Printf("SecretsGetAllowed performed debug print\n") +func DisconnectDatabase() { + log.Printf("DisconnectDatabase called") + defer db.Close() +} - // Append the secrets to the query output, don't decrypt the secrets (we didn't SELECT them anyway) - //secretResults = append(secretResults, r) +func CreateTables() { + var err error + var rowCount int - // Use generics and the GetID() method on the UserSecret struct - // to avoid adding this element to the results - // if there is already a secret with the same ID present - secretResults = utils.AppendIfNotExists(secretResults, r) + // Create database tables if it doesn't exist - //log.Printf("SecretsGetAllowed added secret results\n") + // groups table + if _, err = db.Exec(createGroups); err != nil { + log.Printf("Error checking groups table : '%s'", err) + os.Exit(1) + } + + // Add initial groups + rowCount, _ = CheckCount("groups") + if rowCount == 0 { + if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(1, 'Administrators', 1);"); err != nil { + log.Printf("Error adding initial group entry id 1 : '%s'", err) + os.Exit(1) + } + if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(2, 'Users', 0);"); err != nil { + log.Printf("Error adding initial group entry id 2 : '%s'", err) + os.Exit(1) } - log.Printf("SecretsGetAllowed retrieved '%d' results\n", len(secretResults)) } - return secretResults, nil + // Users table + if _, err = db.Exec(createUsers); err != nil { + log.Printf("Error checking users table : '%s'", err) + os.Exit(1) + } + rowCount, _ = CheckCount("users") + if rowCount == 0 { + // Check if there was an initial password defined in the .env file + initialPassword := os.Getenv("INITIAL_PASSWORD") + if initialPassword == "" { + initialPassword = "password" + } else if initialPassword[:4] == "$2a$" { + log.Printf("CreateTables inital admin password is already a hash") + } else { + cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost) + initialPassword = string(cryptText) + } + if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(1, 1, 'Administrator', ?, false, true);", initialPassword); err != nil { + log.Printf("Error adding initial admin role : '%s'", err) + os.Exit(1) + } + if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(2, 2, 'User', ?, false, false);", initialPassword); err != nil { + log.Printf("Error adding initial admin role : '%s'", err) + os.Exit(1) + } + } + + // Safes table + if _, err = db.Exec(createSafes); err != nil { + log.Printf("Error checking safes table : '%s'", err) + os.Exit(1) + } + // Create an initial safe + rowCount, _ = CheckCount("safes") + if rowCount == 0 { + if _, err = db.Exec("INSERT INTO safes VALUES(1, 'Default Safe');"); err != nil { + log.Printf("Error adding initial safe entry : '%s'", err) + os.Exit(1) + } + } + + // Secrets table + if _, err = db.Exec(createSecrets); err != nil { + log.Printf("Error checking secrets table : '%s'", err) + os.Exit(1) + } + + // permissions table + if _, err = db.Exec(createPermissions); err != nil { + log.Printf("Error checking permissions table : '%s'", err) + os.Exit(1) + } + + // Add initial permissions + rowCount, _ = CheckCount("permissions") + if rowCount == 0 { + if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, GroupId, SafeId) VALUES('Default Admin Group Permission', false, 1, 1);"); err != nil { + log.Printf("Error adding initial permissions entry userid 1 : '%s'", err) + os.Exit(1) + } + if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, SafeId, GroupId) VALUES('Default User Group Permission', false, 1, 2);"); err != nil { + log.Printf("Error adding initial permissions entry userid 2 : '%s'", err) + os.Exit(1) + } + } + + // Schema table should go last so we know if the database has a value in the schema table then everything was created properly + if _, err = db.Exec(createSchema); err != nil { + log.Printf("Error checking schema table : '%s'", err) + os.Exit(1) + } + schemaCheck, _ := CheckColumnExists("schema", "Version") + if !schemaCheck { + if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil { + log.Printf("Error adding initial schema version : '%s'", err) + os.Exit(1) + } + } + + // Audit log table + if _, err = db.Exec(createAudit); err != nil { + log.Printf("Error checking audit table : '%s'", err) + os.Exit(1) + } + + // Remove users RoleId column + userRoleIdCheck, _ := CheckColumnExists("users", "RoleId") + if userRoleIdCheck { + //_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;") + _, err := db.Exec(` + PRAGMA foreign_keys=off; + BEGIN TRANSACTION; + ALTER TABLE users RENAME TO _users_old; + CREATE TABLE users + ( + UserId INTEGER PRIMARY KEY AUTOINCREMENT, + GroupId INTEGER, + UserName VARCHAR, + Password VARCHAR, + Admin BOOLEAN DEFAULT 0, + LdapUser BOOLEAN DEFAULT 0 + ); + INSERT INTO users SELECT * FROM _users_old; + COMMIT; + PRAGMA foreign_keys=on; + DROP TABLE _users_old; + `) + if err != nil { + log.Printf("Error altering users table to drop RoleId column : '%s'\n", err) + os.Exit(1) + } + } + + // Set any unassigned secrets to the default safe id + if _, err = db.Exec("UPDATE users SET LdapUser = 0 WHERE LdapUser is null;"); err != nil { + log.Printf("Error setting LdapUser flag to false for existing users : '%s'", err) + os.Exit(1) + } + + // Remove LdapGroup column from roles table + ldapCheck, _ := CheckColumnExists("roles", "LdapGroup") + if ldapCheck { + _, err := db.Exec("ALTER TABLE roles DROP COLUMN LdapGroup;") + if err != nil { + log.Printf("Error altering roles table to renmove LdapGroup column : '%s'\n", err) + os.Exit(1) + } + } + + // Add SafeId column to secrets table + safeIdCheck, _ := CheckColumnExists("secrets", "SafeId") + if !safeIdCheck { + // Add the column for LdapGroup in the roles table + _, err := db.Exec("ALTER TABLE secrets ADD COLUMN SafeId INTEGER REFERENCES safes(SafeId);") + if err != nil { + log.Printf("Error altering secrets table to add SafeId column : '%s'\n", err) + os.Exit(1) + } + } + + // Set any unassigned secrets to the default safe id + if _, err = db.Exec("UPDATE secrets SET SafeId = 1 WHERE SafeId is null;"); err != nil { + log.Printf("Error setting safe ID of existing secrets : '%s'", err) + os.Exit(1) + } + + // Remove RoleId column from secrets table + secretsRoleIdCheck, _ := CheckColumnExists("secrets", "RoleId") + if secretsRoleIdCheck { + _, err := db.Exec(` + PRAGMA foreign_keys=off; + BEGIN TRANSACTION; + ALTER TABLE secrets RENAME TO _secrets_old; + CREATE TABLE secrets + ( + SecretId INTEGER PRIMARY KEY AUTOINCREMENT, + RoleId INTEGER, + SafeId INTEGER, + DeviceName VARCHAR, + DeviceCategory VARCHAR, + UserName VARCHAR, + Secret VARCHAR, + FOREIGN KEY (SafeId) REFERENCES safes(SafeId) + ); + INSERT INTO secrets SELECT SecretId, RoleId, SafeId, DeviceName, DeviceCategory, UserName, Secret FROM _secrets_old; + ALTER TABLE secrets DROP COLUMN RoleId; + ALTER TABLE secrets ADD COLUMN LastUpdated datetime; + UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null; + COMMIT; + PRAGMA foreign_keys=on; + DROP TABLE _secrets_old; + `) + if err != nil { + log.Printf("Error altering secrets table to remove RoleId column : '%s'\n", err) + os.Exit(1) + } + } + + // Remove the Admin column from roles table + rolesAdminCheck, _ := CheckColumnExists("roles", "Admin") + if rolesAdminCheck { + _, err := db.Exec("ALTER TABLE roles DROP COLUMN Admin;") + if err != nil { + log.Printf("Error altering roles table to remove Admin column : '%s'\n", err) + os.Exit(1) + } + } + + // Remove the RoleId from permissiosn table + permissionsRoleIdCheck, _ := CheckColumnExists("permissions", "RoleId") + if permissionsRoleIdCheck { + _, err := db.Exec(` + PRAGMA foreign_keys=off; + BEGIN TRANSACTION; + ALTER TABLE permissions RENAME TO _permissions_old; + CREATE TABLE permissions + ( + PermissionId INTEGER PRIMARY KEY AUTOINCREMENT, + Description VARCHAR DEFAULT '', + ReadOnly BOOLEAN DEFAULT 0, + SafeId INTEGER, + UserId INTEGER DEFAULT 0, + GroupId INTEGER DEFAULT 0, + FOREIGN KEY (SafeId) REFERENCES safes(SafeId) + ); + INSERT INTO permissions SELECT PermissionId, SafeId, UserId, GroupId, '' AS Description, 0 as ReadOnly FROM _permissions_old; + UPDATE permissions SET ReadOnly = 0 WHERE ReadOnly is null; + UPDATE permissions SET Description = '' WHERE Description is null; + UPDATE permissions SET UserId = 0 WHERE UserId is null; + UPDATE permissions SET GroupId = 0 WHERE GroupId is null; + COMMIT; + PRAGMA foreign_keys=on; + DROP TABLE _permissions_old; + `) + if err != nil { + log.Printf("Error altering permissions table to remove RoleId column : '%s'\n", err) + os.Exit(1) + } + } + + secretsLastUpdatedCheck, _ := CheckColumnExists("secrets", "LastUpdated") + if !secretsLastUpdatedCheck { + // Add the column for LastUpdated in the secrets table + _, err := db.Exec("ALTER TABLE secrets ADD COLUMN LastUpdated datetime;") + if err != nil { + log.Printf("Error altering secrets table to add LastUpdated column : '%s'\n", err) + os.Exit(1) + } + + // Set the default value + if _, err = db.Exec("UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null;"); err != nil { + log.Printf("Error setting LastUpdated of existing secrets : '%s'", err) + os.Exit(1) + } + } + + lastLoginCheck, _ := CheckColumnExists("users", "LastLogin") + if !lastLoginCheck { + // Add the column for LastUpdated in the secrets table + _, err := db.Exec("ALTER TABLE users ADD COLUMN LastLogin datetime;") + if err != nil { + log.Printf("Error altering users table to add LastLogin column : '%s'\n", err) + os.Exit(1) + } + + // Set the default value + if _, err = db.Exec("UPDATE users SET LastLogin = (datetime('1970-01-01 00:00:00')) WHERE LastLogin is null;"); err != nil { + log.Printf("Error setting LastLogin of existing users : '%s'", err) + os.Exit(1) + } + } } -// SecretsGetFromMultipleSafes queries the specified safes for matching secrets -func SecretsGetFromMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) { - var err error - var secretResults []Secret - - queryArgs := []interface{}{} - var query string - // Generate placeholders for the IN clause to match multiple SafeId values - placeholders := make([]string, len(safeIds)) - for i := range safeIds { - placeholders[i] = "?" - } - placeholderStr := strings.Join(placeholders, ",") - - // Create query with the necessary placeholders - query = fmt.Sprintf("SELECT * FROM secrets WHERE SafeId IN (%s) ", placeholderStr) - - // Add the Safe Ids to the arguments list - for _, g := range safeIds { - queryArgs = append(queryArgs, g) - } - - // Add any other arguments to the query if they were specified - if s.SecretId > 0 { - query += " AND SecretId = ? " - queryArgs = append(queryArgs, s.SecretId) - } - - if s.DeviceName != "" { - query += " AND DeviceName LIKE ? " - queryArgs = append(queryArgs, s.DeviceName) - } - - if s.DeviceCategory != "" { - query += " AND DeviceCategory LIKE ? " - queryArgs = append(queryArgs, s.DeviceCategory) - } - - if s.UserName != "" { - query += " AND UserName LIKE ? " - queryArgs = append(queryArgs, s.UserName) - } - - // Execute the query - log.Printf("SecretsGetMultipleSafes query string :\n'%s'\nQuery Args : %+v\n", query, queryArgs) - rows, err := db.Queryx(query, queryArgs...) - +// Count the number of records in the sqlite database +// Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592 +func CheckCount(tablename string) (int, error) { + var count int + stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename) if err != nil { - log.Printf("SecretsGetMultipleSafes error executing sql record : '%s'\n", err) - return secretResults, err - } else { - // parse all the results into a slice - for rows.Next() { - var r Secret - err = rows.StructScan(&r) - if err != nil { - log.Printf("SecretsGetMultipleSafes error parsing sql record : '%s'\n", err) - return secretResults, err - } + log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err) + return 0, err + } + err = stmt.QueryRow().Scan(&count) + if err != nil { + log.Printf("CheckCount error querying database record count : '%s'\n", err) + return 0, err + } + stmt.Close() // or use defer rows.Close(), idc + return count, nil +} - // Decrypt the secret - _, err = r.DecryptSecret() - if err != nil { - log.Printf("SecretsGetMultipleSafes unable to decrypt stored secret : '%s'\n", err) - rows.Close() - return secretResults, err +// From https://stackoverflow.com/a/60100045 +func GenerateInsertMethod(q interface{}) (string, error) { + if reflect.ValueOf(q).Kind() == reflect.Struct { + query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name()) + fieldNames := "" + fieldValues := "" + v := reflect.ValueOf(q) + for i := 0; i < v.NumField(); i++ { + if i == 0 { + fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name) } else { - secretResults = append(secretResults, r) + fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name) + } + switch v.Field(i).Kind() { + case reflect.Int: + if i == 0 { + fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int()) + } else { + fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int()) + } + case reflect.String: + if i == 0 { + fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String()) + } else { + fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String()) + } + case reflect.Bool: + var boolSet int8 + if v.Field(i).Bool() { + boolSet = 1 + } + if i == 0 { + fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet) + } else { + fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet) + } + default: + log.Printf("Unsupported type '%s'\n", v.Field(i).Kind()) } - } - log.Printf("SecretsGetMultipleSafes retrieved '%d' results\n", len(secretResults)) + query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues) + return query, nil } - - return secretResults, nil + return "", errors.New("SqlGenerationError") } -func (s *Secret) UpdateSecret() (*Secret, error) { - - var err error - - // Populate timestamp field if not already set - if s.LastUpdated.IsZero() { - s.LastUpdated = time.Now().UTC() - } - - log.Printf("UpdateSecret storing values '%v'\n", s) - - if s.SecretId == 0 { - err = errors.New("UpdateSecret unable to locate secret with empty secretId field") - log.Printf("UpdateSecret error in pre-check : '%s'\n", err) - return s, err - } - - result, err := db.NamedExec((`UPDATE secrets SET DeviceName = :DeviceName, DeviceCategory = :DeviceCategory, UserName = :UserName, Secret = :Secret, LastUpdated = :LastUpdated WHERE SecretId = :SecretId`), s) +func CheckColumnExists(table string, column string) (bool, error) { + var count int64 + rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';") if err != nil { - log.Printf("UpdateSecret error executing sql record : '%s'\n", err) - return &Secret{}, err - } else { - affected, _ := result.RowsAffected() - id, _ := result.LastInsertId() - log.Printf("UpdateSecret insert returned result id '%d' affecting %d row(s).\n", id, affected) + log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err) + return false, err + } + defer rows.Close() + for rows.Next() { + // cols is an []interface{} of all of the column results + cols, _ := rows.SliceScan() + log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column) + count = cols[0].(int64) + + if count == 1 { + return true, nil + } else { + return false, nil + } } - return s, nil -} - -func (s *Secret) DeleteSecret() (*Secret, error) { - - var err error - - log.Printf("DeleteSecret deleting record with values '%v'\n", s) - - if s.SecretId == 0 { - err = errors.New("unable to locate secret with empty secretId field") - log.Printf("DeleteSecret error in pre-check : '%s'\n", err) - return s, err - } - - result, err := db.NamedExec((`DELETE FROM secrets WHERE SecretId = :SecretId`), s) - if err != nil { - log.Printf("DeleteSecret error executing sql record : '%s'\n", err) - return &Secret{}, err - } else { - affected, _ := result.RowsAffected() - id, _ := result.LastInsertId() - log.Printf("DeleteSecret delete returned result id '%d' affecting %d row(s).\n", id, affected) - } - - return s, nil -} - -// startCipher does the initial setup of the AES256 GCM mode cipher -func startCipher() (cipher.AEAD, error) { - key, err := ProvideKey() - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(key) - if err != nil { - log.Printf("startCipher NewCipher error '%s'\n", err) - return nil, err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - log.Printf("startCipher NewGCM error '%s'\n", err) - return nil, err - } - - return aesgcm, nil -} - -func (s *Secret) EncryptSecret() (*Secret, error) { - - //keyString := os.Getenv("SECRETS_KEY") - //keyString := secretKey - - // The key argument should be the AES key, either 16 or 32 bytes - // to select AES-128 or AES-256. - //key := []byte(keyString) - /* - key, err := ProvideKey() - if err != nil { - return s, err - } - */ - - plaintext := []byte(s.Secret) - - // TODO : move block and aesgcm generation to separate function since the identical code is used for encrypt and decrypt - /* - log.Printf("EncryptSecret applying key '%v' of length '%d' to plaintext secret '%s'\n", key, len(key), s.Secret) - block, err := aes.NewCipher(key) - if err != nil { - log.Printf("EncryptSecret NewCipher error '%s'\n", err) - return s, err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - log.Printf("EncryptSecret NewGCM error '%s'\n", err) - return s, err - } - */ - - aesgcm, err := startCipher() - if err != nil { - log.Printf("EncryptSecret error commencing GCM cipher '%s'\n", err) - return s, err - } - - // Never use more than 2^32 random nonces with a given key because of the risk of a repeat. - nonce := make([]byte, nonceSize) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - log.Printf("EncryptSecret nonce generation error '%s'\n", err) - return s, err - } - //log.Printf("EncryptSecret random nonce value is '%x'\n", nonce) - - ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil) - //log.Printf("EncryptSecret generated ciphertext '%x''\n", ciphertext) - - // Create a new slice to store nonce at the start and then the resulting ciphertext - // Nonce is always 12 bytes - combinedText := append(nonce, ciphertext...) - //log.Printf("EncryptSecret combined secret value is now '%x'\n", combinedText) - - // Store the value back into the struct ready for database operations - s.Secret = hex.EncodeToString(combinedText) - - return s, nil - - //return string(ciphertext[:]), nil -} - -func (s *Secret) DecryptSecret() (*Secret, error) { - // The key argument should be the AES key, either 16 or 32 bytes - // to select AES-128 or AES-256. - //keyString := os.Getenv("SECRETS_KEY") - //keyString := secretKey - - //key := []byte(keyString) - /* - key, err := ProvideKey() - if err != nil { - return s, err - } - */ - - if len(s.Secret) < nonceSize { - log.Printf("DecryptSecret ciphertext is too short to decrypt\n") - return s, errors.New("ciphertext is too short") - } - - crypted, err := hex.DecodeString(s.Secret) - if err != nil { - log.Printf("DecryptSecret unable to convert hex encoded string due to error '%s'\n", err) - return s, err - } - - //log.Printf("DecryptSecret processing secret '%x'\n", crypted) - - // The nonce is the first 12 bytes from the ciphertext - nonce := crypted[:nonceSize] - ciphertext := crypted[nonceSize:] - - /* - log.Printf("DecryptSecret applying key '%v' and nonce '%x' to ciphertext '%x'\n", key, nonce, ciphertext) - - block, err := aes.NewCipher(key) - if err != nil { - log.Printf("DecryptSecret NewCipher error '%s'\n", err) - return s, err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - log.Printf("DecryptSecret NewGCM error '%s'\n", err) - return s, err - } - */ - - aesgcm, err := startCipher() - if err != nil { - log.Printf("DecryptSecret error commencing GCM cipher '%s'\n", err) - return s, err - } - - plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - log.Printf("DecryptSecret Open error '%s'\n", err) - return s, err - } - - //log.Printf("DecryptSecret plaintext is '%s'\n", plaintext) - - s.Secret = string(plaintext) - return s, nil + err = rows.Err() + if err != nil { + log.Printf("CheckColumnExists error getting results : '%s'\n", err) + return false, err + } + + return false, nil } diff --git a/models/secret.go b/models/secret.go new file mode 100644 index 0000000..dafb885 --- /dev/null +++ b/models/secret.go @@ -0,0 +1,478 @@ +package models + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "smt/utils" + "strings" + "time" +) + +const nonceSize = 12 + +// We use the json:"-" field tag to prevent showing these details to the user +type Secret struct { + SecretId int `db:"SecretId" json:"secretId"` + SafeId int `db:"SafeId" json:"safeId"` + DeviceName string `db:"DeviceName" json:"deviceName"` + DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"` + UserName string `db:"UserName" json:"userName"` + Secret string `db:"Secret" json:"secret"` + LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"` +} + +// SecretRestricted is for when we want to output a Secret but not the protected information +type SecretRestricted struct { + SecretId int `db:"SecretId" json:"secretId"` + SafeId int `db:"SafeId" json:"safeId"` + DeviceName string `db:"DeviceName" json:"deviceName"` + DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"` + UserName string `db:"UserName" json:"userName"` + Secret string `db:"Secret" json:"-"` + LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"` +} + +// Used for querying all secrets the user has access to +// Since there are some ambiguous column names (eg UserName is present in both users and secrets table), the order of fields in this struct matters +type UserSecret struct { + Secret + UserUserId int `db:"UserUserId"` + User + //Group + Permission +} + +// This method allows us to use an interface to avoid adding duplicate entries to a []Secret +func (s Secret) GetId() int { + return s.SecretId +} + +func (s *Secret) SaveSecret() (*Secret, error) { + var err error + + // Populate timestamp field if not already set + if s.LastUpdated.IsZero() { + s.LastUpdated = time.Now().UTC() + } + + log.Printf("SaveSecret storing values '%v'\n", s) + result, err := db.NamedExec((`INSERT INTO secrets (SafeId, DeviceName, DeviceCategory, UserName, Secret, LastUpdated) VALUES (:SafeId, :DeviceName, :DeviceCategory, :UserName, :Secret, :LastUpdated)`), s) + + if err != nil { + log.Printf("StoreSecret error executing sql record : '%s'\n", err) + return s, err + } + + affected, _ := result.RowsAffected() + id, _ := result.LastInsertId() + s.SecretId = int(id) + log.Printf("StoreSecret insert returned result id '%d' affecting %d row(s).\n", id, affected) + + return s, nil +} + +// SecretsGetAllowed returns all allowed secrets matching the specified parameters in s +func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) { + var err error + var secretResults []UserSecret + + // Query for group access + queryArgs := []interface{}{} + query := ` + SELECT users.UserId AS UserUserId, permissions.*, + secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName + FROM users + INNER JOIN groups ON users.GroupId = groups.GroupId + INNER JOIN permissions ON groups.GroupId = permissions.GroupId + INNER JOIN secrets on secrets.SafeId = permissions.SafeId + WHERE users.UserId = ? ` + queryArgs = append(queryArgs, userId) + + // Add any other arguments to the query if they were specified + if s.SecretId > 0 { + query += " AND SecretId = ? " + queryArgs = append(queryArgs, s.SecretId) + } + + if s.DeviceName != "" { + query += " AND DeviceName LIKE ? " + queryArgs = append(queryArgs, s.DeviceName) + } + + if s.DeviceCategory != "" { + query += " AND DeviceCategory LIKE ? " + queryArgs = append(queryArgs, s.DeviceCategory) + } + + if s.UserName != "" { + query += " AND secrets.UserName LIKE ? " + queryArgs = append(queryArgs, s.UserName) + } + + // Query for user access + query += ` + UNION + SELECT users.UserId AS UserUserId, permissions.*, + secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName + FROM users + INNER JOIN permissions ON users.UserId = permissions.UserId + INNER JOIN safes on permissions.SafeId = safes.SafeId + INNER JOIN secrets on secrets.SafeId = safes.SafeId + WHERE users.UserId = ?` + queryArgs = append(queryArgs, userId) + + // Add any other arguments to the query if they were specified + if s.SecretId > 0 { + query += " AND SecretId = ? " + queryArgs = append(queryArgs, s.SecretId) + } + if s.DeviceName != "" { + query += " AND DeviceName LIKE ? " + queryArgs = append(queryArgs, s.DeviceName) + } + + if s.DeviceCategory != "" { + query += " AND DeviceCategory LIKE ? " + queryArgs = append(queryArgs, s.DeviceCategory) + } + + if s.UserName != "" { + query += " AND secrets.UserName LIKE ? " + queryArgs = append(queryArgs, s.UserName) + } + + // Execute the query + log.Printf("SecretsGetAllowed query string : '%s'\nArguments:%+v\n", query, queryArgs) + rows, err := db.Queryx(query, queryArgs...) + + if err != nil { + log.Printf("SecretsGetAllowed error executing sql record : '%s'\n", err) + return secretResults, err + } else { + //log.Printf("SecretsGetAllowed any error '%s'\n", rows.Err()) + // parse all the results into a slice + for rows.Next() { + //log.Printf("SecretsGetAllowed processing row\n") + var r UserSecret + err = rows.StructScan(&r) + //log.Printf("SecretsGetAllowed performed struct scan\n") + if err != nil { + log.Printf("SecretsGetAllowed error parsing sql record : '%s'\n", err) + return secretResults, err + } + //log.Printf("r: %v\n", r) + + //log.Printf("SecretsGetAllowed performed err check\n") + + // work around to get the UserId populated in the User field of the struct + r.User.UserId = r.UserUserId + + // For debugging purposes + debugPrint := utils.PrintStructContents(&r, 0) + log.Println(debugPrint) + + //log.Printf("SecretsGetAllowed performed debug print\n") + + // Append the secrets to the query output, don't decrypt the secrets (we didn't SELECT them anyway) + //secretResults = append(secretResults, r) + + // Use generics and the GetID() method on the UserSecret struct + // to avoid adding this element to the results + // if there is already a secret with the same ID present + secretResults = utils.AppendIfNotExists(secretResults, r) + + //log.Printf("SecretsGetAllowed added secret results\n") + } + log.Printf("SecretsGetAllowed retrieved '%d' results\n", len(secretResults)) + } + + return secretResults, nil +} + +// SecretsGetFromMultipleSafes queries the specified safes for matching secrets +func SecretsGetFromMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) { + var err error + var secretResults []Secret + + queryArgs := []interface{}{} + var query string + // Generate placeholders for the IN clause to match multiple SafeId values + placeholders := make([]string, len(safeIds)) + for i := range safeIds { + placeholders[i] = "?" + } + placeholderStr := strings.Join(placeholders, ",") + + // Create query with the necessary placeholders + query = fmt.Sprintf("SELECT * FROM secrets WHERE SafeId IN (%s) ", placeholderStr) + + // Add the Safe Ids to the arguments list + for _, g := range safeIds { + queryArgs = append(queryArgs, g) + } + + // Add any other arguments to the query if they were specified + if s.SecretId > 0 { + query += " AND SecretId = ? " + queryArgs = append(queryArgs, s.SecretId) + } + + if s.DeviceName != "" { + query += " AND DeviceName LIKE ? " + queryArgs = append(queryArgs, s.DeviceName) + } + + if s.DeviceCategory != "" { + query += " AND DeviceCategory LIKE ? " + queryArgs = append(queryArgs, s.DeviceCategory) + } + + if s.UserName != "" { + query += " AND UserName LIKE ? " + queryArgs = append(queryArgs, s.UserName) + } + + // Execute the query + log.Printf("SecretsGetMultipleSafes query string :\n'%s'\nQuery Args : %+v\n", query, queryArgs) + rows, err := db.Queryx(query, queryArgs...) + + if err != nil { + log.Printf("SecretsGetMultipleSafes error executing sql record : '%s'\n", err) + return secretResults, err + } else { + // parse all the results into a slice + for rows.Next() { + var r Secret + err = rows.StructScan(&r) + if err != nil { + log.Printf("SecretsGetMultipleSafes error parsing sql record : '%s'\n", err) + return secretResults, err + } + + // Decrypt the secret + _, err = r.DecryptSecret() + if err != nil { + log.Printf("SecretsGetMultipleSafes unable to decrypt stored secret : '%s'\n", err) + rows.Close() + return secretResults, err + } else { + secretResults = append(secretResults, r) + } + + } + log.Printf("SecretsGetMultipleSafes retrieved '%d' results\n", len(secretResults)) + } + + return secretResults, nil +} + +func (s *Secret) UpdateSecret() (*Secret, error) { + + var err error + + // Populate timestamp field if not already set + if s.LastUpdated.IsZero() { + s.LastUpdated = time.Now().UTC() + } + + log.Printf("UpdateSecret storing values '%v'\n", s) + + if s.SecretId == 0 { + err = errors.New("UpdateSecret unable to locate secret with empty secretId field") + log.Printf("UpdateSecret error in pre-check : '%s'\n", err) + return s, err + } + + result, err := db.NamedExec((`UPDATE secrets SET DeviceName = :DeviceName, DeviceCategory = :DeviceCategory, UserName = :UserName, Secret = :Secret, LastUpdated = :LastUpdated WHERE SecretId = :SecretId`), s) + if err != nil { + log.Printf("UpdateSecret error executing sql record : '%s'\n", err) + return &Secret{}, err + } else { + affected, _ := result.RowsAffected() + id, _ := result.LastInsertId() + log.Printf("UpdateSecret insert returned result id '%d' affecting %d row(s).\n", id, affected) + } + + return s, nil +} + +func (s *Secret) DeleteSecret() (*Secret, error) { + + var err error + + log.Printf("DeleteSecret deleting record with values '%v'\n", s) + + if s.SecretId == 0 { + err = errors.New("unable to locate secret with empty secretId field") + log.Printf("DeleteSecret error in pre-check : '%s'\n", err) + return s, err + } + + result, err := db.NamedExec((`DELETE FROM secrets WHERE SecretId = :SecretId`), s) + if err != nil { + log.Printf("DeleteSecret error executing sql record : '%s'\n", err) + return &Secret{}, err + } else { + affected, _ := result.RowsAffected() + id, _ := result.LastInsertId() + log.Printf("DeleteSecret delete returned result id '%d' affecting %d row(s).\n", id, affected) + } + + return s, nil +} + +// startCipher does the initial setup of the AES256 GCM mode cipher +func startCipher() (cipher.AEAD, error) { + key, err := ProvideKey() + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + log.Printf("startCipher NewCipher error '%s'\n", err) + return nil, err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + log.Printf("startCipher NewGCM error '%s'\n", err) + return nil, err + } + + return aesgcm, nil +} + +func (s *Secret) EncryptSecret() (*Secret, error) { + + //keyString := os.Getenv("SECRETS_KEY") + //keyString := secretKey + + // The key argument should be the AES key, either 16 or 32 bytes + // to select AES-128 or AES-256. + //key := []byte(keyString) + /* + key, err := ProvideKey() + if err != nil { + return s, err + } + */ + + plaintext := []byte(s.Secret) + + // TODO : move block and aesgcm generation to separate function since the identical code is used for encrypt and decrypt + /* + log.Printf("EncryptSecret applying key '%v' of length '%d' to plaintext secret '%s'\n", key, len(key), s.Secret) + block, err := aes.NewCipher(key) + if err != nil { + log.Printf("EncryptSecret NewCipher error '%s'\n", err) + return s, err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + log.Printf("EncryptSecret NewGCM error '%s'\n", err) + return s, err + } + */ + + aesgcm, err := startCipher() + if err != nil { + log.Printf("EncryptSecret error commencing GCM cipher '%s'\n", err) + return s, err + } + + // Never use more than 2^32 random nonces with a given key because of the risk of a repeat. + nonce := make([]byte, nonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + log.Printf("EncryptSecret nonce generation error '%s'\n", err) + return s, err + } + //log.Printf("EncryptSecret random nonce value is '%x'\n", nonce) + + ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil) + //log.Printf("EncryptSecret generated ciphertext '%x''\n", ciphertext) + + // Create a new slice to store nonce at the start and then the resulting ciphertext + // Nonce is always 12 bytes + combinedText := append(nonce, ciphertext...) + //log.Printf("EncryptSecret combined secret value is now '%x'\n", combinedText) + + // Store the value back into the struct ready for database operations + s.Secret = hex.EncodeToString(combinedText) + + return s, nil + + //return string(ciphertext[:]), nil +} + +func (s *Secret) DecryptSecret() (*Secret, error) { + // The key argument should be the AES key, either 16 or 32 bytes + // to select AES-128 or AES-256. + //keyString := os.Getenv("SECRETS_KEY") + //keyString := secretKey + + //key := []byte(keyString) + /* + key, err := ProvideKey() + if err != nil { + return s, err + } + */ + + if len(s.Secret) < nonceSize { + log.Printf("DecryptSecret ciphertext is too short to decrypt\n") + return s, errors.New("ciphertext is too short") + } + + crypted, err := hex.DecodeString(s.Secret) + if err != nil { + log.Printf("DecryptSecret unable to convert hex encoded string due to error '%s'\n", err) + return s, err + } + + //log.Printf("DecryptSecret processing secret '%x'\n", crypted) + + // The nonce is the first 12 bytes from the ciphertext + nonce := crypted[:nonceSize] + ciphertext := crypted[nonceSize:] + + /* + log.Printf("DecryptSecret applying key '%v' and nonce '%x' to ciphertext '%x'\n", key, nonce, ciphertext) + + block, err := aes.NewCipher(key) + if err != nil { + log.Printf("DecryptSecret NewCipher error '%s'\n", err) + return s, err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + log.Printf("DecryptSecret NewGCM error '%s'\n", err) + return s, err + } + */ + + aesgcm, err := startCipher() + if err != nil { + log.Printf("DecryptSecret error commencing GCM cipher '%s'\n", err) + return s, err + } + + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + log.Printf("DecryptSecret Open error '%s'\n", err) + return s, err + } + + //log.Printf("DecryptSecret plaintext is '%s'\n", plaintext) + + s.Secret = string(plaintext) + return s, nil +} diff --git a/models/setup.go b/models/setup.go deleted file mode 100644 index 4a649fb..0000000 --- a/models/setup.go +++ /dev/null @@ -1,527 +0,0 @@ -package models - -import ( - "errors" - "fmt" - "log" - "os" - "reflect" - - "smt/utils" - - "github.com/jmoiron/sqlx" - "golang.org/x/crypto/bcrypt" - _ "modernc.org/sqlite" -) - -var db *sqlx.DB - -const ( - sqlFile = "smt.db" -) - -const createUsers string = ` - CREATE TABLE IF NOT EXISTS users ( - UserId INTEGER PRIMARY KEY AUTOINCREMENT, - GroupId INTEGER, - UserName VARCHAR, - Password VARCHAR, - Admin BOOLEAN DEFAULT 0, - LdapUser BOOLEAN DEFAULT 0 - ); -` - -const createSafes string = ` - CREATE TABLE IF NOT EXISTS safes ( - SafeId INTEGER PRIMARY KEY AUTOINCREMENT, - SafeName VARCHAR - ); -` - -const createGroups string = ` - CREATE TABLE IF NOT EXISTS groups ( - GroupId INTEGER PRIMARY KEY AUTOINCREMENT, - GroupName VARCHAR, - LdapGroup BOOLEAN DEFAULT 0, - LdapDn VARCHAR DEFAULT '', - Admin BOOLEAN DEFAULT 0 - ); -` - -const createPermissions = ` - CREATE TABLE IF NOT EXISTS permissions ( - PermissionId INTEGER PRIMARY KEY AUTOINCREMENT, - Description VARCHAR DEFAULT '', - ReadOnly BOOLEAN DEFAULT 0, - SafeId INTEGER, - UserId INTEGER DEFAULT 0, - GroupId INTEGER DEFAULT 0, - FOREIGN KEY (SafeId) REFERENCES safes(SafeId) - ); -` - -const createSecrets string = ` - CREATE TABLE IF NOT EXISTS secrets ( - SecretId INTEGER PRIMARY KEY AUTOINCREMENT, - SafeId INTEGER, - DeviceName VARCHAR, - DeviceCategory VARCHAR, - UserName VARCHAR, - Secret VARCHAR, - LastUpdated datetime DEFAULT (datetime('1970-01-01 00:00:00')), - FOREIGN KEY (SafeId) REFERENCES safes(SafeId) - ); -` - -const createSchema string = ` - CREATE TABLE IF NOT EXISTS schema ( - Version INTEGER - ); -` - -const createAudit string = ` - CREATE TABLE IF NOT EXISTS audit ( - AuditId INTEGER PRIMARY KEY AUTOINCREMENT, - UserId INTEGER DEFAULT 0, - SecretId INTEGER DEFAULT 0, - EventText VARCHAR, - EventTime datetime - ); -` - -// Establish connection to sqlite database -func ConnectDatabase() { - var err error - - // Try using sqlite as our database - sqlPath := utils.GetFilePath(sqlFile) - db, err = sqlx.Open("sqlite", sqlPath) - - if err != nil { - log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err) - os.Exit(1) - } else { - log.Printf("Connected to sqlite database file '%s'\n", sqlPath) - } - - //sqlx.NameMapper = func(s string) string { return s } - - // Make sure our tables exist - CreateTables() - - //defer db.Close() -} - -func DisconnectDatabase() { - log.Printf("DisconnectDatabase called") - defer db.Close() -} - -func CreateTables() { - var err error - var rowCount int - // Create database tables if it doesn't exist - /* - // Roles table should go first since other tables refer to it - if _, err = db.Exec(createRoles); err != nil { - log.Printf("Error checking roles table : '%s'", err) - os.Exit(1) - } - rowCount, _ = CheckCount("roles") - if rowCount == 0 { - if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false);"); err != nil { - log.Printf("Error adding initial admin role : '%s'", err) - os.Exit(1) - } - if _, err = db.Exec("INSERT INTO roles VALUES(2, 'UserRole', false);"); err != nil { - log.Printf("Error adding initial user role : '%s'", err) - os.Exit(1) - } - if _, err = db.Exec("INSERT INTO roles VALUES(3, 'GuestRole', true);"); err != nil { - log.Printf("Error adding initial guest role : '%s'", err) - os.Exit(1) - } - } - */ - - // groups table - if _, err = db.Exec(createGroups); err != nil { - log.Printf("Error checking groups table : '%s'", err) - os.Exit(1) - } - - // Add initial groups - rowCount, _ = CheckCount("groups") - if rowCount == 0 { - if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(1, 'Administrators', 1);"); err != nil { - log.Printf("Error adding initial group entry id 1 : '%s'", err) - os.Exit(1) - } - if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(2, 'Users', 0);"); err != nil { - log.Printf("Error adding initial group entry id 2 : '%s'", err) - os.Exit(1) - } - } - - // Users table - if _, err = db.Exec(createUsers); err != nil { - log.Printf("Error checking users table : '%s'", err) - os.Exit(1) - } - rowCount, _ = CheckCount("users") - if rowCount == 0 { - // Check if there was an initial password defined in the .env file - initialPassword := os.Getenv("INITIAL_PASSWORD") - if initialPassword == "" { - initialPassword = "password" - } else if initialPassword[:4] == "$2a$" { - log.Printf("CreateTables inital admin password is already a hash") - } else { - cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost) - initialPassword = string(cryptText) - } - if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(1, 1, 'Administrator', ?, false, true);", initialPassword); err != nil { - log.Printf("Error adding initial admin role : '%s'", err) - os.Exit(1) - } - if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(2, 2, 'User', ?, false, false);", initialPassword); err != nil { - log.Printf("Error adding initial admin role : '%s'", err) - os.Exit(1) - } - } - - // Safes table - if _, err = db.Exec(createSafes); err != nil { - log.Printf("Error checking safes table : '%s'", err) - os.Exit(1) - } - // Create an initial safe - rowCount, _ = CheckCount("safes") - if rowCount == 0 { - if _, err = db.Exec("INSERT INTO safes VALUES(1, 'Default Safe');"); err != nil { - log.Printf("Error adding initial safe entry : '%s'", err) - os.Exit(1) - } - } - - // Secrets table - if _, err = db.Exec(createSecrets); err != nil { - log.Printf("Error checking secrets table : '%s'", err) - os.Exit(1) - } - - // permissions table - if _, err = db.Exec(createPermissions); err != nil { - log.Printf("Error checking permissions table : '%s'", err) - os.Exit(1) - } - - // Add initial permissions - rowCount, _ = CheckCount("permissions") - if rowCount == 0 { - if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, GroupId, SafeId) VALUES('Default Admin Group Permission', false, 1, 1);"); err != nil { - log.Printf("Error adding initial permissions entry userid 1 : '%s'", err) - os.Exit(1) - } - if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, SafeId, GroupId) VALUES('Default User Group Permission', false, 1, 2);"); err != nil { - log.Printf("Error adding initial permissions entry userid 2 : '%s'", err) - os.Exit(1) - } - } - - // Schema table should go last so we know if the database has a value in the schema table then everything was created properly - if _, err = db.Exec(createSchema); err != nil { - log.Printf("Error checking schema table : '%s'", err) - os.Exit(1) - } - schemaCheck, _ := CheckColumnExists("schema", "Version") - if !schemaCheck { - if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil { - log.Printf("Error adding initial schema version : '%s'", err) - os.Exit(1) - } - } - - // Audit log table - if _, err = db.Exec(createAudit); err != nil { - log.Printf("Error checking audit table : '%s'", err) - os.Exit(1) - } - - // Remove users RoleId column - userRoleIdCheck, _ := CheckColumnExists("users", "RoleId") - if userRoleIdCheck { - //_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;") - _, err := db.Exec(` - PRAGMA foreign_keys=off; - BEGIN TRANSACTION; - ALTER TABLE users RENAME TO _users_old; - CREATE TABLE users - ( - UserId INTEGER PRIMARY KEY AUTOINCREMENT, - GroupId INTEGER, - UserName VARCHAR, - Password VARCHAR, - Admin BOOLEAN DEFAULT 0, - LdapUser BOOLEAN DEFAULT 0 - ); - INSERT INTO users SELECT * FROM _users_old; - COMMIT; - PRAGMA foreign_keys=on; - DROP TABLE _users_old; - `) - if err != nil { - log.Printf("Error altering users table to drop RoleId column : '%s'\n", err) - os.Exit(1) - } - } - - // Set any unassigned secrets to the default safe id - if _, err = db.Exec("UPDATE users SET LdapUser = 0 WHERE LdapUser is null;"); err != nil { - log.Printf("Error setting LdapUser flag to false for existing users : '%s'", err) - os.Exit(1) - } - - // Remove LdapGroup column from roles table - ldapCheck, _ := CheckColumnExists("roles", "LdapGroup") - if ldapCheck { - _, err := db.Exec("ALTER TABLE roles DROP COLUMN LdapGroup;") - if err != nil { - log.Printf("Error altering roles table to renmove LdapGroup column : '%s'\n", err) - os.Exit(1) - } - } - - // Add SafeId column to secrets table - safeIdCheck, _ := CheckColumnExists("secrets", "SafeId") - if !safeIdCheck { - // Add the column for LdapGroup in the roles table - _, err := db.Exec("ALTER TABLE secrets ADD COLUMN SafeId INTEGER REFERENCES safes(SafeId);") - if err != nil { - log.Printf("Error altering secrets table to add SafeId column : '%s'\n", err) - os.Exit(1) - } - } - - // Set any unassigned secrets to the default safe id - if _, err = db.Exec("UPDATE secrets SET SafeId = 1 WHERE SafeId is null;"); err != nil { - log.Printf("Error setting safe ID of existing secrets : '%s'", err) - os.Exit(1) - } - - // Remove RoleId column from secrets table - secretsRoleIdCheck, _ := CheckColumnExists("secrets", "RoleId") - if secretsRoleIdCheck { - _, err := db.Exec(` - PRAGMA foreign_keys=off; - BEGIN TRANSACTION; - ALTER TABLE secrets RENAME TO _secrets_old; - CREATE TABLE secrets - ( - SecretId INTEGER PRIMARY KEY AUTOINCREMENT, - RoleId INTEGER, - SafeId INTEGER, - DeviceName VARCHAR, - DeviceCategory VARCHAR, - UserName VARCHAR, - Secret VARCHAR, - FOREIGN KEY (SafeId) REFERENCES safes(SafeId) - ); - INSERT INTO secrets SELECT SecretId, RoleId, SafeId, DeviceName, DeviceCategory, UserName, Secret FROM _secrets_old; - ALTER TABLE secrets DROP COLUMN RoleId; - ALTER TABLE secrets ADD COLUMN LastUpdated datetime; - UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null; - COMMIT; - PRAGMA foreign_keys=on; - DROP TABLE _secrets_old; - `) - if err != nil { - log.Printf("Error altering secrets table to remove RoleId column : '%s'\n", err) - os.Exit(1) - } - } - - // Remove the Admin column from roles table - rolesAdminCheck, _ := CheckColumnExists("roles", "Admin") - if rolesAdminCheck { - _, err := db.Exec("ALTER TABLE roles DROP COLUMN Admin;") - if err != nil { - log.Printf("Error altering roles table to remove Admin column : '%s'\n", err) - os.Exit(1) - } - } - - // Remove the RoleId from permissiosn table - permissionsRoleIdCheck, _ := CheckColumnExists("permissions", "RoleId") - if permissionsRoleIdCheck { - _, err := db.Exec(` - PRAGMA foreign_keys=off; - BEGIN TRANSACTION; - ALTER TABLE permissions RENAME TO _permissions_old; - CREATE TABLE permissions - ( - PermissionId INTEGER PRIMARY KEY AUTOINCREMENT, - Description VARCHAR DEFAULT '', - ReadOnly BOOLEAN DEFAULT 0, - SafeId INTEGER, - UserId INTEGER DEFAULT 0, - GroupId INTEGER DEFAULT 0, - FOREIGN KEY (SafeId) REFERENCES safes(SafeId) - ); - INSERT INTO permissions SELECT PermissionId, SafeId, UserId, GroupId, '' AS Description, 0 as ReadOnly FROM _permissions_old; - UPDATE permissions SET ReadOnly = 0 WHERE ReadOnly is null; - UPDATE permissions SET Description = '' WHERE Description is null; - UPDATE permissions SET UserId = 0 WHERE UserId is null; - UPDATE permissions SET GroupId = 0 WHERE GroupId is null; - COMMIT; - PRAGMA foreign_keys=on; - DROP TABLE _permissions_old; - `) - if err != nil { - log.Printf("Error altering permissions table to remove RoleId column : '%s'\n", err) - os.Exit(1) - } - } - - secretsLastUpdatedCheck, _ := CheckColumnExists("secrets", "LastUpdated") - if !secretsLastUpdatedCheck { - // Add the column for LastUpdated in the secrets table - _, err := db.Exec("ALTER TABLE secrets ADD COLUMN LastUpdated datetime;") - if err != nil { - log.Printf("Error altering secrets table to add LastUpdated column : '%s'\n", err) - os.Exit(1) - } - - // Set the default value - if _, err = db.Exec("UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null;"); err != nil { - log.Printf("Error setting LastUpdated of existing secrets : '%s'", err) - os.Exit(1) - } - } - - /* - // Database updates added after initial version released - ldapCheck, _ := CheckColumnExists("roles", "LdapGroup") - - if !ldapCheck { - // Add the column for LdapGroup in the roles table - _, err := db.Exec("ALTER TABLE roles ADD COLUMN LdapGroup VARCHAR DEFAULT '';") - if err != nil { - log.Printf("Error altering roles table to add LdapGroup column : '%s'\n", err) - os.Exit(1) - } - } - - // Add the two LDAP columns to the users table if they weren't there - ldapUserCheck, _ := CheckColumnExists("users", "LdapUser") - if !ldapUserCheck { - log.Printf("CreateTables creating ldap columns in user table") - _, err := db.Exec("ALTER TABLE users ADD COLUMN LdapUser BOOLEAN DEFAULT 0;") - if err != nil { - log.Printf("Error altering users table to add LdapUser column : '%s'\n", err) - os.Exit(1) - } - - _, err = db.Exec("ALTER TABLE users ADD COLUMN LdapDn VARCHAR DEFAULT '';") - if err != nil { - log.Printf("Error altering users table to add LdapDn column : '%s'\n", err) - os.Exit(1) - } - } - */ - -} - -// Count the number of records in the sqlite database -// Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592 -func CheckCount(tablename string) (int, error) { - var count int - stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename) - if err != nil { - log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err) - return 0, err - } - err = stmt.QueryRow().Scan(&count) - if err != nil { - log.Printf("CheckCount error querying database record count : '%s'\n", err) - return 0, err - } - stmt.Close() // or use defer rows.Close(), idc - return count, nil -} - -// From https://stackoverflow.com/a/60100045 -func GenerateInsertMethod(q interface{}) (string, error) { - if reflect.ValueOf(q).Kind() == reflect.Struct { - query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name()) - fieldNames := "" - fieldValues := "" - v := reflect.ValueOf(q) - for i := 0; i < v.NumField(); i++ { - if i == 0 { - fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name) - } else { - fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name) - } - switch v.Field(i).Kind() { - case reflect.Int: - if i == 0 { - fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int()) - } else { - fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int()) - } - case reflect.String: - if i == 0 { - fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String()) - } else { - fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String()) - } - case reflect.Bool: - var boolSet int8 - if v.Field(i).Bool() { - boolSet = 1 - } - if i == 0 { - fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet) - } else { - fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet) - } - default: - log.Printf("Unsupported type '%s'\n", v.Field(i).Kind()) - } - } - query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues) - return query, nil - } - return "", errors.New("SqlGenerationError") -} - -func CheckColumnExists(table string, column string) (bool, error) { - var count int64 - rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';") - if err != nil { - log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err) - return false, err - } - defer rows.Close() - for rows.Next() { - // cols is an []interface{} of all of the column results - cols, _ := rows.SliceScan() - log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column) - count = cols[0].(int64) - - if count == 1 { - return true, nil - } else { - return false, nil - } - } - - err = rows.Err() - if err != nil { - log.Printf("CheckColumnExists error getting results : '%s'\n", err) - return false, err - } - - return false, nil -} diff --git a/models/user.go b/models/user.go index 746e4b5..33eeb67 100644 --- a/models/user.go +++ b/models/user.go @@ -7,18 +7,19 @@ import ( "log" "smt/utils" "smt/utils/token" + "time" "golang.org/x/crypto/bcrypt" ) type User struct { - UserId int `db:"UserId" json:"userId"` - GroupId int `db:"GroupId" json:"groupId"` - UserName string `db:"UserName" json:"userName"` - Password string `db:"Password" json:"-"` - LdapUser bool `db:"LdapUser" json:"ldapUser"` - Admin bool `db:"Admin"` - //LdapDn string `db:"LdapDn" json:"ldapDn"` + UserId int `db:"UserId" json:"userId"` + GroupId int `db:"GroupId" json:"groupId"` + UserName string `db:"UserName" json:"userName"` + Password string `db:"Password" json:"-"` + LdapUser bool `db:"LdapUser" json:"ldapUser"` + Admin bool `db:"Admin"` + LastLogin time.Time `db:"LastLogin" json:"lastLogin"` } type UserRole struct { @@ -207,6 +208,8 @@ func LoginCheck(username string, password string) (string, error) { return "", err } + u.UserSetLastLogin() + return token, nil } @@ -286,15 +289,23 @@ func UserLdapNewLoginCheck(username string, password string) (User, error) { return u, nil } -/* -// StoreLdapUser creates a user record in the database and returns the corresponding userId -func StoreLdapUser(u *User) error { +func (u *User) UserSetLastLogin() error { - // TODO + u.LastLogin = time.Now().UTC() + + result, err := db.NamedExec((`UPDATE users SET LastLogin = :LastLogin WHERE UserId = :UserId`), u) + + if err != nil { + log.Printf("UserSetLastLogin error executing sql update : '%s'\n", err) + return err + } else { + affected, _ := result.RowsAffected() + id, _ := result.LastInsertId() + log.Printf("UserSetLastLogin returned result id '%d' affecting %d row(s).\n", id, affected) + } return nil } -*/ func UserGetByID(uid uint) (User, error) {