Refactor code to use 'any' type and improve context handling
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2026-02-18 16:16:27 +11:00
parent 6517a30fa2
commit f2d6b3158b
36 changed files with 197 additions and 175 deletions

View File

@@ -168,7 +168,7 @@ func looksLikePostgresDSN(dsn string) bool {
}
// ConvertToSQLParams is a utility function that generically converts a struct to a corresponding sqlc-generated struct
func ConvertToSQLParams(input interface{}, output interface{}) {
func ConvertToSQLParams(input any, output any) {
inputVal := reflect.ValueOf(input).Elem()
outputVal := reflect.ValueOf(output).Elem()
@@ -182,15 +182,15 @@ func ConvertToSQLParams(input interface{}, output interface{}) {
// Handle fields of type sql.NullString, sql.NullInt64, and normal string/int64 fields
switch outputField.Type() {
case reflect.TypeOf(sql.NullString{}):
case reflect.TypeFor[sql.NullString]():
// Handle sql.NullString
if inputField.Kind() == reflect.Ptr && inputField.IsNil() {
if inputField.Kind() == reflect.Pointer && inputField.IsNil() {
outputField.Set(reflect.ValueOf(sql.NullString{Valid: false}))
} else {
outputField.Set(reflect.ValueOf(sql.NullString{String: inputField.String(), Valid: true}))
}
case reflect.TypeOf(sql.NullInt64{}):
case reflect.TypeFor[sql.NullInt64]():
// Handle sql.NullInt64
if inputField.Int() == 0 {
outputField.Set(reflect.ValueOf(sql.NullInt64{Valid: false}))
@@ -198,7 +198,7 @@ func ConvertToSQLParams(input interface{}, output interface{}) {
outputField.Set(reflect.ValueOf(sql.NullInt64{Int64: inputField.Int(), Valid: true}))
}
case reflect.TypeOf(sql.NullFloat64{}):
case reflect.TypeFor[sql.NullFloat64]():
// Handle sql.NullFloat64
if inputField.Float() == 0 {
outputField.Set(reflect.ValueOf(sql.NullFloat64{Valid: false}))
@@ -206,19 +206,19 @@ func ConvertToSQLParams(input interface{}, output interface{}) {
outputField.Set(reflect.ValueOf(sql.NullFloat64{Float64: inputField.Float(), Valid: true}))
}
case reflect.TypeOf(""):
case reflect.TypeFor[string]():
// Handle normal string fields
if inputField.Kind() == reflect.Ptr && inputField.IsNil() {
if inputField.Kind() == reflect.Pointer && inputField.IsNil() {
outputField.SetString("") // Set to empty string if input is nil
} else {
outputField.SetString(inputField.String())
}
case reflect.TypeOf(int64(0)):
case reflect.TypeFor[int64]():
// Handle normal int64 fields
outputField.SetInt(inputField.Int())
case reflect.TypeOf(float64(0)):
case reflect.TypeFor[float64]():
// Handle normal float64 fields
outputField.SetFloat(inputField.Float())

View File

@@ -45,8 +45,32 @@ type ensureOnceState struct {
done bool
}
type loggerContextKey struct{}
var ensureOnceRegistry sync.Map
// WithLoggerContext stores a logger in context for downstream DB helper logging.
func WithLoggerContext(ctx context.Context, logger *slog.Logger) context.Context {
if ctx == nil {
ctx = context.Background()
}
if logger == nil {
return ctx
}
return context.WithValue(ctx, loggerContextKey{}, logger)
}
// LoggerFromContext returns a logger previously stored via WithLoggerContext.
func LoggerFromContext(ctx context.Context) *slog.Logger {
if ctx == nil {
return nil
}
if logger, ok := ctx.Value(loggerContextKey{}).(*slog.Logger); ok && logger != nil {
return logger
}
return nil
}
// ensureOncePerDB runs fn once per DB connection for a given logical key.
// The function is considered complete only when fn returns nil.
func ensureOncePerDB(dbConn *sqlx.DB, name string, fn func() error) error {
@@ -98,7 +122,7 @@ func EnsureColumns(ctx context.Context, dbConn *sqlx.DB, tableName string, colum
return nil
}
func execLog(ctx context.Context, dbConn *sqlx.DB, query string, args ...interface{}) (sql.Result, error) {
func execLog(ctx context.Context, dbConn *sqlx.DB, query string, args ...any) (sql.Result, error) {
res, err := dbConn.ExecContext(ctx, query, args...)
if err != nil {
q := strings.TrimSpace(query)
@@ -120,7 +144,7 @@ func execLog(ctx context.Context, dbConn *sqlx.DB, query string, args ...interfa
return res, err
}
func getLog(ctx context.Context, dbConn *sqlx.DB, dest interface{}, query string, args ...interface{}) error {
func getLog(ctx context.Context, dbConn *sqlx.DB, dest any, query string, args ...any) error {
err := dbConn.GetContext(ctx, dest, query, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@@ -137,7 +161,7 @@ func getLog(ctx context.Context, dbConn *sqlx.DB, dest interface{}, query string
return err
}
func selectLog(ctx context.Context, dbConn *sqlx.DB, dest interface{}, query string, args ...interface{}) error {
func selectLog(ctx context.Context, dbConn *sqlx.DB, dest any, query string, args ...any) error {
err := dbConn.SelectContext(ctx, dest, query, args...)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
@@ -602,7 +626,7 @@ func ApplySQLiteTuning(ctx context.Context, dbConn *sqlx.DB) {
pragmaCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
_, err := execLog(pragmaCtx, dbConn, pragma)
cancel()
if logger, ok := ctx.Value("logger").(*slog.Logger); ok && logger != nil {
if logger := LoggerFromContext(ctx); logger != nil {
logger.Debug("Applied SQLite tuning pragma", "pragma", pragma, "error", err)
}
}
@@ -783,7 +807,7 @@ ON CONFLICT ("Vcenter","VmId","VmUuid") DO UPDATE SET
"DeletedAt"=NULL
`
query = sqlx.Rebind(bindType, query)
args := []interface{}{vcenter, vmID, vmUUID, name, cluster, firstSeen, seen.Unix()}
args := []any{vcenter, vmID, vmUUID, name, cluster, firstSeen, seen.Unix()}
_, err := dbConn.ExecContext(ctx, query, args...)
if err != nil {
slog.Warn("lifecycle upsert exec failed", "vcenter", vcenter, "vm_id", vmID, "vm_uuid", vmUUID, "driver", driver, "args_len", len(args), "args", fmt.Sprint(args), "query", strings.TrimSpace(query), "error", err)
@@ -814,7 +838,7 @@ ON CONFLICT ("Vcenter","VmId","VmUuid") DO UPDATE SET
"Cluster"=COALESCE(NULLIF(vm_lifecycle_cache."Cluster", ''), EXCLUDED."Cluster")
`
query = sqlx.Rebind(bindType, query)
args := []interface{}{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt}
args := []any{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt}
_, err := dbConn.ExecContext(ctx, query, args...)
if err != nil {
slog.Warn("lifecycle delete exec failed", "vcenter", vcenter, "vm_id", vmID, "vm_uuid", vmUUID, "driver", driver, "args_len", len(args), "args", fmt.Sprint(args), "query", strings.TrimSpace(query), "error", err)
@@ -845,7 +869,7 @@ ON CONFLICT ("Vcenter","VmId","VmUuid") DO UPDATE SET
"Cluster"=COALESCE(NULLIF(vm_lifecycle_cache."Cluster", ''), EXCLUDED."Cluster")
`
query = sqlx.Rebind(bindType, query)
args := []interface{}{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt}
args := []any{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt}
_, err := dbConn.ExecContext(ctx, query, args...)
if err != nil {
slog.Warn("lifecycle delete event exec failed", "vcenter", vcenter, "vm_id", vmID, "vm_uuid", vmUUID, "driver", driver, "args_len", len(args), "args", fmt.Sprint(args), "query", strings.TrimSpace(query), "error", err)
@@ -988,7 +1012,7 @@ ON CONFLICT ("Date","Vcenter","VmId","VmUuid") DO UPDATE SET
"PoweredOn"=$25,
"SrmPlaceholder"=$26
`
args := []interface{}{
args := []any{
day, v.Vcenter, v.VmId, v.VmUuid, v.Name, v.CreationTime, v.DeletionTime, v.SamplesPresent, v.TotalSamples,
v.SumVcpu, v.SumRam, v.SumDisk, v.TinHits, v.BronzeHits, v.SilverHits, v.GoldHits,
v.LastResourcePool, v.LastDatacenter, v.LastCluster, v.LastFolder, v.LastProvisionedDisk, v.LastVcpuCount, v.LastRamGB, v.IsTemplate, v.PoweredOn, v.SrmPlaceholder,
@@ -1469,7 +1493,7 @@ WHERE "Vcenter" = $4 AND "VmId" = $5 AND "VmUuid" = $6
return err
}
func nullString(val sql.NullString) interface{} {
func nullString(val sql.NullString) any {
if val.Valid {
return val.String
}
@@ -2088,17 +2112,17 @@ type VmLifecycleDiagnostics struct {
FinalLifecycle VmLifecycle
}
func vmLookupPredicate(vmID, vmUUID, name string) (string, []interface{}, bool) {
func vmLookupPredicate(vmID, vmUUID, name string) (string, []any, bool) {
vmID = strings.TrimSpace(vmID)
vmUUID = strings.TrimSpace(vmUUID)
name = strings.TrimSpace(name)
switch {
case vmID != "":
return `"VmId" = ?`, []interface{}{vmID}, true
return `"VmId" = ?`, []any{vmID}, true
case vmUUID != "":
return `"VmUuid" = ?`, []interface{}{vmUUID}, true
return `"VmUuid" = ?`, []any{vmUUID}, true
case name != "":
return `lower("Name") = ?`, []interface{}{strings.ToLower(name)}, true
return `lower("Name") = ?`, []any{strings.ToLower(name)}, true
default:
return "", nil, false
}
@@ -2960,7 +2984,7 @@ SET "AvgIsPresent" = CASE
END
`, summaryTable, endExpr, startExpr, endExpr, startExpr)
query = dbConn.Rebind(query)
args := []interface{}{
args := []any{
windowEnd, windowEnd,
windowStart, windowStart,
windowEnd, windowEnd,
@@ -3530,7 +3554,7 @@ FROM snapshot_runs
WHERE "Success" = 'FALSE' AND "Attempts" < ?
ORDER BY "LastAttempt" ASC
`
args := []interface{}{maxAttempts}
args := []any{maxAttempts}
if driver == "pgx" || driver == "postgres" {
query = `
SELECT "Vcenter","SnapshotTime","Attempts"

View File

@@ -131,7 +131,7 @@ INSERT INTO vm_hourly_stats (
"Datacenter","Cluster","Folder","ProvisionedDisk","VcpuCount","RamGB","IsTemplate","PoweredOn","SrmPlaceholder"
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
`
rows := [][]interface{}{
rows := [][]any{
{int64(1000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(900), int64(0), "Tin", "dc", "cluster", "folder", 100.0, int64(2), int64(4), "FALSE", "TRUE", "FALSE"},
{int64(2000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(900), int64(0), "Gold", "dc", "cluster", "folder", 150.0, int64(4), int64(8), "FALSE", "TRUE", "FALSE"},
}
@@ -371,7 +371,7 @@ INSERT INTO vm_hourly_stats (
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
`
// First row carries an old deletion marker, later row proves VM is still present.
rows := [][]interface{}{
rows := [][]any{
{int64(1700000000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(1699999000), int64(1700003600), "Tin", "dc", "cluster", "folder", 100.0, int64(2), int64(4), "FALSE", "TRUE", "FALSE"},
{int64(1700100000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(1699999000), int64(0), "Gold", "dc", "cluster", "folder", 120.0, int64(4), int64(8), "FALSE", "TRUE", "FALSE"},
}
@@ -563,7 +563,7 @@ CREATE TABLE %s (
INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB")
VALUES (?,?,?,?,?,?)
`, summaryTable)
rows := [][]interface{}{
rows := [][]any{
{"vc-a", "vm-1", "1", "u1", 2.0, 4.0},
{"vc-a", "vm-2", "2", "u2", 3.0, 5.0},
{"vc-b", "vm-3", "3", "u3", 1.0, 2.0},
@@ -633,7 +633,7 @@ CREATE TABLE %s (
INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB")
VALUES (?,?,?,?,?,?)
`, summaryTable)
for _, args := range [][]interface{}{
for _, args := range [][]any{
{"vc-a", "vm-1", "1", "u1", 4.0, 8.0},
{"vc-a", "vm-2", "2", "u2", 2.0, 6.0},
} {
@@ -697,7 +697,7 @@ CREATE TABLE %s (
}
insert1 := fmt.Sprintf(`INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB") VALUES (?,?,?,?,?,?)`, table1)
insert2 := fmt.Sprintf(`INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB") VALUES (?,?,?,?,?,?)`, table2)
for _, args := range [][]interface{}{
for _, args := range [][]any{
{"vc-a", "vm-1", "1", "u1", 2.0, 4.0},
{"vc-b", "vm-2", "2", "u2", 3.0, 5.0},
} {

View File

@@ -55,7 +55,7 @@ type rebindDBTX struct {
db *sqlx.DB
}
func (r rebindDBTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
func (r rebindDBTX) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return r.db.ExecContext(ctx, rebindQuery(query), args...)
}
@@ -63,11 +63,11 @@ func (r rebindDBTX) PrepareContext(ctx context.Context, query string) (*sql.Stmt
return r.db.PrepareContext(ctx, rebindQuery(query))
}
func (r rebindDBTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
func (r rebindDBTX) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return r.db.QueryContext(ctx, rebindQuery(query), args...)
}
func (r rebindDBTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
func (r rebindDBTX) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return r.db.QueryRowContext(ctx, rebindQuery(query), args...)
}

View File

@@ -237,15 +237,15 @@ func copySQLiteTableIntoPostgres(ctx context.Context, source *sqlx.DB, destinati
var rowsCopied int64
for rows.Next() {
rawValues := make([]interface{}, len(columns))
scanTargets := make([]interface{}, len(columns))
rawValues := make([]any, len(columns))
scanTargets := make([]any, len(columns))
for i := range rawValues {
scanTargets[i] = &rawValues[i]
}
if err := rows.Scan(scanTargets...); err != nil {
return rowsCopied, fmt.Errorf("failed to scan row from sqlite table %q: %w", tableName, err)
}
args := make([]interface{}, len(columns))
args := make([]any, len(columns))
for i, col := range columns {
args[i] = coerceSQLiteValueForPostgres(rawValues[i], col.DestinationType)
}
@@ -366,7 +366,7 @@ ORDER BY ordinal_position
return nil
}
func coerceSQLiteValueForPostgres(value interface{}, destinationType string) interface{} {
func coerceSQLiteValueForPostgres(value any, destinationType string) any {
if value == nil {
return nil
}
@@ -385,7 +385,7 @@ func coerceSQLiteValueForPostgres(value interface{}, destinationType string) int
return value
}
func coerceBoolValue(value interface{}) (bool, bool) {
func coerceBoolValue(value any) (bool, bool) {
switch v := value.(type) {
case bool:
return v, true

View File

@@ -51,9 +51,9 @@ func TestIntersectImportColumns(t *testing.T) {
func TestCoerceSQLiteValueForPostgresBoolean(t *testing.T) {
tests := []struct {
name string
input interface{}
input any
destinationType string
want interface{}
want any
}{
{name: "string true", input: "true", destinationType: "boolean", want: true},
{name: "string false", input: "0", destinationType: "boolean", want: false},