diff --git a/db/db.go b/db/db.go index 25ae36c..b0c76bf 100644 --- a/db/db.go +++ b/db/db.go @@ -168,7 +168,7 @@ func looksLikePostgresDSN(dsn string) bool { } // ConvertToSQLParams is a utility function that generically converts a struct to a corresponding sqlc-generated struct -func ConvertToSQLParams(input interface{}, output interface{}) { +func ConvertToSQLParams(input any, output any) { inputVal := reflect.ValueOf(input).Elem() outputVal := reflect.ValueOf(output).Elem() @@ -182,15 +182,15 @@ func ConvertToSQLParams(input interface{}, output interface{}) { // Handle fields of type sql.NullString, sql.NullInt64, and normal string/int64 fields switch outputField.Type() { - case reflect.TypeOf(sql.NullString{}): + case reflect.TypeFor[sql.NullString](): // Handle sql.NullString - if inputField.Kind() == reflect.Ptr && inputField.IsNil() { + if inputField.Kind() == reflect.Pointer && inputField.IsNil() { outputField.Set(reflect.ValueOf(sql.NullString{Valid: false})) } else { outputField.Set(reflect.ValueOf(sql.NullString{String: inputField.String(), Valid: true})) } - case reflect.TypeOf(sql.NullInt64{}): + case reflect.TypeFor[sql.NullInt64](): // Handle sql.NullInt64 if inputField.Int() == 0 { outputField.Set(reflect.ValueOf(sql.NullInt64{Valid: false})) @@ -198,7 +198,7 @@ func ConvertToSQLParams(input interface{}, output interface{}) { outputField.Set(reflect.ValueOf(sql.NullInt64{Int64: inputField.Int(), Valid: true})) } - case reflect.TypeOf(sql.NullFloat64{}): + case reflect.TypeFor[sql.NullFloat64](): // Handle sql.NullFloat64 if inputField.Float() == 0 { outputField.Set(reflect.ValueOf(sql.NullFloat64{Valid: false})) @@ -206,19 +206,19 @@ func ConvertToSQLParams(input interface{}, output interface{}) { outputField.Set(reflect.ValueOf(sql.NullFloat64{Float64: inputField.Float(), Valid: true})) } - case reflect.TypeOf(""): + case reflect.TypeFor[string](): // Handle normal string fields - if inputField.Kind() == reflect.Ptr && inputField.IsNil() { + if inputField.Kind() == reflect.Pointer && inputField.IsNil() { outputField.SetString("") // Set to empty string if input is nil } else { outputField.SetString(inputField.String()) } - case reflect.TypeOf(int64(0)): + case reflect.TypeFor[int64](): // Handle normal int64 fields outputField.SetInt(inputField.Int()) - case reflect.TypeOf(float64(0)): + case reflect.TypeFor[float64](): // Handle normal float64 fields outputField.SetFloat(inputField.Float()) diff --git a/db/helpers.go b/db/helpers.go index ceac960..bf00f58 100644 --- a/db/helpers.go +++ b/db/helpers.go @@ -45,8 +45,32 @@ type ensureOnceState struct { done bool } +type loggerContextKey struct{} + var ensureOnceRegistry sync.Map +// WithLoggerContext stores a logger in context for downstream DB helper logging. +func WithLoggerContext(ctx context.Context, logger *slog.Logger) context.Context { + if ctx == nil { + ctx = context.Background() + } + if logger == nil { + return ctx + } + return context.WithValue(ctx, loggerContextKey{}, logger) +} + +// LoggerFromContext returns a logger previously stored via WithLoggerContext. +func LoggerFromContext(ctx context.Context) *slog.Logger { + if ctx == nil { + return nil + } + if logger, ok := ctx.Value(loggerContextKey{}).(*slog.Logger); ok && logger != nil { + return logger + } + return nil +} + // ensureOncePerDB runs fn once per DB connection for a given logical key. // The function is considered complete only when fn returns nil. func ensureOncePerDB(dbConn *sqlx.DB, name string, fn func() error) error { @@ -98,7 +122,7 @@ func EnsureColumns(ctx context.Context, dbConn *sqlx.DB, tableName string, colum return nil } -func execLog(ctx context.Context, dbConn *sqlx.DB, query string, args ...interface{}) (sql.Result, error) { +func execLog(ctx context.Context, dbConn *sqlx.DB, query string, args ...any) (sql.Result, error) { res, err := dbConn.ExecContext(ctx, query, args...) if err != nil { q := strings.TrimSpace(query) @@ -120,7 +144,7 @@ func execLog(ctx context.Context, dbConn *sqlx.DB, query string, args ...interfa return res, err } -func getLog(ctx context.Context, dbConn *sqlx.DB, dest interface{}, query string, args ...interface{}) error { +func getLog(ctx context.Context, dbConn *sqlx.DB, dest any, query string, args ...any) error { err := dbConn.GetContext(ctx, dest, query, args...) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -137,7 +161,7 @@ func getLog(ctx context.Context, dbConn *sqlx.DB, dest interface{}, query string return err } -func selectLog(ctx context.Context, dbConn *sqlx.DB, dest interface{}, query string, args ...interface{}) error { +func selectLog(ctx context.Context, dbConn *sqlx.DB, dest any, query string, args ...any) error { err := dbConn.SelectContext(ctx, dest, query, args...) if err != nil { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { @@ -602,7 +626,7 @@ func ApplySQLiteTuning(ctx context.Context, dbConn *sqlx.DB) { pragmaCtx, cancel := context.WithTimeout(ctx, 2*time.Second) _, err := execLog(pragmaCtx, dbConn, pragma) cancel() - if logger, ok := ctx.Value("logger").(*slog.Logger); ok && logger != nil { + if logger := LoggerFromContext(ctx); logger != nil { logger.Debug("Applied SQLite tuning pragma", "pragma", pragma, "error", err) } } @@ -783,7 +807,7 @@ ON CONFLICT ("Vcenter","VmId","VmUuid") DO UPDATE SET "DeletedAt"=NULL ` query = sqlx.Rebind(bindType, query) - args := []interface{}{vcenter, vmID, vmUUID, name, cluster, firstSeen, seen.Unix()} + args := []any{vcenter, vmID, vmUUID, name, cluster, firstSeen, seen.Unix()} _, err := dbConn.ExecContext(ctx, query, args...) if err != nil { slog.Warn("lifecycle upsert exec failed", "vcenter", vcenter, "vm_id", vmID, "vm_uuid", vmUUID, "driver", driver, "args_len", len(args), "args", fmt.Sprint(args), "query", strings.TrimSpace(query), "error", err) @@ -814,7 +838,7 @@ ON CONFLICT ("Vcenter","VmId","VmUuid") DO UPDATE SET "Cluster"=COALESCE(NULLIF(vm_lifecycle_cache."Cluster", ''), EXCLUDED."Cluster") ` query = sqlx.Rebind(bindType, query) - args := []interface{}{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt} + args := []any{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt} _, err := dbConn.ExecContext(ctx, query, args...) if err != nil { slog.Warn("lifecycle delete exec failed", "vcenter", vcenter, "vm_id", vmID, "vm_uuid", vmUUID, "driver", driver, "args_len", len(args), "args", fmt.Sprint(args), "query", strings.TrimSpace(query), "error", err) @@ -845,7 +869,7 @@ ON CONFLICT ("Vcenter","VmId","VmUuid") DO UPDATE SET "Cluster"=COALESCE(NULLIF(vm_lifecycle_cache."Cluster", ''), EXCLUDED."Cluster") ` query = sqlx.Rebind(bindType, query) - args := []interface{}{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt} + args := []any{vcenter, vmID, vmUUID, name, cluster, deletedAt, deletedAt, deletedAt} _, err := dbConn.ExecContext(ctx, query, args...) if err != nil { slog.Warn("lifecycle delete event exec failed", "vcenter", vcenter, "vm_id", vmID, "vm_uuid", vmUUID, "driver", driver, "args_len", len(args), "args", fmt.Sprint(args), "query", strings.TrimSpace(query), "error", err) @@ -988,7 +1012,7 @@ ON CONFLICT ("Date","Vcenter","VmId","VmUuid") DO UPDATE SET "PoweredOn"=$25, "SrmPlaceholder"=$26 ` - args := []interface{}{ + args := []any{ day, v.Vcenter, v.VmId, v.VmUuid, v.Name, v.CreationTime, v.DeletionTime, v.SamplesPresent, v.TotalSamples, v.SumVcpu, v.SumRam, v.SumDisk, v.TinHits, v.BronzeHits, v.SilverHits, v.GoldHits, v.LastResourcePool, v.LastDatacenter, v.LastCluster, v.LastFolder, v.LastProvisionedDisk, v.LastVcpuCount, v.LastRamGB, v.IsTemplate, v.PoweredOn, v.SrmPlaceholder, @@ -1469,7 +1493,7 @@ WHERE "Vcenter" = $4 AND "VmId" = $5 AND "VmUuid" = $6 return err } -func nullString(val sql.NullString) interface{} { +func nullString(val sql.NullString) any { if val.Valid { return val.String } @@ -2088,17 +2112,17 @@ type VmLifecycleDiagnostics struct { FinalLifecycle VmLifecycle } -func vmLookupPredicate(vmID, vmUUID, name string) (string, []interface{}, bool) { +func vmLookupPredicate(vmID, vmUUID, name string) (string, []any, bool) { vmID = strings.TrimSpace(vmID) vmUUID = strings.TrimSpace(vmUUID) name = strings.TrimSpace(name) switch { case vmID != "": - return `"VmId" = ?`, []interface{}{vmID}, true + return `"VmId" = ?`, []any{vmID}, true case vmUUID != "": - return `"VmUuid" = ?`, []interface{}{vmUUID}, true + return `"VmUuid" = ?`, []any{vmUUID}, true case name != "": - return `lower("Name") = ?`, []interface{}{strings.ToLower(name)}, true + return `lower("Name") = ?`, []any{strings.ToLower(name)}, true default: return "", nil, false } @@ -2960,7 +2984,7 @@ SET "AvgIsPresent" = CASE END `, summaryTable, endExpr, startExpr, endExpr, startExpr) query = dbConn.Rebind(query) - args := []interface{}{ + args := []any{ windowEnd, windowEnd, windowStart, windowStart, windowEnd, windowEnd, @@ -3530,7 +3554,7 @@ FROM snapshot_runs WHERE "Success" = 'FALSE' AND "Attempts" < ? ORDER BY "LastAttempt" ASC ` - args := []interface{}{maxAttempts} + args := []any{maxAttempts} if driver == "pgx" || driver == "postgres" { query = ` SELECT "Vcenter","SnapshotTime","Attempts" diff --git a/db/helpers_cache_and_index_test.go b/db/helpers_cache_and_index_test.go index 3018fcb..9be14d7 100644 --- a/db/helpers_cache_and_index_test.go +++ b/db/helpers_cache_and_index_test.go @@ -131,7 +131,7 @@ INSERT INTO vm_hourly_stats ( "Datacenter","Cluster","Folder","ProvisionedDisk","VcpuCount","RamGB","IsTemplate","PoweredOn","SrmPlaceholder" ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` - rows := [][]interface{}{ + rows := [][]any{ {int64(1000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(900), int64(0), "Tin", "dc", "cluster", "folder", 100.0, int64(2), int64(4), "FALSE", "TRUE", "FALSE"}, {int64(2000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(900), int64(0), "Gold", "dc", "cluster", "folder", 150.0, int64(4), int64(8), "FALSE", "TRUE", "FALSE"}, } @@ -371,7 +371,7 @@ INSERT INTO vm_hourly_stats ( ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` // First row carries an old deletion marker, later row proves VM is still present. - rows := [][]interface{}{ + rows := [][]any{ {int64(1700000000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(1699999000), int64(1700003600), "Tin", "dc", "cluster", "folder", 100.0, int64(2), int64(4), "FALSE", "TRUE", "FALSE"}, {int64(1700100000), "vc-a", "vm-1", "uuid-1", "demo-vm", int64(1699999000), int64(0), "Gold", "dc", "cluster", "folder", 120.0, int64(4), int64(8), "FALSE", "TRUE", "FALSE"}, } @@ -563,7 +563,7 @@ CREATE TABLE %s ( INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB") VALUES (?,?,?,?,?,?) `, summaryTable) - rows := [][]interface{}{ + rows := [][]any{ {"vc-a", "vm-1", "1", "u1", 2.0, 4.0}, {"vc-a", "vm-2", "2", "u2", 3.0, 5.0}, {"vc-b", "vm-3", "3", "u3", 1.0, 2.0}, @@ -633,7 +633,7 @@ CREATE TABLE %s ( INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB") VALUES (?,?,?,?,?,?) `, summaryTable) - for _, args := range [][]interface{}{ + for _, args := range [][]any{ {"vc-a", "vm-1", "1", "u1", 4.0, 8.0}, {"vc-a", "vm-2", "2", "u2", 2.0, 6.0}, } { @@ -697,7 +697,7 @@ CREATE TABLE %s ( } insert1 := fmt.Sprintf(`INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB") VALUES (?,?,?,?,?,?)`, table1) insert2 := fmt.Sprintf(`INSERT INTO %s ("Vcenter","Name","VmId","VmUuid","AvgVcpuCount","AvgRamGB") VALUES (?,?,?,?,?,?)`, table2) - for _, args := range [][]interface{}{ + for _, args := range [][]any{ {"vc-a", "vm-1", "1", "u1", 2.0, 4.0}, {"vc-b", "vm-2", "2", "u2", 3.0, 5.0}, } { diff --git a/db/postgres.go b/db/postgres.go index 5e884c0..4efce76 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -55,7 +55,7 @@ type rebindDBTX struct { db *sqlx.DB } -func (r rebindDBTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (r rebindDBTX) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { return r.db.ExecContext(ctx, rebindQuery(query), args...) } @@ -63,11 +63,11 @@ func (r rebindDBTX) PrepareContext(ctx context.Context, query string) (*sql.Stmt return r.db.PrepareContext(ctx, rebindQuery(query)) } -func (r rebindDBTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (r rebindDBTX) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return r.db.QueryContext(ctx, rebindQuery(query), args...) } -func (r rebindDBTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { +func (r rebindDBTX) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { return r.db.QueryRowContext(ctx, rebindQuery(query), args...) } diff --git a/db/sqlite_import.go b/db/sqlite_import.go index 85c9d8f..dbf34f2 100644 --- a/db/sqlite_import.go +++ b/db/sqlite_import.go @@ -237,15 +237,15 @@ func copySQLiteTableIntoPostgres(ctx context.Context, source *sqlx.DB, destinati var rowsCopied int64 for rows.Next() { - rawValues := make([]interface{}, len(columns)) - scanTargets := make([]interface{}, len(columns)) + rawValues := make([]any, len(columns)) + scanTargets := make([]any, len(columns)) for i := range rawValues { scanTargets[i] = &rawValues[i] } if err := rows.Scan(scanTargets...); err != nil { return rowsCopied, fmt.Errorf("failed to scan row from sqlite table %q: %w", tableName, err) } - args := make([]interface{}, len(columns)) + args := make([]any, len(columns)) for i, col := range columns { args[i] = coerceSQLiteValueForPostgres(rawValues[i], col.DestinationType) } @@ -366,7 +366,7 @@ ORDER BY ordinal_position return nil } -func coerceSQLiteValueForPostgres(value interface{}, destinationType string) interface{} { +func coerceSQLiteValueForPostgres(value any, destinationType string) any { if value == nil { return nil } @@ -385,7 +385,7 @@ func coerceSQLiteValueForPostgres(value interface{}, destinationType string) int return value } -func coerceBoolValue(value interface{}) (bool, bool) { +func coerceBoolValue(value any) (bool, bool) { switch v := value.(type) { case bool: return v, true diff --git a/db/sqlite_import_test.go b/db/sqlite_import_test.go index 69e5d59..32e88ab 100644 --- a/db/sqlite_import_test.go +++ b/db/sqlite_import_test.go @@ -51,9 +51,9 @@ func TestIntersectImportColumns(t *testing.T) { func TestCoerceSQLiteValueForPostgresBoolean(t *testing.T) { tests := []struct { name string - input interface{} + input any destinationType string - want interface{} + want any }{ {name: "string true", input: "true", destinationType: "boolean", want: true}, {name: "string false", input: "0", destinationType: "boolean", want: false}, diff --git a/go.mod b/go.mod index 8e5ab11..39acf7a 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/swaggo/swag v1.16.6 github.com/vmware/govmomi v0.52.0 github.com/xuri/excelize/v2 v2.10.0 - gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.0 ) @@ -58,6 +58,7 @@ require ( golang.org/x/text v0.33.0 // indirect golang.org/x/tools v0.41.0 // indirect google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect modernc.org/libc v1.67.4 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect diff --git a/internal/report/create.go b/internal/report/create.go index e05b4e4..45bc96b 100644 --- a/internal/report/create.go +++ b/internal/report/create.go @@ -238,7 +238,7 @@ func CreateUpdatesReport(logger *slog.Logger, Database db.Database, ctx context. } // Helper function to get the actual value of sql.Null types -func getFieldValue(field reflect.Value) interface{} { +func getFieldValue(field reflect.Value) any { switch field.Kind() { case reflect.Struct: // Handle sql.Null types based on their concrete type diff --git a/internal/report/snapshots.go b/internal/report/snapshots.go index 77d4a12..261f88e 100644 --- a/internal/report/snapshots.go +++ b/internal/report/snapshots.go @@ -9,6 +9,7 @@ import ( "math" "os" "path/filepath" + "slices" "sort" "strconv" "strings" @@ -844,8 +845,8 @@ func addTotalsChartSheet(logger *slog.Logger, database db.Database, ctx context. if logger == nil { logger = slog.Default() } - if strings.HasPrefix(tableName, "inventory_daily_summary_") { - suffix := strings.TrimPrefix(tableName, "inventory_daily_summary_") + if after, ok := strings.CutPrefix(tableName, "inventory_daily_summary_"); ok { + suffix := after dayStart, err := time.ParseInLocation("20060102", suffix, time.Local) if err != nil { logger.Debug("hourly totals skip: invalid daily summary suffix", "table", tableName, "suffix", suffix, "error", err) @@ -878,8 +879,8 @@ func addTotalsChartSheet(logger *slog.Logger, database db.Database, ctx context. return } - if strings.HasPrefix(tableName, "inventory_monthly_summary_") { - suffix := strings.TrimPrefix(tableName, "inventory_monthly_summary_") + if after, ok := strings.CutPrefix(tableName, "inventory_monthly_summary_"); ok { + suffix := after monthStart, err := time.ParseInLocation("200601", suffix, time.Local) if err != nil { logger.Debug("daily totals skip: invalid monthly summary suffix", "table", tableName, "suffix", suffix, "error", err) @@ -1001,10 +1002,7 @@ func titleCellFromPivotRange(pivotRange, fallback string) string { if err != nil { return fallback } - titleRow := row - 2 - if titleRow < 1 { - titleRow = 1 - } + titleRow := max(row-2, 1) cell, err := excelize.CoordinatesToCellName(col, titleRow) if err != nil { return fallback @@ -1356,16 +1354,16 @@ func reportTypeFromTable(tableName string) string { } func reportWindowFromTable(tableName string) (time.Time, time.Time, bool) { - if strings.HasPrefix(tableName, "inventory_daily_summary_") { - suffix := strings.TrimPrefix(tableName, "inventory_daily_summary_") + if after, ok := strings.CutPrefix(tableName, "inventory_daily_summary_"); ok { + suffix := after dayStart, err := time.ParseInLocation("20060102", suffix, time.Local) if err != nil { return time.Time{}, time.Time{}, false } return dayStart, dayStart.AddDate(0, 0, 1), true } - if strings.HasPrefix(tableName, "inventory_monthly_summary_") { - suffix := strings.TrimPrefix(tableName, "inventory_monthly_summary_") + if after, ok := strings.CutPrefix(tableName, "inventory_monthly_summary_"); ok { + suffix := after monthStart, err := time.ParseInLocation("200601", suffix, time.Local) if err != nil { return time.Time{}, time.Time{}, false @@ -1383,7 +1381,7 @@ func addReportMetadataSheet(logger *slog.Logger, xlsx *excelize.File, meta repor } rows := []struct { key string - value interface{} + value any }{ {"ReportTable", meta.TableName}, {"ReportType", meta.ReportType}, @@ -1398,28 +1396,28 @@ func addReportMetadataSheet(logger *slog.Logger, xlsx *excelize.File, meta repor rows = append(rows, struct { key string - value interface{} + value any }{"DataWindowStart", meta.WindowStart.Format(time.RFC3339)}, struct { key string - value interface{} + value any }{"DataWindowEnd", meta.WindowEnd.Format(time.RFC3339)}, struct { key string - value interface{} + value any }{"DataWindowTimezone", time.Local.String()}, ) } if meta.DBDriver != "" { rows = append(rows, struct { key string - value interface{} + value any }{"DatabaseDriver", meta.DBDriver}) } if meta.Duration > 0 && meta.RowCount > 0 { rows = append(rows, struct { key string - value interface{} + value any }{"RowsPerSecond", math.Round((float64(meta.RowCount)/meta.Duration.Seconds())*1000) / 1000}) } for i, row := range rows { @@ -1433,9 +1431,9 @@ func addReportMetadataSheet(logger *slog.Logger, xlsx *excelize.File, meta repor } } -func scanRowValues(rows *sqlx.Rows, columnCount int) ([]interface{}, error) { - rawValues := make([]interface{}, columnCount) - scanArgs := make([]interface{}, columnCount) +func scanRowValues(rows *sqlx.Rows, columnCount int) ([]any, error) { + rawValues := make([]any, columnCount) + scanArgs := make([]any, columnCount) for i := range rawValues { scanArgs[i] = &rawValues[i] } @@ -1445,7 +1443,7 @@ func scanRowValues(rows *sqlx.Rows, columnCount int) ([]interface{}, error) { return rawValues, nil } -func normalizeCellValue(value interface{}) interface{} { +func normalizeCellValue(value any) any { switch v := value.(type) { case nil: return "" @@ -1749,14 +1747,14 @@ FROM diag, agg_diag DeletedInInterval int64 `db:"deleted_in_interval"` PartialPresence int64 `db:"partial_presence"` } - overlapArgs := []interface{}{ + overlapArgs := []any{ hourEndUnix, hourEndUnix, hourStartUnix, hourStartUnix, hourEndUnix, hourEndUnix, hourStartUnix, hourStartUnix, durationSeconds, } - args := make([]interface{}, 0, len(overlapArgs)*3+6) + args := make([]any, 0, len(overlapArgs)*3+6) args = append(args, overlapArgs...) args = append(args, overlapArgs...) args = append(args, hourStartUnix, hourEndUnix) @@ -1847,7 +1845,7 @@ func estimateSnapshotInterval(records []SnapshotRecord) time.Duration { if len(diffs) == 0 { return time.Hour } - sort.Slice(diffs, func(i, j int) bool { return diffs[i] < diffs[j] }) + slices.Sort(diffs) median := diffs[len(diffs)/2] if median <= 0 { return time.Hour @@ -2032,7 +2030,7 @@ func writeTotalsChart(logger *slog.Logger, xlsx *excelize.File, sheetName string makeChart("K52", "F", "G", "H", "I") } -func formatEpochHuman(value interface{}) string { +func formatEpochHuman(value any) string { var epoch int64 switch v := value.(type) { case nil: diff --git a/internal/report/snapshots_pivot_test.go b/internal/report/snapshots_pivot_test.go index 4d9a234..8fe06fe 100644 --- a/internal/report/snapshots_pivot_test.go +++ b/internal/report/snapshots_pivot_test.go @@ -21,7 +21,7 @@ func TestAddSummaryPivotSheetCreatesPivotTables(t *testing.T) { t.Fatalf("SetSheetRow header failed: %v", err) } - row1 := []interface{}{"vm-1", "dc-1", "pool-1", 4.0, 16.0, 1.0} + row1 := []any{"vm-1", "dc-1", "pool-1", 4.0, 16.0, 1.0} if err := xlsx.SetSheetRow(dataSheet, "A2", &row1); err != nil { t.Fatalf("SetSheetRow data failed: %v", err) } diff --git a/internal/settings/settings.go b/internal/settings/settings.go index 53ef1ef..0439bd0 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -10,7 +10,7 @@ import ( "strings" "vctp/internal/utils" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) var ( diff --git a/internal/tasks/cronstatus.go b/internal/tasks/cronstatus.go index 8220c59..f6b8059 100644 --- a/internal/tasks/cronstatus.go +++ b/internal/tasks/cronstatus.go @@ -180,7 +180,7 @@ WHERE job_name = ? return err } -func nullableString(s string) interface{} { +func nullableString(s string) any { if s == "" { return nil } diff --git a/internal/tasks/dailyAggregate.go b/internal/tasks/dailyAggregate.go index 47e1e6f..9658872 100644 --- a/internal/tasks/dailyAggregate.go +++ b/internal/tasks/dailyAggregate.go @@ -8,7 +8,7 @@ import ( "log/slog" "os" "runtime" - "sort" + "slices" "strings" "sync" "time" @@ -295,7 +295,7 @@ func (c *CronTask) aggregateDailySummaryGo(ctx context.Context, dayStart, dayEnd for _, snap := range hourlySnapshots { snapTimes = append(snapTimes, snap.SnapshotTime.Unix()) } - sort.Slice(snapTimes, func(i, j int) bool { return snapTimes[i] < snapTimes[j] }) + slices.Sort(snapTimes) } lifecycleDeletions := c.applyLifecycleDeletions(ctx, aggMap, dayStart, dayEnd) @@ -353,7 +353,7 @@ LIMIT 1 for t := range set { times = append(times, t) } - sort.Slice(times, func(i, j int) bool { return times[i] < times[j] }) + slices.Sort(times) vcenterSnapTimes[vcenter] = times } @@ -843,20 +843,12 @@ func (c *CronTask) applyInventoryCreations(ctx context.Context, agg map[dailyAgg func (c *CronTask) scanHourlyTablesParallel(ctx context.Context, snapshots []report.SnapshotRecord) (map[dailyAggKey]*dailyAggVal, error) { agg := make(map[dailyAggKey]*dailyAggVal, 1024) mu := sync.Mutex{} - workers := runtime.NumCPU() - if workers < 2 { - workers = 2 - } - if workers > len(snapshots) { - workers = len(snapshots) - } + workers := min(max(runtime.NumCPU(), 2), len(snapshots)) jobs := make(chan report.SnapshotRecord, len(snapshots)) wg := sync.WaitGroup{} for i := 0; i < workers; i++ { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { for snap := range jobs { rows, err := c.scanHourlyTable(ctx, snap) if err != nil { @@ -873,7 +865,7 @@ func (c *CronTask) scanHourlyTablesParallel(ctx context.Context, snapshots []rep } mu.Unlock() } - }() + }) } for _, snap := range snapshots { jobs <- snap @@ -1114,7 +1106,7 @@ WHERE "SnapshotTime" >= ? AND "SnapshotTime" < ?` for t := range timeSet { snapTimes = append(snapTimes, t) } - sort.Slice(snapTimes, func(i, j int) bool { return snapTimes[i] < snapTimes[j] }) + slices.Sort(snapTimes) return agg, snapTimes, rows.Err() } @@ -1166,7 +1158,7 @@ INSERT INTO %s ( silverPct = float64(v.silverHits) * 100 / float64(v.samples) goldPct = float64(v.goldHits) * 100 / float64(v.samples) } - args := []interface{}{ + args := []any{ v.key.Name, v.key.Vcenter, nullIfEmpty(v.key.VmId), @@ -1214,7 +1206,7 @@ func int64OrZero(v sql.NullInt64) int64 { return 0 } -func nullIfEmpty(s string) interface{} { +func nullIfEmpty(s string) any { if strings.TrimSpace(s) == "" { return nil } @@ -1224,13 +1216,13 @@ func nullIfEmpty(s string) interface{} { func makePlaceholders(driver string, n int) string { if driver == "sqlite" { parts := make([]string, n) - for i := 0; i < n; i++ { + for i := range n { parts[i] = "?" } return strings.Join(parts, ",") } parts := make([]string, n) - for i := 0; i < n; i++ { + for i := range n { parts[i] = fmt.Sprintf("$%d", i+1) } return strings.Join(parts, ",") diff --git a/internal/tasks/inventoryDatabase.go b/internal/tasks/inventoryDatabase.go index 0ff02b1..251f29f 100644 --- a/internal/tasks/inventoryDatabase.go +++ b/internal/tasks/inventoryDatabase.go @@ -61,7 +61,7 @@ func insertHourlyCache(ctx context.Context, dbConn *sqlx.DB, rows []InventorySna defer stmt.Close() for _, r := range rows { - args := []interface{}{ + args := []any{ r.SnapshotTime, r.Vcenter, r.VmId, r.VmUuid, r.Name, r.CreationTime, r.DeletionTime, r.ResourcePool, r.Datacenter, r.Cluster, r.Folder, r.ProvisionedDisk, r.VcpuCount, r.RamGB, r.IsTemplate, r.PoweredOn, r.SrmPlaceholder, } @@ -105,7 +105,7 @@ func insertHourlyBatch(ctx context.Context, dbConn *sqlx.DB, tableName string, r } defer stmt.Close() for _, row := range rows { - args := []interface{}{ + args := []any{ row.InventoryId, row.Name, row.Vcenter, @@ -138,7 +138,7 @@ func insertHourlyBatch(ctx context.Context, dbConn *sqlx.DB, tableName string, r defer stmt.Close() for _, row := range rows { - args := []interface{}{ + args := []any{ row.InventoryId, row.Name, row.Vcenter, diff --git a/internal/tasks/inventoryHelpers.go b/internal/tasks/inventoryHelpers.go index 51152a7..a20c7a8 100644 --- a/internal/tasks/inventoryHelpers.go +++ b/internal/tasks/inventoryHelpers.go @@ -27,7 +27,7 @@ func acquireSnapshotProbe(ctx context.Context) (func(), error) { } } -func boolStringFromInterface(value interface{}) string { +func boolStringFromInterface(value any) string { switch v := value.(type) { case nil: return "" @@ -164,7 +164,7 @@ func SnapshotTooSoon(prevUnix, currUnix int64, expectedSeconds int64) bool { } // querySnapshotRows builds a SELECT with proper rebind for the given table/columns/where. -func querySnapshotRows(ctx context.Context, dbConn *sqlx.DB, table string, columns []string, where string, args ...interface{}) (*sqlx.Rows, error) { +func querySnapshotRows(ctx context.Context, dbConn *sqlx.DB, table string, columns []string, where string, args ...any) (*sqlx.Rows, error) { if err := db.ValidateTableName(table); err != nil { return nil, err } diff --git a/internal/tasks/inventorySnapshots.go b/internal/tasks/inventorySnapshots.go index 62a0b46..28cf44c 100644 --- a/internal/tasks/inventorySnapshots.go +++ b/internal/tasks/inventorySnapshots.go @@ -24,8 +24,6 @@ import ( "github.com/vmware/govmomi/vim25/types" ) -type ctxLoggerKey struct{} - type deletionCandidate struct { vmID string vmUUID string @@ -42,10 +40,7 @@ type vcenterResources struct { } func loggerFromCtx(ctx context.Context, fallback *slog.Logger) *slog.Logger { - if ctx == nil { - return fallback - } - if l, ok := ctx.Value(ctxLoggerKey{}).(*slog.Logger); ok && l != nil { + if l := db.LoggerFromContext(ctx); l != nil { return l } return fallback @@ -132,10 +127,7 @@ func (c *CronTask) RunVcenterSnapshotHourly(ctx context.Context, logger *slog.Lo if err != nil { return err } - minIntervalSeconds := intWithDefault(c.Settings.Values.Settings.VcenterInventorySnapshotSeconds, 3600) / 3 - if minIntervalSeconds < 1 { - minIntervalSeconds = 1 - } + minIntervalSeconds := max(intWithDefault(c.Settings.Values.Settings.VcenterInventorySnapshotSeconds, 3600)/3, 1) if !lastSnapshot.IsZero() && startTime.Sub(lastSnapshot) < time.Duration(minIntervalSeconds)*time.Second { c.Logger.Info("Skipping hourly snapshot, last snapshot too recent", "last_snapshot", lastSnapshot, @@ -217,7 +209,7 @@ func (c *CronTask) RunVcenterSnapshotHourly(ctx context.Context, logger *slog.Lo metrics.RecordHourlySnapshot(startTime, rowCount, err) var deferredTables []string - deferredReportTables.Range(func(key, _ interface{}) bool { + deferredReportTables.Range(func(key, _ any) bool { name, ok := key.(string) if ok && strings.TrimSpace(name) != "" && name != tableName { deferredTables = append(deferredTables, name) @@ -488,10 +480,7 @@ func buildUnionQuery(tables []string, columns []string, whereClause string) (str batches := make([]string, 0, (len(tables)/maxCompoundTerms)+1) batchIndex := 0 for start := 0; start < len(tables); start += maxCompoundTerms { - end := start + maxCompoundTerms - if end > len(tables) { - end = len(tables) - } + end := min(start+maxCompoundTerms, len(tables)) queries := make([]string, 0, end-start) for _, table := range tables[start:end] { safeName, err := db.SafeTableName(table) @@ -1337,7 +1326,7 @@ func (c *CronTask) initVcenterResources(ctx context.Context, log *slog.Logger, u func (c *CronTask) captureHourlySnapshotForVcenter(ctx context.Context, startTime time.Time, tableName string, url string, deferredReportTables *sync.Map) error { log := c.Logger.With("vcenter", url) - ctx = context.WithValue(ctx, ctxLoggerKey{}, log) + ctx = db.WithLoggerContext(ctx, log) started := time.Now() log.Debug("connecting to vcenter for hourly snapshot", "url", url) vc, resources, cleanup, err := c.initVcenterResources(ctx, log, url, startTime, started) diff --git a/internal/tasks/monthlyAggregate.go b/internal/tasks/monthlyAggregate.go index d2fa490..aa60ed5 100644 --- a/internal/tasks/monthlyAggregate.go +++ b/internal/tasks/monthlyAggregate.go @@ -7,7 +7,7 @@ import ( "log/slog" "os" "runtime" - "sort" + "slices" "strings" "sync" "time" @@ -246,7 +246,7 @@ func (c *CronTask) aggregateMonthlySummaryGoHourly(ctx context.Context, monthSta for _, snap := range hourlySnapshots { snapTimes = append(snapTimes, snap.SnapshotTime.Unix()) } - sort.Slice(snapTimes, func(i, j int) bool { return snapTimes[i] < snapTimes[j] }) + slices.Sort(snapTimes) } lifecycleDeletions := c.applyLifecycleDeletions(ctx, aggMap, monthStart, monthEnd) @@ -394,20 +394,12 @@ func (c *CronTask) aggregateMonthlySummaryGo(ctx context.Context, monthStart, mo func (c *CronTask) scanDailyTablesParallel(ctx context.Context, snapshots []report.SnapshotRecord) (map[monthlyAggKey]*monthlyAggVal, error) { agg := make(map[monthlyAggKey]*monthlyAggVal, 1024) mu := sync.Mutex{} - workers := runtime.NumCPU() - if workers < 2 { - workers = 2 - } - if workers > len(snapshots) { - workers = len(snapshots) - } + workers := min(max(runtime.NumCPU(), 2), len(snapshots)) jobs := make(chan report.SnapshotRecord, len(snapshots)) wg := sync.WaitGroup{} for i := 0; i < workers; i++ { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { for snap := range jobs { rows, err := c.scanDailyTable(ctx, snap) if err != nil { @@ -424,7 +416,7 @@ func (c *CronTask) scanDailyTablesParallel(ctx context.Context, snapshots []repo } mu.Unlock() } - }() + }) } for _, snap := range snapshots { jobs <- snap diff --git a/main.go b/main.go index d0bdb37..e013955 100644 --- a/main.go +++ b/main.go @@ -449,8 +449,8 @@ func resolveVcenterPassword(logger *slog.Logger, cipher *secrets.Secrets, legacy } // New format: explicit prefix so we can distinguish ciphertext from plaintext safely. - if strings.HasPrefix(raw, encryptedVcenterPasswordPrefix) { - enc := strings.TrimPrefix(raw, encryptedVcenterPasswordPrefix) + if after, ok := strings.CutPrefix(raw, encryptedVcenterPasswordPrefix); ok { + enc := after pass, usedLegacyKey, err := decryptVcenterPasswordWithFallback(logger, cipher, legacyDecryptKeys, enc) if err != nil { return nil, "", fmt.Errorf("prefixed password decrypt failed: %w", err) diff --git a/server/handler/dailyCreationDiagnostics.go b/server/handler/dailyCreationDiagnostics.go index d15fbd9..5047f84 100644 --- a/server/handler/dailyCreationDiagnostics.go +++ b/server/handler/dailyCreationDiagnostics.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "fmt" "net/http" @@ -42,7 +41,7 @@ func (h *Handler) DailyCreationDiagnostics(w http.ResponseWriter, r *http.Reques return } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := withRequestTimeout(r, 10*time.Second) defer cancel() dbConn := h.Database.DB() diff --git a/server/handler/reportDownload.go b/server/handler/reportDownload.go index da62a2c..1044c91 100644 --- a/server/handler/reportDownload.go +++ b/server/handler/reportDownload.go @@ -1,7 +1,6 @@ package handler import ( - "context" "fmt" "net/http" "vctp/internal/report" @@ -16,8 +15,8 @@ import ( // @Failure 500 {object} models.ErrorResponse "Report generation failed" // @Router /api/report/inventory [get] func (h *Handler) InventoryReportDownload(w http.ResponseWriter, r *http.Request) { - - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, reportRequestTimeout) + defer cancel() // Generate the XLSX report reportData, err := report.CreateInventoryReport(h.Logger, h.Database, ctx) @@ -45,8 +44,8 @@ func (h *Handler) InventoryReportDownload(w http.ResponseWriter, r *http.Request // @Failure 500 {object} models.ErrorResponse "Report generation failed" // @Router /api/report/updates [get] func (h *Handler) UpdateReportDownload(w http.ResponseWriter, r *http.Request) { - - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, reportRequestTimeout) + defer cancel() // Generate the XLSX report reportData, err := report.CreateUpdatesReport(h.Logger, h.Database, ctx) diff --git a/server/handler/request_context.go b/server/handler/request_context.go new file mode 100644 index 0000000..818fb3e --- /dev/null +++ b/server/handler/request_context.go @@ -0,0 +1,24 @@ +package handler + +import ( + "context" + "net/http" + "time" +) + +const ( + defaultRequestTimeout = 2 * time.Minute + reportRequestTimeout = 10 * time.Minute + longRunningRequestTimeout = 2 * time.Hour +) + +func withRequestTimeout(r *http.Request, timeout time.Duration) (context.Context, context.CancelFunc) { + base := context.Background() + if r != nil { + base = r.Context() + } + if timeout <= 0 { + return base, func() {} + } + return context.WithTimeout(base, timeout) +} diff --git a/server/handler/snapshotAggregate.go b/server/handler/snapshotAggregate.go index a4480bb..8fa31cc 100644 --- a/server/handler/snapshotAggregate.go +++ b/server/handler/snapshotAggregate.go @@ -1,7 +1,6 @@ package handler import ( - "context" "net/http" "strings" "time" @@ -50,7 +49,8 @@ func (h *Handler) SnapshotAggregateForce(w http.ResponseWriter, r *http.Request) return } - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, longRunningRequestTimeout) + defer cancel() settingsCopy := *h.Settings.Values if granularity != "" { settingsCopy.Settings.MonthlyAggregationGranularity = granularity diff --git a/server/handler/snapshotForceHourly.go b/server/handler/snapshotForceHourly.go index a643b28..fd4e5b9 100644 --- a/server/handler/snapshotForceHourly.go +++ b/server/handler/snapshotForceHourly.go @@ -1,7 +1,6 @@ package handler import ( - "context" "net/http" "strings" "time" @@ -26,7 +25,8 @@ func (h *Handler) SnapshotForceHourly(w http.ResponseWriter, r *http.Request) { return } - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, longRunningRequestTimeout) + defer cancel() ct := &tasks.CronTask{ Logger: h.Logger, Database: h.Database, diff --git a/server/handler/snapshotMigrate.go b/server/handler/snapshotMigrate.go index ae95a33..016f2cf 100644 --- a/server/handler/snapshotMigrate.go +++ b/server/handler/snapshotMigrate.go @@ -1,7 +1,6 @@ package handler import ( - "context" "net/http" "vctp/internal/report" "vctp/server/models" @@ -16,7 +15,8 @@ import ( // @Failure 500 {object} models.SnapshotMigrationResponse "Server error" // @Router /api/snapshots/migrate [post] func (h *Handler) SnapshotMigrate(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, reportRequestTimeout) + defer cancel() stats, err := report.MigrateSnapshotRegistry(ctx, h.Database) if err != nil { writeJSON(w, http.StatusInternalServerError, models.SnapshotMigrationResponse{ diff --git a/server/handler/snapshots.go b/server/handler/snapshots.go index 115b5b6..1afb78f 100644 --- a/server/handler/snapshots.go +++ b/server/handler/snapshots.go @@ -1,7 +1,6 @@ package handler import ( - "context" "fmt" "net/http" "net/url" @@ -58,7 +57,8 @@ func (h *Handler) SnapshotMonthlyList(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/report/snapshot [get] func (h *Handler) SnapshotReportDownload(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, reportRequestTimeout) + defer cancel() tableName := r.URL.Query().Get("table") if tableName == "" { writeJSONError(w, http.StatusBadRequest, "Missing table parameter") @@ -80,7 +80,8 @@ func (h *Handler) SnapshotReportDownload(w http.ResponseWriter, r *http.Request) } func (h *Handler) renderSnapshotList(w http.ResponseWriter, r *http.Request, snapshotType string, title string, renderer func([]views.SnapshotEntry) templ.Component) { - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() if err := report.EnsureSnapshotRegistry(ctx, h.Database); err != nil { h.Logger.Error("Failed to ensure snapshot registry", "error", err) w.WriteHeader(http.StatusInternalServerError) @@ -107,10 +108,7 @@ func (h *Handler) renderSnapshotList(w http.ResponseWriter, r *http.Request, sna case "monthly": group = record.SnapshotTime.Format("2006") } - count := record.SnapshotCount - if count < 0 { - count = 0 - } + count := max(record.SnapshotCount, 0) entries = append(entries, views.SnapshotEntry{ Label: label, Link: "/reports/" + url.PathEscape(record.TableName) + ".xlsx", diff --git a/server/handler/updateCleanup.go b/server/handler/updateCleanup.go index b4f9d3b..d128d69 100644 --- a/server/handler/updateCleanup.go +++ b/server/handler/updateCleanup.go @@ -1,7 +1,6 @@ package handler import ( - "context" "fmt" "net/http" ) @@ -19,6 +18,8 @@ func (h *Handler) UpdateCleanup(w http.ResponseWriter, r *http.Request) { if h.denyLegacyAPI(w, "/api/cleanup/updates") { return } + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() /* // Get the current time @@ -35,11 +36,11 @@ func (h *Handler) UpdateCleanup(w http.ResponseWriter, r *http.Request) { } h.Logger.Debug("database params", "params", params) - err := h.Database.Queries().CleanupUpdates(context.Background(), params) + err := h.Database.Queries().CleanupUpdates(ctx, params) */ - //err := h.Database.Queries().InventoryCleanupTemplates(context.Background()) - err := h.Database.Queries().CleanupUpdatesNullVm(context.Background()) + //err := h.Database.Queries().InventoryCleanupTemplates(ctx) + err := h.Database.Queries().CleanupUpdatesNullVm(ctx) if err != nil { h.Logger.Error("Error received cleaning updates table", "error", err) diff --git a/server/handler/vcCleanup.go b/server/handler/vcCleanup.go index 5095253..9153933 100644 --- a/server/handler/vcCleanup.go +++ b/server/handler/vcCleanup.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "errors" "fmt" @@ -23,7 +22,8 @@ func (h *Handler) VcCleanup(w http.ResponseWriter, r *http.Request) { return } - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() // Get the parameters vcUrl := r.URL.Query().Get("vc_url") diff --git a/server/handler/vcenterCacheRebuild.go b/server/handler/vcenterCacheRebuild.go index b217405..8f66914 100644 --- a/server/handler/vcenterCacheRebuild.go +++ b/server/handler/vcenterCacheRebuild.go @@ -108,7 +108,7 @@ func (h *Handler) rebuildOneVcenterCache(ctx context.Context, vcURL string) (int return 0, 0, 0, fmt.Errorf("unable to connect to vcenter: %w", err) } defer func() { - logoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) defer cancel() if err := vc.Logout(logoutCtx); err != nil { h.Logger.Warn("vcenter cache rebuild logout failed", "vcenter", vcURL, "error", err) diff --git a/server/handler/vmCleanup.go b/server/handler/vmCleanup.go index cbe574e..6ed878d 100644 --- a/server/handler/vmCleanup.go +++ b/server/handler/vmCleanup.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "errors" "fmt" @@ -25,7 +24,8 @@ func (h *Handler) VmCleanup(w http.ResponseWriter, r *http.Request) { return } - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() // Get the parameters vmId := r.URL.Query().Get("vm_id") diff --git a/server/handler/vmCreateEvent.go b/server/handler/vmCreateEvent.go index 218666b..f134c05 100644 --- a/server/handler/vmCreateEvent.go +++ b/server/handler/vmCreateEvent.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "encoding/json" "fmt" @@ -30,6 +29,8 @@ func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) { if h.denyLegacyAPI(w, "/api/event/vm/create") { return } + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() var ( unixTimestamp int64 @@ -96,7 +97,7 @@ func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) { h.Logger.Debug("database params", "params", params) // Insert the new inventory record into the database - result, err := h.Database.Queries().CreateEvent(context.Background(), params) + result, err := h.Database.Queries().CreateEvent(ctx, params) if err != nil { h.Logger.Error("unable to perform database insert", "error", err) writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Error: %v", err)) @@ -109,7 +110,7 @@ func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) { } // prettyPrint comes from https://gist.github.com/sfate/9d45f6c5405dc4c9bf63bf95fe6d1a7c -func prettyPrint(args ...interface{}) { +func prettyPrint(args ...any) { var caller string timeNow := time.Now().Format("01-02-2006 15:04:05") diff --git a/server/handler/vmDeleteEvent.go b/server/handler/vmDeleteEvent.go index 78affbf..cf5ede3 100644 --- a/server/handler/vmDeleteEvent.go +++ b/server/handler/vmDeleteEvent.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "encoding/json" "fmt" @@ -28,6 +27,8 @@ func (h *Handler) VmDeleteEvent(w http.ResponseWriter, r *http.Request) { if h.denyLegacyAPI(w, "/api/event/vm/delete") { return } + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() var ( deletedTimestamp int64 @@ -71,7 +72,7 @@ func (h *Handler) VmDeleteEvent(w http.ResponseWriter, r *http.Request) { DatacenterName: sql.NullString{String: event.CloudEvent.Data.Datacenter.Name, Valid: event.CloudEvent.Data.Datacenter.Name != ""}, } h.Logger.Debug("database params", "params", params) - err = h.Database.Queries().InventoryMarkDeleted(context.Background(), params) + err = h.Database.Queries().InventoryMarkDeleted(ctx, params) if err != nil { h.Logger.Error("Error received marking VM as deleted", "error", err) diff --git a/server/handler/vmImport.go b/server/handler/vmImport.go index e849089..8aed7c0 100644 --- a/server/handler/vmImport.go +++ b/server/handler/vmImport.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "encoding/json" "errors" @@ -58,7 +57,8 @@ func (h *Handler) VmImport(w http.ResponseWriter, r *http.Request) { return } - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() // Query Inventory table for this VM before adding it h.Logger.Debug("Checking inventory table for VM record") diff --git a/server/handler/vmModifyEvent.go b/server/handler/vmModifyEvent.go index d1594ea..6b7226a 100644 --- a/server/handler/vmModifyEvent.go +++ b/server/handler/vmModifyEvent.go @@ -41,7 +41,8 @@ func (h *Handler) VmModifyEvent(w http.ResponseWriter, r *http.Request) { var unixTimestamp int64 re := regexp.MustCompile(`/([^/]+)/[^/]+\.vmdk$`) - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() reqBody, err := io.ReadAll(r.Body) if err != nil { @@ -189,7 +190,7 @@ func (h *Handler) VmModifyEvent(w http.ResponseWriter, r *http.Request) { // If we found a disk change belonging to this VM then recalculate the disk size if diskChangeFound { params.UpdateType = "diskchange" - diskSize := h.calculateNewDiskSize(event) + diskSize := h.calculateNewDiskSize(ctx, event) params.NewProvisionedDisk = sql.NullFloat64{Float64: diskSize, Valid: diskSize > 0} } } @@ -333,7 +334,7 @@ func (h *Handler) processConfigChanges(configChanges string) []map[string]string return result } -func (h *Handler) calculateNewDiskSize(event models.CloudEventReceived) float64 { +func (h *Handler) calculateNewDiskSize(ctx context.Context, event models.CloudEventReceived) float64 { var diskSize float64 var totalDiskBytes int64 h.Logger.Debug("connecting to vcenter") @@ -368,7 +369,9 @@ func (h *Handler) calculateNewDiskSize(event models.CloudEventReceived) float64 } } - _ = vc.Logout(context.Background()) + logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + _ = vc.Logout(logoutCtx) h.Logger.Debug("Calculated new disk size", "value", diskSize) diff --git a/server/handler/vmMoveEvent.go b/server/handler/vmMoveEvent.go index ce40b21..976e0f0 100644 --- a/server/handler/vmMoveEvent.go +++ b/server/handler/vmMoveEvent.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "encoding/json" "errors" @@ -34,7 +33,8 @@ func (h *Handler) VmMoveEvent(w http.ResponseWriter, r *http.Request) { params := queries.CreateUpdateParams{} var unixTimestamp int64 - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) + defer cancel() reqBody, err := io.ReadAll(r.Body) if err != nil { diff --git a/server/handler/vmUpdateDetails.go b/server/handler/vmUpdateDetails.go index 1f9da03..93fa89c 100644 --- a/server/handler/vmUpdateDetails.go +++ b/server/handler/vmUpdateDetails.go @@ -1,7 +1,6 @@ package handler import ( - "context" "database/sql" "net/http" "vctp/db/queries" @@ -28,7 +27,8 @@ func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) { var vmUuid string var dbUuid string - ctx := context.Background() + ctx, cancel := withRequestTimeout(r, longRunningRequestTimeout) + defer cancel() // reload settings in case vcenter list has changed h.Settings.ReadYMLSettings() @@ -101,7 +101,7 @@ func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) { } h.Logger.Debug("database params", "params", params) - err := h.Database.Queries().InventoryUpdate(context.Background(), params) + err := h.Database.Queries().InventoryUpdate(ctx, params) if err != nil { h.Logger.Error("Error received updating inventory for VM", "name", vmObj.Name, "error", err) diff --git a/server/models/models.go b/server/models/models.go index e28292f..d591865 100644 --- a/server/models/models.go +++ b/server/models/models.go @@ -31,9 +31,9 @@ type CloudEventReceived struct { } `json:"Datacenter"` Name string `json:"Name"` } `json:"Datacenter"` - Ds interface{} `json:"Ds"` - Dvs interface{} `json:"Dvs"` - FullFormattedMessage string `json:"FullFormattedMessage"` + Ds any `json:"Ds"` + Dvs any `json:"Dvs"` + FullFormattedMessage string `json:"FullFormattedMessage"` Host struct { Host struct { Type string `json:"Type"` @@ -42,7 +42,7 @@ type CloudEventReceived struct { Name string `json:"Name"` } `json:"Host"` Key int `json:"Key"` - Net interface{} `json:"Net"` + Net any `json:"Net"` NewParent *CloudEventResourcePool `json:"NewParent"` OldParent *CloudEventResourcePool `json:"OldParent"` SrcTemplate *CloudEventVm `json:"SrcTemplate"` @@ -158,7 +158,7 @@ type ConfigSpec struct { } `json:"StorageIOAllocation"` VDiskID any `json:"VDiskId"` VFlashCacheConfigInfo any `json:"VFlashCacheConfigInfo"` - } `json:"Device,omitempty"` + } `json:"Device"` FileOperation string `json:"FileOperation"` Operation string `json:"Operation"` Profile []struct {