Files
vctp2/db/db.go
Nathan Coad ea1eeb5c21
Some checks failed
continuous-integration/drone Build is passing
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / End-to-End (push) Has been cancelled
CI / Publish Docker (push) Has been cancelled
update to support postgresql and add godocs
2026-01-13 17:05:14 +11:00

175 lines
4.4 KiB
Go

package db
import (
"database/sql"
"embed"
"fmt"
"log/slog"
"reflect"
"strings"
"github.com/jmoiron/sqlx"
"github.com/pressly/goose/v3"
)
//go:embed migrations migrations_postgres
var migrations embed.FS
type Database interface {
DB() *sqlx.DB
Queries() Querier
Logger() *slog.Logger
Close() error
}
type Config struct {
Driver string
DSN string
}
func New(logger *slog.Logger, cfg Config) (Database, error) {
driver := normalizeDriver(cfg.Driver)
switch driver {
case "sqlite":
db, err := newLocalDB(logger, cfg.DSN)
if err != nil {
return nil, err
}
if err = db.db.Ping(); err != nil {
return nil, err
}
return db, nil
case "postgres":
db, err := newPostgresDB(logger, cfg.DSN)
if err != nil {
return nil, err
}
if err = db.db.Ping(); err != nil {
return nil, err
}
return db, nil
default:
return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver)
}
}
// Migrate runs the migrations on the database.
func Migrate(db Database, driver string) error {
driver = normalizeDriver(driver)
goose.SetBaseFS(migrations)
switch driver {
case "sqlite":
if err := goose.SetDialect("sqlite3"); err != nil {
return fmt.Errorf("failed to set sqlite dialect: %w", err)
}
if err := goose.Up(db.DB().DB, "migrations"); err != nil {
return fmt.Errorf("failed to run sqlite migrations: %w", err)
}
case "postgres":
if err := goose.SetDialect("postgres"); err != nil {
return fmt.Errorf("failed to set postgres dialect: %w", err)
}
if err := goose.Up(db.DB().DB, "migrations_postgres"); err != nil {
return fmt.Errorf("failed to run postgres migrations: %w", err)
}
default:
return fmt.Errorf("unsupported database driver: %s", driver)
}
// TODO - replace with goose
/*
driver, err := sqlite3.WithInstance(db.DB(), &sqlite3.Config{})
if err != nil {
return fmt.Errorf("failed to create database driver: %w", err)
}
iofsDriver, err := iofs.New(migrations, "migrations")
if err != nil {
return fmt.Errorf("failed to create iofs: %w", err)
}
defer iofsDriver.Close()
m, err := migrate.NewWithInstance("iofs", iofsDriver, "sqlite3", driver)
if err != nil {
return fmt.Errorf("failed to create migration: %w", err)
}
return m.Up()
*/
return nil
}
func normalizeDriver(driver string) string {
normalized := strings.ToLower(strings.TrimSpace(driver))
switch normalized {
case "", "sqlite3":
return "sqlite"
case "postgresql":
return "postgres"
default:
return normalized
}
}
// ConvertToSQLParams is a utility function that generically converts a struct to a corresponding sqlc-generated struct
func ConvertToSQLParams(input interface{}, output interface{}) {
inputVal := reflect.ValueOf(input).Elem()
outputVal := reflect.ValueOf(output).Elem()
for i := 0; i < outputVal.NumField(); i++ {
outputField := outputVal.Field(i)
inputField := inputVal.FieldByName(outputVal.Type().Field(i).Name)
if !inputField.IsValid() || !outputField.CanSet() {
continue
}
// Handle fields of type sql.NullString, sql.NullInt64, and normal string/int64 fields
switch outputField.Type() {
case reflect.TypeOf(sql.NullString{}):
// Handle sql.NullString
if inputField.Kind() == reflect.Ptr && inputField.IsNil() {
outputField.Set(reflect.ValueOf(sql.NullString{Valid: false}))
} else {
outputField.Set(reflect.ValueOf(sql.NullString{String: inputField.String(), Valid: true}))
}
case reflect.TypeOf(sql.NullInt64{}):
// Handle sql.NullInt64
if inputField.Int() == 0 {
outputField.Set(reflect.ValueOf(sql.NullInt64{Valid: false}))
} else {
outputField.Set(reflect.ValueOf(sql.NullInt64{Int64: inputField.Int(), Valid: true}))
}
case reflect.TypeOf(sql.NullFloat64{}):
// Handle sql.NullFloat64
if inputField.Float() == 0 {
outputField.Set(reflect.ValueOf(sql.NullFloat64{Valid: false}))
} else {
outputField.Set(reflect.ValueOf(sql.NullFloat64{Float64: inputField.Float(), Valid: true}))
}
case reflect.TypeOf(""):
// Handle normal string fields
if inputField.Kind() == reflect.Ptr && inputField.IsNil() {
outputField.SetString("") // Set to empty string if input is nil
} else {
outputField.SetString(inputField.String())
}
case reflect.TypeOf(int64(0)):
// Handle normal int64 fields
outputField.SetInt(inputField.Int())
case reflect.TypeOf(float64(0)):
// Handle normal float64 fields
outputField.SetFloat(inputField.Float())
}
}
}