package models import ( "errors" "fmt" "log" "os" "reflect" "smt/utils" "github.com/jmoiron/sqlx" "golang.org/x/crypto/bcrypt" _ "modernc.org/sqlite" ) var db *sqlx.DB const ( sqlFile = "smt.db" ) // TODO drop LdapGroup column const createRoles string = ` CREATE TABLE IF NOT EXISTS roles ( RoleId INTEGER PRIMARY KEY ASC, RoleName VARCHAR, ReadOnly BOOLEAN, Admin BOOLEAN ); ` const createUsers string = ` CREATE TABLE IF NOT EXISTS users ( UserId INTEGER PRIMARY KEY ASC, GroupId INTEGER, UserName VARCHAR, Password VARCHAR, LdapUser BOOLEAN DEFAULT 0, LdapDN VARCHAR DEFAULT '', FOREIGN KEY (GroupId) REFERENCES groups(GroupId) ); ` const createSafes string = ` CREATE TABLE IF NOT EXISTS safes ( SafeId INTEGER PRIMARY KEY ASC, SafeName VARCHAR ); ` const createGroups string = ` CREATE TABLE IF NOT EXISTS groups ( GroupId INTEGER PRIMARY KEY ASC, GroupName VARCHAR, LdapGroup BOOLEAN DEFAULT 0, LdapDN VARCHAR DEFAULT '' ); ` const createPermissions = ` CREATE TABLE IF NOT EXISTS permissions ( PermissionId INTEGER PRIMARY KEY ASC, RoleId INTEGER, SafeId INTEGER, UserId INTEGER, GroupId INTEGER, FOREIGN KEY (RoleId) REFERENCES roles(RoleId), FOREIGN KEY (SafeId) REFERENCES safes(SafeId), FOREIGN KEY (UserId) REFERENCES users(UserId), FOREIGN KEY (GroupId) REFERENCES groups(GroupId) ); ` const createSecrets string = ` CREATE TABLE IF NOT EXISTS secrets ( SecretId INTEGER PRIMARY KEY ASC, SafeId INTEGER, DeviceName VARCHAR, DeviceCategory VARCHAR, UserName VARCHAR, Secret VARCHAR, FOREIGN KEY (SafeId) REFERENCES safes(SafeId) ); ` const createSchema string = ` CREATE TABLE IF NOT EXISTS schema ( Version INTEGER ); ` const createAudit string = ` CREATE TABLE IF NOT EXISTS audit ( UserName VARCHAR, EventText VARCHAR, EventTime INTEGER ); ` // Establish connection to sqlite database func ConnectDatabase() { var err error // Try using sqlite as our database sqlPath := utils.GetFilePath(sqlFile) db, err = sqlx.Open("sqlite", sqlPath) if err != nil { log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err) os.Exit(1) } else { log.Printf("Connected to sqlite database file '%s'\n", sqlPath) } //sqlx.NameMapper = func(s string) string { return s } // Make sure our tables exist CreateTables() //defer db.Close() } func DisconnectDatabase() { log.Printf("DisconnectDatabase called") defer db.Close() } func CreateTables() { var err error var rowCount int // Create database tables if it doesn't exist // Roles table should go first since other tables refer to it if _, err = db.Exec(createRoles); err != nil { log.Printf("Error checking roles table : '%s'", err) os.Exit(1) } rowCount, _ = CheckCount("roles") if rowCount == 0 { if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false, true, '');"); err != nil { log.Printf("Error adding initial admin role : '%s'", err) os.Exit(1) } if _, err = db.Exec("INSERT INTO roles VALUES(2, 'UserRole', false, false, '');"); err != nil { log.Printf("Error adding initial user role : '%s'", err) os.Exit(1) } if _, err = db.Exec("INSERT INTO roles VALUES(3, 'GuestRole', true, false, '');"); err != nil { log.Printf("Error adding initial guest role : '%s'", err) os.Exit(1) } } // Users table if _, err = db.Exec(createUsers); err != nil { log.Printf("Error checking users table : '%s'", err) os.Exit(1) } rowCount, _ = CheckCount("users") if rowCount == 0 { // Check if there was an initial password defined in the .env file initialPassword := os.Getenv("INITIAL_PASSWORD") if initialPassword == "" { initialPassword = "password" } else if initialPassword[:4] == "$2a$" { log.Printf("CreateTables inital admin password is already a hash") } else { cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost) initialPassword = string(cryptText) } if _, err = db.Exec("INSERT INTO users (RoleId, UserName, Password, LdapUser) VALUES(1, 1, 'Administrator', ?, 0);", initialPassword); err != nil { log.Printf("Error adding initial admin role : '%s'", err) os.Exit(1) } } // Safes table if _, err = db.Exec(createSafes); err != nil { log.Printf("Error checking safes table : '%s'", err) os.Exit(1) } // Secrets table if _, err = db.Exec(createSecrets); err != nil { log.Printf("Error checking secrets table : '%s'", err) os.Exit(1) } // groups table if _, err = db.Exec(createGroups); err != nil { log.Printf("Error checking groups table : '%s'", err) os.Exit(1) } // permissions table if _, err = db.Exec(createPermissions); err != nil { log.Printf("Error checking permissions table : '%s'", err) os.Exit(1) } // Schema table should go last so we know if the database has a value in the schema table then everything was created properly if _, err = db.Exec(createSchema); err != nil { log.Printf("Error checking schema table : '%s'", err) os.Exit(1) } schemaCheck, _ := CheckColumnExists("schema", "Version") if !schemaCheck { if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil { log.Printf("Error adding initial schema version : '%s'", err) os.Exit(1) } } // Audit log table if _, err = db.Exec(createAudit); err != nil { log.Printf("Error checking audit table : '%s'", err) os.Exit(1) } // Remove users RoleId column userRoleIdCheck, _ := CheckColumnExists("users", "RoleId") if userRoleIdCheck { //_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;") _, err := db.Exec(` PRAGMA foreign_keys=off; BEGIN TRANSACTION; ALTER TABLE users RENAME TO _users_old; CREATE TABLE users ( UserId INTEGER PRIMARY KEY ASC, GroupId INTEGER, UserName VARCHAR, Password VARCHAR, LdapUser BOOLEAN DEFAULT 0, LdapDN VARCHAR DEFAULT '', FOREIGN KEY (GroupId) REFERENCES groups(GroupId) ); INSERT INTO users SELECT * FROM _users_old; COMMIT; PRAGMA foreign_keys=on; DROP TABLE _users_old; `) if err != nil { log.Printf("Error altering users table to drop RoleId column : '%s'\n", err) os.Exit(1) } } // Remove LdapGroup column from roles table ldapCheck, _ := CheckColumnExists("roles", "LdapGroup") if ldapCheck { _, err := db.Exec("ALTER TABLE roles DROP COLUMN LdapGroup;") if err != nil { log.Printf("Error altering roles table to renmove LdapGroup column : '%s'\n", err) os.Exit(1) } } // Add SafeId column to secrets table safeIdCheck, _ := CheckColumnExists("secrets", "SafeId") if !safeIdCheck { // Add the column for LdapGroup in the roles table _, err := db.Exec("ALTER TABLE secrets ADD COLUMN SafeId INTEGER REFERENCES safes(SafeId);") if err != nil { log.Printf("Error altering secrets table to add SafeId column : '%s'\n", err) os.Exit(1) } } /* // Database updates added after initial version released ldapCheck, _ := CheckColumnExists("roles", "LdapGroup") if !ldapCheck { // Add the column for LdapGroup in the roles table _, err := db.Exec("ALTER TABLE roles ADD COLUMN LdapGroup VARCHAR DEFAULT '';") if err != nil { log.Printf("Error altering roles table to add LdapGroup column : '%s'\n", err) os.Exit(1) } } // Add the two LDAP columns to the users table if they weren't there ldapUserCheck, _ := CheckColumnExists("users", "LdapUser") if !ldapUserCheck { log.Printf("CreateTables creating ldap columns in user table") _, err := db.Exec("ALTER TABLE users ADD COLUMN LdapUser BOOLEAN DEFAULT 0;") if err != nil { log.Printf("Error altering users table to add LdapUser column : '%s'\n", err) os.Exit(1) } _, err = db.Exec("ALTER TABLE users ADD COLUMN LdapDN VARCHAR DEFAULT '';") if err != nil { log.Printf("Error altering users table to add LdapDN column : '%s'\n", err) os.Exit(1) } } */ } // Count the number of records in the sqlite database // Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592 func CheckCount(tablename string) (int, error) { var count int stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename) if err != nil { log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err) return 0, err } err = stmt.QueryRow().Scan(&count) if err != nil { log.Printf("CheckCount error querying database record count : '%s'\n", err) return 0, err } stmt.Close() // or use defer rows.Close(), idc return count, nil } // From https://stackoverflow.com/a/60100045 func GenerateInsertMethod(q interface{}) (string, error) { if reflect.ValueOf(q).Kind() == reflect.Struct { query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name()) fieldNames := "" fieldValues := "" v := reflect.ValueOf(q) for i := 0; i < v.NumField(); i++ { if i == 0 { fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name) } else { fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name) } switch v.Field(i).Kind() { case reflect.Int: if i == 0 { fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int()) } else { fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int()) } case reflect.String: if i == 0 { fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String()) } else { fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String()) } case reflect.Bool: var boolSet int8 if v.Field(i).Bool() { boolSet = 1 } if i == 0 { fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet) } else { fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet) } default: log.Printf("Unsupported type '%s'\n", v.Field(i).Kind()) } } query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues) return query, nil } return "", errors.New("SqlGenerationError") } func CheckColumnExists(table string, column string) (bool, error) { var count int64 rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';") if err != nil { log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err) return false, err } defer rows.Close() for rows.Next() { // cols is an []interface{} of all of the column results cols, _ := rows.SliceScan() log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column) count = cols[0].(int64) if count == 1 { return true, nil } else { return false, nil } } err = rows.Err() if err != nil { log.Printf("CheckColumnExists error getting results : '%s'\n", err) return false, err } return false, nil }