All checks were successful
continuous-integration/drone/push Build is passing
170 lines
4.7 KiB
Go
170 lines
4.7 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"vctp/db/queries"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
// SnapshotTotals summarizes counts and allocations for snapshot tables.
|
|
type SnapshotTotals struct {
|
|
VmCount int64 `db:"vm_count"`
|
|
VcpuTotal int64 `db:"vcpu_total"`
|
|
RamTotal int64 `db:"ram_total"`
|
|
DiskTotal float64 `db:"disk_total"`
|
|
}
|
|
|
|
// ValidateTableName ensures table identifiers are safe for interpolation.
|
|
func ValidateTableName(name string) error {
|
|
if name == "" {
|
|
return fmt.Errorf("table name is empty")
|
|
}
|
|
for _, r := range name {
|
|
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' {
|
|
continue
|
|
}
|
|
return fmt.Errorf("invalid table name: %s", name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SafeTableName returns the name if it passes validation.
|
|
func SafeTableName(name string) (string, error) {
|
|
if err := ValidateTableName(name); err != nil {
|
|
return "", err
|
|
}
|
|
return name, nil
|
|
}
|
|
|
|
// TableHasRows returns true when a table contains at least one row.
|
|
func TableHasRows(ctx context.Context, dbConn *sqlx.DB, table string) (bool, error) {
|
|
if err := ValidateTableName(table); err != nil {
|
|
return false, err
|
|
}
|
|
query := fmt.Sprintf(`SELECT 1 FROM %s LIMIT 1`, table)
|
|
var exists int
|
|
if err := dbConn.GetContext(ctx, &exists, query); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// TableExists checks if a table exists in the current schema.
|
|
func TableExists(ctx context.Context, dbConn *sqlx.DB, table string) bool {
|
|
driver := strings.ToLower(dbConn.DriverName())
|
|
switch driver {
|
|
case "sqlite":
|
|
q := queries.New(dbConn)
|
|
count, err := q.SqliteTableExists(ctx, sql.NullString{String: table, Valid: table != ""})
|
|
return err == nil && count > 0
|
|
case "pgx", "postgres":
|
|
var count int
|
|
err := dbConn.GetContext(ctx, &count, `
|
|
SELECT COUNT(1)
|
|
FROM pg_catalog.pg_tables
|
|
WHERE schemaname = 'public' AND tablename = $1
|
|
`, table)
|
|
return err == nil && count > 0
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// ColumnExists checks if a column exists in a table.
|
|
func ColumnExists(ctx context.Context, dbConn *sqlx.DB, tableName string, columnName string) (bool, error) {
|
|
driver := strings.ToLower(dbConn.DriverName())
|
|
switch driver {
|
|
case "sqlite":
|
|
if _, err := SafeTableName(tableName); err != nil {
|
|
return false, err
|
|
}
|
|
query := fmt.Sprintf(`PRAGMA table_info("%s")`, tableName)
|
|
rows, err := dbConn.QueryxContext(ctx, query)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var (
|
|
cid int
|
|
name string
|
|
colType string
|
|
notNull int
|
|
defaultVal sql.NullString
|
|
pk int
|
|
)
|
|
if err := rows.Scan(&cid, &name, &colType, ¬Null, &defaultVal, &pk); err != nil {
|
|
return false, err
|
|
}
|
|
if strings.EqualFold(name, columnName) {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, rows.Err()
|
|
case "pgx", "postgres":
|
|
var count int
|
|
err := dbConn.GetContext(ctx, &count, `
|
|
SELECT COUNT(1)
|
|
FROM information_schema.columns
|
|
WHERE table_name = $1 AND column_name = $2
|
|
`, tableName, strings.ToLower(columnName))
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
default:
|
|
return false, fmt.Errorf("unsupported driver for column lookup: %s", driver)
|
|
}
|
|
}
|
|
|
|
// SnapshotTotalsForTable returns totals for a snapshot table.
|
|
func SnapshotTotalsForTable(ctx context.Context, dbConn *sqlx.DB, table string) (SnapshotTotals, error) {
|
|
if _, err := SafeTableName(table); err != nil {
|
|
return SnapshotTotals{}, err
|
|
}
|
|
query := fmt.Sprintf(`
|
|
SELECT
|
|
COUNT(DISTINCT "VmId") AS vm_count,
|
|
COALESCE(SUM(CASE WHEN "VcpuCount" IS NOT NULL THEN "VcpuCount" ELSE 0 END), 0) AS vcpu_total,
|
|
COALESCE(SUM(CASE WHEN "RamGB" IS NOT NULL THEN "RamGB" ELSE 0 END), 0) AS ram_total,
|
|
COALESCE(SUM(CASE WHEN "ProvisionedDisk" IS NOT NULL THEN "ProvisionedDisk" ELSE 0 END), 0) AS disk_total
|
|
FROM %s
|
|
WHERE "IsPresent" = 'TRUE'
|
|
`, table)
|
|
|
|
var totals SnapshotTotals
|
|
if err := dbConn.GetContext(ctx, &totals, query); err != nil {
|
|
return SnapshotTotals{}, err
|
|
}
|
|
return totals, nil
|
|
}
|
|
|
|
// SnapshotTotalsForUnion returns totals for a union query of snapshots.
|
|
func SnapshotTotalsForUnion(ctx context.Context, dbConn *sqlx.DB, unionQuery string) (SnapshotTotals, error) {
|
|
query := fmt.Sprintf(`
|
|
SELECT
|
|
COUNT(DISTINCT "VmId") AS vm_count,
|
|
COALESCE(SUM(CASE WHEN "VcpuCount" IS NOT NULL THEN "VcpuCount" ELSE 0 END), 0) AS vcpu_total,
|
|
COALESCE(SUM(CASE WHEN "RamGB" IS NOT NULL THEN "RamGB" ELSE 0 END), 0) AS ram_total,
|
|
COALESCE(SUM(CASE WHEN "ProvisionedDisk" IS NOT NULL THEN "ProvisionedDisk" ELSE 0 END), 0) AS disk_total
|
|
FROM (
|
|
%s
|
|
) snapshots
|
|
WHERE "IsPresent" = 'TRUE'
|
|
`, unionQuery)
|
|
|
|
var totals SnapshotTotals
|
|
if err := dbConn.GetContext(ctx, &totals, query); err != nil {
|
|
return SnapshotTotals{}, err
|
|
}
|
|
return totals, nil
|
|
}
|