package db import ( "context" "database/sql" "fmt" "log/slog" "regexp" "strings" "vctp/db/queries" _ "github.com/jackc/pgx/v5/stdlib" "github.com/jmoiron/sqlx" ) type PostgresDB struct { logger *slog.Logger db *sqlx.DB queries *queries.Queries } var _ Database = (*PostgresDB)(nil) func (d *PostgresDB) DB() *sqlx.DB { return d.db } func (d *PostgresDB) Queries() Querier { return d.queries } func (d *PostgresDB) Logger() *slog.Logger { return d.logger } func (d *PostgresDB) Close() error { return d.db.Close() } func newPostgresDB(logger *slog.Logger, dsn string) (*PostgresDB, error) { if strings.TrimSpace(dsn) == "" { return nil, fmt.Errorf("postgres DSN is required") } db, err := sqlx.Open("pgx", dsn) if err != nil { return nil, err } db.SetMaxOpenConns(10) rebindDB := rebindDBTX{db: db} return &PostgresDB{logger: logger, db: db, queries: queries.New(rebindDB)}, nil } type rebindDBTX struct { db *sqlx.DB } func (r rebindDBTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { return r.db.ExecContext(ctx, rebindQuery(query), args...) } func (r rebindDBTX) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { return r.db.PrepareContext(ctx, rebindQuery(query)) } func (r rebindDBTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { return r.db.QueryContext(ctx, rebindQuery(query), args...) } func (r rebindDBTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { return r.db.QueryRowContext(ctx, rebindQuery(query), args...) } var numberedPlaceholderRe = regexp.MustCompile(`\?\d+`) func rebindQuery(query string) string { unindexed := numberedPlaceholderRe.ReplaceAllString(query, "?") return sqlx.Rebind(sqlx.DOLLAR, unindexed) }