diff --git a/README.md b/README.md index 8146763..730e314 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,18 @@ The backfill command: - Rebuilds hourly/latest vCenter totals caches. - Recomputes daily/monthly rows for `vcenter_aggregate_totals` from registered summary snapshots. +If you want a one-time SQLite-to-Postgres import and exit, use: + +```shell +vctp -settings /path/to/vctp.yml -import-sqlite /path/to/legacy.sqlite3 +``` + +The import command: +- Requires `settings.database_driver: postgres`. +- Copies data from the SQLite source into matching Postgres tables. +- Auto-creates runtime tables (hourly/daily/monthly snapshot tables and cache tables) when needed. +- Replaces existing data in imported Postgres tables during the run. + ## Database Configuration By default the app uses SQLite and creates/opens `db.sqlite3`. @@ -123,6 +135,56 @@ settings: database_url: postgres://user:pass@localhost:5432/vctp?sslmode=disable ``` +### Initial PostgreSQL Setup +Create a dedicated PostgreSQL role and database (run as a PostgreSQL superuser): + +```sql +CREATE ROLE vctp_user LOGIN PASSWORD 'change-this-password'; +CREATE DATABASE vctp OWNER vctp_user; +``` + +Connect to the new database and grant privileges required for migrations and runtime table/index management: + +```sql +\c vctp +GRANT CONNECT, TEMP ON DATABASE vctp TO vctp_user; +GRANT USAGE, CREATE ON SCHEMA public TO vctp_user; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO vctp_user; +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO vctp_user; +ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO vctp_user; +ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO vctp_user; +``` + +Recommended auth/network configuration: + +- Ensure PostgreSQL is listening on the expected interface/port in `postgresql.conf` (for example, `listen_addresses` and `port`). +- Allow vCTP connections in `pg_hba.conf`. Example entries: + +```conf +# local socket +local vctp vctp_user scram-sha-256 +# TCP from application subnet +host vctp vctp_user 10.0.0.0/24 scram-sha-256 +``` + +- Reload/restart PostgreSQL after config changes (`SELECT pg_reload_conf();` or your service manager). +- Ensure host firewall/network ACLs allow traffic to PostgreSQL (default `5432`). + +Example `vctp.yml` database settings: + +```yaml +settings: + database_driver: postgres + enable_experimental_postgres: true + database_url: postgres://vctp_user:change-this-password@db-hostname:5432/vctp?sslmode=disable +``` + +Validate connectivity before starting vCTP: + +```shell +psql "postgres://vctp_user:change-this-password@db-hostname:5432/vctp?sslmode=disable" +``` + PostgreSQL migrations live in `db/migrations_postgres`, while SQLite migrations remain in `db/migrations`. diff --git a/db/db.go b/db/db.go index ea2bc2b..25ae36c 100644 --- a/db/db.go +++ b/db/db.go @@ -119,6 +119,54 @@ func normalizeDriver(driver string) string { } } +// ResolveDriver determines the effective database driver. +// If driver is unset and DSN looks like PostgreSQL, it infers postgres. +func ResolveDriver(configuredDriver, dsn string) (driver string, inferredFromDSN bool, err error) { + normalized := strings.ToLower(strings.TrimSpace(configuredDriver)) + switch normalized { + case "sqlite3": + normalized = "sqlite" + case "postgresql": + normalized = "postgres" + } + + if normalized == "" { + if looksLikePostgresDSN(dsn) { + return "postgres", true, nil + } + return "sqlite", false, nil + } + + if normalized == "sqlite" && looksLikePostgresDSN(dsn) { + return "", false, fmt.Errorf("database_driver is sqlite but database_url looks like a postgres DSN; set settings.database_driver=postgres") + } + + return normalized, false, nil +} + +func looksLikePostgresDSN(dsn string) bool { + trimmed := strings.ToLower(strings.TrimSpace(dsn)) + if trimmed == "" { + return false + } + if strings.HasPrefix(trimmed, "postgres://") || strings.HasPrefix(trimmed, "postgresql://") { + return true + } + + // Also support key=value style PostgreSQL DSNs. + if strings.Contains(trimmed, "=") { + hasHost := strings.Contains(trimmed, "host=") + hasUser := strings.Contains(trimmed, "user=") + hasDB := strings.Contains(trimmed, "dbname=") + hasSSL := strings.Contains(trimmed, "sslmode=") + if (hasHost && hasUser) || (hasHost && hasDB) || (hasUser && hasDB) || (hasHost && hasSSL) { + return true + } + } + + return false +} + // ConvertToSQLParams is a utility function that generically converts a struct to a corresponding sqlc-generated struct func ConvertToSQLParams(input interface{}, output interface{}) { inputVal := reflect.ValueOf(input).Elem() diff --git a/db/driver_resolution_test.go b/db/driver_resolution_test.go new file mode 100644 index 0000000..4d2a93f --- /dev/null +++ b/db/driver_resolution_test.go @@ -0,0 +1,78 @@ +package db + +import "testing" + +func TestResolveDriver(t *testing.T) { + tests := []struct { + name string + configuredDriver string + dsn string + wantDriver string + wantInferred bool + wantErr bool + }{ + { + name: "explicit postgres uri", + configuredDriver: "postgres", + dsn: "postgres://user:pass@localhost:5432/vctp?sslmode=disable", + wantDriver: "postgres", + }, + { + name: "postgresql alias", + configuredDriver: "postgresql", + dsn: "postgres://user:pass@localhost:5432/vctp?sslmode=disable", + wantDriver: "postgres", + }, + { + name: "infer postgres uri", + dsn: "postgres://user:pass@localhost:5432/vctp?sslmode=disable", + wantDriver: "postgres", + wantInferred: true, + }, + { + name: "infer postgres key value dsn", + dsn: "host=localhost port=5432 user=postgres password=secret dbname=vctp sslmode=disable", + wantDriver: "postgres", + wantInferred: true, + }, + { + name: "default sqlite", + dsn: "/var/lib/vctp/db.sqlite3", + wantDriver: "sqlite", + }, + { + name: "sqlite alias", + configuredDriver: "sqlite3", + dsn: "/var/lib/vctp/db.sqlite3", + wantDriver: "sqlite", + }, + { + name: "reject sqlite postgres mismatch", + configuredDriver: "sqlite", + dsn: "postgres://user:pass@localhost:5432/vctp?sslmode=disable", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + driver, inferred, err := ResolveDriver(tc.configuredDriver, tc.dsn) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if driver != tc.wantDriver { + t.Fatalf("driver mismatch: got %q want %q", driver, tc.wantDriver) + } + if inferred != tc.wantInferred { + t.Fatalf("inferred mismatch: got %t want %t", inferred, tc.wantInferred) + } + }) + } +} diff --git a/db/local.go b/db/local.go index 944adc2..6bfe40e 100644 --- a/db/local.go +++ b/db/local.go @@ -2,6 +2,7 @@ package db import ( "database/sql" + "fmt" "log/slog" "strings" "vctp/db/queries" @@ -42,6 +43,9 @@ func (d *LocalDB) Close() error { } func newLocalDB(logger *slog.Logger, dsn string) (*LocalDB, error) { + if looksLikePostgresDSN(dsn) { + return nil, fmt.Errorf("database_driver is sqlite but database_url looks like a postgres DSN; set settings.database_driver=postgres") + } // TODO - work out if https://kerkour.com/sqlite-for-servers is possible without using sqlx /* diff --git a/db/sqlite_import.go b/db/sqlite_import.go new file mode 100644 index 0000000..d716845 --- /dev/null +++ b/db/sqlite_import.go @@ -0,0 +1,441 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "sort" + "strconv" + "strings" + + "github.com/jmoiron/sqlx" + _ "modernc.org/sqlite" +) + +// SQLiteImportStats summarizes a one-shot SQLite-to-Postgres import. +type SQLiteImportStats struct { + SourceDSN string + TablesImported int + TablesSkipped int + RowsImported int64 +} + +type postgresColumn struct { + Name string `db:"column_name"` + DataType string `db:"data_type"` +} + +type importColumn struct { + SourceName string + DestinationName string + DestinationType string +} + +// ImportSQLiteIntoPostgres imports all supported tables from a SQLite database into a configured Postgres database. +func ImportSQLiteIntoPostgres(ctx context.Context, logger *slog.Logger, destination *sqlx.DB, sqliteDSN string) (SQLiteImportStats, error) { + stats := SQLiteImportStats{SourceDSN: strings.TrimSpace(sqliteDSN)} + if ctx == nil { + ctx = context.Background() + } + if logger == nil { + logger = slog.Default() + } + if destination == nil { + return stats, fmt.Errorf("destination database is nil") + } + driver := strings.ToLower(strings.TrimSpace(destination.DriverName())) + if driver != "pgx" && driver != "postgres" { + return stats, fmt.Errorf("sqlite import requires postgres destination; got %s", destination.DriverName()) + } + if strings.TrimSpace(sqliteDSN) == "" { + return stats, fmt.Errorf("sqlite source path/DSN is required") + } + + source, err := sqlx.Open("sqlite", normalizeSqliteDSN(sqliteDSN)) + if err != nil { + return stats, fmt.Errorf("failed to open sqlite source: %w", err) + } + defer source.Close() + + if err := source.PingContext(ctx); err != nil { + return stats, fmt.Errorf("failed to connect to sqlite source: %w", err) + } + + tables, err := listSQLiteUserTables(ctx, source) + if err != nil { + return stats, err + } + sort.Strings(tables) + if len(tables) == 0 { + logger.Warn("sqlite import source has no user tables") + return stats, nil + } + + importTables := make([]string, 0, len(tables)) + for _, tableName := range tables { + if shouldSkipSQLiteImportTable(tableName) { + stats.TablesSkipped++ + continue + } + if err := ensureDestinationImportTable(ctx, destination, tableName); err != nil { + return stats, err + } + importTables = append(importTables, tableName) + } + if len(importTables) == 0 { + logger.Warn("sqlite import found no tables to import after filtering") + return stats, nil + } + + tx, err := destination.BeginTxx(ctx, nil) + if err != nil { + return stats, fmt.Errorf("failed to start postgres import transaction: %w", err) + } + defer tx.Rollback() + + for _, tableName := range importTables { + rowsCopied, err := copySQLiteTableIntoPostgres(ctx, source, tx, tableName) + if err != nil { + return stats, err + } + stats.TablesImported++ + stats.RowsImported += rowsCopied + logger.Info("sqlite import copied table", "table", tableName, "rows", rowsCopied) + } + + if err := tx.Commit(); err != nil { + return stats, fmt.Errorf("failed to commit sqlite import transaction: %w", err) + } + + return stats, nil +} + +func listSQLiteUserTables(ctx context.Context, source *sqlx.DB) ([]string, error) { + rows, err := source.QueryxContext(ctx, ` +SELECT name +FROM sqlite_master +WHERE type = 'table' +ORDER BY name +`) + if err != nil { + return nil, fmt.Errorf("failed to list sqlite tables: %w", err) + } + defer rows.Close() + + var tables []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, fmt.Errorf("failed to scan sqlite table name: %w", err) + } + tables = append(tables, name) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed reading sqlite table names: %w", err) + } + + return tables, nil +} + +func shouldSkipSQLiteImportTable(tableName string) bool { + normalized := strings.ToLower(strings.TrimSpace(tableName)) + if normalized == "" { + return true + } + if strings.HasPrefix(normalized, "sqlite_") { + return true + } + // Destination migration state is managed independently by goose. + if normalized == "goose_db_version" { + return true + } + return false +} + +func ensureDestinationImportTable(ctx context.Context, destination *sqlx.DB, tableName string) error { + if TableExists(ctx, destination, tableName) { + return nil + } + switch { + case strings.HasPrefix(tableName, "inventory_hourly_"): + return EnsureSnapshotTable(ctx, destination, tableName) + case strings.HasPrefix(tableName, "inventory_daily_summary_"), strings.HasPrefix(tableName, "inventory_monthly_summary_"): + return EnsureSummaryTable(ctx, destination, tableName) + case tableName == "snapshot_runs": + return EnsureSnapshotRunTable(ctx, destination) + case tableName == "vm_hourly_stats": + return EnsureVmHourlyStats(ctx, destination) + case tableName == "vm_lifecycle_cache": + return EnsureVmLifecycleCache(ctx, destination) + case tableName == "vm_daily_rollup": + return EnsureVmDailyRollup(ctx, destination) + case tableName == "vm_identity", tableName == "vm_renames": + return EnsureVmIdentityTables(ctx, destination) + case tableName == "vcenter_totals": + return EnsureVcenterTotalsTable(ctx, destination) + case tableName == "vcenter_latest_totals": + return EnsureVcenterLatestTotalsTable(ctx, destination) + case tableName == "vcenter_aggregate_totals": + return EnsureVcenterAggregateTotalsTable(ctx, destination) + default: + return fmt.Errorf("source table %q does not exist in postgres and cannot be auto-created", tableName) + } +} + +func copySQLiteTableIntoPostgres(ctx context.Context, source *sqlx.DB, destinationTX *sqlx.Tx, tableName string) (int64, error) { + sourceColumns, err := listSQLiteTableColumns(ctx, source, tableName) + if err != nil { + return 0, err + } + destinationColumns, err := listPostgresTableColumns(ctx, destinationTX, tableName) + if err != nil { + return 0, err + } + columns := intersectImportColumns(sourceColumns, destinationColumns) + if len(columns) == 0 { + return 0, fmt.Errorf("no overlapping columns between sqlite and postgres table %q", tableName) + } + + if _, err := destinationTX.ExecContext(ctx, fmt.Sprintf(`TRUNCATE TABLE %s RESTART IDENTITY CASCADE`, quoteIdentifier(tableName))); err != nil { + return 0, fmt.Errorf("failed to truncate destination table %q: %w", tableName, err) + } + + sourceColumnNames := make([]string, 0, len(columns)) + destinationColumnNames := make([]string, 0, len(columns)) + for _, col := range columns { + sourceColumnNames = append(sourceColumnNames, col.SourceName) + destinationColumnNames = append(destinationColumnNames, col.DestinationName) + } + + selectSQL := fmt.Sprintf( + `SELECT %s FROM %s`, + joinQuotedIdentifiers(sourceColumnNames), + quoteIdentifier(tableName), + ) + rows, err := source.QueryxContext(ctx, selectSQL) + if err != nil { + return 0, fmt.Errorf("failed to query source table %q: %w", tableName, err) + } + defer rows.Close() + + insertSQL := fmt.Sprintf( + `INSERT INTO %s (%s) VALUES (%s)`, + quoteIdentifier(tableName), + joinQuotedIdentifiers(destinationColumnNames), + postgresPlaceholders(len(columns)), + ) + stmt, err := destinationTX.PreparexContext(ctx, insertSQL) + if err != nil { + return 0, fmt.Errorf("failed to prepare insert for table %q: %w", tableName, err) + } + defer stmt.Close() + + var rowsCopied int64 + for rows.Next() { + rawValues := make([]interface{}, len(columns)) + scanTargets := make([]interface{}, len(columns)) + for i := range rawValues { + scanTargets[i] = &rawValues[i] + } + if err := rows.Scan(scanTargets...); err != nil { + return rowsCopied, fmt.Errorf("failed to scan row from sqlite table %q: %w", tableName, err) + } + args := make([]interface{}, len(columns)) + for i, col := range columns { + args[i] = coerceSQLiteValueForPostgres(rawValues[i], col.DestinationType) + } + if _, err := stmt.ExecContext(ctx, args...); err != nil { + return rowsCopied, fmt.Errorf("failed to insert row into postgres table %q: %w", tableName, err) + } + rowsCopied++ + } + if err := rows.Err(); err != nil { + return rowsCopied, fmt.Errorf("failed to read rows from sqlite table %q: %w", tableName, err) + } + + if err := resetPostgresSerialColumns(ctx, destinationTX, tableName); err != nil { + return rowsCopied, fmt.Errorf("failed to reset postgres sequences for table %q: %w", tableName, err) + } + return rowsCopied, nil +} + +func listSQLiteTableColumns(ctx context.Context, source *sqlx.DB, tableName string) ([]string, error) { + rows, err := source.QueryxContext(ctx, fmt.Sprintf(`PRAGMA table_info(%s)`, quoteIdentifier(tableName))) + if err != nil { + return nil, fmt.Errorf("failed to inspect sqlite table %q: %w", tableName, err) + } + defer rows.Close() + + columns := make([]string, 0) + for rows.Next() { + var ( + cid int + name string + columnType string + notNull int + defaultVal sql.NullString + pk int + ) + if err := rows.Scan(&cid, &name, &columnType, ¬Null, &defaultVal, &pk); err != nil { + return nil, fmt.Errorf("failed to scan sqlite columns for %q: %w", tableName, err) + } + columns = append(columns, name) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed reading sqlite columns for %q: %w", tableName, err) + } + return columns, nil +} + +func listPostgresTableColumns(ctx context.Context, destinationTX *sqlx.Tx, tableName string) ([]postgresColumn, error) { + var columns []postgresColumn + if err := destinationTX.SelectContext(ctx, &columns, ` +SELECT column_name, data_type +FROM information_schema.columns +WHERE table_schema = 'public' AND table_name = $1 +ORDER BY ordinal_position +`, tableName); err != nil { + return nil, fmt.Errorf("failed to inspect postgres columns for %q: %w", tableName, err) + } + if len(columns) == 0 { + return nil, fmt.Errorf("postgres table %q has no columns", tableName) + } + return columns, nil +} + +func intersectImportColumns(sourceColumns []string, destinationColumns []postgresColumn) []importColumn { + sourceByLower := make(map[string]string, len(sourceColumns)) + for _, sourceColumn := range sourceColumns { + sourceByLower[strings.ToLower(sourceColumn)] = sourceColumn + } + + columns := make([]importColumn, 0, len(destinationColumns)) + for _, destinationColumn := range destinationColumns { + sourceColumn, exists := sourceByLower[strings.ToLower(destinationColumn.Name)] + if !exists { + continue + } + columns = append(columns, importColumn{ + SourceName: sourceColumn, + DestinationName: destinationColumn.Name, + DestinationType: destinationColumn.DataType, + }) + } + return columns +} + +func resetPostgresSerialColumns(ctx context.Context, destinationTX *sqlx.Tx, tableName string) error { + type serialColumn struct { + Name string `db:"column_name"` + } + var columns []serialColumn + if err := destinationTX.SelectContext(ctx, &columns, ` +SELECT column_name +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = $1 + AND column_default LIKE 'nextval(%' +ORDER BY ordinal_position +`, tableName); err != nil { + return err + } + + tableRef := fmt.Sprintf("public.%s", quoteIdentifier(tableName)) + for _, column := range columns { + var sequence sql.NullString + if err := destinationTX.GetContext(ctx, &sequence, `SELECT pg_get_serial_sequence($1, $2)`, tableRef, column.Name); err != nil { + return err + } + if !sequence.Valid || strings.TrimSpace(sequence.String) == "" { + continue + } + setvalSQL := fmt.Sprintf( + `SELECT setval($1, COALESCE((SELECT MAX(%s) FROM %s), 0) + 1, false)`, + quoteIdentifier(column.Name), + quoteIdentifier(tableName), + ) + if _, err := destinationTX.ExecContext(ctx, setvalSQL, sequence.String); err != nil { + return err + } + } + return nil +} + +func coerceSQLiteValueForPostgres(value interface{}, destinationType string) interface{} { + if value == nil { + return nil + } + destinationType = strings.ToLower(strings.TrimSpace(destinationType)) + + if bytesValue, ok := value.([]byte); ok && destinationType != "bytea" { + value = string(bytesValue) + } + + if destinationType == "boolean" { + if boolValue, ok := coerceBoolValue(value); ok { + return boolValue + } + } + + return value +} + +func coerceBoolValue(value interface{}) (bool, bool) { + switch v := value.(type) { + case bool: + return v, true + case int64: + return v != 0, true + case int: + return v != 0, true + case float64: + return v != 0, true + case string: + return parseBoolString(v) + case []byte: + return parseBoolString(string(v)) + default: + return false, false + } +} + +func parseBoolString(raw string) (bool, bool) { + normalized := strings.ToLower(strings.TrimSpace(raw)) + switch normalized { + case "1", "t", "true", "y", "yes", "on": + return true, true + case "0", "f", "false", "n", "no", "off": + return false, true + } + if parsedInt, err := strconv.ParseInt(normalized, 10, 64); err == nil { + return parsedInt != 0, true + } + return false, false +} + +func quoteIdentifier(identifier string) string { + return `"` + strings.ReplaceAll(identifier, `"`, `""`) + `"` +} + +func joinQuotedIdentifiers(identifiers []string) string { + if len(identifiers) == 0 { + return "" + } + quoted := make([]string, 0, len(identifiers)) + for _, identifier := range identifiers { + quoted = append(quoted, quoteIdentifier(identifier)) + } + return strings.Join(quoted, ", ") +} + +func postgresPlaceholders(count int) string { + if count <= 0 { + return "" + } + placeholders := make([]string, 0, count) + for i := 1; i <= count; i++ { + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + } + return strings.Join(placeholders, ", ") +} diff --git a/db/sqlite_import_test.go b/db/sqlite_import_test.go new file mode 100644 index 0000000..69e5d59 --- /dev/null +++ b/db/sqlite_import_test.go @@ -0,0 +1,80 @@ +package db + +import ( + "reflect" + "testing" +) + +func TestShouldSkipSQLiteImportTable(t *testing.T) { + tests := []struct { + name string + tableName string + wantSkip bool + }{ + {name: "empty", tableName: "", wantSkip: true}, + {name: "sqlite sequence", tableName: "sqlite_sequence", wantSkip: true}, + {name: "goose table", tableName: "goose_db_version", wantSkip: true}, + {name: "normal table", tableName: "Inventory", wantSkip: false}, + {name: "snapshot table", tableName: "inventory_hourly_1700000000", wantSkip: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := shouldSkipSQLiteImportTable(tc.tableName) + if got != tc.wantSkip { + t.Fatalf("skip mismatch: got %t want %t", got, tc.wantSkip) + } + }) + } +} + +func TestIntersectImportColumns(t *testing.T) { + source := []string{"Iid", "Name", "Vcenter", "CreationTime"} + dest := []postgresColumn{ + {Name: "Iid", DataType: "bigint"}, + {Name: "Name", DataType: "text"}, + {Name: "Vcenter", DataType: "text"}, + {Name: "DeletionTime", DataType: "bigint"}, + } + + got := intersectImportColumns(source, dest) + want := []importColumn{ + {SourceName: "Iid", DestinationName: "Iid", DestinationType: "bigint"}, + {SourceName: "Name", DestinationName: "Name", DestinationType: "text"}, + {SourceName: "Vcenter", DestinationName: "Vcenter", DestinationType: "text"}, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("intersect mismatch:\n got: %#v\nwant: %#v", got, want) + } +} + +func TestCoerceSQLiteValueForPostgresBoolean(t *testing.T) { + tests := []struct { + name string + input interface{} + destinationType string + want interface{} + }{ + {name: "string true", input: "true", destinationType: "boolean", want: true}, + {name: "string false", input: "0", destinationType: "boolean", want: false}, + {name: "int true", input: int64(1), destinationType: "boolean", want: true}, + {name: "int false", input: int64(0), destinationType: "boolean", want: false}, + {name: "bytes text", input: []byte("hello"), destinationType: "text", want: "hello"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := coerceSQLiteValueForPostgres(tc.input, tc.destinationType) + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("coerce mismatch: got %#v want %#v", got, tc.want) + } + }) + } +} + +func TestPostgresPlaceholders(t *testing.T) { + got := postgresPlaceholders(3) + if got != "$1, $2, $3" { + t.Fatalf("unexpected placeholders: %q", got) + } +} diff --git a/main.go b/main.go index 88fc5c9..d0bdb37 100644 --- a/main.go +++ b/main.go @@ -44,6 +44,7 @@ func main() { runInventory := flag.Bool("run-inventory", false, "Run a single inventory snapshot across all configured vCenters and exit") dbCleanup := flag.Bool("db-cleanup", false, "Run a one-time cleanup to drop low-value hourly snapshot indexes and exit") backfillVcenterCache := flag.Bool("backfill-vcenter-cache", false, "Run a one-time backfill for vcenter latest+aggregate cache tables and exit") + importSQLite := flag.String("import-sqlite", "", "Import a SQLite database file/DSN into the configured Postgres database and exit") flag.Parse() bootstrapLogger := log.New(log.LevelInfo, log.OutputText) @@ -68,18 +69,19 @@ func main() { warnDeprecatedPollingSettings(logger, s.Values) // Configure database - dbDriver := strings.TrimSpace(s.Values.Settings.DatabaseDriver) - if dbDriver == "" { - dbDriver = "sqlite" - } - normalizedDriver := strings.ToLower(strings.TrimSpace(dbDriver)) - if normalizedDriver == "" || normalizedDriver == "sqlite3" { - normalizedDriver = "sqlite" - } dbURL := strings.TrimSpace(s.Values.Settings.DatabaseURL) + normalizedDriver, inferredFromDSN, err := db.ResolveDriver(s.Values.Settings.DatabaseDriver, dbURL) + if err != nil { + logger.Error("Invalid database configuration", "error", err) + os.Exit(1) + } + if inferredFromDSN { + logger.Warn("database_driver is unset; inferred postgres from database_url") + } if dbURL == "" && normalizedDriver == "sqlite" { dbURL = utils.GetFilePath("db.sqlite3") } + logger.Info("Effective database driver resolved", "driver", normalizedDriver) database, err := db.New(logger, db.Config{ Driver: normalizedDriver, @@ -97,6 +99,25 @@ func main() { logger.Error("failed to migrate database", "error", err) os.Exit(1) } + if strings.TrimSpace(*importSQLite) != "" { + if normalizedDriver != "postgres" { + logger.Error("sqlite import requires settings.database_driver=postgres") + os.Exit(1) + } + logger.Info("starting one-time sqlite import into postgres", "sqlite_source", strings.TrimSpace(*importSQLite)) + stats, err := db.ImportSQLiteIntoPostgres(ctx, logger, database.DB(), strings.TrimSpace(*importSQLite)) + if err != nil { + logger.Error("failed to import sqlite database into postgres", "error", err) + os.Exit(1) + } + logger.Info("completed sqlite import into postgres", + "sqlite_source", stats.SourceDSN, + "tables_imported", stats.TablesImported, + "tables_skipped", stats.TablesSkipped, + "rows_imported", stats.RowsImported, + ) + return + } if *dbCleanup { dropped, err := db.CleanupHourlySnapshotIndexes(ctx, database.DB()) if err != nil { @@ -332,7 +353,7 @@ func main() { gocron.CronJob(snapshotCleanupCron, false), gocron.NewTask(func() { ct.RunSnapshotCleanup(ctx, logger) - if strings.EqualFold(s.Values.Settings.DatabaseDriver, "sqlite") { + if normalizedDriver == "sqlite" { logger.Info("Performing sqlite VACUUM after snapshot cleanup") if _, err := ct.Database.DB().ExecContext(ctx, "VACUUM"); err != nil { logger.Warn("VACUUM failed after snapshot cleanup", "error", err)