From 56f021590d134a934e6ef92abf9b7d0be2f5345d Mon Sep 17 00:00:00 2001 From: Nathan Coad Date: Wed, 14 Jan 2026 17:00:40 +1100 Subject: [PATCH] work on optimising vcenter queries --- db/helpers.go | 169 ++++++++++++++++++++++ db/queries/models.go | 17 +++ db/queries/query.sql | 10 ++ db/queries/query.sql.go | 26 ++++ db/schema.sql | 19 +++ internal/report/snapshots.go | 46 ++---- internal/tasks/inventorySnapshots.go | 206 +++++++-------------------- internal/vcenter/vcenter.go | 119 ++++++++++++++++ 8 files changed, 419 insertions(+), 193 deletions(-) create mode 100644 db/helpers.go diff --git a/db/helpers.go b/db/helpers.go new file mode 100644 index 0000000..eafcc53 --- /dev/null +++ b/db/helpers.go @@ -0,0 +1,169 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "vctp/db/queries" + + "github.com/jmoiron/sqlx" +) + +// SnapshotTotals summarizes counts and allocations for snapshot tables. +type SnapshotTotals struct { + VmCount int64 `db:"vm_count"` + VcpuTotal int64 `db:"vcpu_total"` + RamTotal int64 `db:"ram_total"` + DiskTotal float64 `db:"disk_total"` +} + +// ValidateTableName ensures table identifiers are safe for interpolation. +func ValidateTableName(name string) error { + if name == "" { + return fmt.Errorf("table name is empty") + } + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' { + continue + } + return fmt.Errorf("invalid table name: %s", name) + } + return nil +} + +// SafeTableName returns the name if it passes validation. +func SafeTableName(name string) (string, error) { + if err := ValidateTableName(name); err != nil { + return "", err + } + return name, nil +} + +// TableHasRows returns true when a table contains at least one row. +func TableHasRows(ctx context.Context, dbConn *sqlx.DB, table string) (bool, error) { + if err := ValidateTableName(table); err != nil { + return false, err + } + query := fmt.Sprintf(`SELECT 1 FROM %s LIMIT 1`, table) + var exists int + if err := dbConn.GetContext(ctx, &exists, query); err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +// TableExists checks if a table exists in the current schema. +func TableExists(ctx context.Context, dbConn *sqlx.DB, table string) bool { + driver := strings.ToLower(dbConn.DriverName()) + switch driver { + case "sqlite": + q := queries.New(dbConn) + count, err := q.SqliteTableExists(ctx, sql.NullString{String: table, Valid: table != ""}) + return err == nil && count > 0 + case "pgx", "postgres": + var count int + err := dbConn.GetContext(ctx, &count, ` +SELECT COUNT(1) +FROM pg_catalog.pg_tables +WHERE schemaname = 'public' AND tablename = $1 +`, table) + return err == nil && count > 0 + default: + return false + } +} + +// ColumnExists checks if a column exists in a table. +func ColumnExists(ctx context.Context, dbConn *sqlx.DB, tableName string, columnName string) (bool, error) { + driver := strings.ToLower(dbConn.DriverName()) + switch driver { + case "sqlite": + if _, err := SafeTableName(tableName); err != nil { + return false, err + } + query := fmt.Sprintf(`PRAGMA table_info("%s")`, tableName) + rows, err := dbConn.QueryxContext(ctx, query) + if err != nil { + return false, err + } + defer rows.Close() + for rows.Next() { + var ( + cid int + name string + colType string + notNull int + defaultVal sql.NullString + pk int + ) + if err := rows.Scan(&cid, &name, &colType, ¬Null, &defaultVal, &pk); err != nil { + return false, err + } + if strings.EqualFold(name, columnName) { + return true, nil + } + } + return false, rows.Err() + case "pgx", "postgres": + var count int + err := dbConn.GetContext(ctx, &count, ` +SELECT COUNT(1) +FROM information_schema.columns +WHERE table_name = $1 AND column_name = $2 +`, tableName, strings.ToLower(columnName)) + if err != nil { + return false, err + } + return count > 0, nil + default: + return false, fmt.Errorf("unsupported driver for column lookup: %s", driver) + } +} + +// SnapshotTotalsForTable returns totals for a snapshot table. +func SnapshotTotalsForTable(ctx context.Context, dbConn *sqlx.DB, table string) (SnapshotTotals, error) { + if _, err := SafeTableName(table); err != nil { + return SnapshotTotals{}, err + } + query := fmt.Sprintf(` +SELECT + COUNT(DISTINCT "VmId") AS vm_count, + COALESCE(SUM(CASE WHEN "VcpuCount" IS NOT NULL THEN "VcpuCount" ELSE 0 END), 0) AS vcpu_total, + COALESCE(SUM(CASE WHEN "RamGB" IS NOT NULL THEN "RamGB" ELSE 0 END), 0) AS ram_total, + COALESCE(SUM(CASE WHEN "ProvisionedDisk" IS NOT NULL THEN "ProvisionedDisk" ELSE 0 END), 0) AS disk_total +FROM %s +WHERE "IsPresent" = 'TRUE' +`, table) + + var totals SnapshotTotals + if err := dbConn.GetContext(ctx, &totals, query); err != nil { + return SnapshotTotals{}, err + } + return totals, nil +} + +// SnapshotTotalsForUnion returns totals for a union query of snapshots. +func SnapshotTotalsForUnion(ctx context.Context, dbConn *sqlx.DB, unionQuery string) (SnapshotTotals, error) { + query := fmt.Sprintf(` +SELECT + COUNT(DISTINCT "VmId") AS vm_count, + COALESCE(SUM(CASE WHEN "VcpuCount" IS NOT NULL THEN "VcpuCount" ELSE 0 END), 0) AS vcpu_total, + COALESCE(SUM(CASE WHEN "RamGB" IS NOT NULL THEN "RamGB" ELSE 0 END), 0) AS ram_total, + COALESCE(SUM(CASE WHEN "ProvisionedDisk" IS NOT NULL THEN "ProvisionedDisk" ELSE 0 END), 0) AS disk_total +FROM ( +%s +) snapshots +WHERE "IsPresent" = 'TRUE' +`, unionQuery) + + var totals SnapshotTotals + if err := dbConn.GetContext(ctx, &totals, query); err != nil { + return SnapshotTotals{}, err + } + return totals, nil +} diff --git a/db/queries/models.go b/db/queries/models.go index 9062ca9..d6904d2 100644 --- a/db/queries/models.go +++ b/db/queries/models.go @@ -59,6 +59,15 @@ type InventoryHistory struct { PreviousProvisionedDisk sql.NullFloat64 `db:"PreviousProvisionedDisk" json:"PreviousProvisionedDisk"` } +type PragmaTableInfo struct { + Cid sql.NullInt64 `db:"cid" json:"cid"` + Name sql.NullString `db:"name" json:"name"` + Type sql.NullString `db:"type" json:"type"` + Notnull sql.NullInt64 `db:"notnull" json:"notnull"` + DfltValue sql.NullString `db:"dflt_value" json:"dflt_value"` + Pk sql.NullInt64 `db:"pk" json:"pk"` +} + type SnapshotRegistry struct { ID int64 `db:"id" json:"id"` SnapshotType string `db:"snapshot_type" json:"snapshot_type"` @@ -66,6 +75,14 @@ type SnapshotRegistry struct { SnapshotTime int64 `db:"snapshot_time" json:"snapshot_time"` } +type SqliteMaster struct { + Type sql.NullString `db:"type" json:"type"` + Name sql.NullString `db:"name" json:"name"` + TblName sql.NullString `db:"tbl_name" json:"tbl_name"` + Rootpage sql.NullInt64 `db:"rootpage" json:"rootpage"` + Sql sql.NullString `db:"sql" json:"sql"` +} + type Update struct { Uid int64 `db:"Uid" json:"Uid"` InventoryId sql.NullInt64 `db:"InventoryId" json:"InventoryId"` diff --git a/db/queries/query.sql b/db/queries/query.sql index 63542cd..b8b71ad 100644 --- a/db/queries/query.sql +++ b/db/queries/query.sql @@ -119,3 +119,13 @@ INSERT INTO inventory_history ( ?, ?, ?, ?, ?, ?, ? ) RETURNING *; + +-- name: SqliteTableExists :one +SELECT COUNT(1) AS count +FROM sqlite_master +WHERE type = 'table' AND name = sqlc.arg('table_name'); + +-- name: SqliteColumnExists :one +SELECT COUNT(1) AS count +FROM pragma_table_info +WHERE name = sqlc.arg('column_name'); diff --git a/db/queries/query.sql.go b/db/queries/query.sql.go index 08f0497..1a120f6 100644 --- a/db/queries/query.sql.go +++ b/db/queries/query.sql.go @@ -876,6 +876,32 @@ func (q *Queries) ListUnprocessedEvents(ctx context.Context, eventtime sql.NullI return items, nil } +const sqliteColumnExists = `-- name: SqliteColumnExists :one +SELECT COUNT(1) AS count +FROM pragma_table_info +WHERE name = ?1 +` + +func (q *Queries) SqliteColumnExists(ctx context.Context, columnName sql.NullString) (int64, error) { + row := q.db.QueryRowContext(ctx, sqliteColumnExists, columnName) + var count int64 + err := row.Scan(&count) + return count, err +} + +const sqliteTableExists = `-- name: SqliteTableExists :one +SELECT COUNT(1) AS count +FROM sqlite_master +WHERE type = 'table' AND name = ?1 +` + +func (q *Queries) SqliteTableExists(ctx context.Context, tableName sql.NullString) (int64, error) { + row := q.db.QueryRowContext(ctx, sqliteTableExists, tableName) + var count int64 + err := row.Scan(&count) + return count, err +} + const updateEventsProcessed = `-- name: UpdateEventsProcessed :exec UPDATE events SET "Processed" = 1 diff --git a/db/schema.sql b/db/schema.sql index 25eeebb..538581d 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -72,3 +72,22 @@ CREATE TABLE IF NOT EXISTS snapshot_registry ( "table_name" TEXT NOT NULL UNIQUE, "snapshot_time" INTEGER NOT NULL ); + +-- The following tables are declared for sqlc type-checking only. +-- Do not apply this file as a migration. +CREATE TABLE sqlite_master ( + "type" TEXT, + "name" TEXT, + "tbl_name" TEXT, + "rootpage" INTEGER, + "sql" TEXT +); + +CREATE TABLE pragma_table_info ( + "cid" INTEGER, + "name" TEXT, + "type" TEXT, + "notnull" INTEGER, + "dflt_value" TEXT, + "pk" INTEGER +); diff --git a/internal/report/snapshots.go b/internal/report/snapshots.go index 529b479..fa08cc1 100644 --- a/internal/report/snapshots.go +++ b/internal/report/snapshots.go @@ -407,7 +407,7 @@ func FormatSnapshotLabel(snapshotType string, snapshotTime time.Time, tableName } func CreateTableReport(logger *slog.Logger, Database db.Database, ctx context.Context, tableName string) ([]byte, error) { - if err := validateTableName(tableName); err != nil { + if err := db.ValidateTableName(tableName); err != nil { return nil, err } @@ -651,34 +651,6 @@ func addTotalsChartSheet(logger *slog.Logger, database db.Database, ctx context. } } -func validateTableName(name string) error { - if name == "" { - return fmt.Errorf("table name is empty") - } - for _, r := range name { - if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' { - continue - } - return fmt.Errorf("invalid table name: %s", name) - } - return nil -} - -func tableHasRows(ctx context.Context, dbConn *sqlx.DB, table string) (bool, error) { - if err := validateTableName(table); err != nil { - return false, err - } - query := fmt.Sprintf(`SELECT 1 FROM %s LIMIT 1`, table) - var exists int - if err := dbConn.GetContext(ctx, &exists, query); err != nil { - if err == sql.ErrNoRows { - return false, nil - } - return false, err - } - return true, nil -} - func tableColumns(ctx context.Context, dbConn *sqlx.DB, tableName string) ([]string, error) { driver := strings.ToLower(dbConn.DriverName()) switch driver { @@ -777,7 +749,7 @@ type columnDef struct { } func ensureSummaryReportColumns(ctx context.Context, dbConn *sqlx.DB, tableName string) error { - if err := validateTableName(tableName); err != nil { + if err := db.ValidateTableName(tableName); err != nil { return err } columns, err := tableColumns(ctx, dbConn, tableName) @@ -920,10 +892,10 @@ type totalsPoint struct { func buildHourlyTotals(ctx context.Context, dbConn *sqlx.DB, records []SnapshotRecord) ([]totalsPoint, error) { points := make([]totalsPoint, 0, len(records)) for _, record := range records { - if err := validateTableName(record.TableName); err != nil { + if err := db.ValidateTableName(record.TableName); err != nil { return nil, err } - if rowsExist, err := tableHasRows(ctx, dbConn, record.TableName); err != nil || !rowsExist { + if rowsExist, err := db.TableHasRows(ctx, dbConn, record.TableName); err != nil || !rowsExist { continue } query := fmt.Sprintf(` @@ -970,10 +942,10 @@ WHERE %s func buildDailyTotals(ctx context.Context, dbConn *sqlx.DB, records []SnapshotRecord) ([]totalsPoint, error) { points := make([]totalsPoint, 0, len(records)) for _, record := range records { - if err := validateTableName(record.TableName); err != nil { + if err := db.ValidateTableName(record.TableName); err != nil { return nil, err } - if rowsExist, err := tableHasRows(ctx, dbConn, record.TableName); err != nil || !rowsExist { + if rowsExist, err := db.TableHasRows(ctx, dbConn, record.TableName); err != nil || !rowsExist { continue } query := fmt.Sprintf(` @@ -1124,10 +1096,10 @@ func formatEpochHuman(value interface{}) string { } func renameTable(ctx context.Context, dbConn *sqlx.DB, oldName string, newName string) error { - if err := validateTableName(oldName); err != nil { + if err := db.ValidateTableName(oldName); err != nil { return err } - if err := validateTableName(newName); err != nil { + if err := db.ValidateTableName(newName); err != nil { return err } _, err := dbConn.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME TO %s`, oldName, newName)) @@ -1138,7 +1110,7 @@ func renameTable(ctx context.Context, dbConn *sqlx.DB, oldName string, newName s } func latestSnapshotTime(ctx context.Context, dbConn *sqlx.DB, tableName string) (time.Time, error) { - if err := validateTableName(tableName); err != nil { + if err := db.ValidateTableName(tableName); err != nil { return time.Time{}, err } query := fmt.Sprintf(`SELECT MAX("SnapshotTime") FROM %s`, tableName) diff --git a/internal/tasks/inventorySnapshots.go b/internal/tasks/inventorySnapshots.go index 70a00e3..b038daf 100644 --- a/internal/tasks/inventorySnapshots.go +++ b/internal/tasks/inventorySnapshots.go @@ -10,6 +10,7 @@ import ( "sync" "sync/atomic" "time" + "vctp/db" "vctp/db/queries" "vctp/internal/report" "vctp/internal/vcenter" @@ -43,6 +44,8 @@ type inventorySnapshotRow struct { IsPresent string } +type snapshotTotals = db.SnapshotTotals + // RunVcenterSnapshotHourly records hourly inventory snapshots into a daily table. func (c *CronTask) RunVcenterSnapshotHourly(ctx context.Context, logger *slog.Logger) error { startedAt := time.Now() @@ -148,7 +151,7 @@ func (c *CronTask) aggregateDailySummary(ctx context.Context, targetTime time.Ti if err := report.EnsureSnapshotRegistry(ctx, c.Database); err != nil { return err } - if rowsExist, err := tableHasRows(ctx, dbConn, summaryTable); err != nil { + if rowsExist, err := db.TableHasRows(ctx, dbConn, summaryTable); err != nil { return err } else if rowsExist && !force { c.Logger.Debug("Daily summary already exists, skipping aggregation", "summary_table", summaryTable) @@ -158,7 +161,7 @@ func (c *CronTask) aggregateDailySummary(ctx context.Context, targetTime time.Ti return err } } - if rowsExist, err := tableHasRows(ctx, dbConn, summaryTable); err != nil { + if rowsExist, err := db.TableHasRows(ctx, dbConn, summaryTable); err != nil { return err } else if rowsExist { c.Logger.Debug("Daily summary already exists, skipping aggregation", "summary_table", summaryTable) @@ -185,7 +188,7 @@ func (c *CronTask) aggregateDailySummary(ctx context.Context, targetTime time.Ti `"SrmPlaceholder"`, `"VmUuid"`, `"SnapshotTime"`, `"IsPresent"`, }, templateExclusionFilter()) - currentTotals, err := snapshotTotalsForUnion(ctx, dbConn, unionQuery) + currentTotals, err := db.SnapshotTotalsForUnion(ctx, dbConn, unionQuery) if err != nil { c.Logger.Warn("unable to calculate daily totals", "error", err, "date", dayStart.Format("2006-01-02")) } else { @@ -213,7 +216,7 @@ func (c *CronTask) aggregateDailySummary(ctx context.Context, targetTime time.Ti `"ProvisionedDisk"`, `"VcpuCount"`, `"RamGB"`, `"IsTemplate"`, `"PoweredOn"`, `"SrmPlaceholder"`, `"VmUuid"`, `"SnapshotTime"`, `"IsPresent"`, }, templateExclusionFilter()) - prevTotals, err := snapshotTotalsForUnion(ctx, dbConn, prevUnion) + prevTotals, err := db.SnapshotTotalsForUnion(ctx, dbConn, prevUnion) if err != nil { c.Logger.Warn("unable to calculate previous day totals", "error", err, "date", prevStart.Format("2006-01-02")) } else { @@ -337,7 +340,7 @@ func (c *CronTask) aggregateMonthlySummary(ctx context.Context, targetMonth time if err := ensureMonthlySummaryTable(ctx, dbConn, monthlyTable); err != nil { return err } - if rowsExist, err := tableHasRows(ctx, dbConn, monthlyTable); err != nil { + if rowsExist, err := db.TableHasRows(ctx, dbConn, monthlyTable); err != nil { return err } else if rowsExist && !force { c.Logger.Debug("Monthly summary already exists, skipping aggregation", "summary_table", monthlyTable) @@ -347,7 +350,7 @@ func (c *CronTask) aggregateMonthlySummary(ctx context.Context, targetMonth time return err } } - if rowsExist, err := tableHasRows(ctx, dbConn, monthlyTable); err != nil { + if rowsExist, err := db.TableHasRows(ctx, dbConn, monthlyTable); err != nil { return err } else if rowsExist { c.Logger.Debug("Monthly summary already exists, skipping aggregation", "summary_table", monthlyTable) @@ -368,7 +371,7 @@ func (c *CronTask) aggregateMonthlySummary(ctx context.Context, targetMonth time return fmt.Errorf("no valid daily snapshot tables found for %s", targetMonth.Format("2006-01")) } - monthlyTotals, err := snapshotTotalsForUnion(ctx, dbConn, unionQuery) + monthlyTotals, err := db.SnapshotTotalsForUnion(ctx, dbConn, unionQuery) if err != nil { c.Logger.Warn("unable to calculate monthly totals", "error", err, "month", targetMonth.Format("2006-01")) } else { @@ -518,25 +521,15 @@ func (c *CronTask) RunSnapshotCleanup(ctx context.Context, logger *slog.Logger) } func hourlyInventoryTableName(t time.Time) (string, error) { - return safeTableName(fmt.Sprintf("inventory_hourly_%d", t.Unix())) + return db.SafeTableName(fmt.Sprintf("inventory_hourly_%d", t.Unix())) } func dailySummaryTableName(t time.Time) (string, error) { - return safeTableName(fmt.Sprintf("inventory_daily_summary_%s", t.Format("20060102"))) + return db.SafeTableName(fmt.Sprintf("inventory_daily_summary_%s", t.Format("20060102"))) } func monthlySummaryTableName(t time.Time) (string, error) { - return safeTableName(fmt.Sprintf("inventory_monthly_summary_%s", t.Format("200601"))) -} - -func safeTableName(name string) (string, error) { - for _, r := range name { - if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' { - continue - } - return "", fmt.Errorf("invalid table name: %s", name) - } - return name, nil + return db.SafeTableName(fmt.Sprintf("inventory_monthly_summary_%s", t.Format("200601"))) } func ensureDailyInventoryTable(ctx context.Context, dbConn *sqlx.DB, tableName string) error { @@ -820,7 +813,7 @@ func buildUnionQuery(tables []string, columns []string, whereClause string) stri queries := make([]string, 0, len(tables)) columnList := strings.Join(columns, ", ") for _, table := range tables { - if _, err := safeTableName(table); err != nil { + if _, err := db.SafeTableName(table); err != nil { continue } query := fmt.Sprintf("SELECT %s FROM %s", columnList, table) @@ -860,7 +853,7 @@ func truncateDate(t time.Time) time.Time { } func dropSnapshotTable(ctx context.Context, dbConn *sqlx.DB, table string) error { - if _, err := safeTableName(table); err != nil { + if _, err := db.SafeTableName(table); err != nil { return err } _, err := dbConn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", table)) @@ -868,7 +861,7 @@ func dropSnapshotTable(ctx context.Context, dbConn *sqlx.DB, table string) error } func clearTable(ctx context.Context, dbConn *sqlx.DB, table string) error { - if _, err := safeTableName(table); err != nil { + if _, err := db.SafeTableName(table); err != nil { return err } _, err := dbConn.ExecContext(ctx, fmt.Sprintf("DELETE FROM %s", table)) @@ -878,45 +871,23 @@ func clearTable(ctx context.Context, dbConn *sqlx.DB, table string) error { return nil } -func tableHasRows(ctx context.Context, dbConn *sqlx.DB, table string) (bool, error) { - if _, err := safeTableName(table); err != nil { - return false, err - } - query := fmt.Sprintf(`SELECT 1 FROM %s LIMIT 1`, table) - var exists int - if err := dbConn.GetContext(ctx, &exists, query); err != nil { - if err == sql.ErrNoRows { - return false, nil - } - return false, err - } - return true, nil -} - func filterSnapshotsWithRows(ctx context.Context, dbConn *sqlx.DB, snapshots []report.SnapshotRecord) []report.SnapshotRecord { filtered := snapshots[:0] for _, snapshot := range snapshots { - if rowsExist, err := tableHasRows(ctx, dbConn, snapshot.TableName); err == nil && rowsExist { + if rowsExist, err := db.TableHasRows(ctx, dbConn, snapshot.TableName); err == nil && rowsExist { filtered = append(filtered, snapshot) } } return filtered } -type snapshotTotals struct { - VmCount int64 `db:"vm_count"` - VcpuTotal int64 `db:"vcpu_total"` - RamTotal int64 `db:"ram_total"` - DiskTotal float64 `db:"disk_total"` -} - type columnDef struct { Name string Type string } func ensureSnapshotColumns(ctx context.Context, dbConn *sqlx.DB, tableName string, columns []columnDef) error { - if _, err := safeTableName(tableName); err != nil { + if _, err := db.SafeTableName(tableName); err != nil { return err } for _, column := range columns { @@ -968,7 +939,7 @@ func ensureSnapshotRowID(ctx context.Context, dbConn *sqlx.DB, tableName string) driver := strings.ToLower(dbConn.DriverName()) switch driver { case "pgx", "postgres": - hasColumn, err := columnExists(ctx, dbConn, tableName, "RowId") + hasColumn, err := db.ColumnExists(ctx, dbConn, tableName, "RowId") if err != nil { return err } @@ -994,112 +965,8 @@ func ensureSnapshotRowID(ctx context.Context, dbConn *sqlx.DB, tableName string) return nil } -func columnExists(ctx context.Context, dbConn *sqlx.DB, tableName string, columnName string) (bool, error) { - driver := strings.ToLower(dbConn.DriverName()) - switch driver { - case "sqlite": - query := fmt.Sprintf(`PRAGMA table_info("%s")`, tableName) - rows, err := dbConn.QueryxContext(ctx, query) - if err != nil { - return false, err - } - defer rows.Close() - for rows.Next() { - var ( - cid int - name string - colType string - notNull int - defaultVal sql.NullString - pk int - ) - if err := rows.Scan(&cid, &name, &colType, ¬Null, &defaultVal, &pk); err != nil { - return false, err - } - if strings.EqualFold(name, columnName) { - return true, nil - } - } - return false, rows.Err() - case "pgx", "postgres": - var count int - err := dbConn.GetContext(ctx, &count, ` -SELECT COUNT(1) -FROM information_schema.columns -WHERE table_name = $1 AND column_name = $2 -`, tableName, strings.ToLower(columnName)) - if err != nil { - return false, err - } - return count > 0, nil - default: - return false, fmt.Errorf("unsupported driver for column lookup: %s", driver) - } -} - -func snapshotTotalsForTable(ctx context.Context, dbConn *sqlx.DB, table string) (snapshotTotals, error) { - if _, err := safeTableName(table); err != nil { - return snapshotTotals{}, err - } - query := fmt.Sprintf(` -SELECT - COUNT(DISTINCT "VmId") AS vm_count, - COALESCE(SUM(CASE WHEN "VcpuCount" IS NOT NULL THEN "VcpuCount" ELSE 0 END), 0) AS vcpu_total, - COALESCE(SUM(CASE WHEN "RamGB" IS NOT NULL THEN "RamGB" ELSE 0 END), 0) AS ram_total, - COALESCE(SUM(CASE WHEN "ProvisionedDisk" IS NOT NULL THEN "ProvisionedDisk" ELSE 0 END), 0) AS disk_total -FROM %s -WHERE "IsPresent" = 'TRUE' -`, table) - - var totals snapshotTotals - if err := dbConn.GetContext(ctx, &totals, query); err != nil { - return snapshotTotals{}, err - } - return totals, nil -} - -func snapshotTotalsForUnion(ctx context.Context, dbConn *sqlx.DB, unionQuery string) (snapshotTotals, error) { - query := fmt.Sprintf(` -SELECT - COUNT(DISTINCT "VmId") AS vm_count, - COALESCE(SUM(CASE WHEN "VcpuCount" IS NOT NULL THEN "VcpuCount" ELSE 0 END), 0) AS vcpu_total, - COALESCE(SUM(CASE WHEN "RamGB" IS NOT NULL THEN "RamGB" ELSE 0 END), 0) AS ram_total, - COALESCE(SUM(CASE WHEN "ProvisionedDisk" IS NOT NULL THEN "ProvisionedDisk" ELSE 0 END), 0) AS disk_total -FROM ( -%s -) snapshots -WHERE "IsPresent" = 'TRUE' -`, unionQuery) - - var totals snapshotTotals - if err := dbConn.GetContext(ctx, &totals, query); err != nil { - return snapshotTotals{}, err - } - return totals, nil -} - func tableExists(ctx context.Context, dbConn *sqlx.DB, table string) bool { - driver := strings.ToLower(dbConn.DriverName()) - switch driver { - case "sqlite": - var count int - err := dbConn.GetContext(ctx, &count, ` -SELECT COUNT(1) -FROM sqlite_master -WHERE type = 'table' AND name = ? -`, table) - return err == nil && count > 0 - case "pgx", "postgres": - var count int - err := dbConn.GetContext(ctx, &count, ` -SELECT COUNT(1) -FROM pg_catalog.pg_tables -WHERE schemaname = 'public' AND tablename = $1 -`, table) - return err == nil && count > 0 - default: - return false - } + return db.TableExists(ctx, dbConn, table) } func nullInt64ToInt(value sql.NullInt64) int64 { @@ -1142,7 +1009,7 @@ func normalizeResourcePool(value string) string { } } -func snapshotFromVM(vmObject *mo.VirtualMachine, vc *vcenter.Vcenter, snapshotTime time.Time, inv *queries.Inventory) (inventorySnapshotRow, error) { +func snapshotFromVM(vmObject *mo.VirtualMachine, vc *vcenter.Vcenter, snapshotTime time.Time, inv *queries.Inventory, hostLookup map[string]vcenter.HostLookup, folderLookup vcenter.FolderLookup) (inventorySnapshotRow, error) { if vmObject == nil { return inventorySnapshotRow{}, fmt.Errorf("missing VM object") } @@ -1238,8 +1105,21 @@ func snapshotFromVM(vmObject *mo.VirtualMachine, vc *vcenter.Vcenter, snapshotTi } if row.Folder.String == "" { - if folderPath, err := vc.GetVMFolderPath(*vmObject); err == nil { + if folderPath, ok := vc.GetVMFolderPathFromLookup(*vmObject, folderLookup); ok { row.Folder = sql.NullString{String: folderPath, Valid: folderPath != ""} + } else if folderPath, err := vc.GetVMFolderPath(*vmObject); err == nil { + row.Folder = sql.NullString{String: folderPath, Valid: folderPath != ""} + } + } + + if vmObject.Runtime.Host != nil && hostLookup != nil { + if lookup, ok := hostLookup[vmObject.Runtime.Host.Value]; ok { + if row.Cluster.String == "" && lookup.Cluster != "" { + row.Cluster = sql.NullString{String: lookup.Cluster, Valid: true} + } + if row.Datacenter.String == "" && lookup.Datacenter != "" { + row.Datacenter = sql.NullString{String: lookup.Datacenter, Valid: true} + } } } @@ -1337,6 +1217,20 @@ func (c *CronTask) captureHourlySnapshotForVcenter(ctx context.Context, startTim if !canDetectMissing { c.Logger.Warn("no VMs returned from vcenter; skipping missing VM detection", "url", url) } + hostLookup, err := vc.BuildHostLookup() + if err != nil { + c.Logger.Warn("failed to build host lookup", "url", url, "error", err) + hostLookup = nil + } else { + c.Logger.Debug("built host lookup", "url", url, "hosts", len(hostLookup)) + } + folderLookup, err := vc.BuildFolderPathLookup() + if err != nil { + c.Logger.Warn("failed to build folder lookup", "url", url, "error", err) + folderLookup = nil + } else { + c.Logger.Debug("built folder lookup", "url", url, "folders", len(folderLookup)) + } inventoryRows, err := c.Database.Queries().GetInventoryByVcenter(ctx, url) if err != nil { @@ -1373,7 +1267,7 @@ func (c *CronTask) captureHourlySnapshotForVcenter(ctx context.Context, startTim inv = &existingCopy } - row, err := snapshotFromVM(vmObj, vc, startTime, inv) + row, err := snapshotFromVM(vmObj, vc, startTime, inv, hostLookup, folderLookup) if err != nil { c.Logger.Error("unable to build snapshot for VM", "vm_id", vm.Reference().Value, "error", err) continue diff --git a/internal/vcenter/vcenter.go b/internal/vcenter/vcenter.go index 8a641fb..fe96468 100644 --- a/internal/vcenter/vcenter.go +++ b/internal/vcenter/vcenter.go @@ -37,6 +37,13 @@ type VmProperties struct { ResourcePool string } +type HostLookup struct { + Cluster string + Datacenter string +} + +type FolderLookup map[string]string + // New creates a new Vcenter with the given logger func New(logger *slog.Logger, creds *VcenterLogin) *Vcenter { @@ -143,6 +150,118 @@ func (v *Vcenter) GetAllVmReferences() ([]*object.VirtualMachine, error) { return results, err } +func (v *Vcenter) BuildHostLookup() (map[string]HostLookup, error) { + finder := find.NewFinder(v.client.Client, true) + datacenters, err := finder.DatacenterList(v.ctx, "*") + if err != nil { + return nil, fmt.Errorf("failed to list datacenters: %w", err) + } + + lookup := make(map[string]HostLookup) + clusterCache := make(map[string]string) + + for _, dc := range datacenters { + finder.SetDatacenter(dc) + hosts, err := finder.HostSystemList(v.ctx, "*") + if err != nil { + v.Logger.Warn("failed to list hosts for datacenter", "datacenter", dc.Name(), "error", err) + continue + } + + for _, host := range hosts { + ref := host.Reference() + var moHost mo.HostSystem + if err := v.client.RetrieveOne(v.ctx, ref, []string{"parent"}, &moHost); err != nil { + v.Logger.Warn("failed to retrieve host info", "host", host.Name(), "error", err) + continue + } + + clusterName := "" + if moHost.Parent != nil { + if cached, ok := clusterCache[moHost.Parent.Value]; ok { + clusterName = cached + } else { + var moCompute mo.ComputeResource + if err := v.client.RetrieveOne(v.ctx, *moHost.Parent, []string{"name"}, &moCompute); err == nil { + clusterName = moCompute.Name + clusterCache[moHost.Parent.Value] = clusterName + } + } + } + + lookup[ref.Value] = HostLookup{ + Cluster: clusterName, + Datacenter: dc.Name(), + } + } + } + + return lookup, nil +} + +func (v *Vcenter) BuildFolderPathLookup() (FolderLookup, error) { + m := view.NewManager(v.client.Client) + folders, err := m.CreateContainerView(v.ctx, v.client.ServiceContent.RootFolder, []string{"Folder"}, true) + if err != nil { + return nil, err + } + defer folders.Destroy(v.ctx) + + var results []mo.Folder + if err := folders.Retrieve(v.ctx, []string{"Folder"}, []string{"name", "parent"}, &results); err != nil { + return nil, err + } + + nameByID := make(map[string]string, len(results)) + parentByID := make(map[string]*types.ManagedObjectReference, len(results)) + for _, folder := range results { + nameByID[folder.Reference().Value] = folder.Name + parentByID[folder.Reference().Value] = folder.Parent + } + + paths := make(FolderLookup, len(results)) + var buildPath func(id string) string + buildPath = func(id string) string { + if pathValue, ok := paths[id]; ok { + return pathValue + } + name, ok := nameByID[id] + if !ok { + return "" + } + parent := parentByID[id] + if parent == nil || parent.Type == "Datacenter" { + paths[id] = path.Join("/", name) + return paths[id] + } + if parent.Type != "Folder" { + paths[id] = path.Join("/", name) + return paths[id] + } + parentPath := buildPath(parent.Value) + if parentPath == "" { + paths[id] = path.Join("/", name) + return paths[id] + } + paths[id] = path.Join(parentPath, name) + return paths[id] + } + + for id := range nameByID { + _ = buildPath(id) + } + + return paths, nil +} + +func (v *Vcenter) GetVMFolderPathFromLookup(vm mo.VirtualMachine, lookup FolderLookup) (string, bool) { + if vm.Parent == nil || lookup == nil { + return "", false + } + pathValue, ok := lookup[vm.Parent.Value] + return pathValue, ok +} + func (v *Vcenter) ConvertObjToMoVM(vmObj *object.VirtualMachine) (*mo.VirtualMachine, error) { // Use the InventoryPath to extract the datacenter name and VM path inventoryPath := vmObj.InventoryPath