Files
smt/models/setup.go
Nathan Coad 1a2b6e5b41
All checks were successful
continuous-integration/drone/push Build is passing
fix RoleId in secrets table
2024-01-08 15:45:08 +11:00

451 lines
12 KiB
Go

package models
import (
"errors"
"fmt"
"log"
"os"
"reflect"
"smt/utils"
"github.com/jmoiron/sqlx"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
)
var db *sqlx.DB
const (
sqlFile = "smt.db"
)
// TODO drop LdapGroup column
const createRoles string = `
CREATE TABLE IF NOT EXISTS roles (
RoleId INTEGER PRIMARY KEY ASC,
RoleName VARCHAR,
ReadOnly BOOLEAN
);
`
const createUsers string = `
CREATE TABLE IF NOT EXISTS users (
UserId INTEGER PRIMARY KEY ASC,
GroupId INTEGER,
UserName VARCHAR,
Password VARCHAR,
Admin BOOLEAN DEFAULT 0,
LdapUser BOOLEAN DEFAULT 0,
FOREIGN KEY (GroupId) REFERENCES groups(GroupId)
);
`
const createSafes string = `
CREATE TABLE IF NOT EXISTS safes (
SafeId INTEGER PRIMARY KEY ASC,
SafeName VARCHAR
);
`
const createGroups string = `
CREATE TABLE IF NOT EXISTS groups (
GroupId INTEGER PRIMARY KEY ASC,
GroupName VARCHAR,
LdapGroup BOOLEAN DEFAULT 0,
LdapDN VARCHAR DEFAULT '',
Admin BOOLEAN DEFAULT 0
);
`
const createPermissions = `
CREATE TABLE IF NOT EXISTS permissions (
PermissionId INTEGER PRIMARY KEY ASC,
RoleId INTEGER,
SafeId INTEGER,
UserId INTEGER,
GroupId INTEGER,
FOREIGN KEY (RoleId) REFERENCES roles(RoleId),
FOREIGN KEY (SafeId) REFERENCES safes(SafeId),
FOREIGN KEY (UserId) REFERENCES users(UserId),
FOREIGN KEY (GroupId) REFERENCES groups(GroupId)
);
`
const createSecrets string = `
CREATE TABLE IF NOT EXISTS secrets (
SecretId INTEGER PRIMARY KEY ASC,
SafeId INTEGER,
DeviceName VARCHAR,
DeviceCategory VARCHAR,
UserName VARCHAR,
Secret VARCHAR,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
`
const createSchema string = `
CREATE TABLE IF NOT EXISTS schema (
Version INTEGER
);
`
const createAudit string = `
CREATE TABLE IF NOT EXISTS audit (
UserName VARCHAR,
EventText VARCHAR,
EventTime INTEGER
);
`
// Establish connection to sqlite database
func ConnectDatabase() {
var err error
// Try using sqlite as our database
sqlPath := utils.GetFilePath(sqlFile)
db, err = sqlx.Open("sqlite", sqlPath)
if err != nil {
log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err)
os.Exit(1)
} else {
log.Printf("Connected to sqlite database file '%s'\n", sqlPath)
}
//sqlx.NameMapper = func(s string) string { return s }
// Make sure our tables exist
CreateTables()
//defer db.Close()
}
func DisconnectDatabase() {
log.Printf("DisconnectDatabase called")
defer db.Close()
}
func CreateTables() {
var err error
var rowCount int
// Create database tables if it doesn't exist
// Roles table should go first since other tables refer to it
if _, err = db.Exec(createRoles); err != nil {
log.Printf("Error checking roles table : '%s'", err)
os.Exit(1)
}
rowCount, _ = CheckCount("roles")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false);"); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO roles VALUES(2, 'UserRole', false);"); err != nil {
log.Printf("Error adding initial user role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO roles VALUES(3, 'GuestRole', true);"); err != nil {
log.Printf("Error adding initial guest role : '%s'", err)
os.Exit(1)
}
}
// Users table
if _, err = db.Exec(createUsers); err != nil {
log.Printf("Error checking users table : '%s'", err)
os.Exit(1)
}
rowCount, _ = CheckCount("users")
if rowCount == 0 {
// Check if there was an initial password defined in the .env file
initialPassword := os.Getenv("INITIAL_PASSWORD")
if initialPassword == "" {
initialPassword = "password"
} else if initialPassword[:4] == "$2a$" {
log.Printf("CreateTables inital admin password is already a hash")
} else {
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
initialPassword = string(cryptText)
}
if _, err = db.Exec("INSERT INTO users (RoleId, UserName, Password, LdapUser) VALUES(1, 1, 'Administrator', ?, 0);", initialPassword); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
}
// Safes table
if _, err = db.Exec(createSafes); err != nil {
log.Printf("Error checking safes table : '%s'", err)
os.Exit(1)
}
// Create an initial safe
rowCount, _ = CheckCount("safes")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO safes VALUES(1, 'Default Safe');"); err != nil {
log.Printf("Error adding initial safe entry : '%s'", err)
os.Exit(1)
}
}
// Set any unassigned secrets to the default safe id
if _, err = db.Exec("UPDATE secrets SET SafeId = 1 WHERE SafeId is null;"); err != nil {
log.Printf("Error setting safe ID of existing secrets : '%s'", err)
os.Exit(1)
}
// Secrets table
if _, err = db.Exec(createSecrets); err != nil {
log.Printf("Error checking secrets table : '%s'", err)
os.Exit(1)
}
// groups table
if _, err = db.Exec(createGroups); err != nil {
log.Printf("Error checking groups table : '%s'", err)
os.Exit(1)
}
// permissions table
if _, err = db.Exec(createPermissions); err != nil {
log.Printf("Error checking permissions table : '%s'", err)
os.Exit(1)
}
// Add initial groups
rowCount, _ = CheckCount("groups")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(1, 'Administrators', 1);"); err != nil {
log.Printf("Error adding initial group entry id 1 : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(2, 'Users', 0);"); err != nil {
log.Printf("Error adding initial group entry id 2 : '%s'", err)
os.Exit(1)
}
}
// Add initial permissions
rowCount, _ = CheckCount("permissions")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO permissions (RoleId, SafeId, UserId) VALUES(1, 1, 1);"); err != nil {
log.Printf("Error adding initial permissions entry userid 1 : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO permissions (RoleId, SafeId, UserId) VALUES(1, 1, 2);"); err != nil {
log.Printf("Error adding initial permissions entry userid 2 : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO permissions (RoleId, SafeId, UserId) VALUES(1, 1, 3);"); err != nil {
log.Printf("Error adding initial permissions entry userid 3 : '%s'", err)
os.Exit(1)
}
}
// Schema table should go last so we know if the database has a value in the schema table then everything was created properly
if _, err = db.Exec(createSchema); err != nil {
log.Printf("Error checking schema table : '%s'", err)
os.Exit(1)
}
schemaCheck, _ := CheckColumnExists("schema", "Version")
if !schemaCheck {
if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil {
log.Printf("Error adding initial schema version : '%s'", err)
os.Exit(1)
}
}
// Audit log table
if _, err = db.Exec(createAudit); err != nil {
log.Printf("Error checking audit table : '%s'", err)
os.Exit(1)
}
// Remove users RoleId column
userRoleIdCheck, _ := CheckColumnExists("users", "RoleId")
if userRoleIdCheck {
//_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;")
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE users RENAME TO _users_old;
CREATE TABLE users
(
UserId INTEGER PRIMARY KEY ASC,
GroupId INTEGER,
UserName VARCHAR,
Password VARCHAR,
Admin BOOLEAN DEFAULT 0,
LdapUser BOOLEAN DEFAULT 0
FOREIGN KEY (GroupId) REFERENCES groups(GroupId)
);
INSERT INTO users SELECT * FROM _users_old;
COMMIT;
PRAGMA foreign_keys=on;
DROP TABLE _users_old;
`)
if err != nil {
log.Printf("Error altering users table to drop RoleId column : '%s'\n", err)
os.Exit(1)
}
}
// Remove LdapGroup column from roles table
ldapCheck, _ := CheckColumnExists("roles", "LdapGroup")
if ldapCheck {
_, err := db.Exec("ALTER TABLE roles DROP COLUMN LdapGroup;")
if err != nil {
log.Printf("Error altering roles table to renmove LdapGroup column : '%s'\n", err)
os.Exit(1)
}
}
// Remove RoleId column from secrets table
secretsRoleIdCheck, _ := CheckColumnExists("secrets", "RoleId")
if secretsRoleIdCheck {
_, err := db.Exec("ALTER TABLE secrets DROP COLUMN RoleId;")
if err != nil {
log.Printf("Error altering secrets table to renmove RoleId column : '%s'\n", err)
os.Exit(1)
}
}
// Add SafeId column to secrets table
safeIdCheck, _ := CheckColumnExists("secrets", "SafeId")
if !safeIdCheck {
// Add the column for LdapGroup in the roles table
_, err := db.Exec("ALTER TABLE secrets ADD COLUMN SafeId INTEGER REFERENCES safes(SafeId);")
if err != nil {
log.Printf("Error altering secrets table to add SafeId column : '%s'\n", err)
os.Exit(1)
}
}
/*
// Database updates added after initial version released
ldapCheck, _ := CheckColumnExists("roles", "LdapGroup")
if !ldapCheck {
// Add the column for LdapGroup in the roles table
_, err := db.Exec("ALTER TABLE roles ADD COLUMN LdapGroup VARCHAR DEFAULT '';")
if err != nil {
log.Printf("Error altering roles table to add LdapGroup column : '%s'\n", err)
os.Exit(1)
}
}
// Add the two LDAP columns to the users table if they weren't there
ldapUserCheck, _ := CheckColumnExists("users", "LdapUser")
if !ldapUserCheck {
log.Printf("CreateTables creating ldap columns in user table")
_, err := db.Exec("ALTER TABLE users ADD COLUMN LdapUser BOOLEAN DEFAULT 0;")
if err != nil {
log.Printf("Error altering users table to add LdapUser column : '%s'\n", err)
os.Exit(1)
}
_, err = db.Exec("ALTER TABLE users ADD COLUMN LdapDN VARCHAR DEFAULT '';")
if err != nil {
log.Printf("Error altering users table to add LdapDN column : '%s'\n", err)
os.Exit(1)
}
}
*/
}
// Count the number of records in the sqlite database
// Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592
func CheckCount(tablename string) (int, error) {
var count int
stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename)
if err != nil {
log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err)
return 0, err
}
err = stmt.QueryRow().Scan(&count)
if err != nil {
log.Printf("CheckCount error querying database record count : '%s'\n", err)
return 0, err
}
stmt.Close() // or use defer rows.Close(), idc
return count, nil
}
// From https://stackoverflow.com/a/60100045
func GenerateInsertMethod(q interface{}) (string, error) {
if reflect.ValueOf(q).Kind() == reflect.Struct {
query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name())
fieldNames := ""
fieldValues := ""
v := reflect.ValueOf(q)
for i := 0; i < v.NumField(); i++ {
if i == 0 {
fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name)
} else {
fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name)
}
switch v.Field(i).Kind() {
case reflect.Int:
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int())
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int())
}
case reflect.String:
if i == 0 {
fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String())
} else {
fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String())
}
case reflect.Bool:
var boolSet int8
if v.Field(i).Bool() {
boolSet = 1
}
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet)
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet)
}
default:
log.Printf("Unsupported type '%s'\n", v.Field(i).Kind())
}
}
query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues)
return query, nil
}
return "", errors.New("SqlGenerationError")
}
func CheckColumnExists(table string, column string) (bool, error) {
var count int64
rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';")
if err != nil {
log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err)
return false, err
}
defer rows.Close()
for rows.Next() {
// cols is an []interface{} of all of the column results
cols, _ := rows.SliceScan()
log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
count = cols[0].(int64)
if count == 1 {
return true, nil
} else {
return false, nil
}
}
err = rows.Err()
if err != nil {
log.Printf("CheckColumnExists error getting results : '%s'\n", err)
return false, err
}
return false, nil
}