Files
smt/models/setup.go
Nathan Coad d8da3027e2
All checks were successful
continuous-integration/drone/push Build is passing
try sqlite transaction FK removal
2024-01-08 10:04:26 +11:00

376 lines
9.9 KiB
Go

package models
import (
"errors"
"fmt"
"log"
"os"
"reflect"
"smt/utils"
"github.com/jmoiron/sqlx"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
)
var db *sqlx.DB
const (
sqlFile = "smt.db"
)
// TODO drop LdapGroup column
const createRoles string = `
CREATE TABLE IF NOT EXISTS roles (
RoleId INTEGER PRIMARY KEY ASC,
RoleName VARCHAR,
ReadOnly BOOLEAN,
Admin BOOLEAN,
LdapGroup VARCHAR DEFAULT ''
);
`
const createUsers string = `
CREATE TABLE IF NOT EXISTS users (
UserId INTEGER PRIMARY KEY ASC,
GroupId INTEGER,
UserName VARCHAR,
Password VARCHAR,
LdapUser BOOLEAN DEFAULT 0,
LdapDN VARCHAR DEFAULT '',
FOREIGN KEY (GroupId) REFERENCES groups(GroupId)
);
`
const createSafes string = `
CREATE TABLE IF NOT EXISTS safes (
SafeId INTEGER PRIMARY KEY ASC,
SafeName VARCHAR
);
`
const createGroups string = `
CREATE TABLE IF NOT EXISTS groups (
GroupId INTEGER PRIMARY KEY ASC,
GroupName VARCHAR,
LdapGroup BOOLEAN DEFAULT 0,
LdapDN VARCHAR DEFAULT ''
);
`
const createPermissions = `
CREATE TABLE IF NOT EXISTS permissions (
PermissionId INTEGER PRIMARY KEY ASC,
RoleId INTEGER,
SafeId INTEGER,
UserId INTEGER,
GroupId INTEGER,
FOREIGN KEY (RoleId) REFERENCES roles(RoleId),
FOREIGN KEY (SafeId) REFERENCES safes(SafeId),
FOREIGN KEY (UserId) REFERENCES users(UserId),
FOREIGN KEY (GroupId) REFERENCES groups(GroupId)
);
`
const createSecrets string = `
CREATE TABLE IF NOT EXISTS secrets (
SecretId INTEGER PRIMARY KEY ASC,
SafeId INTEGER,
DeviceName VARCHAR,
DeviceCategory VARCHAR,
UserName VARCHAR,
Secret VARCHAR,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
`
const createSchema string = `
CREATE TABLE IF NOT EXISTS schema (
Version INTEGER
);
`
const createAudit string = `
CREATE TABLE IF NOT EXISTS audit (
UserName VARCHAR,
EventText VARCHAR,
EventTime INTEGER
);
`
// Establish connection to sqlite database
func ConnectDatabase() {
var err error
// Try using sqlite as our database
sqlPath := utils.GetFilePath(sqlFile)
db, err = sqlx.Open("sqlite", sqlPath)
if err != nil {
log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err)
os.Exit(1)
} else {
log.Printf("Connected to sqlite database file '%s'\n", sqlPath)
}
//sqlx.NameMapper = func(s string) string { return s }
// Make sure our tables exist
CreateTables()
//defer db.Close()
}
func DisconnectDatabase() {
log.Printf("DisconnectDatabase called")
defer db.Close()
}
func CreateTables() {
var err error
var rowCount int
// Create database tables if it doesn't exist
// Roles table should go first since other tables refer to it
if _, err = db.Exec(createRoles); err != nil {
log.Printf("Error checking roles table : '%s'", err)
os.Exit(1)
}
rowCount, _ = CheckCount("roles")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false, true, '');"); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO roles VALUES(2, 'UserRole', false, false, '');"); err != nil {
log.Printf("Error adding initial user role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO roles VALUES(3, 'GuestRole', true, false, '');"); err != nil {
log.Printf("Error adding initial guest role : '%s'", err)
os.Exit(1)
}
}
// Users table
if _, err = db.Exec(createUsers); err != nil {
log.Printf("Error checking users table : '%s'", err)
os.Exit(1)
}
rowCount, _ = CheckCount("users")
if rowCount == 0 {
// Check if there was an initial password defined in the .env file
initialPassword := os.Getenv("INITIAL_PASSWORD")
if initialPassword == "" {
initialPassword = "password"
} else if initialPassword[:4] == "$2a$" {
log.Printf("CreateTables inital admin password is already a hash")
} else {
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
initialPassword = string(cryptText)
}
if _, err = db.Exec("INSERT INTO users (RoleId, UserName, Password, LdapUser) VALUES(1, 1, 'Administrator', ?, 0);", initialPassword); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
}
// Safes table
if _, err = db.Exec(createSafes); err != nil {
log.Printf("Error checking safes table : '%s'", err)
os.Exit(1)
}
// Secrets table
if _, err = db.Exec(createSecrets); err != nil {
log.Printf("Error checking secrets table : '%s'", err)
os.Exit(1)
}
// groups table
if _, err = db.Exec(createGroups); err != nil {
log.Printf("Error checking groups table : '%s'", err)
os.Exit(1)
}
// permissions table
if _, err = db.Exec(createPermissions); err != nil {
log.Printf("Error checking permissions table : '%s'", err)
os.Exit(1)
}
// Schema table should go last so we know if the database has a value in the schema table then everything was created properly
if _, err = db.Exec(createSchema); err != nil {
log.Printf("Error checking schema table : '%s'", err)
os.Exit(1)
}
schemaCheck, _ := CheckColumnExists("schema", "Version")
if !schemaCheck {
if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil {
log.Printf("Error adding initial schema version : '%s'", err)
os.Exit(1)
}
}
// Audit log table
if _, err = db.Exec(createAudit); err != nil {
log.Printf("Error checking audit table : '%s'", err)
os.Exit(1)
}
// Remove users RoleId column
userRoleIdCheck, _ := CheckColumnExists("users", "RoleId")
if userRoleIdCheck {
//_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;")
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE users RENAME TO _users_old;
CREATE TABLE users
(
UserId INTEGER PRIMARY KEY ASC,
GroupId INTEGER,
UserName VARCHAR,
Password VARCHAR,
LdapUser BOOLEAN DEFAULT 0,
LdapDN VARCHAR DEFAULT '',
FOREIGN KEY (GroupId) REFERENCES groups(GroupId)
);
INSERT INTO users SELECT * FROM _users_old;
COMMIT;
PRAGMA foreign_keys=on;
`)
if err != nil {
log.Printf("Error altering users table to drop RoleId column : '%s'\n", err)
os.Exit(1)
}
}
/*
// Database updates added after initial version released
ldapCheck, _ := CheckColumnExists("roles", "LdapGroup")
if !ldapCheck {
// Add the column for LdapGroup in the roles table
_, err := db.Exec("ALTER TABLE roles ADD COLUMN LdapGroup VARCHAR DEFAULT '';")
if err != nil {
log.Printf("Error altering roles table to add LdapGroup column : '%s'\n", err)
os.Exit(1)
}
}
// Add the two LDAP columns to the users table if they weren't there
ldapUserCheck, _ := CheckColumnExists("users", "LdapUser")
if !ldapUserCheck {
log.Printf("CreateTables creating ldap columns in user table")
_, err := db.Exec("ALTER TABLE users ADD COLUMN LdapUser BOOLEAN DEFAULT 0;")
if err != nil {
log.Printf("Error altering users table to add LdapUser column : '%s'\n", err)
os.Exit(1)
}
_, err = db.Exec("ALTER TABLE users ADD COLUMN LdapDN VARCHAR DEFAULT '';")
if err != nil {
log.Printf("Error altering users table to add LdapDN column : '%s'\n", err)
os.Exit(1)
}
}
*/
}
// Count the number of records in the sqlite database
// Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592
func CheckCount(tablename string) (int, error) {
var count int
stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename)
if err != nil {
log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err)
return 0, err
}
err = stmt.QueryRow().Scan(&count)
if err != nil {
log.Printf("CheckCount error querying database record count : '%s'\n", err)
return 0, err
}
stmt.Close() // or use defer rows.Close(), idc
return count, nil
}
// From https://stackoverflow.com/a/60100045
func GenerateInsertMethod(q interface{}) (string, error) {
if reflect.ValueOf(q).Kind() == reflect.Struct {
query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name())
fieldNames := ""
fieldValues := ""
v := reflect.ValueOf(q)
for i := 0; i < v.NumField(); i++ {
if i == 0 {
fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name)
} else {
fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name)
}
switch v.Field(i).Kind() {
case reflect.Int:
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int())
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int())
}
case reflect.String:
if i == 0 {
fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String())
} else {
fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String())
}
case reflect.Bool:
var boolSet int8
if v.Field(i).Bool() {
boolSet = 1
}
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet)
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet)
}
default:
log.Printf("Unsupported type '%s'\n", v.Field(i).Kind())
}
}
query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues)
return query, nil
}
return "", errors.New("SqlGenerationError")
}
func CheckColumnExists(table string, column string) (bool, error) {
var count int64
rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';")
if err != nil {
log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err)
return false, err
}
defer rows.Close()
for rows.Next() {
// cols is an []interface{} of all of the column results
cols, _ := rows.SliceScan()
log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
count = cols[0].(int64)
if count == 1 {
return true, nil
} else {
return false, nil
}
}
err = rows.Err()
if err != nil {
log.Printf("CheckColumnExists error getting results : '%s'\n", err)
return false, err
}
return false, nil
}