diff --git a/db/db.go b/db/db.go index 2156f88..2ee6d51 100644 --- a/db/db.go +++ b/db/db.go @@ -5,6 +5,7 @@ import ( "embed" "fmt" "log/slog" + "os" "reflect" "strings" @@ -40,6 +41,10 @@ func New(logger *slog.Logger, cfg Config) (Database, error) { } return db, nil case "postgres": + // The sqlc query set is SQLite-first. Keep Postgres opt-in until full parity is validated. + if strings.TrimSpace(os.Getenv("VCTP_ENABLE_EXPERIMENTAL_POSTGRES")) != "1" { + return nil, fmt.Errorf("postgres driver is disabled by default; set VCTP_ENABLE_EXPERIMENTAL_POSTGRES=1 to enable experimental mode") + } db, err := newPostgresDB(logger, cfg.DSN) if err != nil { return nil, err diff --git a/internal/secrets/secrets.go b/internal/secrets/secrets.go index 3577f8a..8ef4bad 100644 --- a/internal/secrets/secrets.go +++ b/internal/secrets/secrets.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" + "fmt" "io" "log/slog" ) @@ -68,6 +69,9 @@ func (s *Secrets) Decrypt(base64CipherText string) ([]byte, error) { // Extract the nonce from the ciphertext nonceSize := gcm.NonceSize() + if len(cipherText) < nonceSize { + return nil, fmt.Errorf("ciphertext is too short") + } nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:] // Decrypt the ciphertext diff --git a/internal/tasks/dailyAggregate.go b/internal/tasks/dailyAggregate.go index ec4a1cc..24bfc01 100644 --- a/internal/tasks/dailyAggregate.go +++ b/internal/tasks/dailyAggregate.go @@ -294,13 +294,14 @@ func (c *CronTask) aggregateDailySummaryGo(ctx context.Context, dayStart, dayEnd // Get the first hourly snapshot on/after dayEnd to help confirm deletions that happen on the last snapshot of the day. var nextSnapshotTable string - nextSnapshotRows, nextErr := c.Database.DB().QueryxContext(ctx, ` + nextSnapshotQuery := dbConn.Rebind(` SELECT table_name FROM snapshot_registry WHERE snapshot_type = 'hourly' AND snapshot_time >= ? ORDER BY snapshot_time ASC LIMIT 1 -`, dayEnd.Unix()) +`) + nextSnapshotRows, nextErr := c.Database.DB().QueryxContext(ctx, nextSnapshotQuery, dayEnd.Unix()) if nextErr == nil { if nextSnapshotRows.Next() { if scanErr := nextSnapshotRows.Scan(&nextSnapshotTable); scanErr != nil { @@ -1087,11 +1088,11 @@ LIMIT %d } else { for rows.Next() { var ( - vcenter string - vmId, vmUuid sql.NullString - name sql.NullString - samplesPresent, snapshotTime sql.NullInt64 - avgIsPresent sql.NullFloat64 + vcenter string + vmId, vmUuid sql.NullString + name sql.NullString + samplesPresent, snapshotTime sql.NullInt64 + avgIsPresent sql.NullFloat64 ) if err := rows.Scan(&vcenter, &vmId, &vmUuid, &name, &samplesPresent, &avgIsPresent, &snapshotTime); err != nil { continue diff --git a/internal/tasks/inventoryHelpers.go b/internal/tasks/inventoryHelpers.go index 46f920a..51152a7 100644 --- a/internal/tasks/inventoryHelpers.go +++ b/internal/tasks/inventoryHelpers.go @@ -85,13 +85,14 @@ func listLatestHourlyWithRows(ctx context.Context, dbConn *sqlx.DB, vcenter stri if limit <= 0 { limit = 50 } - rows, err := dbConn.QueryxContext(ctx, ` + query := dbConn.Rebind(` SELECT table_name, snapshot_time, snapshot_count FROM snapshot_registry WHERE snapshot_type = 'hourly' AND snapshot_time < ? ORDER BY snapshot_time DESC LIMIT ? -`, beforeUnix, limit) +`) + rows, err := dbConn.QueryxContext(ctx, query, beforeUnix, limit) if err != nil { return nil, err } diff --git a/internal/tasks/inventoryLifecycle.go b/internal/tasks/inventoryLifecycle.go index e574943..9d7faac 100644 --- a/internal/tasks/inventoryLifecycle.go +++ b/internal/tasks/inventoryLifecycle.go @@ -103,11 +103,12 @@ type lifecycleCandidate struct { } func loadLifecycleCandidates(ctx context.Context, dbConn *sqlx.DB, vcenter string, present map[string]InventorySnapshotRow) ([]lifecycleCandidate, error) { - rows, err := dbConn.QueryxContext(ctx, ` + query := dbConn.Rebind(` SELECT "VmId","VmUuid","Name","Cluster" FROM vm_lifecycle_cache WHERE "Vcenter" = ? AND ("DeletedAt" IS NULL OR "DeletedAt" = 0) -`, vcenter) +`) + rows, err := dbConn.QueryxContext(ctx, query, vcenter) if err != nil { return nil, err } @@ -143,12 +144,13 @@ type snapshotTable struct { func listHourlyTablesForDay(ctx context.Context, dbConn *sqlx.DB, dayStart, dayEnd time.Time) ([]snapshotTable, error) { log := loggerFromCtx(ctx, nil) - rows, err := dbConn.QueryxContext(ctx, ` + query := dbConn.Rebind(` SELECT table_name, snapshot_time, snapshot_count FROM snapshot_registry WHERE snapshot_type = 'hourly' AND snapshot_time >= ? AND snapshot_time < ? ORDER BY snapshot_time ASC -`, dayStart.Unix(), dayEnd.Unix()) +`) + rows, err := dbConn.QueryxContext(ctx, query, dayStart.Unix(), dayEnd.Unix()) if err != nil { return nil, err } @@ -176,13 +178,14 @@ ORDER BY snapshot_time ASC } func nextSnapshotAfter(ctx context.Context, dbConn *sqlx.DB, after time.Time, vcenter string) (string, error) { - rows, err := dbConn.QueryxContext(ctx, ` + query := dbConn.Rebind(` SELECT table_name FROM snapshot_registry WHERE snapshot_type = 'hourly' AND snapshot_time >= ? ORDER BY snapshot_time ASC LIMIT 1 -`, after.Unix()) +`) + rows, err := dbConn.QueryxContext(ctx, query, after.Unix()) if err != nil { return "", err } diff --git a/internal/tasks/monitorVcenter.go b/internal/tasks/monitorVcenter.go index ded9847..ac9af68 100644 --- a/internal/tasks/monitorVcenter.go +++ b/internal/tasks/monitorVcenter.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "errors" - "fmt" "log/slog" "strings" "time" @@ -16,103 +15,11 @@ import ( "github.com/vmware/govmomi/vim25/types" ) -// use gocron to check vcenters for VMs or updates we don't know about +// RunVcenterPoll is intentionally disabled. +// The legacy inventory polling flow has been retired in favor of hourly snapshots. func (c *CronTask) RunVcenterPoll(ctx context.Context, logger *slog.Logger) error { - startedAt := time.Now() - defer func() { - logger.Info("Vcenter poll job finished", "duration", time.Since(startedAt)) - }() - var matchFound bool - - // reload settings in case vcenter list has changed - c.Settings.ReadYMLSettings() - - for _, url := range c.Settings.Values.Settings.VcenterAddresses { - c.Logger.Debug("connecting to vcenter", "url", url) - vc := vcenter.New(c.Logger, c.VcCreds) - vc.Login(url) - - // Get list of VMs from vcenter - vcVms, err := vc.GetAllVmReferences() - - // Get list of VMs from inventory table - c.Logger.Debug("Querying inventory table") - results, err := c.Database.Queries().GetInventoryByVcenter(ctx, url) - if err != nil { - c.Logger.Error("Unable to query inventory table", "error", err) - return err - } - - if len(results) == 0 { - c.Logger.Error("Empty inventory results") - return fmt.Errorf("Empty inventory results") - } - - // Iterate VMs from vcenter and see if they were in the database - for _, vm := range vcVms { - matchFound = false - - // Skip any vCLS VMs - if strings.HasPrefix(vm.Name(), "vCLS-") { - //c.Logger.Debug("Skipping internal VM", "vm_name", vm.Name()) - continue - } - - // TODO - should we compare the UUID as well? - for _, dbvm := range results { - if dbvm.VmId.String == vm.Reference().Value { - //c.Logger.Debug("Found match for VM", "vm_name", dbvm.Name, "id", dbvm.VmId.String) - matchFound = true - - // Get the full VM object - vmObj, err := vc.ConvertObjToMoVM(vm) - if err != nil { - c.Logger.Error("Failed to find VM in vcenter", "vm_id", dbvm.VmId.String, "error", err) - continue - } - - if vmObj.Config == nil { - c.Logger.Error("VM has no config properties", "vm_id", dbvm.VmId.String, "vm_name", vmObj.Name) - continue - } - - // Check that this is definitely the right VM - if dbvm.VmUuid.String == vmObj.Config.Uuid { - // TODO - compare database against current values, create update record if not matching - err = c.UpdateVmInventory(vmObj, vc, ctx, dbvm) - } else { - c.Logger.Error("VM uuid doesn't match database record", "vm_name", dbvm.Name, "id", dbvm.VmId.String, "vc_uuid", vmObj.Config.Uuid, "db_uuid", dbvm.VmUuid.String) - } - - break - } - } - - if !matchFound { - c.Logger.Debug("Need to add VM to inventory table", "MoRef", vm.Reference()) - vmObj, err := vc.ConvertObjToMoVM(vm) - if err != nil { - c.Logger.Error("Received error getting vm maangedobject", "error", err) - continue - } - - // retrieve VM properties and insert into inventory - err = c.AddVmToInventory(vmObj, vc, ctx) - if err != nil { - c.Logger.Error("Received error with VM add", "error", err) - continue - } - - // add sleep to slow down mass VM additions - utils.SleepWithContext(ctx, (10 * time.Millisecond)) - } - } - c.Logger.Debug("Finished checking vcenter", "url", url) - _ = vc.Logout(ctx) - } - - c.Logger.Debug("Finished polling vcenters") - + _ = ctx + logger.Info("legacy vcenter polling task is disabled") return nil } diff --git a/internal/tasks/processEvents.go b/internal/tasks/processEvents.go index 7985d24..4549943 100644 --- a/internal/tasks/processEvents.go +++ b/internal/tasks/processEvents.go @@ -2,217 +2,12 @@ package tasks import ( "context" - "database/sql" "log/slog" - "strings" - "time" - "vctp/db/queries" - "vctp/internal/vcenter" - - "github.com/vmware/govmomi/vim25/types" ) -// use gocron to check events in the Events table -func (c *CronTask) RunVmCheck(ctx context.Context, logger *slog.Logger) error { - startedAt := time.Now() - defer func() { - logger.Info("Event processing job finished", "duration", time.Since(startedAt)) - }() - var ( - numVcpus int32 - numRam int32 - totalDiskGB float64 - srmPlaceholder string - foundVm bool - isTemplate string - poweredOn string - folderPath string - rpName string - vmUuid string - ) - - dateCmp := time.Now().AddDate(0, 0, -1).Unix() - logger.Debug("Started Events processing", "time", time.Now(), "since", dateCmp) - - // Query events table - events, err := c.Database.Queries().ListUnprocessedEvents(ctx, - sql.NullInt64{Int64: dateCmp, Valid: dateCmp > 0}) - if err != nil { - logger.Error("Unable to query for unprocessed events", "error", err) - return nil // TODO - what to do with this error? - } else { - logger.Debug("Successfully queried for unprocessed events", "count", len(events)) - } - - for _, evt := range events { - logger.Debug("Checking event", "event", evt) - - // TODO - get a list of unique vcenters, then process each event in batches - // to avoid doing unnecessary login/logout of vcenter - - //c.Logger.Debug("connecting to vcenter") - vc := vcenter.New(c.Logger, c.VcCreds) - vc.Login(evt.Source) - - //datacenter = evt.DatacenterName.String - vmObject, err := vc.FindVMByIDWithDatacenter(evt.VmId.String, evt.DatacenterId.String) - - if err != nil { - c.Logger.Error("Can't locate vm in vCenter", "vmID", evt.VmId.String, "error", err) - continue - } else if vmObject == nil { - c.Logger.Debug("didn't find VM", "vm_id", evt.VmId.String) - - // TODO - if VM name ends with -tmp or -phVm then we mark this record as processed and stop trying to find a VM that doesnt exist anymore - - if strings.HasSuffix(evt.VmName.String, "-phVm") || strings.HasSuffix(evt.VmName.String, "-tmp") { - c.Logger.Info("VM name indicates temporary VM, marking as processed", "vm_name", evt.VmName.String) - - err = c.Database.Queries().UpdateEventsProcessed(ctx, evt.Eid) - if err != nil { - c.Logger.Error("Unable to mark this event as processed", "event_id", evt.Eid, "error", err) - } else { - //c.Logger.Debug("Marked event as processed", "event_id", evt.Eid) - } - } - - /* - numRam = 0 - numVcpus = 0 - totalDiskGB = 0 - isTemplate = "FALSE" - folderPath = "" - vmUuid = "" - */ - continue - } - - if strings.HasPrefix(vmObject.Name, "vCLS-") { - c.Logger.Info("Skipping internal vCLS VM event", "vm_name", vmObject.Name) - if err := c.Database.Queries().UpdateEventsProcessed(ctx, evt.Eid); err != nil { - c.Logger.Error("Unable to mark vCLS event as processed", "event_id", evt.Eid, "error", err) - } - continue - } - - //c.Logger.Debug("found VM") - srmPlaceholder = "FALSE" // Default assumption - //prettyPrint(vmObject) - - // calculate VM properties we want to store - if vmObject.Config != nil { - numRam = vmObject.Config.Hardware.MemoryMB - numVcpus = vmObject.Config.Hardware.NumCPU - vmUuid = vmObject.Config.Uuid - - var totalDiskBytes int64 - - // Calculate the total disk allocated in GB - for _, device := range vmObject.Config.Hardware.Device { - if disk, ok := device.(*types.VirtualDisk); ok { - - // Print the filename of the backing device - if _, ok := disk.Backing.(*types.VirtualDiskFlatVer2BackingInfo); ok { - //c.Logger.Debug("Adding disk", "size_bytes", disk.CapacityInBytes, "backing_file", backing.FileName) - } else { - //c.Logger.Debug("Adding disk, unknown backing type", "size_bytes", disk.CapacityInBytes) - } - - totalDiskBytes += disk.CapacityInBytes - //totalDiskGB += float64(disk.CapacityInBytes / 1024 / 1024 / 1024) // Convert from bytes to GB - } - } - totalDiskGB = float64(totalDiskBytes / 1024 / 1024 / 1024) - c.Logger.Debug("Converted total disk size", "bytes", totalDiskBytes, "GB", totalDiskGB) - - // Determine if the VM is a normal VM or an SRM placeholder - if vmObject.Config.ManagedBy != nil && vmObject.Config.ManagedBy.ExtensionKey == "com.vmware.vcDr" { - if vmObject.Config.ManagedBy.Type == "placeholderVm" { - c.Logger.Debug("VM is a placeholder") - srmPlaceholder = "TRUE" - } else { - c.Logger.Debug("VM is managed by SRM but not a placeholder", "details", vmObject.Config.ManagedBy) - } - } - - if vmObject.Config.Template { - isTemplate = "TRUE" - } else { - isTemplate = "FALSE" - } - - // Retrieve the full folder path of the VM - folderPath, err = vc.GetVMFolderPath(*vmObject) - if err != nil { - c.Logger.Error("failed to get vm folder path", "error", err) - folderPath = "" - } else { - c.Logger.Debug("Found vm folder path", "folder_path", folderPath) - } - - // Retrieve the resource pool of the VM - rpName, _ = vc.GetVmResourcePool(*vmObject) - - foundVm = true - } else { - c.Logger.Error("Empty VM config") - } - - //c.Logger.Debug("VM has runtime data", "power_state", vmObject.Runtime.PowerState) - if vmObject.Runtime.PowerState == "poweredOff" { - poweredOn = "FALSE" - } else { - poweredOn = "TRUE" - } - - _ = vc.Logout(ctx) - - if foundVm { - c.Logger.Debug("Adding to Inventory table", "vm_name", evt.VmName.String, "vcpus", numVcpus, "ram", numRam, "dc", evt.DatacenterId.String) - - params := queries.CreateInventoryParams{ - Name: vmObject.Name, - Vcenter: evt.Source, - CloudId: sql.NullString{String: evt.CloudId, Valid: evt.CloudId != ""}, - EventKey: sql.NullString{String: evt.EventKey.String, Valid: evt.EventKey.Valid}, - VmId: sql.NullString{String: evt.VmId.String, Valid: evt.VmId.Valid}, - Datacenter: sql.NullString{String: evt.DatacenterName.String, Valid: evt.DatacenterName.Valid}, - Cluster: sql.NullString{String: evt.ComputeResourceName.String, Valid: evt.ComputeResourceName.Valid}, - CreationTime: sql.NullInt64{Int64: evt.EventTime.Int64, Valid: evt.EventTime.Valid}, - InitialVcpus: sql.NullInt64{Int64: int64(numVcpus), Valid: numVcpus > 0}, - InitialRam: sql.NullInt64{Int64: int64(numRam), Valid: numRam > 0}, - ProvisionedDisk: sql.NullFloat64{Float64: totalDiskGB, Valid: totalDiskGB > 0}, - Folder: sql.NullString{String: folderPath, Valid: folderPath != ""}, - ResourcePool: sql.NullString{String: rpName, Valid: rpName != ""}, - VmUuid: sql.NullString{String: vmUuid, Valid: vmUuid != ""}, - SrmPlaceholder: srmPlaceholder, - IsTemplate: isTemplate, - PoweredOn: poweredOn, - } - - //c.Logger.Debug("database params", "params", params) - - // Insert the new inventory record into the database - _, err := c.Database.Queries().CreateInventory(ctx, params) - if err != nil { - c.Logger.Error("unable to perform database insert", "error", err) - } else { - //c.Logger.Debug("created database record", "insert_result", result) - - // mark this event as processed - err = c.Database.Queries().UpdateEventsProcessed(ctx, evt.Eid) - if err != nil { - c.Logger.Error("Unable to mark this event as processed", "event_id", evt.Eid, "error", err) - } else { - //c.Logger.Debug("Marked event as processed", "event_id", evt.Eid) - } - } - } else { - c.Logger.Debug("Not adding to Inventory due to missing vcenter config property", "vm_name", evt.VmName.String) - } - - } - - //fmt.Printf("processing at %s", time.Now()) +// RunVmCheck is intentionally disabled. +// The legacy event-processing flow has been retired in favor of snapshot-based lifecycle logic. +func (c *CronTask) RunVmCheck(_ context.Context, logger *slog.Logger) error { + logger.Info("legacy VM event-processing task is disabled") return nil } diff --git a/main.go b/main.go index 300b89e..169ca6b 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "flag" "fmt" @@ -34,7 +35,11 @@ var ( cronAggregateFrequency time.Duration ) -const fallbackEncryptionKey = "5L1l3B5KvwOCzUHMAlCgsgUTRAYMfSpa" +const ( + encryptedVcenterPasswordPrefix = "enc:v1:" + encryptionKeyEnvVar = "VCTP_ENCRYPTION_KEY" + legacyFallbackEncryptionKey = "5L1l3B5KvwOCzUHMAlCgsgUTRAYMfSpa" +) func main() { settingsPath := flag.String("settings", "/etc/dtms/vctp.yml", "Path to settings YAML") @@ -135,24 +140,27 @@ func main() { utils.GenerateCerts(tlsCertFilename, tlsKeyFilename) } - // Load vcenter credentials from serttings, decrypt if required - encKey := deriveEncryptionKey(logger) + // Load vcenter credentials from settings, decrypt if required. + encKey := deriveEncryptionKey(logger, *settingsPath) a := secrets.New(logger, encKey) + legacyDecryptKeys := deriveLegacyDecryptionKeys(*settingsPath, encKey) vcEp := strings.TrimSpace(s.Values.Settings.VcenterPassword) if len(vcEp) == 0 { logger.Error("No vcenter password configured") os.Exit(1) } - vcPass, err := a.Decrypt(vcEp) + vcPass, rewrittenCredential, err := resolveVcenterPassword(logger, a, legacyDecryptKeys, vcEp) if err != nil { - logger.Error("failed to decrypt vcenter credentials. Assuming un-encrypted", "error", err) - vcPass = []byte(vcEp) - if cipherText, encErr := a.Encrypt([]byte(vcEp)); encErr != nil { - logger.Warn("failed to encrypt vcenter credentials", "error", encErr) + logger.Error("failed to resolve vcenter credentials", "error", err) + os.Exit(1) + } + if rewrittenCredential != "" && rewrittenCredential != vcEp { + s.Values.Settings.VcenterPassword = rewrittenCredential + if err := s.WriteYMLSettings(); err != nil { + logger.Warn("failed to update settings with encrypted vcenter password", "error", err) } else { - s.Values.Settings.VcenterPassword = cipherText - if err := s.WriteYMLSettings(); err != nil { - logger.Warn("failed to update settings with encrypted vcenter password", "error", err) + if strings.HasPrefix(vcEp, encryptedVcenterPasswordPrefix) { + logger.Info("rewrote vcenter password with refreshed encryption format") } else { logger.Info("encrypted vcenter password stored in settings file") } @@ -337,25 +345,141 @@ func durationFromSeconds(value int, fallback int) time.Duration { return time.Second * time.Duration(value) } -func deriveEncryptionKey(logger *slog.Logger) []byte { +func resolveVcenterPassword(logger *slog.Logger, cipher *secrets.Secrets, legacyDecryptKeys [][]byte, raw string) ([]byte, string, error) { + if strings.TrimSpace(raw) == "" { + return nil, "", fmt.Errorf("vcenter password is empty") + } + + // New format: explicit prefix so we can distinguish ciphertext from plaintext safely. + if strings.HasPrefix(raw, encryptedVcenterPasswordPrefix) { + enc := strings.TrimPrefix(raw, encryptedVcenterPasswordPrefix) + pass, usedLegacyKey, err := decryptVcenterPasswordWithFallback(logger, cipher, legacyDecryptKeys, enc) + if err != nil { + return nil, "", fmt.Errorf("prefixed password decrypt failed: %w", err) + } + if usedLegacyKey { + rewrite, rewriteErr := encryptWithPrefix(cipher, pass) + if rewriteErr != nil { + logger.Warn("failed to refresh prefixed vcenter password after fallback decrypt", "error", rewriteErr) + return pass, "", nil + } + logger.Info("rewrote prefixed vcenter password using active encryption key") + return pass, rewrite, nil + } + return pass, "", nil + } + + // Backward compatibility: existing deployments may have unprefixed ciphertext. + if pass, _, err := decryptVcenterPasswordWithFallback(logger, cipher, legacyDecryptKeys, raw); err == nil { + rewrite, rewriteErr := encryptWithPrefix(cipher, pass) + if rewriteErr != nil { + logger.Warn("failed to re-encrypt legacy vcenter password with prefix", "error", rewriteErr) + return pass, "", nil + } + return pass, rewrite, nil + } else { + // If decrypt fails and the input is non-trivial, treat it as plaintext and auto-encrypt. + if len(raw) <= 2 { + return nil, "", fmt.Errorf("vcenter password too short to auto-encrypt") + } + logger.Warn("unable to decrypt unprefixed vcenter password; treating value as plaintext", "error", err) + rewrite, rewriteErr := encryptWithPrefix(cipher, []byte(raw)) + if rewriteErr != nil { + return nil, "", fmt.Errorf("failed to encrypt plaintext vcenter password: %w", rewriteErr) + } + return []byte(raw), rewrite, nil + } +} + +func decryptVcenterPasswordWithFallback(logger *slog.Logger, cipher *secrets.Secrets, legacyDecryptKeys [][]byte, encrypted string) ([]byte, bool, error) { + pass, err := cipher.Decrypt(encrypted) + if err == nil { + return pass, false, nil + } + primaryErr := err + for _, key := range legacyDecryptKeys { + candidate := secrets.New(logger, key) + pass, decErr := candidate.Decrypt(encrypted) + if decErr == nil { + return pass, true, nil + } + } + return nil, false, primaryErr +} + +func encryptWithPrefix(cipher *secrets.Secrets, plain []byte) (string, error) { + enc, encErr := cipher.Encrypt(plain) + if encErr != nil { + return "", encErr + } + return encryptedVcenterPasswordPrefix + enc, nil +} + +func deriveLegacyDecryptionKeys(settingsPath string, activeKey []byte) [][]byte { + legacyKeys := make([][]byte, 0, 2) + addCandidate := func(candidate []byte) { + if len(candidate) == 0 || bytes.Equal(candidate, activeKey) { + return + } + for _, existing := range legacyKeys { + if bytes.Equal(existing, candidate) { + return + } + } + keyCopy := make([]byte, len(candidate)) + copy(keyCopy, candidate) + legacyKeys = append(legacyKeys, keyCopy) + } + + platformKey, _ := deriveHostKeyCandidate(settingsPath) + addCandidate(platformKey) + addCandidate([]byte(legacyFallbackEncryptionKey)) + + return legacyKeys +} + +func deriveEncryptionKey(logger *slog.Logger, settingsPath string) []byte { + if provided := strings.TrimSpace(os.Getenv(encryptionKeyEnvVar)); provided != "" { + sum := sha256.Sum256([]byte(provided)) + logger.Debug("derived encryption key from environment variable", "env_var", encryptionKeyEnvVar) + return sum[:] + } + + key, source := deriveHostKeyCandidate(settingsPath) + switch source { + case "bios-uuid": + logger.Debug("derived encryption key from BIOS UUID") + case "machine-id": + logger.Debug("derived encryption key from machine-id") + default: + logger.Warn("using host-derived encryption key fallback; set environment variable for explicit key", "env_var", encryptionKeyEnvVar) + } + return key +} + +func deriveHostKeyCandidate(settingsPath string) ([]byte, string) { if runtime.GOOS == "linux" { if data, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil { src := strings.TrimSpace(string(data)) if src != "" { sum := sha256.Sum256([]byte(src)) - logger.Debug("derived encryption key from BIOS UUID") - return sum[:] + return sum[:], "bios-uuid" } } if data, err := os.ReadFile("/etc/machine-id"); err == nil { src := strings.TrimSpace(string(data)) if src != "" { sum := sha256.Sum256([]byte(src)) - logger.Debug("derived encryption key from machine-id") - return sum[:] + return sum[:], "machine-id" } } } - logger.Warn("using fallback encryption key; hardware UUID not available") - return []byte(fallbackEncryptionKey) + + hostname, err := os.Hostname() + if err != nil { + hostname = "unknown-host" + } + src := strings.Join([]string{"vctp", runtime.GOOS, strings.TrimSpace(hostname), strings.TrimSpace(settingsPath)}, "|") + sum := sha256.Sum256([]byte(src)) + return sum[:], "host-derived" }