add lastlogin for user
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2024-01-19 09:17:31 +11:00
parent bb3bf3093d
commit 2ab6240a24
4 changed files with 949 additions and 974 deletions

View File

@@ -1,478 +1,491 @@
package models package models
import ( import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"os"
"reflect"
"smt/utils" "smt/utils"
"strings"
"time" "github.com/jmoiron/sqlx"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
) )
const nonceSize = 12 var db *sqlx.DB
// We use the json:"-" field tag to prevent showing these details to the user const (
type Secret struct { sqlFile = "smt.db"
SecretId int `db:"SecretId" json:"secretId"` )
SafeId int `db:"SafeId" json:"safeId"`
DeviceName string `db:"DeviceName" json:"deviceName"`
DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"`
UserName string `db:"UserName" json:"userName"`
Secret string `db:"Secret" json:"secret"`
LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"`
}
// SecretRestricted is for when we want to output a Secret but not the protected information const createUsers string = `
type SecretRestricted struct { CREATE TABLE IF NOT EXISTS users (
SecretId int `db:"SecretId" json:"secretId"` UserId INTEGER PRIMARY KEY AUTOINCREMENT,
SafeId int `db:"SafeId" json:"safeId"` GroupId INTEGER,
DeviceName string `db:"DeviceName" json:"deviceName"` UserName VARCHAR,
DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"` Password VARCHAR,
UserName string `db:"UserName" json:"userName"` Admin BOOLEAN DEFAULT 0,
Secret string `db:"Secret" json:"-"` LdapUser BOOLEAN DEFAULT 0,
LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"` LastLogin datetime DEFAULT (datetime('1970-01-01 00:00:00')),
} );
`
// Used for querying all secrets the user has access to const createSafes string = `
// Since there are some ambiguous column names (eg UserName is present in both users and secrets table), the order of fields in this struct matters CREATE TABLE IF NOT EXISTS safes (
type UserSecret struct { SafeId INTEGER PRIMARY KEY AUTOINCREMENT,
Secret SafeName VARCHAR
UserUserId int `db:"UserUserId"` );
User `
//Group
Permission
}
// This method allows us to use an interface to avoid adding duplicate entries to a []Secret const createGroups string = `
func (s Secret) GetId() int { CREATE TABLE IF NOT EXISTS groups (
return s.SecretId GroupId INTEGER PRIMARY KEY AUTOINCREMENT,
} GroupName VARCHAR,
LdapGroup BOOLEAN DEFAULT 0,
LdapDn VARCHAR DEFAULT '',
Admin BOOLEAN DEFAULT 0
);
`
func (s *Secret) SaveSecret() (*Secret, error) { const createPermissions = `
CREATE TABLE IF NOT EXISTS permissions (
PermissionId INTEGER PRIMARY KEY AUTOINCREMENT,
Description VARCHAR DEFAULT '',
ReadOnly BOOLEAN DEFAULT 0,
SafeId INTEGER,
UserId INTEGER DEFAULT 0,
GroupId INTEGER DEFAULT 0,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
`
const createSecrets string = `
CREATE TABLE IF NOT EXISTS secrets (
SecretId INTEGER PRIMARY KEY AUTOINCREMENT,
SafeId INTEGER,
DeviceName VARCHAR,
DeviceCategory VARCHAR,
UserName VARCHAR,
Secret VARCHAR,
LastUpdated datetime DEFAULT (datetime('1970-01-01 00:00:00')),
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
`
const createSchema string = `
CREATE TABLE IF NOT EXISTS schema (
Version INTEGER
);
`
const createAudit string = `
CREATE TABLE IF NOT EXISTS audit (
AuditId INTEGER PRIMARY KEY AUTOINCREMENT,
UserId INTEGER DEFAULT 0,
SecretId INTEGER DEFAULT 0,
EventText VARCHAR,
EventTime datetime
);
`
// Establish connection to sqlite database
func ConnectDatabase() {
var err error var err error
// Populate timestamp field if not already set // Try using sqlite as our database
if s.LastUpdated.IsZero() { sqlPath := utils.GetFilePath(sqlFile)
s.LastUpdated = time.Now().UTC() db, err = sqlx.Open("sqlite", sqlPath)
}
log.Printf("SaveSecret storing values '%v'\n", s)
result, err := db.NamedExec((`INSERT INTO secrets (SafeId, DeviceName, DeviceCategory, UserName, Secret, LastUpdated) VALUES (:SafeId, :DeviceName, :DeviceCategory, :UserName, :Secret, :LastUpdated)`), s)
if err != nil { if err != nil {
log.Printf("StoreSecret error executing sql record : '%s'\n", err) log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err)
return s, err os.Exit(1)
}
affected, _ := result.RowsAffected()
id, _ := result.LastInsertId()
s.SecretId = int(id)
log.Printf("StoreSecret insert returned result id '%d' affecting %d row(s).\n", id, affected)
return s, nil
}
// SecretsGetAllowed returns all allowed secrets matching the specified parameters in s
func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) {
var err error
var secretResults []UserSecret
// Query for group access
queryArgs := []interface{}{}
query := `
SELECT users.UserId AS UserUserId, permissions.*,
secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName
FROM users
INNER JOIN groups ON users.GroupId = groups.GroupId
INNER JOIN permissions ON groups.GroupId = permissions.GroupId
INNER JOIN secrets on secrets.SafeId = permissions.SafeId
WHERE users.UserId = ? `
queryArgs = append(queryArgs, userId)
// Add any other arguments to the query if they were specified
if s.SecretId > 0 {
query += " AND SecretId = ? "
queryArgs = append(queryArgs, s.SecretId)
}
if s.DeviceName != "" {
query += " AND DeviceName LIKE ? "
queryArgs = append(queryArgs, s.DeviceName)
}
if s.DeviceCategory != "" {
query += " AND DeviceCategory LIKE ? "
queryArgs = append(queryArgs, s.DeviceCategory)
}
if s.UserName != "" {
query += " AND secrets.UserName LIKE ? "
queryArgs = append(queryArgs, s.UserName)
}
// Query for user access
query += `
UNION
SELECT users.UserId AS UserUserId, permissions.*,
secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName
FROM users
INNER JOIN permissions ON users.UserId = permissions.UserId
INNER JOIN safes on permissions.SafeId = safes.SafeId
INNER JOIN secrets on secrets.SafeId = safes.SafeId
WHERE users.UserId = ?`
queryArgs = append(queryArgs, userId)
// Add any other arguments to the query if they were specified
if s.SecretId > 0 {
query += " AND SecretId = ? "
queryArgs = append(queryArgs, s.SecretId)
}
if s.DeviceName != "" {
query += " AND DeviceName LIKE ? "
queryArgs = append(queryArgs, s.DeviceName)
}
if s.DeviceCategory != "" {
query += " AND DeviceCategory LIKE ? "
queryArgs = append(queryArgs, s.DeviceCategory)
}
if s.UserName != "" {
query += " AND secrets.UserName LIKE ? "
queryArgs = append(queryArgs, s.UserName)
}
// Execute the query
log.Printf("SecretsGetAllowed query string : '%s'\nArguments:%+v\n", query, queryArgs)
rows, err := db.Queryx(query, queryArgs...)
if err != nil {
log.Printf("SecretsGetAllowed error executing sql record : '%s'\n", err)
return secretResults, err
} else { } else {
//log.Printf("SecretsGetAllowed any error '%s'\n", rows.Err()) log.Printf("Connected to sqlite database file '%s'\n", sqlPath)
// parse all the results into a slice }
for rows.Next() {
//log.Printf("SecretsGetAllowed processing row\n")
var r UserSecret
err = rows.StructScan(&r)
//log.Printf("SecretsGetAllowed performed struct scan\n")
if err != nil {
log.Printf("SecretsGetAllowed error parsing sql record : '%s'\n", err)
return secretResults, err
}
//log.Printf("r: %v\n", r)
//log.Printf("SecretsGetAllowed performed err check\n") //sqlx.NameMapper = func(s string) string { return s }
// work around to get the UserId populated in the User field of the struct // Make sure our tables exist
r.User.UserId = r.UserUserId CreateTables()
// For debugging purposes //defer db.Close()
debugPrint := utils.PrintStructContents(&r, 0) }
log.Println(debugPrint)
//log.Printf("SecretsGetAllowed performed debug print\n") func DisconnectDatabase() {
log.Printf("DisconnectDatabase called")
defer db.Close()
}
// Append the secrets to the query output, don't decrypt the secrets (we didn't SELECT them anyway) func CreateTables() {
//secretResults = append(secretResults, r) var err error
var rowCount int
// Use generics and the GetID() method on the UserSecret struct // Create database tables if it doesn't exist
// to avoid adding this element to the results
// if there is already a secret with the same ID present
secretResults = utils.AppendIfNotExists(secretResults, r)
//log.Printf("SecretsGetAllowed added secret results\n") // groups table
if _, err = db.Exec(createGroups); err != nil {
log.Printf("Error checking groups table : '%s'", err)
os.Exit(1)
}
// Add initial groups
rowCount, _ = CheckCount("groups")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(1, 'Administrators', 1);"); err != nil {
log.Printf("Error adding initial group entry id 1 : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(2, 'Users', 0);"); err != nil {
log.Printf("Error adding initial group entry id 2 : '%s'", err)
os.Exit(1)
} }
log.Printf("SecretsGetAllowed retrieved '%d' results\n", len(secretResults))
} }
return secretResults, nil // Users table
if _, err = db.Exec(createUsers); err != nil {
log.Printf("Error checking users table : '%s'", err)
os.Exit(1)
}
rowCount, _ = CheckCount("users")
if rowCount == 0 {
// Check if there was an initial password defined in the .env file
initialPassword := os.Getenv("INITIAL_PASSWORD")
if initialPassword == "" {
initialPassword = "password"
} else if initialPassword[:4] == "$2a$" {
log.Printf("CreateTables inital admin password is already a hash")
} else {
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
initialPassword = string(cryptText)
}
if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(1, 1, 'Administrator', ?, false, true);", initialPassword); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(2, 2, 'User', ?, false, false);", initialPassword); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
}
// Safes table
if _, err = db.Exec(createSafes); err != nil {
log.Printf("Error checking safes table : '%s'", err)
os.Exit(1)
}
// Create an initial safe
rowCount, _ = CheckCount("safes")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO safes VALUES(1, 'Default Safe');"); err != nil {
log.Printf("Error adding initial safe entry : '%s'", err)
os.Exit(1)
}
}
// Secrets table
if _, err = db.Exec(createSecrets); err != nil {
log.Printf("Error checking secrets table : '%s'", err)
os.Exit(1)
}
// permissions table
if _, err = db.Exec(createPermissions); err != nil {
log.Printf("Error checking permissions table : '%s'", err)
os.Exit(1)
}
// Add initial permissions
rowCount, _ = CheckCount("permissions")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, GroupId, SafeId) VALUES('Default Admin Group Permission', false, 1, 1);"); err != nil {
log.Printf("Error adding initial permissions entry userid 1 : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, SafeId, GroupId) VALUES('Default User Group Permission', false, 1, 2);"); err != nil {
log.Printf("Error adding initial permissions entry userid 2 : '%s'", err)
os.Exit(1)
}
}
// Schema table should go last so we know if the database has a value in the schema table then everything was created properly
if _, err = db.Exec(createSchema); err != nil {
log.Printf("Error checking schema table : '%s'", err)
os.Exit(1)
}
schemaCheck, _ := CheckColumnExists("schema", "Version")
if !schemaCheck {
if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil {
log.Printf("Error adding initial schema version : '%s'", err)
os.Exit(1)
}
}
// Audit log table
if _, err = db.Exec(createAudit); err != nil {
log.Printf("Error checking audit table : '%s'", err)
os.Exit(1)
}
// Remove users RoleId column
userRoleIdCheck, _ := CheckColumnExists("users", "RoleId")
if userRoleIdCheck {
//_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;")
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE users RENAME TO _users_old;
CREATE TABLE users
(
UserId INTEGER PRIMARY KEY AUTOINCREMENT,
GroupId INTEGER,
UserName VARCHAR,
Password VARCHAR,
Admin BOOLEAN DEFAULT 0,
LdapUser BOOLEAN DEFAULT 0
);
INSERT INTO users SELECT * FROM _users_old;
COMMIT;
PRAGMA foreign_keys=on;
DROP TABLE _users_old;
`)
if err != nil {
log.Printf("Error altering users table to drop RoleId column : '%s'\n", err)
os.Exit(1)
}
}
// Set any unassigned secrets to the default safe id
if _, err = db.Exec("UPDATE users SET LdapUser = 0 WHERE LdapUser is null;"); err != nil {
log.Printf("Error setting LdapUser flag to false for existing users : '%s'", err)
os.Exit(1)
}
// Remove LdapGroup column from roles table
ldapCheck, _ := CheckColumnExists("roles", "LdapGroup")
if ldapCheck {
_, err := db.Exec("ALTER TABLE roles DROP COLUMN LdapGroup;")
if err != nil {
log.Printf("Error altering roles table to renmove LdapGroup column : '%s'\n", err)
os.Exit(1)
}
}
// Add SafeId column to secrets table
safeIdCheck, _ := CheckColumnExists("secrets", "SafeId")
if !safeIdCheck {
// Add the column for LdapGroup in the roles table
_, err := db.Exec("ALTER TABLE secrets ADD COLUMN SafeId INTEGER REFERENCES safes(SafeId);")
if err != nil {
log.Printf("Error altering secrets table to add SafeId column : '%s'\n", err)
os.Exit(1)
}
}
// Set any unassigned secrets to the default safe id
if _, err = db.Exec("UPDATE secrets SET SafeId = 1 WHERE SafeId is null;"); err != nil {
log.Printf("Error setting safe ID of existing secrets : '%s'", err)
os.Exit(1)
}
// Remove RoleId column from secrets table
secretsRoleIdCheck, _ := CheckColumnExists("secrets", "RoleId")
if secretsRoleIdCheck {
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE secrets RENAME TO _secrets_old;
CREATE TABLE secrets
(
SecretId INTEGER PRIMARY KEY AUTOINCREMENT,
RoleId INTEGER,
SafeId INTEGER,
DeviceName VARCHAR,
DeviceCategory VARCHAR,
UserName VARCHAR,
Secret VARCHAR,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
INSERT INTO secrets SELECT SecretId, RoleId, SafeId, DeviceName, DeviceCategory, UserName, Secret FROM _secrets_old;
ALTER TABLE secrets DROP COLUMN RoleId;
ALTER TABLE secrets ADD COLUMN LastUpdated datetime;
UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null;
COMMIT;
PRAGMA foreign_keys=on;
DROP TABLE _secrets_old;
`)
if err != nil {
log.Printf("Error altering secrets table to remove RoleId column : '%s'\n", err)
os.Exit(1)
}
}
// Remove the Admin column from roles table
rolesAdminCheck, _ := CheckColumnExists("roles", "Admin")
if rolesAdminCheck {
_, err := db.Exec("ALTER TABLE roles DROP COLUMN Admin;")
if err != nil {
log.Printf("Error altering roles table to remove Admin column : '%s'\n", err)
os.Exit(1)
}
}
// Remove the RoleId from permissiosn table
permissionsRoleIdCheck, _ := CheckColumnExists("permissions", "RoleId")
if permissionsRoleIdCheck {
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE permissions RENAME TO _permissions_old;
CREATE TABLE permissions
(
PermissionId INTEGER PRIMARY KEY AUTOINCREMENT,
Description VARCHAR DEFAULT '',
ReadOnly BOOLEAN DEFAULT 0,
SafeId INTEGER,
UserId INTEGER DEFAULT 0,
GroupId INTEGER DEFAULT 0,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
INSERT INTO permissions SELECT PermissionId, SafeId, UserId, GroupId, '' AS Description, 0 as ReadOnly FROM _permissions_old;
UPDATE permissions SET ReadOnly = 0 WHERE ReadOnly is null;
UPDATE permissions SET Description = '' WHERE Description is null;
UPDATE permissions SET UserId = 0 WHERE UserId is null;
UPDATE permissions SET GroupId = 0 WHERE GroupId is null;
COMMIT;
PRAGMA foreign_keys=on;
DROP TABLE _permissions_old;
`)
if err != nil {
log.Printf("Error altering permissions table to remove RoleId column : '%s'\n", err)
os.Exit(1)
}
}
secretsLastUpdatedCheck, _ := CheckColumnExists("secrets", "LastUpdated")
if !secretsLastUpdatedCheck {
// Add the column for LastUpdated in the secrets table
_, err := db.Exec("ALTER TABLE secrets ADD COLUMN LastUpdated datetime;")
if err != nil {
log.Printf("Error altering secrets table to add LastUpdated column : '%s'\n", err)
os.Exit(1)
}
// Set the default value
if _, err = db.Exec("UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null;"); err != nil {
log.Printf("Error setting LastUpdated of existing secrets : '%s'", err)
os.Exit(1)
}
}
lastLoginCheck, _ := CheckColumnExists("users", "LastLogin")
if !lastLoginCheck {
// Add the column for LastUpdated in the secrets table
_, err := db.Exec("ALTER TABLE users ADD COLUMN LastLogin datetime;")
if err != nil {
log.Printf("Error altering users table to add LastLogin column : '%s'\n", err)
os.Exit(1)
}
// Set the default value
if _, err = db.Exec("UPDATE users SET LastLogin = (datetime('1970-01-01 00:00:00')) WHERE LastLogin is null;"); err != nil {
log.Printf("Error setting LastLogin of existing users : '%s'", err)
os.Exit(1)
}
}
} }
// SecretsGetFromMultipleSafes queries the specified safes for matching secrets // Count the number of records in the sqlite database
func SecretsGetFromMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) { // Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592
var err error func CheckCount(tablename string) (int, error) {
var secretResults []Secret var count int
stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename)
queryArgs := []interface{}{}
var query string
// Generate placeholders for the IN clause to match multiple SafeId values
placeholders := make([]string, len(safeIds))
for i := range safeIds {
placeholders[i] = "?"
}
placeholderStr := strings.Join(placeholders, ",")
// Create query with the necessary placeholders
query = fmt.Sprintf("SELECT * FROM secrets WHERE SafeId IN (%s) ", placeholderStr)
// Add the Safe Ids to the arguments list
for _, g := range safeIds {
queryArgs = append(queryArgs, g)
}
// Add any other arguments to the query if they were specified
if s.SecretId > 0 {
query += " AND SecretId = ? "
queryArgs = append(queryArgs, s.SecretId)
}
if s.DeviceName != "" {
query += " AND DeviceName LIKE ? "
queryArgs = append(queryArgs, s.DeviceName)
}
if s.DeviceCategory != "" {
query += " AND DeviceCategory LIKE ? "
queryArgs = append(queryArgs, s.DeviceCategory)
}
if s.UserName != "" {
query += " AND UserName LIKE ? "
queryArgs = append(queryArgs, s.UserName)
}
// Execute the query
log.Printf("SecretsGetMultipleSafes query string :\n'%s'\nQuery Args : %+v\n", query, queryArgs)
rows, err := db.Queryx(query, queryArgs...)
if err != nil { if err != nil {
log.Printf("SecretsGetMultipleSafes error executing sql record : '%s'\n", err) log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err)
return secretResults, err return 0, err
} else { }
// parse all the results into a slice err = stmt.QueryRow().Scan(&count)
for rows.Next() { if err != nil {
var r Secret log.Printf("CheckCount error querying database record count : '%s'\n", err)
err = rows.StructScan(&r) return 0, err
if err != nil { }
log.Printf("SecretsGetMultipleSafes error parsing sql record : '%s'\n", err) stmt.Close() // or use defer rows.Close(), idc
return secretResults, err return count, nil
} }
// Decrypt the secret // From https://stackoverflow.com/a/60100045
_, err = r.DecryptSecret() func GenerateInsertMethod(q interface{}) (string, error) {
if err != nil { if reflect.ValueOf(q).Kind() == reflect.Struct {
log.Printf("SecretsGetMultipleSafes unable to decrypt stored secret : '%s'\n", err) query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name())
rows.Close() fieldNames := ""
return secretResults, err fieldValues := ""
v := reflect.ValueOf(q)
for i := 0; i < v.NumField(); i++ {
if i == 0 {
fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name)
} else { } else {
secretResults = append(secretResults, r) fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name)
}
switch v.Field(i).Kind() {
case reflect.Int:
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int())
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int())
}
case reflect.String:
if i == 0 {
fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String())
} else {
fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String())
}
case reflect.Bool:
var boolSet int8
if v.Field(i).Bool() {
boolSet = 1
}
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet)
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet)
}
default:
log.Printf("Unsupported type '%s'\n", v.Field(i).Kind())
} }
} }
log.Printf("SecretsGetMultipleSafes retrieved '%d' results\n", len(secretResults)) query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues)
return query, nil
} }
return "", errors.New("SqlGenerationError")
return secretResults, nil
} }
func (s *Secret) UpdateSecret() (*Secret, error) { func CheckColumnExists(table string, column string) (bool, error) {
var count int64
var err error rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';")
// Populate timestamp field if not already set
if s.LastUpdated.IsZero() {
s.LastUpdated = time.Now().UTC()
}
log.Printf("UpdateSecret storing values '%v'\n", s)
if s.SecretId == 0 {
err = errors.New("UpdateSecret unable to locate secret with empty secretId field")
log.Printf("UpdateSecret error in pre-check : '%s'\n", err)
return s, err
}
result, err := db.NamedExec((`UPDATE secrets SET DeviceName = :DeviceName, DeviceCategory = :DeviceCategory, UserName = :UserName, Secret = :Secret, LastUpdated = :LastUpdated WHERE SecretId = :SecretId`), s)
if err != nil { if err != nil {
log.Printf("UpdateSecret error executing sql record : '%s'\n", err) log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err)
return &Secret{}, err return false, err
} else { }
affected, _ := result.RowsAffected() defer rows.Close()
id, _ := result.LastInsertId() for rows.Next() {
log.Printf("UpdateSecret insert returned result id '%d' affecting %d row(s).\n", id, affected) // cols is an []interface{} of all of the column results
cols, _ := rows.SliceScan()
log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
count = cols[0].(int64)
if count == 1 {
return true, nil
} else {
return false, nil
}
} }
return s, nil err = rows.Err()
} if err != nil {
log.Printf("CheckColumnExists error getting results : '%s'\n", err)
func (s *Secret) DeleteSecret() (*Secret, error) { return false, err
}
var err error
return false, nil
log.Printf("DeleteSecret deleting record with values '%v'\n", s)
if s.SecretId == 0 {
err = errors.New("unable to locate secret with empty secretId field")
log.Printf("DeleteSecret error in pre-check : '%s'\n", err)
return s, err
}
result, err := db.NamedExec((`DELETE FROM secrets WHERE SecretId = :SecretId`), s)
if err != nil {
log.Printf("DeleteSecret error executing sql record : '%s'\n", err)
return &Secret{}, err
} else {
affected, _ := result.RowsAffected()
id, _ := result.LastInsertId()
log.Printf("DeleteSecret delete returned result id '%d' affecting %d row(s).\n", id, affected)
}
return s, nil
}
// startCipher does the initial setup of the AES256 GCM mode cipher
func startCipher() (cipher.AEAD, error) {
key, err := ProvideKey()
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
log.Printf("startCipher NewCipher error '%s'\n", err)
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
log.Printf("startCipher NewGCM error '%s'\n", err)
return nil, err
}
return aesgcm, nil
}
func (s *Secret) EncryptSecret() (*Secret, error) {
//keyString := os.Getenv("SECRETS_KEY")
//keyString := secretKey
// The key argument should be the AES key, either 16 or 32 bytes
// to select AES-128 or AES-256.
//key := []byte(keyString)
/*
key, err := ProvideKey()
if err != nil {
return s, err
}
*/
plaintext := []byte(s.Secret)
// TODO : move block and aesgcm generation to separate function since the identical code is used for encrypt and decrypt
/*
log.Printf("EncryptSecret applying key '%v' of length '%d' to plaintext secret '%s'\n", key, len(key), s.Secret)
block, err := aes.NewCipher(key)
if err != nil {
log.Printf("EncryptSecret NewCipher error '%s'\n", err)
return s, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
log.Printf("EncryptSecret NewGCM error '%s'\n", err)
return s, err
}
*/
aesgcm, err := startCipher()
if err != nil {
log.Printf("EncryptSecret error commencing GCM cipher '%s'\n", err)
return s, err
}
// Never use more than 2^32 random nonces with a given key because of the risk of a repeat.
nonce := make([]byte, nonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
log.Printf("EncryptSecret nonce generation error '%s'\n", err)
return s, err
}
//log.Printf("EncryptSecret random nonce value is '%x'\n", nonce)
ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil)
//log.Printf("EncryptSecret generated ciphertext '%x''\n", ciphertext)
// Create a new slice to store nonce at the start and then the resulting ciphertext
// Nonce is always 12 bytes
combinedText := append(nonce, ciphertext...)
//log.Printf("EncryptSecret combined secret value is now '%x'\n", combinedText)
// Store the value back into the struct ready for database operations
s.Secret = hex.EncodeToString(combinedText)
return s, nil
//return string(ciphertext[:]), nil
}
func (s *Secret) DecryptSecret() (*Secret, error) {
// The key argument should be the AES key, either 16 or 32 bytes
// to select AES-128 or AES-256.
//keyString := os.Getenv("SECRETS_KEY")
//keyString := secretKey
//key := []byte(keyString)
/*
key, err := ProvideKey()
if err != nil {
return s, err
}
*/
if len(s.Secret) < nonceSize {
log.Printf("DecryptSecret ciphertext is too short to decrypt\n")
return s, errors.New("ciphertext is too short")
}
crypted, err := hex.DecodeString(s.Secret)
if err != nil {
log.Printf("DecryptSecret unable to convert hex encoded string due to error '%s'\n", err)
return s, err
}
//log.Printf("DecryptSecret processing secret '%x'\n", crypted)
// The nonce is the first 12 bytes from the ciphertext
nonce := crypted[:nonceSize]
ciphertext := crypted[nonceSize:]
/*
log.Printf("DecryptSecret applying key '%v' and nonce '%x' to ciphertext '%x'\n", key, nonce, ciphertext)
block, err := aes.NewCipher(key)
if err != nil {
log.Printf("DecryptSecret NewCipher error '%s'\n", err)
return s, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
log.Printf("DecryptSecret NewGCM error '%s'\n", err)
return s, err
}
*/
aesgcm, err := startCipher()
if err != nil {
log.Printf("DecryptSecret error commencing GCM cipher '%s'\n", err)
return s, err
}
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
log.Printf("DecryptSecret Open error '%s'\n", err)
return s, err
}
//log.Printf("DecryptSecret plaintext is '%s'\n", plaintext)
s.Secret = string(plaintext)
return s, nil
} }

478
models/secret.go Normal file
View File

@@ -0,0 +1,478 @@
package models
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"smt/utils"
"strings"
"time"
)
const nonceSize = 12
// We use the json:"-" field tag to prevent showing these details to the user
type Secret struct {
SecretId int `db:"SecretId" json:"secretId"`
SafeId int `db:"SafeId" json:"safeId"`
DeviceName string `db:"DeviceName" json:"deviceName"`
DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"`
UserName string `db:"UserName" json:"userName"`
Secret string `db:"Secret" json:"secret"`
LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"`
}
// SecretRestricted is for when we want to output a Secret but not the protected information
type SecretRestricted struct {
SecretId int `db:"SecretId" json:"secretId"`
SafeId int `db:"SafeId" json:"safeId"`
DeviceName string `db:"DeviceName" json:"deviceName"`
DeviceCategory string `db:"DeviceCategory" json:"deviceCategory"`
UserName string `db:"UserName" json:"userName"`
Secret string `db:"Secret" json:"-"`
LastUpdated time.Time `db:"LastUpdated" json:"lastUpdated"`
}
// Used for querying all secrets the user has access to
// Since there are some ambiguous column names (eg UserName is present in both users and secrets table), the order of fields in this struct matters
type UserSecret struct {
Secret
UserUserId int `db:"UserUserId"`
User
//Group
Permission
}
// This method allows us to use an interface to avoid adding duplicate entries to a []Secret
func (s Secret) GetId() int {
return s.SecretId
}
func (s *Secret) SaveSecret() (*Secret, error) {
var err error
// Populate timestamp field if not already set
if s.LastUpdated.IsZero() {
s.LastUpdated = time.Now().UTC()
}
log.Printf("SaveSecret storing values '%v'\n", s)
result, err := db.NamedExec((`INSERT INTO secrets (SafeId, DeviceName, DeviceCategory, UserName, Secret, LastUpdated) VALUES (:SafeId, :DeviceName, :DeviceCategory, :UserName, :Secret, :LastUpdated)`), s)
if err != nil {
log.Printf("StoreSecret error executing sql record : '%s'\n", err)
return s, err
}
affected, _ := result.RowsAffected()
id, _ := result.LastInsertId()
s.SecretId = int(id)
log.Printf("StoreSecret insert returned result id '%d' affecting %d row(s).\n", id, affected)
return s, nil
}
// SecretsGetAllowed returns all allowed secrets matching the specified parameters in s
func SecretsGetAllowed(s *Secret, userId int) ([]UserSecret, error) {
var err error
var secretResults []UserSecret
// Query for group access
queryArgs := []interface{}{}
query := `
SELECT users.UserId AS UserUserId, permissions.*,
secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName
FROM users
INNER JOIN groups ON users.GroupId = groups.GroupId
INNER JOIN permissions ON groups.GroupId = permissions.GroupId
INNER JOIN secrets on secrets.SafeId = permissions.SafeId
WHERE users.UserId = ? `
queryArgs = append(queryArgs, userId)
// Add any other arguments to the query if they were specified
if s.SecretId > 0 {
query += " AND SecretId = ? "
queryArgs = append(queryArgs, s.SecretId)
}
if s.DeviceName != "" {
query += " AND DeviceName LIKE ? "
queryArgs = append(queryArgs, s.DeviceName)
}
if s.DeviceCategory != "" {
query += " AND DeviceCategory LIKE ? "
queryArgs = append(queryArgs, s.DeviceCategory)
}
if s.UserName != "" {
query += " AND secrets.UserName LIKE ? "
queryArgs = append(queryArgs, s.UserName)
}
// Query for user access
query += `
UNION
SELECT users.UserId AS UserUserId, permissions.*,
secrets.SecretId, secrets.SafeId, secrets.DeviceName, secrets.DeviceCategory, secrets.UserName
FROM users
INNER JOIN permissions ON users.UserId = permissions.UserId
INNER JOIN safes on permissions.SafeId = safes.SafeId
INNER JOIN secrets on secrets.SafeId = safes.SafeId
WHERE users.UserId = ?`
queryArgs = append(queryArgs, userId)
// Add any other arguments to the query if they were specified
if s.SecretId > 0 {
query += " AND SecretId = ? "
queryArgs = append(queryArgs, s.SecretId)
}
if s.DeviceName != "" {
query += " AND DeviceName LIKE ? "
queryArgs = append(queryArgs, s.DeviceName)
}
if s.DeviceCategory != "" {
query += " AND DeviceCategory LIKE ? "
queryArgs = append(queryArgs, s.DeviceCategory)
}
if s.UserName != "" {
query += " AND secrets.UserName LIKE ? "
queryArgs = append(queryArgs, s.UserName)
}
// Execute the query
log.Printf("SecretsGetAllowed query string : '%s'\nArguments:%+v\n", query, queryArgs)
rows, err := db.Queryx(query, queryArgs...)
if err != nil {
log.Printf("SecretsGetAllowed error executing sql record : '%s'\n", err)
return secretResults, err
} else {
//log.Printf("SecretsGetAllowed any error '%s'\n", rows.Err())
// parse all the results into a slice
for rows.Next() {
//log.Printf("SecretsGetAllowed processing row\n")
var r UserSecret
err = rows.StructScan(&r)
//log.Printf("SecretsGetAllowed performed struct scan\n")
if err != nil {
log.Printf("SecretsGetAllowed error parsing sql record : '%s'\n", err)
return secretResults, err
}
//log.Printf("r: %v\n", r)
//log.Printf("SecretsGetAllowed performed err check\n")
// work around to get the UserId populated in the User field of the struct
r.User.UserId = r.UserUserId
// For debugging purposes
debugPrint := utils.PrintStructContents(&r, 0)
log.Println(debugPrint)
//log.Printf("SecretsGetAllowed performed debug print\n")
// Append the secrets to the query output, don't decrypt the secrets (we didn't SELECT them anyway)
//secretResults = append(secretResults, r)
// Use generics and the GetID() method on the UserSecret struct
// to avoid adding this element to the results
// if there is already a secret with the same ID present
secretResults = utils.AppendIfNotExists(secretResults, r)
//log.Printf("SecretsGetAllowed added secret results\n")
}
log.Printf("SecretsGetAllowed retrieved '%d' results\n", len(secretResults))
}
return secretResults, nil
}
// SecretsGetFromMultipleSafes queries the specified safes for matching secrets
func SecretsGetFromMultipleSafes(s *Secret, safeIds []int) ([]Secret, error) {
var err error
var secretResults []Secret
queryArgs := []interface{}{}
var query string
// Generate placeholders for the IN clause to match multiple SafeId values
placeholders := make([]string, len(safeIds))
for i := range safeIds {
placeholders[i] = "?"
}
placeholderStr := strings.Join(placeholders, ",")
// Create query with the necessary placeholders
query = fmt.Sprintf("SELECT * FROM secrets WHERE SafeId IN (%s) ", placeholderStr)
// Add the Safe Ids to the arguments list
for _, g := range safeIds {
queryArgs = append(queryArgs, g)
}
// Add any other arguments to the query if they were specified
if s.SecretId > 0 {
query += " AND SecretId = ? "
queryArgs = append(queryArgs, s.SecretId)
}
if s.DeviceName != "" {
query += " AND DeviceName LIKE ? "
queryArgs = append(queryArgs, s.DeviceName)
}
if s.DeviceCategory != "" {
query += " AND DeviceCategory LIKE ? "
queryArgs = append(queryArgs, s.DeviceCategory)
}
if s.UserName != "" {
query += " AND UserName LIKE ? "
queryArgs = append(queryArgs, s.UserName)
}
// Execute the query
log.Printf("SecretsGetMultipleSafes query string :\n'%s'\nQuery Args : %+v\n", query, queryArgs)
rows, err := db.Queryx(query, queryArgs...)
if err != nil {
log.Printf("SecretsGetMultipleSafes error executing sql record : '%s'\n", err)
return secretResults, err
} else {
// parse all the results into a slice
for rows.Next() {
var r Secret
err = rows.StructScan(&r)
if err != nil {
log.Printf("SecretsGetMultipleSafes error parsing sql record : '%s'\n", err)
return secretResults, err
}
// Decrypt the secret
_, err = r.DecryptSecret()
if err != nil {
log.Printf("SecretsGetMultipleSafes unable to decrypt stored secret : '%s'\n", err)
rows.Close()
return secretResults, err
} else {
secretResults = append(secretResults, r)
}
}
log.Printf("SecretsGetMultipleSafes retrieved '%d' results\n", len(secretResults))
}
return secretResults, nil
}
func (s *Secret) UpdateSecret() (*Secret, error) {
var err error
// Populate timestamp field if not already set
if s.LastUpdated.IsZero() {
s.LastUpdated = time.Now().UTC()
}
log.Printf("UpdateSecret storing values '%v'\n", s)
if s.SecretId == 0 {
err = errors.New("UpdateSecret unable to locate secret with empty secretId field")
log.Printf("UpdateSecret error in pre-check : '%s'\n", err)
return s, err
}
result, err := db.NamedExec((`UPDATE secrets SET DeviceName = :DeviceName, DeviceCategory = :DeviceCategory, UserName = :UserName, Secret = :Secret, LastUpdated = :LastUpdated WHERE SecretId = :SecretId`), s)
if err != nil {
log.Printf("UpdateSecret error executing sql record : '%s'\n", err)
return &Secret{}, err
} else {
affected, _ := result.RowsAffected()
id, _ := result.LastInsertId()
log.Printf("UpdateSecret insert returned result id '%d' affecting %d row(s).\n", id, affected)
}
return s, nil
}
func (s *Secret) DeleteSecret() (*Secret, error) {
var err error
log.Printf("DeleteSecret deleting record with values '%v'\n", s)
if s.SecretId == 0 {
err = errors.New("unable to locate secret with empty secretId field")
log.Printf("DeleteSecret error in pre-check : '%s'\n", err)
return s, err
}
result, err := db.NamedExec((`DELETE FROM secrets WHERE SecretId = :SecretId`), s)
if err != nil {
log.Printf("DeleteSecret error executing sql record : '%s'\n", err)
return &Secret{}, err
} else {
affected, _ := result.RowsAffected()
id, _ := result.LastInsertId()
log.Printf("DeleteSecret delete returned result id '%d' affecting %d row(s).\n", id, affected)
}
return s, nil
}
// startCipher does the initial setup of the AES256 GCM mode cipher
func startCipher() (cipher.AEAD, error) {
key, err := ProvideKey()
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
log.Printf("startCipher NewCipher error '%s'\n", err)
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
log.Printf("startCipher NewGCM error '%s'\n", err)
return nil, err
}
return aesgcm, nil
}
func (s *Secret) EncryptSecret() (*Secret, error) {
//keyString := os.Getenv("SECRETS_KEY")
//keyString := secretKey
// The key argument should be the AES key, either 16 or 32 bytes
// to select AES-128 or AES-256.
//key := []byte(keyString)
/*
key, err := ProvideKey()
if err != nil {
return s, err
}
*/
plaintext := []byte(s.Secret)
// TODO : move block and aesgcm generation to separate function since the identical code is used for encrypt and decrypt
/*
log.Printf("EncryptSecret applying key '%v' of length '%d' to plaintext secret '%s'\n", key, len(key), s.Secret)
block, err := aes.NewCipher(key)
if err != nil {
log.Printf("EncryptSecret NewCipher error '%s'\n", err)
return s, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
log.Printf("EncryptSecret NewGCM error '%s'\n", err)
return s, err
}
*/
aesgcm, err := startCipher()
if err != nil {
log.Printf("EncryptSecret error commencing GCM cipher '%s'\n", err)
return s, err
}
// Never use more than 2^32 random nonces with a given key because of the risk of a repeat.
nonce := make([]byte, nonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
log.Printf("EncryptSecret nonce generation error '%s'\n", err)
return s, err
}
//log.Printf("EncryptSecret random nonce value is '%x'\n", nonce)
ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil)
//log.Printf("EncryptSecret generated ciphertext '%x''\n", ciphertext)
// Create a new slice to store nonce at the start and then the resulting ciphertext
// Nonce is always 12 bytes
combinedText := append(nonce, ciphertext...)
//log.Printf("EncryptSecret combined secret value is now '%x'\n", combinedText)
// Store the value back into the struct ready for database operations
s.Secret = hex.EncodeToString(combinedText)
return s, nil
//return string(ciphertext[:]), nil
}
func (s *Secret) DecryptSecret() (*Secret, error) {
// The key argument should be the AES key, either 16 or 32 bytes
// to select AES-128 or AES-256.
//keyString := os.Getenv("SECRETS_KEY")
//keyString := secretKey
//key := []byte(keyString)
/*
key, err := ProvideKey()
if err != nil {
return s, err
}
*/
if len(s.Secret) < nonceSize {
log.Printf("DecryptSecret ciphertext is too short to decrypt\n")
return s, errors.New("ciphertext is too short")
}
crypted, err := hex.DecodeString(s.Secret)
if err != nil {
log.Printf("DecryptSecret unable to convert hex encoded string due to error '%s'\n", err)
return s, err
}
//log.Printf("DecryptSecret processing secret '%x'\n", crypted)
// The nonce is the first 12 bytes from the ciphertext
nonce := crypted[:nonceSize]
ciphertext := crypted[nonceSize:]
/*
log.Printf("DecryptSecret applying key '%v' and nonce '%x' to ciphertext '%x'\n", key, nonce, ciphertext)
block, err := aes.NewCipher(key)
if err != nil {
log.Printf("DecryptSecret NewCipher error '%s'\n", err)
return s, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
log.Printf("DecryptSecret NewGCM error '%s'\n", err)
return s, err
}
*/
aesgcm, err := startCipher()
if err != nil {
log.Printf("DecryptSecret error commencing GCM cipher '%s'\n", err)
return s, err
}
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
log.Printf("DecryptSecret Open error '%s'\n", err)
return s, err
}
//log.Printf("DecryptSecret plaintext is '%s'\n", plaintext)
s.Secret = string(plaintext)
return s, nil
}

View File

@@ -1,527 +0,0 @@
package models
import (
"errors"
"fmt"
"log"
"os"
"reflect"
"smt/utils"
"github.com/jmoiron/sqlx"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
)
var db *sqlx.DB
const (
sqlFile = "smt.db"
)
const createUsers string = `
CREATE TABLE IF NOT EXISTS users (
UserId INTEGER PRIMARY KEY AUTOINCREMENT,
GroupId INTEGER,
UserName VARCHAR,
Password VARCHAR,
Admin BOOLEAN DEFAULT 0,
LdapUser BOOLEAN DEFAULT 0
);
`
const createSafes string = `
CREATE TABLE IF NOT EXISTS safes (
SafeId INTEGER PRIMARY KEY AUTOINCREMENT,
SafeName VARCHAR
);
`
const createGroups string = `
CREATE TABLE IF NOT EXISTS groups (
GroupId INTEGER PRIMARY KEY AUTOINCREMENT,
GroupName VARCHAR,
LdapGroup BOOLEAN DEFAULT 0,
LdapDn VARCHAR DEFAULT '',
Admin BOOLEAN DEFAULT 0
);
`
const createPermissions = `
CREATE TABLE IF NOT EXISTS permissions (
PermissionId INTEGER PRIMARY KEY AUTOINCREMENT,
Description VARCHAR DEFAULT '',
ReadOnly BOOLEAN DEFAULT 0,
SafeId INTEGER,
UserId INTEGER DEFAULT 0,
GroupId INTEGER DEFAULT 0,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
`
const createSecrets string = `
CREATE TABLE IF NOT EXISTS secrets (
SecretId INTEGER PRIMARY KEY AUTOINCREMENT,
SafeId INTEGER,
DeviceName VARCHAR,
DeviceCategory VARCHAR,
UserName VARCHAR,
Secret VARCHAR,
LastUpdated datetime DEFAULT (datetime('1970-01-01 00:00:00')),
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
`
const createSchema string = `
CREATE TABLE IF NOT EXISTS schema (
Version INTEGER
);
`
const createAudit string = `
CREATE TABLE IF NOT EXISTS audit (
AuditId INTEGER PRIMARY KEY AUTOINCREMENT,
UserId INTEGER DEFAULT 0,
SecretId INTEGER DEFAULT 0,
EventText VARCHAR,
EventTime datetime
);
`
// Establish connection to sqlite database
func ConnectDatabase() {
var err error
// Try using sqlite as our database
sqlPath := utils.GetFilePath(sqlFile)
db, err = sqlx.Open("sqlite", sqlPath)
if err != nil {
log.Printf("Error opening sqlite database connection to file '%s' : '%s'\n", sqlPath, err)
os.Exit(1)
} else {
log.Printf("Connected to sqlite database file '%s'\n", sqlPath)
}
//sqlx.NameMapper = func(s string) string { return s }
// Make sure our tables exist
CreateTables()
//defer db.Close()
}
func DisconnectDatabase() {
log.Printf("DisconnectDatabase called")
defer db.Close()
}
func CreateTables() {
var err error
var rowCount int
// Create database tables if it doesn't exist
/*
// Roles table should go first since other tables refer to it
if _, err = db.Exec(createRoles); err != nil {
log.Printf("Error checking roles table : '%s'", err)
os.Exit(1)
}
rowCount, _ = CheckCount("roles")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false);"); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO roles VALUES(2, 'UserRole', false);"); err != nil {
log.Printf("Error adding initial user role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO roles VALUES(3, 'GuestRole', true);"); err != nil {
log.Printf("Error adding initial guest role : '%s'", err)
os.Exit(1)
}
}
*/
// groups table
if _, err = db.Exec(createGroups); err != nil {
log.Printf("Error checking groups table : '%s'", err)
os.Exit(1)
}
// Add initial groups
rowCount, _ = CheckCount("groups")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(1, 'Administrators', 1);"); err != nil {
log.Printf("Error adding initial group entry id 1 : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO groups (GroupId, GroupName, Admin) VALUES(2, 'Users', 0);"); err != nil {
log.Printf("Error adding initial group entry id 2 : '%s'", err)
os.Exit(1)
}
}
// Users table
if _, err = db.Exec(createUsers); err != nil {
log.Printf("Error checking users table : '%s'", err)
os.Exit(1)
}
rowCount, _ = CheckCount("users")
if rowCount == 0 {
// Check if there was an initial password defined in the .env file
initialPassword := os.Getenv("INITIAL_PASSWORD")
if initialPassword == "" {
initialPassword = "password"
} else if initialPassword[:4] == "$2a$" {
log.Printf("CreateTables inital admin password is already a hash")
} else {
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
initialPassword = string(cryptText)
}
if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(1, 1, 'Administrator', ?, false, true);", initialPassword); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO users (UserId, GroupId, UserName, Password, LdapUser, Admin) VALUES(2, 2, 'User', ?, false, false);", initialPassword); err != nil {
log.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
}
// Safes table
if _, err = db.Exec(createSafes); err != nil {
log.Printf("Error checking safes table : '%s'", err)
os.Exit(1)
}
// Create an initial safe
rowCount, _ = CheckCount("safes")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO safes VALUES(1, 'Default Safe');"); err != nil {
log.Printf("Error adding initial safe entry : '%s'", err)
os.Exit(1)
}
}
// Secrets table
if _, err = db.Exec(createSecrets); err != nil {
log.Printf("Error checking secrets table : '%s'", err)
os.Exit(1)
}
// permissions table
if _, err = db.Exec(createPermissions); err != nil {
log.Printf("Error checking permissions table : '%s'", err)
os.Exit(1)
}
// Add initial permissions
rowCount, _ = CheckCount("permissions")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, GroupId, SafeId) VALUES('Default Admin Group Permission', false, 1, 1);"); err != nil {
log.Printf("Error adding initial permissions entry userid 1 : '%s'", err)
os.Exit(1)
}
if _, err = db.Exec("INSERT INTO permissions (Description, ReadOnly, SafeId, GroupId) VALUES('Default User Group Permission', false, 1, 2);"); err != nil {
log.Printf("Error adding initial permissions entry userid 2 : '%s'", err)
os.Exit(1)
}
}
// Schema table should go last so we know if the database has a value in the schema table then everything was created properly
if _, err = db.Exec(createSchema); err != nil {
log.Printf("Error checking schema table : '%s'", err)
os.Exit(1)
}
schemaCheck, _ := CheckColumnExists("schema", "Version")
if !schemaCheck {
if _, err = db.Exec("INSERT INTO schema VALUES(2);"); err != nil {
log.Printf("Error adding initial schema version : '%s'", err)
os.Exit(1)
}
}
// Audit log table
if _, err = db.Exec(createAudit); err != nil {
log.Printf("Error checking audit table : '%s'", err)
os.Exit(1)
}
// Remove users RoleId column
userRoleIdCheck, _ := CheckColumnExists("users", "RoleId")
if userRoleIdCheck {
//_, err := db.Exec("ALTER TABLE users DROP COLUMN RoleId;")
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE users RENAME TO _users_old;
CREATE TABLE users
(
UserId INTEGER PRIMARY KEY AUTOINCREMENT,
GroupId INTEGER,
UserName VARCHAR,
Password VARCHAR,
Admin BOOLEAN DEFAULT 0,
LdapUser BOOLEAN DEFAULT 0
);
INSERT INTO users SELECT * FROM _users_old;
COMMIT;
PRAGMA foreign_keys=on;
DROP TABLE _users_old;
`)
if err != nil {
log.Printf("Error altering users table to drop RoleId column : '%s'\n", err)
os.Exit(1)
}
}
// Set any unassigned secrets to the default safe id
if _, err = db.Exec("UPDATE users SET LdapUser = 0 WHERE LdapUser is null;"); err != nil {
log.Printf("Error setting LdapUser flag to false for existing users : '%s'", err)
os.Exit(1)
}
// Remove LdapGroup column from roles table
ldapCheck, _ := CheckColumnExists("roles", "LdapGroup")
if ldapCheck {
_, err := db.Exec("ALTER TABLE roles DROP COLUMN LdapGroup;")
if err != nil {
log.Printf("Error altering roles table to renmove LdapGroup column : '%s'\n", err)
os.Exit(1)
}
}
// Add SafeId column to secrets table
safeIdCheck, _ := CheckColumnExists("secrets", "SafeId")
if !safeIdCheck {
// Add the column for LdapGroup in the roles table
_, err := db.Exec("ALTER TABLE secrets ADD COLUMN SafeId INTEGER REFERENCES safes(SafeId);")
if err != nil {
log.Printf("Error altering secrets table to add SafeId column : '%s'\n", err)
os.Exit(1)
}
}
// Set any unassigned secrets to the default safe id
if _, err = db.Exec("UPDATE secrets SET SafeId = 1 WHERE SafeId is null;"); err != nil {
log.Printf("Error setting safe ID of existing secrets : '%s'", err)
os.Exit(1)
}
// Remove RoleId column from secrets table
secretsRoleIdCheck, _ := CheckColumnExists("secrets", "RoleId")
if secretsRoleIdCheck {
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE secrets RENAME TO _secrets_old;
CREATE TABLE secrets
(
SecretId INTEGER PRIMARY KEY AUTOINCREMENT,
RoleId INTEGER,
SafeId INTEGER,
DeviceName VARCHAR,
DeviceCategory VARCHAR,
UserName VARCHAR,
Secret VARCHAR,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
INSERT INTO secrets SELECT SecretId, RoleId, SafeId, DeviceName, DeviceCategory, UserName, Secret FROM _secrets_old;
ALTER TABLE secrets DROP COLUMN RoleId;
ALTER TABLE secrets ADD COLUMN LastUpdated datetime;
UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null;
COMMIT;
PRAGMA foreign_keys=on;
DROP TABLE _secrets_old;
`)
if err != nil {
log.Printf("Error altering secrets table to remove RoleId column : '%s'\n", err)
os.Exit(1)
}
}
// Remove the Admin column from roles table
rolesAdminCheck, _ := CheckColumnExists("roles", "Admin")
if rolesAdminCheck {
_, err := db.Exec("ALTER TABLE roles DROP COLUMN Admin;")
if err != nil {
log.Printf("Error altering roles table to remove Admin column : '%s'\n", err)
os.Exit(1)
}
}
// Remove the RoleId from permissiosn table
permissionsRoleIdCheck, _ := CheckColumnExists("permissions", "RoleId")
if permissionsRoleIdCheck {
_, err := db.Exec(`
PRAGMA foreign_keys=off;
BEGIN TRANSACTION;
ALTER TABLE permissions RENAME TO _permissions_old;
CREATE TABLE permissions
(
PermissionId INTEGER PRIMARY KEY AUTOINCREMENT,
Description VARCHAR DEFAULT '',
ReadOnly BOOLEAN DEFAULT 0,
SafeId INTEGER,
UserId INTEGER DEFAULT 0,
GroupId INTEGER DEFAULT 0,
FOREIGN KEY (SafeId) REFERENCES safes(SafeId)
);
INSERT INTO permissions SELECT PermissionId, SafeId, UserId, GroupId, '' AS Description, 0 as ReadOnly FROM _permissions_old;
UPDATE permissions SET ReadOnly = 0 WHERE ReadOnly is null;
UPDATE permissions SET Description = '' WHERE Description is null;
UPDATE permissions SET UserId = 0 WHERE UserId is null;
UPDATE permissions SET GroupId = 0 WHERE GroupId is null;
COMMIT;
PRAGMA foreign_keys=on;
DROP TABLE _permissions_old;
`)
if err != nil {
log.Printf("Error altering permissions table to remove RoleId column : '%s'\n", err)
os.Exit(1)
}
}
secretsLastUpdatedCheck, _ := CheckColumnExists("secrets", "LastUpdated")
if !secretsLastUpdatedCheck {
// Add the column for LastUpdated in the secrets table
_, err := db.Exec("ALTER TABLE secrets ADD COLUMN LastUpdated datetime;")
if err != nil {
log.Printf("Error altering secrets table to add LastUpdated column : '%s'\n", err)
os.Exit(1)
}
// Set the default value
if _, err = db.Exec("UPDATE secrets SET LastUpdated = (datetime('1970-01-01 00:00:00')) WHERE LastUpdated is null;"); err != nil {
log.Printf("Error setting LastUpdated of existing secrets : '%s'", err)
os.Exit(1)
}
}
/*
// Database updates added after initial version released
ldapCheck, _ := CheckColumnExists("roles", "LdapGroup")
if !ldapCheck {
// Add the column for LdapGroup in the roles table
_, err := db.Exec("ALTER TABLE roles ADD COLUMN LdapGroup VARCHAR DEFAULT '';")
if err != nil {
log.Printf("Error altering roles table to add LdapGroup column : '%s'\n", err)
os.Exit(1)
}
}
// Add the two LDAP columns to the users table if they weren't there
ldapUserCheck, _ := CheckColumnExists("users", "LdapUser")
if !ldapUserCheck {
log.Printf("CreateTables creating ldap columns in user table")
_, err := db.Exec("ALTER TABLE users ADD COLUMN LdapUser BOOLEAN DEFAULT 0;")
if err != nil {
log.Printf("Error altering users table to add LdapUser column : '%s'\n", err)
os.Exit(1)
}
_, err = db.Exec("ALTER TABLE users ADD COLUMN LdapDn VARCHAR DEFAULT '';")
if err != nil {
log.Printf("Error altering users table to add LdapDn column : '%s'\n", err)
os.Exit(1)
}
}
*/
}
// Count the number of records in the sqlite database
// Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592
func CheckCount(tablename string) (int, error) {
var count int
stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename)
if err != nil {
log.Printf("CheckCount error preparing sqlite statement : '%s'\n", err)
return 0, err
}
err = stmt.QueryRow().Scan(&count)
if err != nil {
log.Printf("CheckCount error querying database record count : '%s'\n", err)
return 0, err
}
stmt.Close() // or use defer rows.Close(), idc
return count, nil
}
// From https://stackoverflow.com/a/60100045
func GenerateInsertMethod(q interface{}) (string, error) {
if reflect.ValueOf(q).Kind() == reflect.Struct {
query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name())
fieldNames := ""
fieldValues := ""
v := reflect.ValueOf(q)
for i := 0; i < v.NumField(); i++ {
if i == 0 {
fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name)
} else {
fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name)
}
switch v.Field(i).Kind() {
case reflect.Int:
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int())
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int())
}
case reflect.String:
if i == 0 {
fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String())
} else {
fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String())
}
case reflect.Bool:
var boolSet int8
if v.Field(i).Bool() {
boolSet = 1
}
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet)
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet)
}
default:
log.Printf("Unsupported type '%s'\n", v.Field(i).Kind())
}
}
query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues)
return query, nil
}
return "", errors.New("SqlGenerationError")
}
func CheckColumnExists(table string, column string) (bool, error) {
var count int64
rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';")
if err != nil {
log.Printf("CheckColumnExists error querying database for existence of column '%s' : '%s'\n", column, err)
return false, err
}
defer rows.Close()
for rows.Next() {
// cols is an []interface{} of all of the column results
cols, _ := rows.SliceScan()
log.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
count = cols[0].(int64)
if count == 1 {
return true, nil
} else {
return false, nil
}
}
err = rows.Err()
if err != nil {
log.Printf("CheckColumnExists error getting results : '%s'\n", err)
return false, err
}
return false, nil
}

View File

@@ -7,18 +7,19 @@ import (
"log" "log"
"smt/utils" "smt/utils"
"smt/utils/token" "smt/utils/token"
"time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
type User struct { type User struct {
UserId int `db:"UserId" json:"userId"` UserId int `db:"UserId" json:"userId"`
GroupId int `db:"GroupId" json:"groupId"` GroupId int `db:"GroupId" json:"groupId"`
UserName string `db:"UserName" json:"userName"` UserName string `db:"UserName" json:"userName"`
Password string `db:"Password" json:"-"` Password string `db:"Password" json:"-"`
LdapUser bool `db:"LdapUser" json:"ldapUser"` LdapUser bool `db:"LdapUser" json:"ldapUser"`
Admin bool `db:"Admin"` Admin bool `db:"Admin"`
//LdapDn string `db:"LdapDn" json:"ldapDn"` LastLogin time.Time `db:"LastLogin" json:"lastLogin"`
} }
type UserRole struct { type UserRole struct {
@@ -207,6 +208,8 @@ func LoginCheck(username string, password string) (string, error) {
return "", err return "", err
} }
u.UserSetLastLogin()
return token, nil return token, nil
} }
@@ -286,15 +289,23 @@ func UserLdapNewLoginCheck(username string, password string) (User, error) {
return u, nil return u, nil
} }
/* func (u *User) UserSetLastLogin() error {
// StoreLdapUser creates a user record in the database and returns the corresponding userId
func StoreLdapUser(u *User) error {
// TODO u.LastLogin = time.Now().UTC()
result, err := db.NamedExec((`UPDATE users SET LastLogin = :LastLogin WHERE UserId = :UserId`), u)
if err != nil {
log.Printf("UserSetLastLogin error executing sql update : '%s'\n", err)
return err
} else {
affected, _ := result.RowsAffected()
id, _ := result.LastInsertId()
log.Printf("UserSetLastLogin returned result id '%d' affecting %d row(s).\n", id, affected)
}
return nil return nil
} }
*/
func UserGetByID(uid uint) (User, error) { func UserGetByID(uid uint) (User, error) {