package main import ( "bytes" "context" "flag" "fmt" "os" "runtime" "strings" "time" "vctp/db" "vctp/internal/report" "vctp/internal/secrets" "vctp/internal/settings" "vctp/internal/tasks" utils "vctp/internal/utils" "vctp/internal/vcenter" "vctp/log" "vctp/server" "vctp/server/router" "crypto/sha256" "log/slog" "github.com/go-co-op/gocron/v2" ) var ( bindDisableTls bool sha1ver string // sha1 revision used to build the program buildTime string // when the executable was built cronSnapshotFrequency time.Duration cronAggregateFrequency time.Duration ) const ( encryptedVcenterPasswordPrefix = "enc:v1:" legacyFallbackEncryptionKey = "5L1l3B5KvwOCzUHMAlCgsgUTRAYMfSpa" ) func main() { settingsPath := flag.String("settings", "/etc/dtms/vctp.yml", "Path to settings YAML") runInventory := flag.Bool("run-inventory", false, "Run a single inventory snapshot across all configured vCenters and exit") dbCleanup := flag.Bool("db-cleanup", false, "Run a one-time cleanup to drop low-value hourly snapshot indexes and exit") backfillVcenterCache := flag.Bool("backfill-vcenter-cache", false, "Run a one-time backfill for vcenter latest+aggregate cache tables and exit") importSQLite := flag.String("import-sqlite", "", "Import a SQLite database file/DSN into the configured Postgres database and exit") flag.Parse() bootstrapLogger := log.New(log.LevelInfo, log.OutputText) ctx, cancel := context.WithCancel(context.Background()) // Load settings from yaml s := settings.New(bootstrapLogger, *settingsPath) err := s.ReadYMLSettings() if err != nil { bootstrapLogger.Error("failed to open yaml settings file", "error", err, "filename", *settingsPath) os.Exit(1) } logger := log.New( log.ToLevel(strings.ToLower(strings.TrimSpace(s.Values.Settings.LogLevel))), log.ToOutput(strings.ToLower(strings.TrimSpace(s.Values.Settings.LogOutput))), ) s.Logger = logger logger.Info("vCTP starting", "build_time", buildTime, "sha1_version", sha1ver, "go_version", runtime.Version(), "settings_file", *settingsPath) warnDeprecatedPollingSettings(logger, s.Values) // Configure database dbURL := strings.TrimSpace(s.Values.Settings.DatabaseURL) normalizedDriver, inferredFromDSN, err := db.ResolveDriver(s.Values.Settings.DatabaseDriver, dbURL) if err != nil { logger.Error("Invalid database configuration", "error", err) os.Exit(1) } if inferredFromDSN { logger.Warn("database_driver is unset; inferred postgres from database_url") } if dbURL == "" && normalizedDriver == "sqlite" { dbURL = utils.GetFilePath("db.sqlite3") } logger.Info("Effective database driver resolved", "driver", normalizedDriver) database, err := db.New(logger, db.Config{ Driver: normalizedDriver, DSN: dbURL, EnableExperimentalPostgres: s.Values.Settings.EnableExperimentalPostgres, }) if err != nil { logger.Error("Failed to create database", "error", err) os.Exit(1) } defer database.Close() //defer database.DB().Close() if err = db.Migrate(database, normalizedDriver); err != nil { logger.Error("failed to migrate database", "error", err) os.Exit(1) } if strings.TrimSpace(*importSQLite) != "" { if normalizedDriver != "postgres" { logger.Error("sqlite import requires settings.database_driver=postgres") os.Exit(1) } logger.Info("starting one-time sqlite import into postgres", "sqlite_source", strings.TrimSpace(*importSQLite)) stats, err := db.ImportSQLiteIntoPostgres(ctx, logger, database.DB(), strings.TrimSpace(*importSQLite)) if err != nil { logger.Error("failed to import sqlite database into postgres", "error", err) os.Exit(1) } logger.Info("completed sqlite import into postgres", "sqlite_source", stats.SourceDSN, "tables_imported", stats.TablesImported, "tables_skipped", stats.TablesSkipped, "rows_imported", stats.RowsImported, ) return } if *dbCleanup { dropped, err := db.CleanupHourlySnapshotIndexes(ctx, database.DB()) if err != nil { logger.Error("failed to cleanup hourly snapshot indexes", "error", err) os.Exit(1) } logger.Info("completed hourly snapshot index cleanup", "indexes_dropped", dropped) return } if *backfillVcenterCache { logger.Info("starting one-time vcenter cache backfill") if err := report.EnsureSnapshotRegistry(ctx, database); err != nil { logger.Error("failed to ensure snapshot registry", "error", err) os.Exit(1) } hourlyRecords, err := report.ListSnapshots(ctx, database, "hourly") if err != nil { logger.Error("failed to list hourly snapshots from registry", "error", err) os.Exit(1) } if len(hourlyRecords) == 0 { logger.Warn("snapshot registry has no hourly entries; attempting registry migration before cache backfill") stats, err := report.MigrateSnapshotRegistry(ctx, database) if err != nil { logger.Error("failed to migrate snapshot registry before cache backfill", "error", err) os.Exit(1) } logger.Info("snapshot registry migration complete", "hourly_renamed", stats.HourlyRenamed, "hourly_registered", stats.HourlyRegistered, "daily_registered", stats.DailyRegistered, "monthly_registered", stats.MonthlyRegistered, ) } if err := db.SyncVcenterTotalsFromSnapshots(ctx, database.DB()); err != nil { logger.Error("failed to backfill hourly vcenter totals cache", "error", err) os.Exit(1) } latestSynced, err := db.SyncVcenterLatestTotalsFromHistory(ctx, database.DB()) if err != nil { logger.Error("failed to backfill latest vcenter totals cache", "error", err) os.Exit(1) } dailySnapshots, dailyRows, dailyErr := db.SyncVcenterAggregateTotalsFromRegistry(ctx, database.DB(), "daily") if dailyErr != nil { logger.Warn("daily vcenter aggregate cache backfill completed with warnings", "error", dailyErr) } monthlySnapshots, monthlyRows, monthlyErr := db.SyncVcenterAggregateTotalsFromRegistry(ctx, database.DB(), "monthly") if monthlyErr != nil { logger.Warn("monthly vcenter aggregate cache backfill completed with warnings", "error", monthlyErr) } logger.Info("completed one-time vcenter cache backfill", "latest_rows_synced", latestSynced, "daily_snapshots_refreshed", dailySnapshots, "daily_rows_upserted", dailyRows, "monthly_snapshots_refreshed", monthlySnapshots, "monthly_rows_upserted", monthlyRows, ) return } // Determine bind IP bindIP := strings.TrimSpace(s.Values.Settings.BindIP) if bindIP == "" { bindIP = utils.GetOutboundIP().String() } // Determine bind port bindPort := s.Values.Settings.BindPort if bindPort == 0 { bindPort = 9443 } bindAddress := fmt.Sprint(bindIP, ":", bindPort) //logger.Info("Will listen on address", "ip", bindIP, "port", bindPort) // Determine bind disable TLS bindDisableTls = s.Values.Settings.BindDisableTLS // Get file names for TLS cert/key tlsCertFilename := strings.TrimSpace(s.Values.Settings.TLSCertFilename) if tlsCertFilename != "" { tlsCertFilename = utils.GetFilePath(tlsCertFilename) } else { tlsCertFilename = "./cert.pem" } tlsKeyFilename := strings.TrimSpace(s.Values.Settings.TLSKeyFilename) if tlsKeyFilename != "" { tlsKeyFilename = utils.GetFilePath(tlsKeyFilename) } else { tlsKeyFilename = "./privkey.pem" } // Generate certificate if required if !(utils.FileExists(tlsCertFilename) && utils.FileExists(tlsKeyFilename)) { logger.Warn("Specified TLS certificate or private key do not exist", "certificate", tlsCertFilename, "tls-key", tlsKeyFilename) if err := utils.GenerateCerts(tlsCertFilename, tlsKeyFilename); err != nil { logger.Error("failed to generate TLS cert/key", "error", err) os.Exit(1) } } // Load vcenter credentials from settings, decrypt if required. encKey := deriveEncryptionKey(logger, *settingsPath, s.Values.Settings.EncryptionKey) a := secrets.New(logger, encKey) legacyDecryptKeys := deriveLegacyDecryptionKeys(*settingsPath, encKey) vcEp := strings.TrimSpace(s.Values.Settings.VcenterPassword) if len(vcEp) == 0 { logger.Error("No vcenter password configured") os.Exit(1) } vcPass, rewrittenCredential, err := resolveVcenterPassword(logger, a, legacyDecryptKeys, vcEp) if err != nil { logger.Error("failed to resolve vcenter credentials", "error", err) os.Exit(1) } if rewrittenCredential != "" && rewrittenCredential != vcEp { s.Values.Settings.VcenterPassword = rewrittenCredential if err := s.WriteYMLSettings(); err != nil { logger.Warn("failed to update settings with encrypted vcenter password", "error", err) } else { if strings.HasPrefix(vcEp, encryptedVcenterPasswordPrefix) { logger.Info("rewrote vcenter password with refreshed encryption format") } else { logger.Info("encrypted vcenter password stored in settings file") } } } creds := vcenter.VcenterLogin{ Username: strings.TrimSpace(s.Values.Settings.VcenterUsername), Password: string(vcPass), Insecure: s.Values.Settings.VcenterInsecure, } if creds.Username == "" { logger.Error("No vcenter username configured") os.Exit(1) } // Set a recognizable User-Agent for vCenter sessions. ua := "vCTP" if sha1ver != "" { ua = fmt.Sprintf("vCTP/%s", sha1ver) } vcenter.SetUserAgent(ua) // Prepare the task scheduler c, err := gocron.NewScheduler() if err != nil { logger.Error("failed to create scheduler", "error", err) os.Exit(1) } // Pass useful information to the cron jobs ct := &tasks.CronTask{ Logger: logger, Database: database, Settings: s, VcCreds: &creds, FirstHourlySnapshotCheck: true, } // One-shot mode: run a single inventory snapshot across all configured vCenters and exit. if *runInventory { logger.Info("Running one-shot inventory snapshot across all vCenters") if err := ct.RunVcenterSnapshotHourly(ctx, logger, true); err != nil { logger.Error("One-shot inventory snapshot failed", "error", err) os.Exit(1) } logger.Info("One-shot inventory snapshot complete; exiting") return } cronSnapshotFrequency = durationFromSeconds(s.Values.Settings.VcenterInventorySnapshotSeconds, 3600) logger.Debug("Setting VM inventory snapshot cronjob frequency to", "frequency", cronSnapshotFrequency) cronAggregateFrequency = durationFromSeconds(s.Values.Settings.VcenterInventoryAggregateSeconds, 86400) logger.Debug("Setting VM inventory daily aggregation cronjob frequency to", "frequency", cronAggregateFrequency) startsAt3 := alignStart(time.Now(), cronSnapshotFrequency) job3, err := c.NewJob( gocron.DurationJob(cronSnapshotFrequency), gocron.NewTask(func() { ct.RunVcenterSnapshotHourly(ctx, logger, false) }), gocron.WithSingletonMode(gocron.LimitModeReschedule), gocron.WithStartAt(gocron.WithStartDateTime(startsAt3)), ) if err != nil { logger.Error("failed to start vcenter inventory snapshot cron job", "error", err) os.Exit(1) } logger.Debug("Created vcenter inventory snapshot cron job", "job", job3.ID(), "starting_at", startsAt3) startsAt4 := time.Now().Add(cronAggregateFrequency) if cronAggregateFrequency == time.Hour*24 { now := time.Now() startsAt4 = time.Date(now.Year(), now.Month(), now.Day()+1, 0, 10, 0, 0, now.Location()) } job4, err := c.NewJob( gocron.DurationJob(cronAggregateFrequency), gocron.NewTask(func() { ct.RunVcenterDailyAggregate(ctx, logger) }), gocron.WithSingletonMode(gocron.LimitModeReschedule), gocron.WithStartAt(gocron.WithStartDateTime(startsAt4)), ) if err != nil { logger.Error("failed to start vcenter inventory aggregation cron job", "error", err) os.Exit(1) } logger.Debug("Created vcenter inventory aggregation cron job", "job", job4.ID(), "starting_at", startsAt4) monthlyCron := strings.TrimSpace(s.Values.Settings.MonthlyAggregationCron) if monthlyCron == "" { monthlyCron = "10 3 1 * *" } logger.Debug("Setting monthly aggregation cron schedule", "cron", monthlyCron) job5, err := c.NewJob( gocron.CronJob(monthlyCron, false), gocron.NewTask(func() { ct.RunVcenterMonthlyAggregate(ctx, logger) }), gocron.WithSingletonMode(gocron.LimitModeReschedule), ) if err != nil { logger.Error("failed to start vcenter monthly aggregation cron job", "error", err) os.Exit(1) } logger.Debug("Created vcenter monthly aggregation cron job", "job", job5.ID()) snapshotCleanupCron := strings.TrimSpace(s.Values.Settings.SnapshotCleanupCron) if snapshotCleanupCron == "" { snapshotCleanupCron = "30 2 * * *" } job6, err := c.NewJob( gocron.CronJob(snapshotCleanupCron, false), gocron.NewTask(func() { ct.RunSnapshotCleanup(ctx, logger) if normalizedDriver == "sqlite" { logger.Info("Performing sqlite VACUUM after snapshot cleanup") if _, err := ct.Database.DB().ExecContext(ctx, "VACUUM"); err != nil { logger.Warn("VACUUM failed after snapshot cleanup", "error", err) } else { logger.Debug("VACUUM completed after snapshot cleanup") } } }), gocron.WithSingletonMode(gocron.LimitModeReschedule), ) if err != nil { logger.Error("failed to start snapshot cleanup cron job", "error", err) os.Exit(1) } logger.Debug("Created snapshot cleanup cron job", "job", job6.ID()) // Retry failed hourly snapshots retrySeconds := s.Values.Settings.HourlySnapshotRetrySeconds if retrySeconds <= 0 { retrySeconds = 300 } job7, err := c.NewJob( gocron.DurationJob(time.Duration(retrySeconds)*time.Second), gocron.NewTask(func() { ct.RunHourlySnapshotRetry(ctx, logger) }), gocron.WithSingletonMode(gocron.LimitModeReschedule), ) if err != nil { logger.Error("failed to start hourly snapshot retry cron job", "error", err) os.Exit(1) } logger.Debug("Created hourly snapshot retry cron job", "job", job7.ID(), "interval_seconds", retrySeconds) // start cron scheduler c.Start() // Start server r := router.New(logger, database, buildTime, sha1ver, runtime.Version(), &creds, a, s) svr := server.New( logger, c, cancel, bindAddress, server.WithRouter(r), server.SetTls(bindDisableTls), server.SetCertificate(tlsCertFilename), server.SetPrivateKey(tlsKeyFilename), ) //logger.Debug("Server configured", "object", svr) if err := svr.StartAndWait(); err != nil { logger.Error("server terminated with error", "error", err) os.Exit(1) } os.Exit(0) } // alignStart snaps the first run to a sensible boundary (hour or 15-minute block) when possible. func alignStart(now time.Time, freq time.Duration) time.Time { if freq == time.Hour { return now.Truncate(time.Hour).Add(time.Hour) } quarter := 15 * time.Minute if freq%quarter == 0 { return now.Truncate(quarter).Add(quarter) } return now.Add(freq) } func warnDeprecatedPollingSettings(logger *slog.Logger, cfg *settings.SettingsYML) { if cfg == nil { return } if cfg.Settings.VcenterEventPollingSeconds > 0 { logger.Warn("vcenter_event_polling_seconds is deprecated and ignored; snapshot lifecycle processing is used instead", "value", cfg.Settings.VcenterEventPollingSeconds, ) } if cfg.Settings.VcenterInventoryPollingSeconds > 0 { logger.Warn("vcenter_inventory_polling_seconds is deprecated and ignored; hourly snapshot jobs are used instead", "value", cfg.Settings.VcenterInventoryPollingSeconds, ) } } func durationFromSeconds(value int, fallback int) time.Duration { if value <= 0 { return time.Second * time.Duration(fallback) } return time.Second * time.Duration(value) } func resolveVcenterPassword(logger *slog.Logger, cipher *secrets.Secrets, legacyDecryptKeys [][]byte, raw string) ([]byte, string, error) { if strings.TrimSpace(raw) == "" { return nil, "", fmt.Errorf("vcenter password is empty") } // New format: explicit prefix so we can distinguish ciphertext from plaintext safely. if after, ok := strings.CutPrefix(raw, encryptedVcenterPasswordPrefix); ok { enc := after pass, usedLegacyKey, err := decryptVcenterPasswordWithFallback(logger, cipher, legacyDecryptKeys, enc) if err != nil { return nil, "", fmt.Errorf("prefixed password decrypt failed: %w", err) } if usedLegacyKey { rewrite, rewriteErr := encryptWithPrefix(cipher, pass) if rewriteErr != nil { logger.Warn("failed to refresh prefixed vcenter password after fallback decrypt", "error", rewriteErr) return pass, "", nil } logger.Info("rewrote prefixed vcenter password using active encryption key") return pass, rewrite, nil } return pass, "", nil } // Backward compatibility: existing deployments may have unprefixed ciphertext. if pass, _, err := decryptVcenterPasswordWithFallback(logger, cipher, legacyDecryptKeys, raw); err == nil { rewrite, rewriteErr := encryptWithPrefix(cipher, pass) if rewriteErr != nil { logger.Warn("failed to re-encrypt legacy vcenter password with prefix", "error", rewriteErr) return pass, "", nil } return pass, rewrite, nil } else { // If decrypt fails and the input is non-trivial, treat it as plaintext and auto-encrypt. if len(raw) <= 2 { return nil, "", fmt.Errorf("vcenter password too short to auto-encrypt") } logger.Warn("unable to decrypt unprefixed vcenter password; treating value as plaintext", "error", err) rewrite, rewriteErr := encryptWithPrefix(cipher, []byte(raw)) if rewriteErr != nil { return nil, "", fmt.Errorf("failed to encrypt plaintext vcenter password: %w", rewriteErr) } return []byte(raw), rewrite, nil } } func decryptVcenterPasswordWithFallback(logger *slog.Logger, cipher *secrets.Secrets, legacyDecryptKeys [][]byte, encrypted string) ([]byte, bool, error) { pass, err := cipher.Decrypt(encrypted) if err == nil { return pass, false, nil } primaryErr := err for _, key := range legacyDecryptKeys { candidate := secrets.New(logger, key) pass, decErr := candidate.Decrypt(encrypted) if decErr == nil { return pass, true, nil } } return nil, false, primaryErr } func encryptWithPrefix(cipher *secrets.Secrets, plain []byte) (string, error) { enc, encErr := cipher.Encrypt(plain) if encErr != nil { return "", encErr } return encryptedVcenterPasswordPrefix + enc, nil } func deriveLegacyDecryptionKeys(settingsPath string, activeKey []byte) [][]byte { legacyKeys := make([][]byte, 0, 2) addCandidate := func(candidate []byte) { if len(candidate) == 0 || bytes.Equal(candidate, activeKey) { return } for _, existing := range legacyKeys { if bytes.Equal(existing, candidate) { return } } keyCopy := make([]byte, len(candidate)) copy(keyCopy, candidate) legacyKeys = append(legacyKeys, keyCopy) } platformKey, _ := deriveHostKeyCandidate(settingsPath) addCandidate(platformKey) addCandidate([]byte(legacyFallbackEncryptionKey)) return legacyKeys } func deriveEncryptionKey(logger *slog.Logger, settingsPath string, configuredKey string) []byte { if provided := strings.TrimSpace(configuredKey); provided != "" { sum := sha256.Sum256([]byte(provided)) logger.Debug("derived encryption key from settings", "setting", "settings.encryption_key") return sum[:] } key, source := deriveHostKeyCandidate(settingsPath) switch source { case "bios-uuid": logger.Debug("derived encryption key from BIOS UUID") case "machine-id": logger.Debug("derived encryption key from machine-id") default: logger.Warn("using host-derived encryption key fallback; set settings.encryption_key for an explicit key") } return key } func deriveHostKeyCandidate(settingsPath string) ([]byte, string) { if runtime.GOOS == "linux" { if data, err := os.ReadFile("/sys/class/dmi/id/product_uuid"); err == nil { src := strings.TrimSpace(string(data)) if src != "" { sum := sha256.Sum256([]byte(src)) return sum[:], "bios-uuid" } } if data, err := os.ReadFile("/etc/machine-id"); err == nil { src := strings.TrimSpace(string(data)) if src != "" { sum := sha256.Sum256([]byte(src)) return sum[:], "machine-id" } } } hostname, err := os.Hostname() if err != nil { hostname = "unknown-host" } src := strings.Join([]string{"vctp", runtime.GOOS, strings.TrimSpace(hostname), strings.TrimSpace(settingsPath)}, "|") sum := sha256.Sum256([]byte(src)) return sum[:], "host-derived" }