package models import ( "errors" "fmt" "log" "os" "reflect" "smt/utils" "github.com/jmoiron/sqlx" "golang.org/x/crypto/bcrypt" _ "modernc.org/sqlite" ) var db *sqlx.DB const ( sqlFile = "smt.db" ) const createUsers string = ` CREATE TABLE IF NOT EXISTS users ( UserId INTEGER PRIMARY KEY AUTOINCREMENT, GroupId INTEGER, UserName VARCHAR, Password VARCHAR, Admin BOOLEAN DEFAULT 0, LdapUser BOOLEAN DEFAULT 0 ); ` const createSafes string = ` CREATE TABLE IF NOT EXISTS safes ( SafeId INTEGER PRIMARY KEY AUTOINCREMENT, SafeName VARCHAR ); ` const createGroups string = ` CREATE TABLE IF NOT EXISTS groups ( GroupId INTEGER PRIMARY KEY AUTOINCREMENT, GroupName VARCHAR, LdapGroup BOOLEAN DEFAULT 0, LdapDn VARCHAR DEFAULT '', Admin BOOLEAN DEFAULT 0 ); ` const createPermissions = ` CREATE TABLE IF NOT EXISTS permissions ( PermissionId INTEGER PRIMARY KEY AUTOINCREMENT, Description VARCHAR DEFAULT '', ReadOnly BOOLEAN DEFAULT 0, SafeId INTEGER, UserId INTEGER DEFAULT 0, GroupId INTEGER DEFAULT 0, FOREIGN KEY (SafeId) REFERENCES safes(SafeId) ); ` const createSecrets string = ` CREATE TABLE IF NOT EXISTS secrets ( SecretId INTEGER PRIMARY KEY AUTOINCREMENT, SafeId INTEGER, DeviceName VARCHAR, DeviceCategory VARCHAR, UserName VARCHAR, Secret VARCHAR, LastUpdated datetime DEFAULT (datetime('1970-01-01 00:00:00')), FOREIGN KEY (SafeId) REFERENCES safes(SafeId) ); ` const createSchema string = ` CREATE TABLE IF NOT EXISTS schema ( Version INTEGER ); ` const createAudit string = ` CREATE TABLE IF NOT EXISTS audit ( AuditId INTEGER PRIMARY KEY AUTOINCREMENT, UserId INTEGER DEFAULT 0, SecretId INTEGER DEFAULT 0, EventText VARCHAR, EventTime datetime ); ` // Establish connection to sqlite database func ConnectDatabase() { var err error // Try using sqlite as our database sqlPath := utils.GetFilePath(sqlFile) db, err = sqlx.Open("sqlite", sqlPath) if err != nil { log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err) os.Exit(1) } else { log.Printf("Connected to sqlite database file '%s'\n", sqlPath) } //sqlx.NameMapper = func(s string) string { return s } // Make sure our tables exist CreateTables() //defer db.Close() } func DisconnectDatabase() { log.Printf("DisconnectDatabase called") defer db.Close() } func CreateTables() { var err error var rowCount int // Create database tables if it doesn't exist /* // Roles table should go first since other tables refer to it if _, err = db.Exec(createRoles); err != nil { log.Printf("Error checking roles table : '%s'", err) os.Exit(1) } rowCount, _ = CheckCount("roles") if rowCount == 0 { if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false);"); err != nil { log.Printf("Error adding initial admin role : '%s'", err) os.Exit(1) } if _, err = db.Exec("INSERT INTO roles VALUES(2, 'UserRole', false);"); err != nil { log.Printf("Error adding initial user role : '%s'", err) os.Exit(1) } if _, err = db.Exec("INSERT INTO roles VALUES(3, 'GuestRole', true);"); err != nil { log.Printf("Error adding initial guest role : '%s'", err) os.Exit(1) } } */ // groups table if _, err = db.Exec(createGroups); err != nil { log.Printf("Error checking groups table : '%s'", err) os.Exit(1) } // Add initial groups rowCount, _ = CheckCount("groups") if rowCount == 0 { if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(1, 'Administrators', 1);"); err != nil { log.Printf("Error adding initial group entry id 1 : '%s'", err) os.Exit(1) } if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(2, 'Users', 0);"); err != nil { log.Printf("Error adding initial group entry id 2 : '%s'", err) os.Exit(1) } } // Users table if _, err = db.Exec(createUsers); err != nil { log.Printf("Error checking users table : '%s'", err) os.Exit(1) } rowCount, _ = CheckCount("users") if rowCount == 0 { // Check if there was an initial password defined in the .env file initialPassword := os.Getenv("INITIAL_PASSWORD") if initialPassword == "" { initialPassword = "password" } else if initialPassword[:4] == "$2a$" { log.Printf("CreateTables inital admin password is already a hash") } else { cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost) initialPassword = string(cryptText) } if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(1, 1, 'Administrator', ?, false, true);", initialPassword); err != nil { log.Printf("Error adding initial admin role : '%s'", err) os.Exit(1) } if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(2, 2, 'User', ?, false, false);", initialPassword); err != nil { log.Printf("Error adding initial admin role : '%s'", err) os.Exit(1) } } // Safes table if _, err = db.Exec(createSafes); err != nil { log.Printf("Error checking safes table : '%s'", err) os.Exit(1) } // Create an initial safe rowCount, _ = CheckCount("safes") if rowCount == 0 { if _, err = db.Exec("INSERT INTO safes VALUES(1, 'Default Safe');"); err != nil { log.Printf("Error adding initial safe entry : '%s'", err) os.Exit(1) } } // Secrets table if _, err = db.Exec(createSecrets); err != nil { log.Printf("Error checking secrets table : '%s'", err) os.Exit(1) } // permissions table if _, err = db.Exec(createPermissions); err != nil { log.Printf("Error checking permissions table : '%s'", err) os.Exit(1) } // Add initial permissions rowCount, _ = CheckCount("permissions") if rowCount == 0 { if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, GroupId, SafeId) VALUES('Default Admin Group Permission', false, 1, 1);"); err != nil { log.Printf("Error adding initial permissions entry userid 1 : '%s'", err) os.Exit(1) } if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, SafeId, GroupId) VALUES('Default User Group Permission', false, 1, 2);"); err != nil { log.Printf("Error adding initial permissions entry userid 2 : '%s'", err) os.Exit(1) } } // Schema table should go last so we know if the database has a value in the schema table then everything was created properly if _, err = db.Exec(createSchema); err != nil { log.Printf("Error checking schema table : '%s'", err) os.Exit(1) } schemaCheck, _ := CheckColumnExists("schema", "Version") if !schemaCheck { if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil { log.Printf("Error adding initial schema version : '%s'", err) os.Exit(1) } } // Audit log table if _, err = db.Exec(createAudit); err != nil { log.Printf("Error checking audit table : '%s'", err) os.Exit(1) } // Remove users RoleId column userRoleIdCheck, _ := CheckColumnExists("users", "RoleId") if userRoleIdCheck { //_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;") _, err := db.Exec(` PRAGMA foreign_keys=off; BEGIN TRANSACTION; ALTER TABLE users RENAME TO _users_old; CREATE TABLE users ( UserId INTEGER PRIMARY KEY AUTOINCREMENT, GroupId INTEGER, UserName VARCHAR, Password VARCHAR, Admin BOOLEAN DEFAULT 0, LdapUser BOOLEAN DEFAULT 0 ); INSERT INTO users SELECT * FROM _users_old; COMMIT; PRAGMA foreign_keys=on; DROP TABLE _users_old; `) if err != nil { log.Printf("Error altering users table to drop RoleId column : '%s'\n", err) os.Exit(1) } } // Set any unassigned secrets to the default safe id if _, err = db.Exec("UPDATE users SET LdapUser = 0 WHERE LdapUser is null;"); err != nil { log.Printf("Error setting LdapUser flag to false for existing users : '%s'", err) os.Exit(1) } // Remove LdapGroup column from roles table ldapCheck, _ := CheckColumnExists("roles", "LdapGroup") if ldapCheck { _, err := db.Exec("ALTER TABLE roles DROP COLUMN LdapGroup;") if err != nil { log.Printf("Error altering roles table to renmove LdapGroup column : '%s'\n", err) os.Exit(1) } } // Add SafeId column to secrets table safeIdCheck, _ := CheckColumnExists("secrets", "SafeId") if !safeIdCheck { // Add the column for LdapGroup in the roles table _, err := db.Exec("ALTER TABLE secrets ADD COLUMN SafeId INTEGER REFERENCES safes(SafeId);") if err != nil { log.Printf("Error altering secrets table to add SafeId column : '%s'\n", err) os.Exit(1) } } // Set any unassigned secrets to the default safe id if _, err = db.Exec("UPDATE secrets SET SafeId = 1 WHERE SafeId is null;"); err != nil { log.Printf("Error setting safe ID of existing secrets : '%s'", err) os.Exit(1) } // Remove RoleId column from secrets table secretsRoleIdCheck, _ := CheckColumnExists("secrets", "RoleId") if secretsRoleIdCheck { _, err := db.Exec(` PRAGMA foreign_keys=off; BEGIN TRANSACTION; ALTER TABLE secrets RENAME TO _secrets_old; CREATE TABLE secrets ( SecretId INTEGER PRIMARY KEY AUTOINCREMENT, RoleId INTEGER, SafeId INTEGER, DeviceName VARCHAR, DeviceCategory VARCHAR, UserName VARCHAR, Secret VARCHAR, FOREIGN KEY (SafeId) REFERENCES safes(SafeId) ); INSERT INTO secrets SELECT SecretId, RoleId, SafeId, DeviceName, DeviceCategory, UserName, Secret FROM _secrets_old; ALTER TABLE secrets DROP COLUMN RoleId; COMMIT; PRAGMA foreign_keys=on; DROP TABLE _secrets_old; `) if err != nil { log.Printf("Error altering secrets table to remove RoleId column : '%s'\n", err) os.Exit(1) } } // Remove the Admin column from roles table rolesAdminCheck, _ := CheckColumnExists("roles", "Admin") if rolesAdminCheck { _, err := db.Exec("ALTER TABLE roles DROP COLUMN Admin;") if err != nil { log.Printf("Error altering roles table to remove Admin column : '%s'\n", err) os.Exit(1) } } // Remove the RoleId from permissiosn table permissionsRoleIdCheck, _ := CheckColumnExists("permissions", "RoleId") if permissionsRoleIdCheck { _, err := db.Exec(` PRAGMA foreign_keys=off; BEGIN TRANSACTION; ALTER TABLE permissions RENAME TO _permissions_old; CREATE TABLE permissions ( PermissionId INTEGER PRIMARY KEY AUTOINCREMENT, Description VARCHAR DEFAULT '', ReadOnly BOOLEAN DEFAULT 0, SafeId INTEGER, UserId INTEGER DEFAULT 0, GroupId INTEGER DEFAULT 0, FOREIGN KEY (SafeId) REFERENCES safes(SafeId) ); INSERT INTO permissions SELECT PermissionId, SafeId, UserId, GroupId, '' AS Description, 0 as ReadOnly FROM _permissions_old; UPDATE permissions SET ReadOnly = 0 WHERE ReadOnly is null; UPDATE permissions SET Description = '' WHERE Description is null; UPDATE permissions SET UserId = 0 WHERE UserId is null; UPDATE permissions SET GroupId = 0 WHERE GroupId is null; COMMIT; PRAGMA foreign_keys=on; DROP TABLE _permissions_old; `) if err != nil { log.Printf("Error altering permissions table to remove RoleId column : '%s'\n", err) os.Exit(1) } } secretsLastUpdatedCheck, _ := CheckColumnExists("secrets", "LastUpdated") if !secretsLastUpdatedCheck { // Add the column for LastUpdated in the secrets table _, err := db.Exec("ALTER TABLE secrets ADD COLUMN LastUpdated datetime DEFAULT (datetime('1970-01-01 00:00:00'));") if err != nil { log.Printf("Error altering secrets table to add LastUpdated column : '%s'\n", err) os.Exit(1) } } /* // Database updates added after initial version released ldapCheck, _ := CheckColumnExists("roles", "LdapGroup") if !ldapCheck { // Add the column for LdapGroup in the roles table _, err := db.Exec("ALTER TABLE roles ADD COLUMN LdapGroup VARCHAR DEFAULT '';") if err != nil { log.Printf("Error altering roles table to add LdapGroup column : '%s'\n", err) os.Exit(1) } } // Add the two LDAP columns to the users table if they weren't there ldapUserCheck, _ := CheckColumnExists("users", "LdapUser") if !ldapUserCheck { log.Printf("CreateTables creating ldap columns in user table") _, err := db.Exec("ALTER TABLE users ADD COLUMN LdapUser BOOLEAN DEFAULT 0;") if err != nil { log.Printf("Error altering users table to add LdapUser column : '%s'\n", err) os.Exit(1) } _, err = db.Exec("ALTER TABLE users ADD COLUMN LdapDn VARCHAR DEFAULT '';") if err != nil { log.Printf("Error altering users table to add LdapDn column : '%s'\n", err) os.Exit(1) } } */ } // Count the number of records in the sqlite database // Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592 func CheckCount(tablename string) (int, error) { var count int stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename) if err != nil { log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err) return 0, err } err = stmt.QueryRow().Scan(&count) if err != nil { log.Printf("CheckCount error querying database record count : '%s'\n", err) return 0, err } stmt.Close() // or use defer rows.Close(), idc return count, nil } // From https://stackoverflow.com/a/60100045 func GenerateInsertMethod(q interface{}) (string, error) { if reflect.ValueOf(q).Kind() == reflect.Struct { query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name()) fieldNames := "" fieldValues := "" v := reflect.ValueOf(q) for i := 0; i < v.NumField(); i++ { if i == 0 { fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name) } else { fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name) } switch v.Field(i).Kind() { case reflect.Int: if i == 0 { fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int()) } else { fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int()) } case reflect.String: if i == 0 { fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String()) } else { fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String()) } case reflect.Bool: var boolSet int8 if v.Field(i).Bool() { boolSet = 1 } if i == 0 { fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet) } else { fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet) } default: log.Printf("Unsupported type '%s'\n", v.Field(i).Kind()) } } query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues) return query, nil } return "", errors.New("SqlGenerationError") } func CheckColumnExists(table string, column string) (bool, error) { var count int64 rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';") if err != nil { log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err) return false, err } defer rows.Close() for rows.Next() { // cols is an []interface{} of all of the column results cols, _ := rows.SliceScan() log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column) count = cols[0].(int64) if count == 1 { return true, nil } else { return false, nil } } err = rows.Err() if err != nil { log.Printf("CheckColumnExists error getting results : '%s'\n", err) return false, err } return false, nil }