improve postgres support
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
441
db/sqlite_import.go
Normal file
441
db/sqlite_import.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// SQLiteImportStats summarizes a one-shot SQLite-to-Postgres import.
|
||||
type SQLiteImportStats struct {
|
||||
SourceDSN string
|
||||
TablesImported int
|
||||
TablesSkipped int
|
||||
RowsImported int64
|
||||
}
|
||||
|
||||
type postgresColumn struct {
|
||||
Name string `db:"column_name"`
|
||||
DataType string `db:"data_type"`
|
||||
}
|
||||
|
||||
type importColumn struct {
|
||||
SourceName string
|
||||
DestinationName string
|
||||
DestinationType string
|
||||
}
|
||||
|
||||
// ImportSQLiteIntoPostgres imports all supported tables from a SQLite database into a configured Postgres database.
|
||||
func ImportSQLiteIntoPostgres(ctx context.Context, logger *slog.Logger, destination *sqlx.DB, sqliteDSN string) (SQLiteImportStats, error) {
|
||||
stats := SQLiteImportStats{SourceDSN: strings.TrimSpace(sqliteDSN)}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
if destination == nil {
|
||||
return stats, fmt.Errorf("destination database is nil")
|
||||
}
|
||||
driver := strings.ToLower(strings.TrimSpace(destination.DriverName()))
|
||||
if driver != "pgx" && driver != "postgres" {
|
||||
return stats, fmt.Errorf("sqlite import requires postgres destination; got %s", destination.DriverName())
|
||||
}
|
||||
if strings.TrimSpace(sqliteDSN) == "" {
|
||||
return stats, fmt.Errorf("sqlite source path/DSN is required")
|
||||
}
|
||||
|
||||
source, err := sqlx.Open("sqlite", normalizeSqliteDSN(sqliteDSN))
|
||||
if err != nil {
|
||||
return stats, fmt.Errorf("failed to open sqlite source: %w", err)
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
if err := source.PingContext(ctx); err != nil {
|
||||
return stats, fmt.Errorf("failed to connect to sqlite source: %w", err)
|
||||
}
|
||||
|
||||
tables, err := listSQLiteUserTables(ctx, source)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
sort.Strings(tables)
|
||||
if len(tables) == 0 {
|
||||
logger.Warn("sqlite import source has no user tables")
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
importTables := make([]string, 0, len(tables))
|
||||
for _, tableName := range tables {
|
||||
if shouldSkipSQLiteImportTable(tableName) {
|
||||
stats.TablesSkipped++
|
||||
continue
|
||||
}
|
||||
if err := ensureDestinationImportTable(ctx, destination, tableName); err != nil {
|
||||
return stats, err
|
||||
}
|
||||
importTables = append(importTables, tableName)
|
||||
}
|
||||
if len(importTables) == 0 {
|
||||
logger.Warn("sqlite import found no tables to import after filtering")
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
tx, err := destination.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return stats, fmt.Errorf("failed to start postgres import transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, tableName := range importTables {
|
||||
rowsCopied, err := copySQLiteTableIntoPostgres(ctx, source, tx, tableName)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
stats.TablesImported++
|
||||
stats.RowsImported += rowsCopied
|
||||
logger.Info("sqlite import copied table", "table", tableName, "rows", rowsCopied)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return stats, fmt.Errorf("failed to commit sqlite import transaction: %w", err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func listSQLiteUserTables(ctx context.Context, source *sqlx.DB) ([]string, error) {
|
||||
rows, err := source.QueryxContext(ctx, `
|
||||
SELECT name
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table'
|
||||
ORDER BY name
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list sqlite tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan sqlite table name: %w", err)
|
||||
}
|
||||
tables = append(tables, name)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed reading sqlite table names: %w", err)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func shouldSkipSQLiteImportTable(tableName string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(tableName))
|
||||
if normalized == "" {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(normalized, "sqlite_") {
|
||||
return true
|
||||
}
|
||||
// Destination migration state is managed independently by goose.
|
||||
if normalized == "goose_db_version" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ensureDestinationImportTable(ctx context.Context, destination *sqlx.DB, tableName string) error {
|
||||
if TableExists(ctx, destination, tableName) {
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(tableName, "inventory_hourly_"):
|
||||
return EnsureSnapshotTable(ctx, destination, tableName)
|
||||
case strings.HasPrefix(tableName, "inventory_daily_summary_"), strings.HasPrefix(tableName, "inventory_monthly_summary_"):
|
||||
return EnsureSummaryTable(ctx, destination, tableName)
|
||||
case tableName == "snapshot_runs":
|
||||
return EnsureSnapshotRunTable(ctx, destination)
|
||||
case tableName == "vm_hourly_stats":
|
||||
return EnsureVmHourlyStats(ctx, destination)
|
||||
case tableName == "vm_lifecycle_cache":
|
||||
return EnsureVmLifecycleCache(ctx, destination)
|
||||
case tableName == "vm_daily_rollup":
|
||||
return EnsureVmDailyRollup(ctx, destination)
|
||||
case tableName == "vm_identity", tableName == "vm_renames":
|
||||
return EnsureVmIdentityTables(ctx, destination)
|
||||
case tableName == "vcenter_totals":
|
||||
return EnsureVcenterTotalsTable(ctx, destination)
|
||||
case tableName == "vcenter_latest_totals":
|
||||
return EnsureVcenterLatestTotalsTable(ctx, destination)
|
||||
case tableName == "vcenter_aggregate_totals":
|
||||
return EnsureVcenterAggregateTotalsTable(ctx, destination)
|
||||
default:
|
||||
return fmt.Errorf("source table %q does not exist in postgres and cannot be auto-created", tableName)
|
||||
}
|
||||
}
|
||||
|
||||
func copySQLiteTableIntoPostgres(ctx context.Context, source *sqlx.DB, destinationTX *sqlx.Tx, tableName string) (int64, error) {
|
||||
sourceColumns, err := listSQLiteTableColumns(ctx, source, tableName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
destinationColumns, err := listPostgresTableColumns(ctx, destinationTX, tableName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
columns := intersectImportColumns(sourceColumns, destinationColumns)
|
||||
if len(columns) == 0 {
|
||||
return 0, fmt.Errorf("no overlapping columns between sqlite and postgres table %q", tableName)
|
||||
}
|
||||
|
||||
if _, err := destinationTX.ExecContext(ctx, fmt.Sprintf(`TRUNCATE TABLE %s RESTART IDENTITY CASCADE`, quoteIdentifier(tableName))); err != nil {
|
||||
return 0, fmt.Errorf("failed to truncate destination table %q: %w", tableName, err)
|
||||
}
|
||||
|
||||
sourceColumnNames := make([]string, 0, len(columns))
|
||||
destinationColumnNames := make([]string, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
sourceColumnNames = append(sourceColumnNames, col.SourceName)
|
||||
destinationColumnNames = append(destinationColumnNames, col.DestinationName)
|
||||
}
|
||||
|
||||
selectSQL := fmt.Sprintf(
|
||||
`SELECT %s FROM %s`,
|
||||
joinQuotedIdentifiers(sourceColumnNames),
|
||||
quoteIdentifier(tableName),
|
||||
)
|
||||
rows, err := source.QueryxContext(ctx, selectSQL)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to query source table %q: %w", tableName, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
insertSQL := fmt.Sprintf(
|
||||
`INSERT INTO %s (%s) VALUES (%s)`,
|
||||
quoteIdentifier(tableName),
|
||||
joinQuotedIdentifiers(destinationColumnNames),
|
||||
postgresPlaceholders(len(columns)),
|
||||
)
|
||||
stmt, err := destinationTX.PreparexContext(ctx, insertSQL)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to prepare insert for table %q: %w", tableName, err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
var rowsCopied int64
|
||||
for rows.Next() {
|
||||
rawValues := make([]interface{}, len(columns))
|
||||
scanTargets := make([]interface{}, len(columns))
|
||||
for i := range rawValues {
|
||||
scanTargets[i] = &rawValues[i]
|
||||
}
|
||||
if err := rows.Scan(scanTargets...); err != nil {
|
||||
return rowsCopied, fmt.Errorf("failed to scan row from sqlite table %q: %w", tableName, err)
|
||||
}
|
||||
args := make([]interface{}, len(columns))
|
||||
for i, col := range columns {
|
||||
args[i] = coerceSQLiteValueForPostgres(rawValues[i], col.DestinationType)
|
||||
}
|
||||
if _, err := stmt.ExecContext(ctx, args...); err != nil {
|
||||
return rowsCopied, fmt.Errorf("failed to insert row into postgres table %q: %w", tableName, err)
|
||||
}
|
||||
rowsCopied++
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return rowsCopied, fmt.Errorf("failed to read rows from sqlite table %q: %w", tableName, err)
|
||||
}
|
||||
|
||||
if err := resetPostgresSerialColumns(ctx, destinationTX, tableName); err != nil {
|
||||
return rowsCopied, fmt.Errorf("failed to reset postgres sequences for table %q: %w", tableName, err)
|
||||
}
|
||||
return rowsCopied, nil
|
||||
}
|
||||
|
||||
func listSQLiteTableColumns(ctx context.Context, source *sqlx.DB, tableName string) ([]string, error) {
|
||||
rows, err := source.QueryxContext(ctx, fmt.Sprintf(`PRAGMA table_info(%s)`, quoteIdentifier(tableName)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to inspect sqlite table %q: %w", tableName, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
cid int
|
||||
name string
|
||||
columnType string
|
||||
notNull int
|
||||
defaultVal sql.NullString
|
||||
pk int
|
||||
)
|
||||
if err := rows.Scan(&cid, &name, &columnType, ¬Null, &defaultVal, &pk); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan sqlite columns for %q: %w", tableName, err)
|
||||
}
|
||||
columns = append(columns, name)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed reading sqlite columns for %q: %w", tableName, err)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func listPostgresTableColumns(ctx context.Context, destinationTX *sqlx.Tx, tableName string) ([]postgresColumn, error) {
|
||||
var columns []postgresColumn
|
||||
if err := destinationTX.SelectContext(ctx, &columns, `
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public' AND table_name = $1
|
||||
ORDER BY ordinal_position
|
||||
`, tableName); err != nil {
|
||||
return nil, fmt.Errorf("failed to inspect postgres columns for %q: %w", tableName, err)
|
||||
}
|
||||
if len(columns) == 0 {
|
||||
return nil, fmt.Errorf("postgres table %q has no columns", tableName)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func intersectImportColumns(sourceColumns []string, destinationColumns []postgresColumn) []importColumn {
|
||||
sourceByLower := make(map[string]string, len(sourceColumns))
|
||||
for _, sourceColumn := range sourceColumns {
|
||||
sourceByLower[strings.ToLower(sourceColumn)] = sourceColumn
|
||||
}
|
||||
|
||||
columns := make([]importColumn, 0, len(destinationColumns))
|
||||
for _, destinationColumn := range destinationColumns {
|
||||
sourceColumn, exists := sourceByLower[strings.ToLower(destinationColumn.Name)]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, importColumn{
|
||||
SourceName: sourceColumn,
|
||||
DestinationName: destinationColumn.Name,
|
||||
DestinationType: destinationColumn.DataType,
|
||||
})
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func resetPostgresSerialColumns(ctx context.Context, destinationTX *sqlx.Tx, tableName string) error {
|
||||
type serialColumn struct {
|
||||
Name string `db:"column_name"`
|
||||
}
|
||||
var columns []serialColumn
|
||||
if err := destinationTX.SelectContext(ctx, &columns, `
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
AND column_default LIKE 'nextval(%'
|
||||
ORDER BY ordinal_position
|
||||
`, tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tableRef := fmt.Sprintf("public.%s", quoteIdentifier(tableName))
|
||||
for _, column := range columns {
|
||||
var sequence sql.NullString
|
||||
if err := destinationTX.GetContext(ctx, &sequence, `SELECT pg_get_serial_sequence($1, $2)`, tableRef, column.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if !sequence.Valid || strings.TrimSpace(sequence.String) == "" {
|
||||
continue
|
||||
}
|
||||
setvalSQL := fmt.Sprintf(
|
||||
`SELECT setval($1, COALESCE((SELECT MAX(%s) FROM %s), 0) + 1, false)`,
|
||||
quoteIdentifier(column.Name),
|
||||
quoteIdentifier(tableName),
|
||||
)
|
||||
if _, err := destinationTX.ExecContext(ctx, setvalSQL, sequence.String); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func coerceSQLiteValueForPostgres(value interface{}, destinationType string) interface{} {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
destinationType = strings.ToLower(strings.TrimSpace(destinationType))
|
||||
|
||||
if bytesValue, ok := value.([]byte); ok && destinationType != "bytea" {
|
||||
value = string(bytesValue)
|
||||
}
|
||||
|
||||
if destinationType == "boolean" {
|
||||
if boolValue, ok := coerceBoolValue(value); ok {
|
||||
return boolValue
|
||||
}
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func coerceBoolValue(value interface{}) (bool, bool) {
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return v, true
|
||||
case int64:
|
||||
return v != 0, true
|
||||
case int:
|
||||
return v != 0, true
|
||||
case float64:
|
||||
return v != 0, true
|
||||
case string:
|
||||
return parseBoolString(v)
|
||||
case []byte:
|
||||
return parseBoolString(string(v))
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseBoolString(raw string) (bool, bool) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch normalized {
|
||||
case "1", "t", "true", "y", "yes", "on":
|
||||
return true, true
|
||||
case "0", "f", "false", "n", "no", "off":
|
||||
return false, true
|
||||
}
|
||||
if parsedInt, err := strconv.ParseInt(normalized, 10, 64); err == nil {
|
||||
return parsedInt != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func quoteIdentifier(identifier string) string {
|
||||
return `"` + strings.ReplaceAll(identifier, `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
func joinQuotedIdentifiers(identifiers []string) string {
|
||||
if len(identifiers) == 0 {
|
||||
return ""
|
||||
}
|
||||
quoted := make([]string, 0, len(identifiers))
|
||||
for _, identifier := range identifiers {
|
||||
quoted = append(quoted, quoteIdentifier(identifier))
|
||||
}
|
||||
return strings.Join(quoted, ", ")
|
||||
}
|
||||
|
||||
func postgresPlaceholders(count int) string {
|
||||
if count <= 0 {
|
||||
return ""
|
||||
}
|
||||
placeholders := make([]string, 0, count)
|
||||
for i := 1; i <= count; i++ {
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
||||
}
|
||||
return strings.Join(placeholders, ", ")
|
||||
}
|
||||
Reference in New Issue
Block a user