package db import ( "context" "database/sql" "fmt" "log/slog" "sort" "strconv" "strings" "github.com/jmoiron/sqlx" _ "modernc.org/sqlite" ) // SQLiteImportStats summarizes a one-shot SQLite-to-Postgres import. type SQLiteImportStats struct { SourceDSN string TablesImported int TablesSkipped int RowsImported int64 } type postgresColumn struct { Name string `db:"column_name"` DataType string `db:"data_type"` } type importColumn struct { SourceName string DestinationName string DestinationType string } // ImportSQLiteIntoPostgres imports all supported tables from a SQLite database into a configured Postgres database. func ImportSQLiteIntoPostgres(ctx context.Context, logger *slog.Logger, destination *sqlx.DB, sqliteDSN string) (SQLiteImportStats, error) { stats := SQLiteImportStats{SourceDSN: strings.TrimSpace(sqliteDSN)} if ctx == nil { ctx = context.Background() } if logger == nil { logger = slog.Default() } if destination == nil { return stats, fmt.Errorf("destination database is nil") } driver := strings.ToLower(strings.TrimSpace(destination.DriverName())) if driver != "pgx" && driver != "postgres" { return stats, fmt.Errorf("sqlite import requires postgres destination; got %s", destination.DriverName()) } if strings.TrimSpace(sqliteDSN) == "" { return stats, fmt.Errorf("sqlite source path/DSN is required") } source, err := sqlx.Open("sqlite", normalizeSqliteDSN(sqliteDSN)) if err != nil { return stats, fmt.Errorf("failed to open sqlite source: %w", err) } defer source.Close() if err := source.PingContext(ctx); err != nil { return stats, fmt.Errorf("failed to connect to sqlite source: %w", err) } tables, err := listSQLiteUserTables(ctx, source) if err != nil { return stats, err } sort.Strings(tables) if len(tables) == 0 { logger.Warn("sqlite import source has no user tables") return stats, nil } importTables := make([]string, 0, len(tables)) for _, tableName := range tables { if shouldSkipSQLiteImportTable(tableName) { stats.TablesSkipped++ continue } if err := ensureDestinationImportTable(ctx, destination, tableName); err != nil { return stats, err } importTables = append(importTables, tableName) } if len(importTables) == 0 { logger.Warn("sqlite import found no tables to import after filtering") return stats, nil } tx, err := destination.BeginTxx(ctx, nil) if err != nil { return stats, fmt.Errorf("failed to start postgres import transaction: %w", err) } defer tx.Rollback() for _, tableName := range importTables { rowsCopied, err := copySQLiteTableIntoPostgres(ctx, source, tx, tableName) if err != nil { return stats, err } stats.TablesImported++ stats.RowsImported += rowsCopied logger.Info("sqlite import copied table", "table", tableName, "rows", rowsCopied) } if err := tx.Commit(); err != nil { return stats, fmt.Errorf("failed to commit sqlite import transaction: %w", err) } return stats, nil } func listSQLiteUserTables(ctx context.Context, source *sqlx.DB) ([]string, error) { rows, err := source.QueryxContext(ctx, ` SELECT name FROM sqlite_master WHERE type = 'table' ORDER BY name `) if err != nil { return nil, fmt.Errorf("failed to list sqlite tables: %w", err) } defer rows.Close() var tables []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, fmt.Errorf("failed to scan sqlite table name: %w", err) } tables = append(tables, name) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed reading sqlite table names: %w", err) } return tables, nil } func shouldSkipSQLiteImportTable(tableName string) bool { normalized := strings.ToLower(strings.TrimSpace(tableName)) if normalized == "" { return true } if strings.HasPrefix(normalized, "sqlite_") { return true } // Destination migration state is managed independently by goose. if normalized == "goose_db_version" { return true } return false } func ensureDestinationImportTable(ctx context.Context, destination *sqlx.DB, tableName string) error { if TableExists(ctx, destination, tableName) { return nil } switch { case strings.HasPrefix(tableName, "inventory_hourly_"): return EnsureSnapshotTable(ctx, destination, tableName) case strings.HasPrefix(tableName, "inventory_daily_summary_"), strings.HasPrefix(tableName, "inventory_monthly_summary_"): return EnsureSummaryTable(ctx, destination, tableName) case tableName == "snapshot_runs": return EnsureSnapshotRunTable(ctx, destination) case tableName == "vm_hourly_stats": return EnsureVmHourlyStats(ctx, destination) case tableName == "vm_lifecycle_cache": return EnsureVmLifecycleCache(ctx, destination) case tableName == "vm_daily_rollup": return EnsureVmDailyRollup(ctx, destination) case tableName == "vm_identity", tableName == "vm_renames": return EnsureVmIdentityTables(ctx, destination) case tableName == "vcenter_totals": return EnsureVcenterTotalsTable(ctx, destination) case tableName == "vcenter_latest_totals": return EnsureVcenterLatestTotalsTable(ctx, destination) case tableName == "vcenter_aggregate_totals": return EnsureVcenterAggregateTotalsTable(ctx, destination) default: return fmt.Errorf("source table %q does not exist in postgres and cannot be auto-created", tableName) } } func copySQLiteTableIntoPostgres(ctx context.Context, source *sqlx.DB, destinationTX *sqlx.Tx, tableName string) (int64, error) { sourceColumns, err := listSQLiteTableColumns(ctx, source, tableName) if err != nil { return 0, err } destinationColumns, err := listPostgresTableColumns(ctx, destinationTX, tableName) if err != nil { return 0, err } columns := intersectImportColumns(sourceColumns, destinationColumns) if len(columns) == 0 { return 0, fmt.Errorf("no overlapping columns between sqlite and postgres table %q", tableName) } if _, err := destinationTX.ExecContext(ctx, fmt.Sprintf(`TRUNCATE TABLE %s RESTART IDENTITY CASCADE`, quoteIdentifier(tableName))); err != nil { return 0, fmt.Errorf("failed to truncate destination table %q: %w", tableName, err) } sourceColumnNames := make([]string, 0, len(columns)) destinationColumnNames := make([]string, 0, len(columns)) for _, col := range columns { sourceColumnNames = append(sourceColumnNames, col.SourceName) destinationColumnNames = append(destinationColumnNames, col.DestinationName) } selectSQL := fmt.Sprintf( `SELECT %s FROM %s`, joinQuotedIdentifiers(sourceColumnNames), quoteIdentifier(tableName), ) rows, err := source.QueryxContext(ctx, selectSQL) if err != nil { return 0, fmt.Errorf("failed to query source table %q: %w", tableName, err) } defer rows.Close() insertSQL := fmt.Sprintf( `INSERT INTO %s (%s) VALUES (%s)`, quoteIdentifier(tableName), joinQuotedIdentifiers(destinationColumnNames), postgresPlaceholders(len(columns)), ) stmt, err := destinationTX.PreparexContext(ctx, insertSQL) if err != nil { return 0, fmt.Errorf("failed to prepare insert for table %q: %w", tableName, err) } defer stmt.Close() var rowsCopied int64 for rows.Next() { rawValues := make([]interface{}, len(columns)) scanTargets := make([]interface{}, len(columns)) for i := range rawValues { scanTargets[i] = &rawValues[i] } if err := rows.Scan(scanTargets...); err != nil { return rowsCopied, fmt.Errorf("failed to scan row from sqlite table %q: %w", tableName, err) } args := make([]interface{}, len(columns)) for i, col := range columns { args[i] = coerceSQLiteValueForPostgres(rawValues[i], col.DestinationType) } if _, err := stmt.ExecContext(ctx, args...); err != nil { return rowsCopied, fmt.Errorf("failed to insert row into postgres table %q: %w", tableName, err) } rowsCopied++ } if err := rows.Err(); err != nil { return rowsCopied, fmt.Errorf("failed to read rows from sqlite table %q: %w", tableName, err) } if err := resetPostgresSerialColumns(ctx, destinationTX, tableName); err != nil { return rowsCopied, fmt.Errorf("failed to reset postgres sequences for table %q: %w", tableName, err) } return rowsCopied, nil } func listSQLiteTableColumns(ctx context.Context, source *sqlx.DB, tableName string) ([]string, error) { rows, err := source.QueryxContext(ctx, fmt.Sprintf(`PRAGMA table_info(%s)`, quoteIdentifier(tableName))) if err != nil { return nil, fmt.Errorf("failed to inspect sqlite table %q: %w", tableName, err) } defer rows.Close() columns := make([]string, 0) for rows.Next() { var ( cid int name string columnType string notNull int defaultVal sql.NullString pk int ) if err := rows.Scan(&cid, &name, &columnType, ¬Null, &defaultVal, &pk); err != nil { return nil, fmt.Errorf("failed to scan sqlite columns for %q: %w", tableName, err) } columns = append(columns, name) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed reading sqlite columns for %q: %w", tableName, err) } return columns, nil } func listPostgresTableColumns(ctx context.Context, destinationTX *sqlx.Tx, tableName string) ([]postgresColumn, error) { var columns []postgresColumn if err := destinationTX.SelectContext(ctx, &columns, ` SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'public' AND table_name = $1 ORDER BY ordinal_position `, tableName); err != nil { return nil, fmt.Errorf("failed to inspect postgres columns for %q: %w", tableName, err) } if len(columns) == 0 { return nil, fmt.Errorf("postgres table %q has no columns", tableName) } return columns, nil } func intersectImportColumns(sourceColumns []string, destinationColumns []postgresColumn) []importColumn { sourceByLower := make(map[string]string, len(sourceColumns)) for _, sourceColumn := range sourceColumns { sourceByLower[strings.ToLower(sourceColumn)] = sourceColumn } columns := make([]importColumn, 0, len(destinationColumns)) for _, destinationColumn := range destinationColumns { sourceColumn, exists := sourceByLower[strings.ToLower(destinationColumn.Name)] if !exists { continue } columns = append(columns, importColumn{ SourceName: sourceColumn, DestinationName: destinationColumn.Name, DestinationType: destinationColumn.DataType, }) } return columns } func resetPostgresSerialColumns(ctx context.Context, destinationTX *sqlx.Tx, tableName string) error { type serialColumn struct { Name string `db:"column_name"` } var columns []serialColumn if err := destinationTX.SelectContext(ctx, &columns, ` SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = $1 AND column_default LIKE 'nextval(%' ORDER BY ordinal_position `, tableName); err != nil { return err } tableRef := fmt.Sprintf("public.%s", quoteIdentifier(tableName)) for _, column := range columns { var sequence sql.NullString if err := destinationTX.GetContext(ctx, &sequence, `SELECT pg_get_serial_sequence($1, $2)`, tableRef, column.Name); err != nil { return err } if !sequence.Valid || strings.TrimSpace(sequence.String) == "" { continue } setvalSQL := fmt.Sprintf( `SELECT setval($1, COALESCE((SELECT MAX(%s) FROM %s), 0) + 1, false)`, quoteIdentifier(column.Name), quoteIdentifier(tableName), ) if _, err := destinationTX.ExecContext(ctx, setvalSQL, sequence.String); err != nil { return err } } return nil } func coerceSQLiteValueForPostgres(value interface{}, destinationType string) interface{} { if value == nil { return nil } destinationType = strings.ToLower(strings.TrimSpace(destinationType)) if bytesValue, ok := value.([]byte); ok && destinationType != "bytea" { value = string(bytesValue) } if destinationType == "boolean" { if boolValue, ok := coerceBoolValue(value); ok { return boolValue } } return value } func coerceBoolValue(value interface{}) (bool, bool) { switch v := value.(type) { case bool: return v, true case int64: return v != 0, true case int: return v != 0, true case float64: return v != 0, true case string: return parseBoolString(v) case []byte: return parseBoolString(string(v)) default: return false, false } } func parseBoolString(raw string) (bool, bool) { normalized := strings.ToLower(strings.TrimSpace(raw)) switch normalized { case "1", "t", "true", "y", "yes", "on": return true, true case "0", "f", "false", "n", "no", "off": return false, true } if parsedInt, err := strconv.ParseInt(normalized, 10, 64); err == nil { return parsedInt != 0, true } return false, false } func quoteIdentifier(identifier string) string { return `"` + strings.ReplaceAll(identifier, `"`, `""`) + `"` } func joinQuotedIdentifiers(identifiers []string) string { if len(identifiers) == 0 { return "" } quoted := make([]string, 0, len(identifiers)) for _, identifier := range identifiers { quoted = append(quoted, quoteIdentifier(identifier)) } return strings.Join(quoted, ", ") } func postgresPlaceholders(count int) string { if count <= 0 { return "" } placeholders := make([]string, 0, count) for i := 1; i <= count; i++ { placeholders = append(placeholders, fmt.Sprintf("$%d", i)) } return strings.Join(placeholders, ", ") }