All checks were successful
continuous-integration/drone/push Build is passing
228 lines
6.1 KiB
Go
228 lines
6.1 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"log/slog"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/pressly/goose/v3"
|
|
)
|
|
|
|
//go:embed migrations migrations_postgres
|
|
var migrations embed.FS
|
|
|
|
type Database interface {
|
|
DB() *sqlx.DB
|
|
Queries() Querier
|
|
Logger() *slog.Logger
|
|
Close() error
|
|
}
|
|
|
|
type Config struct {
|
|
Driver string
|
|
DSN string
|
|
EnableExperimentalPostgres bool
|
|
}
|
|
|
|
func New(logger *slog.Logger, cfg Config) (Database, error) {
|
|
driver := normalizeDriver(cfg.Driver)
|
|
switch driver {
|
|
case "sqlite":
|
|
db, err := newLocalDB(logger, cfg.DSN)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err = db.db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
return db, nil
|
|
case "postgres":
|
|
// The sqlc query set is SQLite-first. Keep Postgres opt-in until full parity is validated.
|
|
if !cfg.EnableExperimentalPostgres {
|
|
return nil, fmt.Errorf("postgres driver is disabled by default; set settings.enable_experimental_postgres=true to enable experimental mode")
|
|
}
|
|
db, err := newPostgresDB(logger, cfg.DSN)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err = db.db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
return db, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver)
|
|
}
|
|
}
|
|
|
|
// Migrate runs the migrations on the database.
|
|
func Migrate(db Database, driver string) error {
|
|
driver = normalizeDriver(driver)
|
|
|
|
goose.SetBaseFS(migrations)
|
|
|
|
switch driver {
|
|
case "sqlite":
|
|
if err := goose.SetDialect("sqlite3"); err != nil {
|
|
return fmt.Errorf("failed to set sqlite dialect: %w", err)
|
|
}
|
|
if err := goose.Up(db.DB().DB, "migrations"); err != nil {
|
|
return fmt.Errorf("failed to run sqlite migrations: %w", err)
|
|
}
|
|
case "postgres":
|
|
if err := goose.SetDialect("postgres"); err != nil {
|
|
return fmt.Errorf("failed to set postgres dialect: %w", err)
|
|
}
|
|
if err := goose.Up(db.DB().DB, "migrations_postgres"); err != nil {
|
|
return fmt.Errorf("failed to run postgres migrations: %w", err)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported database driver: %s", driver)
|
|
}
|
|
|
|
// TODO - replace with goose
|
|
/*
|
|
driver, err := sqlite3.WithInstance(db.DB(), &sqlite3.Config{})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create database driver: %w", err)
|
|
}
|
|
|
|
iofsDriver, err := iofs.New(migrations, "migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create iofs: %w", err)
|
|
}
|
|
defer iofsDriver.Close()
|
|
|
|
m, err := migrate.NewWithInstance("iofs", iofsDriver, "sqlite3", driver)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migration: %w", err)
|
|
}
|
|
|
|
return m.Up()
|
|
*/
|
|
|
|
return nil
|
|
}
|
|
|
|
func normalizeDriver(driver string) string {
|
|
normalized := strings.ToLower(strings.TrimSpace(driver))
|
|
switch normalized {
|
|
case "", "sqlite3":
|
|
return "sqlite"
|
|
case "postgresql":
|
|
return "postgres"
|
|
default:
|
|
return normalized
|
|
}
|
|
}
|
|
|
|
// ResolveDriver determines the effective database driver.
|
|
// If driver is unset and DSN looks like PostgreSQL, it infers postgres.
|
|
func ResolveDriver(configuredDriver, dsn string) (driver string, inferredFromDSN bool, err error) {
|
|
normalized := strings.ToLower(strings.TrimSpace(configuredDriver))
|
|
switch normalized {
|
|
case "sqlite3":
|
|
normalized = "sqlite"
|
|
case "postgresql":
|
|
normalized = "postgres"
|
|
}
|
|
|
|
if normalized == "" {
|
|
if looksLikePostgresDSN(dsn) {
|
|
return "postgres", true, nil
|
|
}
|
|
return "sqlite", false, nil
|
|
}
|
|
|
|
if normalized == "sqlite" && looksLikePostgresDSN(dsn) {
|
|
return "", false, fmt.Errorf("database_driver is sqlite but database_url looks like a postgres DSN; set settings.database_driver=postgres")
|
|
}
|
|
|
|
return normalized, false, nil
|
|
}
|
|
|
|
func looksLikePostgresDSN(dsn string) bool {
|
|
trimmed := strings.ToLower(strings.TrimSpace(dsn))
|
|
if trimmed == "" {
|
|
return false
|
|
}
|
|
if strings.HasPrefix(trimmed, "postgres://") || strings.HasPrefix(trimmed, "postgresql://") {
|
|
return true
|
|
}
|
|
|
|
// Also support key=value style PostgreSQL DSNs.
|
|
if strings.Contains(trimmed, "=") {
|
|
hasHost := strings.Contains(trimmed, "host=")
|
|
hasUser := strings.Contains(trimmed, "user=")
|
|
hasDB := strings.Contains(trimmed, "dbname=")
|
|
hasSSL := strings.Contains(trimmed, "sslmode=")
|
|
if (hasHost && hasUser) || (hasHost && hasDB) || (hasUser && hasDB) || (hasHost && hasSSL) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// ConvertToSQLParams is a utility function that generically converts a struct to a corresponding sqlc-generated struct
|
|
func ConvertToSQLParams(input any, output any) {
|
|
inputVal := reflect.ValueOf(input).Elem()
|
|
outputVal := reflect.ValueOf(output).Elem()
|
|
|
|
for i := 0; i < outputVal.NumField(); i++ {
|
|
outputField := outputVal.Field(i)
|
|
inputField := inputVal.FieldByName(outputVal.Type().Field(i).Name)
|
|
|
|
if !inputField.IsValid() || !outputField.CanSet() {
|
|
continue
|
|
}
|
|
|
|
// Handle fields of type sql.NullString, sql.NullInt64, and normal string/int64 fields
|
|
switch outputField.Type() {
|
|
case reflect.TypeFor[sql.NullString]():
|
|
// Handle sql.NullString
|
|
if inputField.Kind() == reflect.Pointer && inputField.IsNil() {
|
|
outputField.Set(reflect.ValueOf(sql.NullString{Valid: false}))
|
|
} else {
|
|
outputField.Set(reflect.ValueOf(sql.NullString{String: inputField.String(), Valid: true}))
|
|
}
|
|
|
|
case reflect.TypeFor[sql.NullInt64]():
|
|
// Handle sql.NullInt64
|
|
if inputField.Int() == 0 {
|
|
outputField.Set(reflect.ValueOf(sql.NullInt64{Valid: false}))
|
|
} else {
|
|
outputField.Set(reflect.ValueOf(sql.NullInt64{Int64: inputField.Int(), Valid: true}))
|
|
}
|
|
|
|
case reflect.TypeFor[sql.NullFloat64]():
|
|
// Handle sql.NullFloat64
|
|
if inputField.Float() == 0 {
|
|
outputField.Set(reflect.ValueOf(sql.NullFloat64{Valid: false}))
|
|
} else {
|
|
outputField.Set(reflect.ValueOf(sql.NullFloat64{Float64: inputField.Float(), Valid: true}))
|
|
}
|
|
|
|
case reflect.TypeFor[string]():
|
|
// Handle normal string fields
|
|
if inputField.Kind() == reflect.Pointer && inputField.IsNil() {
|
|
outputField.SetString("") // Set to empty string if input is nil
|
|
} else {
|
|
outputField.SetString(inputField.String())
|
|
}
|
|
|
|
case reflect.TypeFor[int64]():
|
|
// Handle normal int64 fields
|
|
outputField.SetInt(inputField.Int())
|
|
|
|
case reflect.TypeFor[float64]():
|
|
// Handle normal float64 fields
|
|
outputField.SetFloat(inputField.Float())
|
|
|
|
}
|
|
}
|
|
}
|