package db import ( "database/sql" "embed" "fmt" "log/slog" "reflect" "strings" "github.com/jmoiron/sqlx" "github.com/pressly/goose/v3" ) //go:embed migrations migrations_postgres var migrations embed.FS type Database interface { DB() *sqlx.DB Queries() Querier Logger() *slog.Logger Close() error } type Config struct { Driver string DSN string EnableExperimentalPostgres bool } func New(logger *slog.Logger, cfg Config) (Database, error) { driver := normalizeDriver(cfg.Driver) switch driver { case "sqlite": db, err := newLocalDB(logger, cfg.DSN) if err != nil { return nil, err } if err = db.db.Ping(); err != nil { return nil, err } return db, nil case "postgres": // The sqlc query set is SQLite-first. Keep Postgres opt-in until full parity is validated. if !cfg.EnableExperimentalPostgres { return nil, fmt.Errorf("postgres driver is disabled by default; set settings.enable_experimental_postgres=true to enable experimental mode") } db, err := newPostgresDB(logger, cfg.DSN) if err != nil { return nil, err } if err = db.db.Ping(); err != nil { return nil, err } return db, nil default: return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver) } } // Migrate runs the migrations on the database. func Migrate(db Database, driver string) error { driver = normalizeDriver(driver) goose.SetBaseFS(migrations) switch driver { case "sqlite": if err := goose.SetDialect("sqlite3"); err != nil { return fmt.Errorf("failed to set sqlite dialect: %w", err) } if err := goose.Up(db.DB().DB, "migrations"); err != nil { return fmt.Errorf("failed to run sqlite migrations: %w", err) } case "postgres": if err := goose.SetDialect("postgres"); err != nil { return fmt.Errorf("failed to set postgres dialect: %w", err) } if err := goose.Up(db.DB().DB, "migrations_postgres"); err != nil { return fmt.Errorf("failed to run postgres migrations: %w", err) } default: return fmt.Errorf("unsupported database driver: %s", driver) } // TODO - replace with goose /* driver, err := sqlite3.WithInstance(db.DB(), &sqlite3.Config{}) if err != nil { return fmt.Errorf("failed to create database driver: %w", err) } iofsDriver, err := iofs.New(migrations, "migrations") if err != nil { return fmt.Errorf("failed to create iofs: %w", err) } defer iofsDriver.Close() m, err := migrate.NewWithInstance("iofs", iofsDriver, "sqlite3", driver) if err != nil { return fmt.Errorf("failed to create migration: %w", err) } return m.Up() */ return nil } func normalizeDriver(driver string) string { normalized := strings.ToLower(strings.TrimSpace(driver)) switch normalized { case "", "sqlite3": return "sqlite" case "postgresql": return "postgres" default: return normalized } } // ResolveDriver determines the effective database driver. // If driver is unset and DSN looks like PostgreSQL, it infers postgres. func ResolveDriver(configuredDriver, dsn string) (driver string, inferredFromDSN bool, err error) { normalized := strings.ToLower(strings.TrimSpace(configuredDriver)) switch normalized { case "sqlite3": normalized = "sqlite" case "postgresql": normalized = "postgres" } if normalized == "" { if looksLikePostgresDSN(dsn) { return "postgres", true, nil } return "sqlite", false, nil } if normalized == "sqlite" && looksLikePostgresDSN(dsn) { return "", false, fmt.Errorf("database_driver is sqlite but database_url looks like a postgres DSN; set settings.database_driver=postgres") } return normalized, false, nil } func looksLikePostgresDSN(dsn string) bool { trimmed := strings.ToLower(strings.TrimSpace(dsn)) if trimmed == "" { return false } if strings.HasPrefix(trimmed, "postgres://") || strings.HasPrefix(trimmed, "postgresql://") { return true } // Also support key=value style PostgreSQL DSNs. if strings.Contains(trimmed, "=") { hasHost := strings.Contains(trimmed, "host=") hasUser := strings.Contains(trimmed, "user=") hasDB := strings.Contains(trimmed, "dbname=") hasSSL := strings.Contains(trimmed, "sslmode=") if (hasHost && hasUser) || (hasHost && hasDB) || (hasUser && hasDB) || (hasHost && hasSSL) { return true } } return false } // ConvertToSQLParams is a utility function that generically converts a struct to a corresponding sqlc-generated struct func ConvertToSQLParams(input interface{}, output interface{}) { inputVal := reflect.ValueOf(input).Elem() outputVal := reflect.ValueOf(output).Elem() for i := 0; i < outputVal.NumField(); i++ { outputField := outputVal.Field(i) inputField := inputVal.FieldByName(outputVal.Type().Field(i).Name) if !inputField.IsValid() || !outputField.CanSet() { continue } // Handle fields of type sql.NullString, sql.NullInt64, and normal string/int64 fields switch outputField.Type() { case reflect.TypeOf(sql.NullString{}): // Handle sql.NullString if inputField.Kind() == reflect.Ptr && inputField.IsNil() { outputField.Set(reflect.ValueOf(sql.NullString{Valid: false})) } else { outputField.Set(reflect.ValueOf(sql.NullString{String: inputField.String(), Valid: true})) } case reflect.TypeOf(sql.NullInt64{}): // Handle sql.NullInt64 if inputField.Int() == 0 { outputField.Set(reflect.ValueOf(sql.NullInt64{Valid: false})) } else { outputField.Set(reflect.ValueOf(sql.NullInt64{Int64: inputField.Int(), Valid: true})) } case reflect.TypeOf(sql.NullFloat64{}): // Handle sql.NullFloat64 if inputField.Float() == 0 { outputField.Set(reflect.ValueOf(sql.NullFloat64{Valid: false})) } else { outputField.Set(reflect.ValueOf(sql.NullFloat64{Float64: inputField.Float(), Valid: true})) } case reflect.TypeOf(""): // Handle normal string fields if inputField.Kind() == reflect.Ptr && inputField.IsNil() { outputField.SetString("") // Set to empty string if input is nil } else { outputField.SetString(inputField.String()) } case reflect.TypeOf(int64(0)): // Handle normal int64 fields outputField.SetInt(inputField.Int()) case reflect.TypeOf(float64(0)): // Handle normal float64 fields outputField.SetFloat(inputField.Float()) } } }