From 9a561f3b0785a83f4aa4e81e7d199e3c38d42f71 Mon Sep 17 00:00:00 2001 From: Nathan Coad Date: Fri, 20 Mar 2026 13:21:15 +1100 Subject: [PATCH] cleanups and code fixes incl templ --- .drone.yml | 2 +- Dockerfile | 11 +- components/core/{header.tmpl => header.templ} | 0 internal/settings/settings.go | 12 +- internal/settings/settings_strict_test.go | 55 +++++++ main.go | 10 +- server/handler/method_guards_test.go | 146 ++++++++++++++++++ server/handler/request_context.go | 22 +++ server/handler/snapshotAggregate.go | 5 + server/handler/snapshotForceHourly.go | 5 + server/handler/snapshotMigrate.go | 5 + server/handler/snapshotRegenerateHourly.go | 5 + server/handler/updateCleanup.go | 5 + server/handler/vcCleanup.go | 5 + server/handler/vmCleanup.go | 5 + server/handler/vmCreateEvent.go | 18 +-- server/handler/vmDeleteEvent.go | 19 +-- server/handler/vmImport.go | 27 ++-- server/handler/vmModifyEvent.go | 47 +++--- server/handler/vmMoveEvent.go | 22 +-- server/handler/vmUpdateDetails.go | 34 +++- server/server.go | 93 ++++++----- src/postinstall.sh | 9 ++ src/preinstall.sh | 4 +- 24 files changed, 425 insertions(+), 141 deletions(-) rename components/core/{header.tmpl => header.templ} (100%) create mode 100644 internal/settings/settings_strict_test.go create mode 100644 server/handler/method_guards_test.go diff --git a/.drone.yml b/.drone.yml index d0e97b8..fa6cdcb 100644 --- a/.drone.yml +++ b/.drone.yml @@ -34,7 +34,7 @@ steps: path: /shared commands: - export PATH=/drone/src/pkg.tools:$PATH - - go install github.com/a-h/templ/cmd/templ@v0.3.977 + - go install github.com/a-h/templ/cmd/templ@v0.3.1001 - go install github.com/sqlc-dev/sqlc/cmd/sqlc@v1.29.0 - go install github.com/swaggo/swag/cmd/swag@v1.16.6 # - go install github.com/goreleaser/nfpm/v2/cmd/nfpm@latest diff --git a/Dockerfile b/Dockerfile index fee72ea..01b6ef5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,16 +1,19 @@ ## Build -FROM golang:1.26-alpine AS build +FROM golang:1.26.0-alpine AS build ARG VERSION='dev' +ARG TAILWIND_VERSION='v3.4.17' +ARG TEMPL_VERSION='v0.3.1001' +ARG SQLC_VERSION='v1.29.0' RUN apk update && apk add --no-cache curl -RUN curl -sLO https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 \ +RUN curl -fsSLo tailwindcss-linux-x64 https://github.com/tailwindlabs/tailwindcss/releases/download/${TAILWIND_VERSION}/tailwindcss-linux-x64 \ && chmod +x tailwindcss-linux-x64 \ && mv tailwindcss-linux-x64 /usr/local/bin/tailwindcss -RUN go install github.com/a-h/templ/cmd/templ@v0.2.663 \ - && go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest +RUN go install github.com/a-h/templ/cmd/templ@${TEMPL_VERSION} \ + && go install github.com/sqlc-dev/sqlc/cmd/sqlc@${SQLC_VERSION} WORKDIR /app diff --git a/components/core/header.tmpl b/components/core/header.templ similarity index 100% rename from components/core/header.tmpl rename to components/core/header.templ diff --git a/internal/settings/settings.go b/internal/settings/settings.go index 0439bd0..9e8dd1e 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -106,6 +106,7 @@ func (s *Settings) ReadYMLSettings() error { // Init new YAML decode d := yaml.NewDecoder(file) + d.KnownFields(true) // Start YAML decoding from file if err := d.Decode(&settings); err != nil { @@ -149,9 +150,9 @@ func (s *Settings) WriteYMLSettings() error { return fmt.Errorf("unable to encode settings file: %w", err) } - mode := os.FileMode(0o644) + mode := os.FileMode(0o600) if info, err := os.Stat(s.SettingsPath); err == nil { - mode = info.Mode().Perm() + mode = secureSettingsFileMode(info.Mode().Perm()) } dir := filepath.Dir(s.SettingsPath) @@ -181,3 +182,10 @@ func (s *Settings) WriteYMLSettings() error { return nil } + +func secureSettingsFileMode(mode os.FileMode) os.FileMode { + // Ensure owner read/write, strip world permissions and all execute bits. + secured := mode & 0o660 + secured |= 0o600 + return secured +} diff --git a/internal/settings/settings_strict_test.go b/internal/settings/settings_strict_test.go new file mode 100644 index 0000000..4ee54cb --- /dev/null +++ b/internal/settings/settings_strict_test.go @@ -0,0 +1,55 @@ +package settings + +import ( + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestReadYMLSettingsRejectsUnknownField(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "vctp.yml") + content := `settings: + log_level: "info" + unknown_field: true +` + if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write settings file: %v", err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + s := New(logger, settingsPath) + err := s.ReadYMLSettings() + if err == nil { + t.Fatal("expected unknown field decode error") + } + if !strings.Contains(strings.ToLower(err.Error()), "unknown_field") { + t.Fatalf("expected error to mention unknown field, got: %v", err) + } +} + +func TestSecureSettingsFileMode(t *testing.T) { + cases := []struct { + name string + in os.FileMode + want os.FileMode + }{ + {name: "already strict", in: 0o600, want: 0o600}, + {name: "group read allowed", in: 0o640, want: 0o640}, + {name: "too open world", in: 0o666, want: 0o660}, + {name: "exec bits stripped", in: 0o755, want: 0o640}, + {name: "no perms gets owner rw", in: 0o000, want: 0o600}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := secureSettingsFileMode(tc.in) + if got != tc.want { + t.Fatalf("unexpected mode conversion: in=%#o got=%#o want=%#o", tc.in, got, tc.want) + } + }) + } +} diff --git a/main.go b/main.go index e013955..3a59068 100644 --- a/main.go +++ b/main.go @@ -285,7 +285,10 @@ func main() { // One-shot mode: run a single inventory snapshot across all configured vCenters and exit. if *runInventory { logger.Info("Running one-shot inventory snapshot across all vCenters") - ct.RunVcenterSnapshotHourly(ctx, logger, true) + if err := ct.RunVcenterSnapshotHourly(ctx, logger, true); err != nil { + logger.Error("One-shot inventory snapshot failed", "error", err) + os.Exit(1) + } logger.Info("One-shot inventory snapshot complete; exiting") return } @@ -403,7 +406,10 @@ func main() { ) //logger.Debug("Server configured", "object", svr) - svr.StartAndWait() + if err := svr.StartAndWait(); err != nil { + logger.Error("server terminated with error", "error", err) + os.Exit(1) + } os.Exit(0) } diff --git a/server/handler/method_guards_test.go b/server/handler/method_guards_test.go new file mode 100644 index 0000000..2e98309 --- /dev/null +++ b/server/handler/method_guards_test.go @@ -0,0 +1,146 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "vctp/server/models" +) + +func TestMutatingHandlersRejectWrongMethod(t *testing.T) { + h := &Handler{Logger: newTestLogger()} + + tests := []struct { + name string + path string + call func(*Handler, *httptest.ResponseRecorder, *http.Request) + }{ + { + name: "snapshot force hourly", + path: "/api/snapshots/hourly/force", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.SnapshotForceHourly(rr, req) + }, + }, + { + name: "snapshot aggregate", + path: "/api/snapshots/aggregate", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.SnapshotAggregateForce(rr, req) + }, + }, + { + name: "snapshot migrate", + path: "/api/snapshots/migrate", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.SnapshotMigrate(rr, req) + }, + }, + { + name: "snapshot regenerate hourly", + path: "/api/snapshots/regenerate-hourly-reports", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.SnapshotRegenerateHourlyReports(rr, req) + }, + }, + { + name: "vm create event", + path: "/api/event/vm/create", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VmCreateEvent(rr, req) + }, + }, + { + name: "vm modify event", + path: "/api/event/vm/modify", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VmModifyEvent(rr, req) + }, + }, + { + name: "vm move event", + path: "/api/event/vm/move", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VmMoveEvent(rr, req) + }, + }, + { + name: "vm delete event", + path: "/api/event/vm/delete", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VmDeleteEvent(rr, req) + }, + }, + { + name: "vm import", + path: "/api/import/vm", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VmImport(rr, req) + }, + }, + { + name: "vm update details", + path: "/api/inventory/vm/update", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VmUpdateDetails(rr, req) + }, + }, + { + name: "vm cleanup", + path: "/api/inventory/vm/delete", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VmCleanup(rr, req) + }, + }, + { + name: "vcenter cleanup", + path: "/api/cleanup/vcenter", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.VcCleanup(rr, req) + }, + }, + { + name: "update cleanup", + path: "/api/cleanup/updates", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.UpdateCleanup(rr, req) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.path, strings.NewReader("{}")) + rr := httptest.NewRecorder() + tc.call(h, rr, req) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if !strings.Contains(rr.Body.String(), "method not allowed") { + t.Fatalf("expected method not allowed response, got: %s", rr.Body.String()) + } + }) + } +} + +func TestVcenterLoginFailuresAreHandled(t *testing.T) { + h := &Handler{Logger: newTestLogger()} + event := models.CloudEventReceived{} + event.CloudEvent.Source = "https://invalid.local/sdk" + + disk := h.calculateNewDiskSize(context.Background(), event) + if disk != 0 { + t.Fatalf("expected disk size 0 on login failure, got %f", disk) + } + + id, err := h.AddVmToInventory(event, context.Background(), 0) + if err == nil { + t.Fatal("expected error on login failure") + } + if id != 0 { + t.Fatalf("expected id 0 on login failure, got %d", id) + } +} diff --git a/server/handler/request_context.go b/server/handler/request_context.go index 818fb3e..aa61f29 100644 --- a/server/handler/request_context.go +++ b/server/handler/request_context.go @@ -2,6 +2,9 @@ package handler import ( "context" + "encoding/json" + "errors" + "io" "net/http" "time" ) @@ -10,6 +13,7 @@ const ( defaultRequestTimeout = 2 * time.Minute reportRequestTimeout = 10 * time.Minute longRunningRequestTimeout = 2 * time.Hour + defaultJSONBodyLimitBytes = 1 << 20 // 1 MiB ) func withRequestTimeout(r *http.Request, timeout time.Duration) (context.Context, context.CancelFunc) { @@ -22,3 +26,21 @@ func withRequestTimeout(r *http.Request, timeout time.Duration) (context.Context } return context.WithTimeout(base, timeout) } + +func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) error { + if r == nil || r.Body == nil { + return errors.New("request body is required") + } + decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, defaultJSONBodyLimitBytes)) + if err := decoder.Decode(dst); err != nil { + return err + } + var trailing any + if err := decoder.Decode(&trailing); err != io.EOF { + if err == nil { + return errors.New("request body must contain only one JSON object") + } + return err + } + return nil +} diff --git a/server/handler/snapshotAggregate.go b/server/handler/snapshotAggregate.go index 8fa31cc..ce19515 100644 --- a/server/handler/snapshotAggregate.go +++ b/server/handler/snapshotAggregate.go @@ -21,6 +21,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/snapshots/aggregate [post] func (h *Handler) SnapshotAggregateForce(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + snapshotType := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("type"))) dateValue := strings.TrimSpace(r.URL.Query().Get("date")) granularity := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("granularity"))) diff --git a/server/handler/snapshotForceHourly.go b/server/handler/snapshotForceHourly.go index fd4e5b9..64045a5 100644 --- a/server/handler/snapshotForceHourly.go +++ b/server/handler/snapshotForceHourly.go @@ -19,6 +19,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/snapshots/hourly/force [post] func (h *Handler) SnapshotForceHourly(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + confirm := strings.TrimSpace(r.URL.Query().Get("confirm")) if strings.ToUpper(confirm) != "FORCE" { writeJSONError(w, http.StatusBadRequest, "confirm must be 'FORCE'") diff --git a/server/handler/snapshotMigrate.go b/server/handler/snapshotMigrate.go index 016f2cf..df840f5 100644 --- a/server/handler/snapshotMigrate.go +++ b/server/handler/snapshotMigrate.go @@ -15,6 +15,11 @@ import ( // @Failure 500 {object} models.SnapshotMigrationResponse "Server error" // @Router /api/snapshots/migrate [post] func (h *Handler) SnapshotMigrate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + ctx, cancel := withRequestTimeout(r, reportRequestTimeout) defer cancel() stats, err := report.MigrateSnapshotRegistry(ctx, h.Database) diff --git a/server/handler/snapshotRegenerateHourly.go b/server/handler/snapshotRegenerateHourly.go index 726624a..2b7886b 100644 --- a/server/handler/snapshotRegenerateHourly.go +++ b/server/handler/snapshotRegenerateHourly.go @@ -19,6 +19,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/snapshots/regenerate-hourly-reports [post] func (h *Handler) SnapshotRegenerateHourlyReports(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + ctx := r.Context() reportsDir := strings.TrimSpace(h.Settings.Values.Settings.ReportsDir) if reportsDir == "" { diff --git a/server/handler/updateCleanup.go b/server/handler/updateCleanup.go index d128d69..8a9bf2d 100644 --- a/server/handler/updateCleanup.go +++ b/server/handler/updateCleanup.go @@ -15,6 +15,11 @@ import ( // @Failure 500 {string} string "Server error" // @Router /api/cleanup/updates [delete] func (h *Handler) UpdateCleanup(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/cleanup/updates") { return } diff --git a/server/handler/vcCleanup.go b/server/handler/vcCleanup.go index 9153933..345130b 100644 --- a/server/handler/vcCleanup.go +++ b/server/handler/vcCleanup.go @@ -18,6 +18,11 @@ import ( // @Failure 400 {object} models.ErrorResponse "Invalid request" // @Router /api/cleanup/vcenter [delete] func (h *Handler) VcCleanup(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/cleanup/vcenter") { return } diff --git a/server/handler/vmCleanup.go b/server/handler/vmCleanup.go index 6ed878d..f1fe887 100644 --- a/server/handler/vmCleanup.go +++ b/server/handler/vmCleanup.go @@ -20,6 +20,11 @@ import ( // @Failure 400 {object} models.ErrorResponse "Invalid request" // @Router /api/inventory/vm/delete [delete] func (h *Handler) VmCleanup(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/inventory/vm/delete") { return } diff --git a/server/handler/vmCreateEvent.go b/server/handler/vmCreateEvent.go index f134c05..7652d8c 100644 --- a/server/handler/vmCreateEvent.go +++ b/server/handler/vmCreateEvent.go @@ -4,7 +4,6 @@ import ( "database/sql" "encoding/json" "fmt" - "io" "net/http" "runtime" "strconv" @@ -26,6 +25,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/event/vm/create [post] func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/event/vm/create") { return } @@ -39,18 +43,8 @@ func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) { //datacenter string ) - reqBody, err := io.ReadAll(r.Body) - if err != nil { - h.Logger.Error("Invalid data received", "error", err) - writeJSONError(w, http.StatusInternalServerError, "Invalid data received") - return - } else { - h.Logger.Debug("received input data", "length", len(reqBody)) - } - - // Decode the JSON body into CloudEventReceived struct var event models.CloudEventReceived - if err := json.Unmarshal(reqBody, &event); err != nil { + if err := decodeJSONBody(w, r, &event); err != nil { h.Logger.Error("unable to decode json", "error", err) writeJSONError(w, http.StatusBadRequest, "Invalid JSON body") return diff --git a/server/handler/vmDeleteEvent.go b/server/handler/vmDeleteEvent.go index cf5ede3..6e77cde 100644 --- a/server/handler/vmDeleteEvent.go +++ b/server/handler/vmDeleteEvent.go @@ -2,9 +2,7 @@ package handler import ( "database/sql" - "encoding/json" "fmt" - "io" "net/http" "time" "vctp/db/queries" @@ -24,6 +22,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/event/vm/delete [post] func (h *Handler) VmDeleteEvent(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/event/vm/delete") { return } @@ -34,18 +37,8 @@ func (h *Handler) VmDeleteEvent(w http.ResponseWriter, r *http.Request) { deletedTimestamp int64 ) - reqBody, err := io.ReadAll(r.Body) - if err != nil { - h.Logger.Error("Invalid data received", "error", err) - writeJSONError(w, http.StatusInternalServerError, "Invalid data received") - return - } else { - //h.Logger.Debug("received input data", "length", len(reqBody)) - } - - // Decode the JSON body into CloudEventReceived struct var event models.CloudEventReceived - if err := json.Unmarshal(reqBody, &event); err != nil { + if err := decodeJSONBody(w, r, &event); err != nil { h.Logger.Error("unable to decode json", "error", err) prettyPrint(event) writeJSONError(w, http.StatusBadRequest, "Invalid JSON body") diff --git a/server/handler/vmImport.go b/server/handler/vmImport.go index 8aed7c0..3e00d80 100644 --- a/server/handler/vmImport.go +++ b/server/handler/vmImport.go @@ -2,10 +2,8 @@ package handler import ( "database/sql" - "encoding/json" "errors" "fmt" - "io" "net/http" "strings" "vctp/db" @@ -25,26 +23,19 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/import/vm [post] func (h *Handler) VmImport(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/import/vm") { return } - // Read request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - h.Logger.Error("Invalid data received", "length", len(reqBody), "error", err) - writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Invalid data received: '%s'", err)) - return - - } else { - h.Logger.Debug("received input data", "length", len(reqBody)) - } - - // Decode the JSON body into CloudEventReceived struct var inData models.ImportReceived - if err := json.Unmarshal(reqBody, &inData); err != nil { - h.Logger.Error("Unable to decode json request body", "length", len(reqBody), "error", err) - writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Unable to decode json request body: '%s'", err)) + if err := decodeJSONBody(w, r, &inData); err != nil { + h.Logger.Error("Unable to decode json request body", "error", err) + writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("Unable to decode json request body: '%s'", err)) return } else { //h.Logger.Debug("successfully decoded JSON") @@ -66,7 +57,7 @@ func (h *Handler) VmImport(w http.ResponseWriter, r *http.Request) { VmId: sql.NullString{String: inData.VmId, Valid: inData.VmId != ""}, DatacenterName: sql.NullString{String: inData.Datacenter, Valid: inData.Datacenter != ""}, } - _, err = h.Database.Queries().GetInventoryVmId(ctx, invParams) + _, err := h.Database.Queries().GetInventoryVmId(ctx, invParams) if err != nil { if errors.Is(err, sql.ErrNoRows) { diff --git a/server/handler/vmModifyEvent.go b/server/handler/vmModifyEvent.go index 6b7226a..c1992c3 100644 --- a/server/handler/vmModifyEvent.go +++ b/server/handler/vmModifyEvent.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "regexp" "strconv" @@ -32,6 +31,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/event/vm/modify [post] func (h *Handler) VmModifyEvent(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/event/vm/modify") { return } @@ -44,18 +48,10 @@ func (h *Handler) VmModifyEvent(w http.ResponseWriter, r *http.Request) { ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) defer cancel() - reqBody, err := io.ReadAll(r.Body) - if err != nil { - h.Logger.Error("Invalid data received", "length", len(reqBody), "error", err) - writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Invalid data received: '%s'", err)) - return - } - - // Decode the JSON body into CloudEventReceived struct var event models.CloudEventReceived - if err := json.Unmarshal(reqBody, &event); err != nil { - h.Logger.Error("Unable to decode json request body", "length", len(reqBody), "error", err) - writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Unable to decode json request body: '%s'", err)) + if err := decodeJSONBody(w, r, &event); err != nil { + h.Logger.Error("Unable to decode json request body", "error", err) + writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("Unable to decode json request body: '%s'", err)) return } else { //h.Logger.Debug("successfully decoded JSON") @@ -339,7 +335,15 @@ func (h *Handler) calculateNewDiskSize(ctx context.Context, event models.CloudEv var totalDiskBytes int64 h.Logger.Debug("connecting to vcenter") vc := vcenter.New(h.Logger, h.VcCreds) - vc.Login(event.CloudEvent.Source) + if err := vc.Login(event.CloudEvent.Source); err != nil { + h.Logger.Error("unable to connect to vcenter while calculating disk size", "source", event.CloudEvent.Source, "error", err) + return 0 + } + defer func() { + logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + _ = vc.Logout(logoutCtx) + }() vmObject, err := vc.FindVMByIDWithDatacenter(event.CloudEvent.Data.VM.VM.Value, event.CloudEvent.Data.Datacenter.Datacenter.Value) @@ -369,10 +373,6 @@ func (h *Handler) calculateNewDiskSize(ctx context.Context, event models.CloudEv } } - logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) - defer cancel() - _ = vc.Logout(logoutCtx) - h.Logger.Debug("Calculated new disk size", "value", diskSize) return diskSize @@ -394,7 +394,15 @@ func (h *Handler) AddVmToInventory(evt models.CloudEventReceived, ctx context.Co ) //c.Logger.Debug("connecting to vcenter") vc := vcenter.New(h.Logger, h.VcCreds) - vc.Login(evt.CloudEvent.Source) + if err := vc.Login(evt.CloudEvent.Source); err != nil { + h.Logger.Error("unable to connect to vcenter", "source", evt.CloudEvent.Source, "error", err) + return 0, err + } + defer func() { + logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + _ = vc.Logout(logoutCtx) + }() //datacenter = evt.DatacenterName.String vmObject, err := vc.FindVMByIDWithDatacenter(evt.CloudEvent.Data.VM.VM.Value, evt.CloudEvent.Data.Datacenter.Datacenter.Value) @@ -410,7 +418,6 @@ func (h *Handler) AddVmToInventory(evt models.CloudEventReceived, ctx context.Co if strings.HasPrefix(vmObject.Name, "vCLS-") { h.Logger.Info("Skipping internal vCLS VM", "vm_name", vmObject.Name) - _ = vc.Logout(ctx) return 0, nil } @@ -484,8 +491,6 @@ func (h *Handler) AddVmToInventory(evt models.CloudEventReceived, ctx context.Co poweredOn = "TRUE" } - _ = vc.Logout(ctx) - if foundVm { e := evt.CloudEvent h.Logger.Debug("Adding to Inventory table", "vm_name", e.Data.VM.Name, "vcpus", numVcpus, "ram", numRam, "dc", e.Data.Datacenter.Datacenter.Value) diff --git a/server/handler/vmMoveEvent.go b/server/handler/vmMoveEvent.go index 976e0f0..c68c95f 100644 --- a/server/handler/vmMoveEvent.go +++ b/server/handler/vmMoveEvent.go @@ -2,10 +2,8 @@ package handler import ( "database/sql" - "encoding/json" "errors" "fmt" - "io" "net/http" "strconv" "time" @@ -26,6 +24,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/event/vm/move [post] func (h *Handler) VmMoveEvent(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/event/vm/move") { return } @@ -36,21 +39,10 @@ func (h *Handler) VmMoveEvent(w http.ResponseWriter, r *http.Request) { ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) defer cancel() - reqBody, err := io.ReadAll(r.Body) - if err != nil { - h.Logger.Error("Invalid data received", "error", err) - writeJSONError(w, http.StatusInternalServerError, "Invalid data received") - return - } else { - //h.Logger.Debug("received input data", "length", len(reqBody)) - } - - // Decode the JSON body into CloudEventReceived struct var event models.CloudEventReceived - if err := json.Unmarshal(reqBody, &event); err != nil { + if err := decodeJSONBody(w, r, &event); err != nil { h.Logger.Error("unable to unmarshal json", "error", err) - prettyPrint(reqBody) - writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Unable to unmarshal JSON in request body: '%s'", err)) + writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("Unable to unmarshal JSON in request body: '%s'", err)) return } else { h.Logger.Debug("successfully decoded JSON") diff --git a/server/handler/vmUpdateDetails.go b/server/handler/vmUpdateDetails.go index 93fa89c..852983d 100644 --- a/server/handler/vmUpdateDetails.go +++ b/server/handler/vmUpdateDetails.go @@ -1,8 +1,10 @@ package handler import ( + "context" "database/sql" "net/http" + "time" "vctp/db/queries" "vctp/internal/vcenter" ) @@ -17,6 +19,11 @@ import ( // @Failure 500 {object} models.ErrorResponse "Server error" // @Router /api/inventory/vm/update [post] func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h.denyLegacyAPI(w, "/api/inventory/vm/update") { return } @@ -31,20 +38,42 @@ func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) { defer cancel() // reload settings in case vcenter list has changed - h.Settings.ReadYMLSettings() + if err := h.Settings.ReadYMLSettings(); err != nil { + h.Logger.Error("unable to reload settings", "error", err) + writeJSONError(w, http.StatusInternalServerError, "Unable to reload settings") + return + } for _, url := range h.Settings.Values.Settings.VcenterAddresses { h.Logger.Debug("connecting to vcenter", "url", url) vc := vcenter.New(h.Logger, h.VcCreds) - vc.Login(url) + if err := vc.Login(url); err != nil { + h.Logger.Error("unable to connect to vcenter", "url", url, "error", err) + writeJSONError(w, http.StatusInternalServerError, "Unable to connect to vcenter") + return + } + logout := func() { + logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + if err := vc.Logout(logoutCtx); err != nil { + h.Logger.Warn("vcenter logout failed", "url", url, "error", err) + } + } // Get list of VMs from vcenter vms, err := vc.GetAllVmReferences() + if err != nil { + logout() + h.Logger.Error("Unable to query vcenter VM references", "url", url, "error", err) + writeJSONError(w, http.StatusInternalServerError, "Unable to query vcenter VM references") + return + } // Get list of VMs from inventory table h.Logger.Debug("Querying inventory table") results, err := h.Database.Queries().GetInventoryByVcenter(ctx, url) if err != nil { + logout() h.Logger.Error("Unable to query inventory table", "error", err) writeJSONError(w, http.StatusInternalServerError, "Unable to query inventory table") return @@ -116,6 +145,7 @@ func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) { } } + logout() } h.Logger.Debug("Processed vm update successfully") diff --git a/server/server.go b/server/server.go index c839ca9..2bdd18f 100644 --- a/server/server.go +++ b/server/server.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "syscall" "time" "github.com/go-co-op/gocron/v2" @@ -18,6 +19,7 @@ type Server struct { logger *slog.Logger cron gocron.Scheduler cancel context.CancelFunc + startErr chan error disableTls bool tlsCertFilename string tlsKeyFilename string @@ -29,33 +31,26 @@ func New(logger *slog.Logger, cron gocron.Scheduler, cancel context.CancelFunc, // Set some options for TLS tlsConfig := &tls.Config{ - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, - PreferServerCipherSuites: true, - InsecureSkipVerify: true, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - }, + MinVersion: tls.VersionTLS12, } srv := &http.Server{ - Addr: addr, - //WriteTimeout: 120 * time.Second, - WriteTimeout: 0, - ReadTimeout: 30 * time.Second, - TLSConfig: tlsConfig, - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), + Addr: addr, + WriteTimeout: 2 * time.Minute, + ReadTimeout: 30 * time.Second, + ReadHeaderTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + TLSConfig: tlsConfig, + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), } // Set the initial server values server := &Server{ - srv: srv, - logger: logger, - cron: cron, - cancel: cancel, + srv: srv, + logger: logger, + cron: cron, + cancel: cancel, + startErr: make(chan error, 1), } // Apply any options @@ -120,60 +115,60 @@ func SetPrivateKey(tlsKeyFilename string) Option { } // StartAndWait starts the server and waits for a signal to shut down. -func (s *Server) StartAndWait() { +func (s *Server) StartAndWait() error { s.Start() - s.GracefulShutdown() + return s.GracefulShutdown() } // Start starts the server. func (s *Server) Start() { - go func() { if s.disableTls { s.logger.Info("starting server", "port", s.srv.Addr) - if err := s.srv.ListenAndServe(); err != nil { - s.logger.Error("failed to start server", "error", err) - os.Exit(1) + if err := s.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + s.startErr <- err } } else { s.logger.Info("starting TLS server", "port", s.srv.Addr, "cert", s.tlsCertFilename, "key", s.tlsKeyFilename) if err := s.srv.ListenAndServeTLS(s.tlsCertFilename, s.tlsKeyFilename); err != nil && err != http.ErrServerClosed { - s.logger.Error("failed to start server", "error", err) - os.Exit(1) + s.startErr <- err } } }() } // GracefulShutdown shuts down the server gracefully. -func (s *Server) GracefulShutdown() { - c := make(chan os.Signal, 1) - // We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) - // SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught. - signal.Notify(c, os.Interrupt) +func (s *Server) GracefulShutdown() error { + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(signals) - // Block until we receive our signal. - <-c + var startErr error + select { + case sig := <-signals: + s.logger.Info("shutdown signal received", "signal", sig.String()) - // Create a deadline to wait for. - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - // Doesn't block if no connections, but will otherwise wait - // until the timeout deadline. - _ = s.srv.Shutdown(ctx) + // Create a deadline to wait for. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Doesn't block if no connections, but will otherwise wait + // until the timeout deadline. + if err := s.srv.Shutdown(ctx); err != nil { + s.logger.Warn("http server shutdown returned error", "error", err) + } + case err := <-s.startErr: + s.logger.Error("http server exited with startup/runtime error", "error", err) + startErr = err + } - s.logger.Info("runing cron shutdown") - err := s.cron.Shutdown() - if err != nil { + s.logger.Info("running cron shutdown") + if err := s.cron.Shutdown(); err != nil { s.logger.Error("error shutting cron", "error", err) } - s.logger.Info("runing cancel") + s.logger.Info("running cancel") s.cancel() - // Optionally, you could run srv.Shutdown in a goroutine and block on - // <-ctx.Done() if your application should wait for other services - // to finalize based on context cancellation. s.logger.Info("shutting down") - //os.Exit(0) + return startErr } diff --git a/src/postinstall.sh b/src/postinstall.sh index 776a275..45f0fec 100644 --- a/src/postinstall.sh +++ b/src/postinstall.sh @@ -100,6 +100,15 @@ merge_missing_settings_from_rpmnew() { merge_missing_settings_from_rpmnew "$TARGET_CFG" "$SOURCE_CFG" || : +if [ -f "$TARGET_CFG" ]; then + chown root:dtms "$TARGET_CFG" || : + chmod 640 "$TARGET_CFG" || : +fi +if [ -f "$SOURCE_CFG" ]; then + chown root:dtms "$SOURCE_CFG" || : + chmod 640 "$SOURCE_CFG" || : +fi + if command -v systemctl >/dev/null 2>&1; then systemctl daemon-reload || : if [ "${1:-0}" -eq 1 ]; then diff --git a/src/preinstall.sh b/src/preinstall.sh index a87c395..f336775 100644 --- a/src/preinstall.sh +++ b/src/preinstall.sh @@ -15,10 +15,10 @@ getent passwd "$USER" >/dev/null || useradd -r -g "$GROUP" -m -s /bin/bash -c "v [ -d /etc/dtms ] || mkdir -p /etc/dtms # set group ownership on vctp config directory if not already done -[ "$(stat -c "%G" /etc/dtms)" = "$GROUP" ] || chgrp -R "$GROUP" /etc/dtms +[ "$(stat -c "%G" /etc/dtms)" = "$GROUP" ] || chgrp "$GROUP" /etc/dtms # set permissions on vctp config directory if not already done -[ "$(stat -c "%a" /etc/dtms)" = "774" ] || chmod -R 774 /etc/dtms +[ "$(stat -c "%a" /etc/dtms)" = "750" ] || chmod 750 /etc/dtms # create vctp data directory if it doesn't exist [ -d /var/lib/vctp ] || mkdir -p /var/lib/vctp