cleanups and code fixes incl templ
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2026-03-20 13:21:15 +11:00
parent 4fbb2582e3
commit 9a561f3b07
24 changed files with 425 additions and 141 deletions

View File

@@ -34,7 +34,7 @@ steps:
path: /shared path: /shared
commands: commands:
- export PATH=/drone/src/pkg.tools:$PATH - export PATH=/drone/src/pkg.tools:$PATH
- go install github.com/a-h/templ/cmd/templ@v0.3.977 - go install github.com/a-h/templ/cmd/templ@v0.3.1001
- go install github.com/sqlc-dev/sqlc/cmd/sqlc@v1.29.0 - go install github.com/sqlc-dev/sqlc/cmd/sqlc@v1.29.0
- go install github.com/swaggo/swag/cmd/swag@v1.16.6 - go install github.com/swaggo/swag/cmd/swag@v1.16.6
# - go install github.com/goreleaser/nfpm/v2/cmd/nfpm@latest # - go install github.com/goreleaser/nfpm/v2/cmd/nfpm@latest

View File

@@ -1,16 +1,19 @@
## Build ## Build
FROM golang:1.26-alpine AS build FROM golang:1.26.0-alpine AS build
ARG VERSION='dev' ARG VERSION='dev'
ARG TAILWIND_VERSION='v3.4.17'
ARG TEMPL_VERSION='v0.3.1001'
ARG SQLC_VERSION='v1.29.0'
RUN apk update && apk add --no-cache curl RUN apk update && apk add --no-cache curl
RUN curl -sLO https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 \ RUN curl -fsSLo tailwindcss-linux-x64 https://github.com/tailwindlabs/tailwindcss/releases/download/${TAILWIND_VERSION}/tailwindcss-linux-x64 \
&& chmod +x tailwindcss-linux-x64 \ && chmod +x tailwindcss-linux-x64 \
&& mv tailwindcss-linux-x64 /usr/local/bin/tailwindcss && mv tailwindcss-linux-x64 /usr/local/bin/tailwindcss
RUN go install github.com/a-h/templ/cmd/templ@v0.2.663 \ RUN go install github.com/a-h/templ/cmd/templ@${TEMPL_VERSION} \
&& go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest && go install github.com/sqlc-dev/sqlc/cmd/sqlc@${SQLC_VERSION}
WORKDIR /app WORKDIR /app

View File

@@ -106,6 +106,7 @@ func (s *Settings) ReadYMLSettings() error {
// Init new YAML decode // Init new YAML decode
d := yaml.NewDecoder(file) d := yaml.NewDecoder(file)
d.KnownFields(true)
// Start YAML decoding from file // Start YAML decoding from file
if err := d.Decode(&settings); err != nil { if err := d.Decode(&settings); err != nil {
@@ -149,9 +150,9 @@ func (s *Settings) WriteYMLSettings() error {
return fmt.Errorf("unable to encode settings file: %w", err) return fmt.Errorf("unable to encode settings file: %w", err)
} }
mode := os.FileMode(0o644) mode := os.FileMode(0o600)
if info, err := os.Stat(s.SettingsPath); err == nil { if info, err := os.Stat(s.SettingsPath); err == nil {
mode = info.Mode().Perm() mode = secureSettingsFileMode(info.Mode().Perm())
} }
dir := filepath.Dir(s.SettingsPath) dir := filepath.Dir(s.SettingsPath)
@@ -181,3 +182,10 @@ func (s *Settings) WriteYMLSettings() error {
return nil return nil
} }
func secureSettingsFileMode(mode os.FileMode) os.FileMode {
// Ensure owner read/write, strip world permissions and all execute bits.
secured := mode & 0o660
secured |= 0o600
return secured
}

View File

@@ -0,0 +1,55 @@
package settings
import (
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
)
func TestReadYMLSettingsRejectsUnknownField(t *testing.T) {
tmpDir := t.TempDir()
settingsPath := filepath.Join(tmpDir, "vctp.yml")
content := `settings:
log_level: "info"
unknown_field: true
`
if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil {
t.Fatalf("failed to write settings file: %v", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
s := New(logger, settingsPath)
err := s.ReadYMLSettings()
if err == nil {
t.Fatal("expected unknown field decode error")
}
if !strings.Contains(strings.ToLower(err.Error()), "unknown_field") {
t.Fatalf("expected error to mention unknown field, got: %v", err)
}
}
func TestSecureSettingsFileMode(t *testing.T) {
cases := []struct {
name string
in os.FileMode
want os.FileMode
}{
{name: "already strict", in: 0o600, want: 0o600},
{name: "group read allowed", in: 0o640, want: 0o640},
{name: "too open world", in: 0o666, want: 0o660},
{name: "exec bits stripped", in: 0o755, want: 0o640},
{name: "no perms gets owner rw", in: 0o000, want: 0o600},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := secureSettingsFileMode(tc.in)
if got != tc.want {
t.Fatalf("unexpected mode conversion: in=%#o got=%#o want=%#o", tc.in, got, tc.want)
}
})
}
}

10
main.go
View File

@@ -285,7 +285,10 @@ func main() {
// One-shot mode: run a single inventory snapshot across all configured vCenters and exit. // One-shot mode: run a single inventory snapshot across all configured vCenters and exit.
if *runInventory { if *runInventory {
logger.Info("Running one-shot inventory snapshot across all vCenters") logger.Info("Running one-shot inventory snapshot across all vCenters")
ct.RunVcenterSnapshotHourly(ctx, logger, true) if err := ct.RunVcenterSnapshotHourly(ctx, logger, true); err != nil {
logger.Error("One-shot inventory snapshot failed", "error", err)
os.Exit(1)
}
logger.Info("One-shot inventory snapshot complete; exiting") logger.Info("One-shot inventory snapshot complete; exiting")
return return
} }
@@ -403,7 +406,10 @@ func main() {
) )
//logger.Debug("Server configured", "object", svr) //logger.Debug("Server configured", "object", svr)
svr.StartAndWait() if err := svr.StartAndWait(); err != nil {
logger.Error("server terminated with error", "error", err)
os.Exit(1)
}
os.Exit(0) os.Exit(0)
} }

View File

@@ -0,0 +1,146 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"vctp/server/models"
)
func TestMutatingHandlersRejectWrongMethod(t *testing.T) {
h := &Handler{Logger: newTestLogger()}
tests := []struct {
name string
path string
call func(*Handler, *httptest.ResponseRecorder, *http.Request)
}{
{
name: "snapshot force hourly",
path: "/api/snapshots/hourly/force",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.SnapshotForceHourly(rr, req)
},
},
{
name: "snapshot aggregate",
path: "/api/snapshots/aggregate",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.SnapshotAggregateForce(rr, req)
},
},
{
name: "snapshot migrate",
path: "/api/snapshots/migrate",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.SnapshotMigrate(rr, req)
},
},
{
name: "snapshot regenerate hourly",
path: "/api/snapshots/regenerate-hourly-reports",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.SnapshotRegenerateHourlyReports(rr, req)
},
},
{
name: "vm create event",
path: "/api/event/vm/create",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VmCreateEvent(rr, req)
},
},
{
name: "vm modify event",
path: "/api/event/vm/modify",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VmModifyEvent(rr, req)
},
},
{
name: "vm move event",
path: "/api/event/vm/move",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VmMoveEvent(rr, req)
},
},
{
name: "vm delete event",
path: "/api/event/vm/delete",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VmDeleteEvent(rr, req)
},
},
{
name: "vm import",
path: "/api/import/vm",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VmImport(rr, req)
},
},
{
name: "vm update details",
path: "/api/inventory/vm/update",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VmUpdateDetails(rr, req)
},
},
{
name: "vm cleanup",
path: "/api/inventory/vm/delete",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VmCleanup(rr, req)
},
},
{
name: "vcenter cleanup",
path: "/api/cleanup/vcenter",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.VcCleanup(rr, req)
},
},
{
name: "update cleanup",
path: "/api/cleanup/updates",
call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) {
h.UpdateCleanup(rr, req)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.path, strings.NewReader("{}"))
rr := httptest.NewRecorder()
tc.call(h, rr, req)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
if !strings.Contains(rr.Body.String(), "method not allowed") {
t.Fatalf("expected method not allowed response, got: %s", rr.Body.String())
}
})
}
}
func TestVcenterLoginFailuresAreHandled(t *testing.T) {
h := &Handler{Logger: newTestLogger()}
event := models.CloudEventReceived{}
event.CloudEvent.Source = "https://invalid.local/sdk"
disk := h.calculateNewDiskSize(context.Background(), event)
if disk != 0 {
t.Fatalf("expected disk size 0 on login failure, got %f", disk)
}
id, err := h.AddVmToInventory(event, context.Background(), 0)
if err == nil {
t.Fatal("expected error on login failure")
}
if id != 0 {
t.Fatalf("expected id 0 on login failure, got %d", id)
}
}

View File

@@ -2,6 +2,9 @@ package handler
import ( import (
"context" "context"
"encoding/json"
"errors"
"io"
"net/http" "net/http"
"time" "time"
) )
@@ -10,6 +13,7 @@ const (
defaultRequestTimeout = 2 * time.Minute defaultRequestTimeout = 2 * time.Minute
reportRequestTimeout = 10 * time.Minute reportRequestTimeout = 10 * time.Minute
longRunningRequestTimeout = 2 * time.Hour longRunningRequestTimeout = 2 * time.Hour
defaultJSONBodyLimitBytes = 1 << 20 // 1 MiB
) )
func withRequestTimeout(r *http.Request, timeout time.Duration) (context.Context, context.CancelFunc) { func withRequestTimeout(r *http.Request, timeout time.Duration) (context.Context, context.CancelFunc) {
@@ -22,3 +26,21 @@ func withRequestTimeout(r *http.Request, timeout time.Duration) (context.Context
} }
return context.WithTimeout(base, timeout) return context.WithTimeout(base, timeout)
} }
func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) error {
if r == nil || r.Body == nil {
return errors.New("request body is required")
}
decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, defaultJSONBodyLimitBytes))
if err := decoder.Decode(dst); err != nil {
return err
}
var trailing any
if err := decoder.Decode(&trailing); err != io.EOF {
if err == nil {
return errors.New("request body must contain only one JSON object")
}
return err
}
return nil
}

View File

@@ -21,6 +21,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/snapshots/aggregate [post] // @Router /api/snapshots/aggregate [post]
func (h *Handler) SnapshotAggregateForce(w http.ResponseWriter, r *http.Request) { func (h *Handler) SnapshotAggregateForce(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
snapshotType := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("type"))) snapshotType := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("type")))
dateValue := strings.TrimSpace(r.URL.Query().Get("date")) dateValue := strings.TrimSpace(r.URL.Query().Get("date"))
granularity := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("granularity"))) granularity := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("granularity")))

View File

@@ -19,6 +19,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/snapshots/hourly/force [post] // @Router /api/snapshots/hourly/force [post]
func (h *Handler) SnapshotForceHourly(w http.ResponseWriter, r *http.Request) { func (h *Handler) SnapshotForceHourly(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
confirm := strings.TrimSpace(r.URL.Query().Get("confirm")) confirm := strings.TrimSpace(r.URL.Query().Get("confirm"))
if strings.ToUpper(confirm) != "FORCE" { if strings.ToUpper(confirm) != "FORCE" {
writeJSONError(w, http.StatusBadRequest, "confirm must be 'FORCE'") writeJSONError(w, http.StatusBadRequest, "confirm must be 'FORCE'")

View File

@@ -15,6 +15,11 @@ import (
// @Failure 500 {object} models.SnapshotMigrationResponse "Server error" // @Failure 500 {object} models.SnapshotMigrationResponse "Server error"
// @Router /api/snapshots/migrate [post] // @Router /api/snapshots/migrate [post]
func (h *Handler) SnapshotMigrate(w http.ResponseWriter, r *http.Request) { func (h *Handler) SnapshotMigrate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ctx, cancel := withRequestTimeout(r, reportRequestTimeout) ctx, cancel := withRequestTimeout(r, reportRequestTimeout)
defer cancel() defer cancel()
stats, err := report.MigrateSnapshotRegistry(ctx, h.Database) stats, err := report.MigrateSnapshotRegistry(ctx, h.Database)

View File

@@ -19,6 +19,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/snapshots/regenerate-hourly-reports [post] // @Router /api/snapshots/regenerate-hourly-reports [post]
func (h *Handler) SnapshotRegenerateHourlyReports(w http.ResponseWriter, r *http.Request) { func (h *Handler) SnapshotRegenerateHourlyReports(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ctx := r.Context() ctx := r.Context()
reportsDir := strings.TrimSpace(h.Settings.Values.Settings.ReportsDir) reportsDir := strings.TrimSpace(h.Settings.Values.Settings.ReportsDir)
if reportsDir == "" { if reportsDir == "" {

View File

@@ -15,6 +15,11 @@ import (
// @Failure 500 {string} string "Server error" // @Failure 500 {string} string "Server error"
// @Router /api/cleanup/updates [delete] // @Router /api/cleanup/updates [delete]
func (h *Handler) UpdateCleanup(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateCleanup(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/cleanup/updates") { if h.denyLegacyAPI(w, "/api/cleanup/updates") {
return return
} }

View File

@@ -18,6 +18,11 @@ import (
// @Failure 400 {object} models.ErrorResponse "Invalid request" // @Failure 400 {object} models.ErrorResponse "Invalid request"
// @Router /api/cleanup/vcenter [delete] // @Router /api/cleanup/vcenter [delete]
func (h *Handler) VcCleanup(w http.ResponseWriter, r *http.Request) { func (h *Handler) VcCleanup(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/cleanup/vcenter") { if h.denyLegacyAPI(w, "/api/cleanup/vcenter") {
return return
} }

View File

@@ -20,6 +20,11 @@ import (
// @Failure 400 {object} models.ErrorResponse "Invalid request" // @Failure 400 {object} models.ErrorResponse "Invalid request"
// @Router /api/inventory/vm/delete [delete] // @Router /api/inventory/vm/delete [delete]
func (h *Handler) VmCleanup(w http.ResponseWriter, r *http.Request) { func (h *Handler) VmCleanup(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/inventory/vm/delete") { if h.denyLegacyAPI(w, "/api/inventory/vm/delete") {
return return
} }

View File

@@ -4,7 +4,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"runtime" "runtime"
"strconv" "strconv"
@@ -26,6 +25,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/event/vm/create [post] // @Router /api/event/vm/create [post]
func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) { func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/event/vm/create") { if h.denyLegacyAPI(w, "/api/event/vm/create") {
return return
} }
@@ -39,18 +43,8 @@ func (h *Handler) VmCreateEvent(w http.ResponseWriter, r *http.Request) {
//datacenter string //datacenter string
) )
reqBody, err := io.ReadAll(r.Body)
if err != nil {
h.Logger.Error("Invalid data received", "error", err)
writeJSONError(w, http.StatusInternalServerError, "Invalid data received")
return
} else {
h.Logger.Debug("received input data", "length", len(reqBody))
}
// Decode the JSON body into CloudEventReceived struct
var event models.CloudEventReceived var event models.CloudEventReceived
if err := json.Unmarshal(reqBody, &event); err != nil { if err := decodeJSONBody(w, r, &event); err != nil {
h.Logger.Error("unable to decode json", "error", err) h.Logger.Error("unable to decode json", "error", err)
writeJSONError(w, http.StatusBadRequest, "Invalid JSON body") writeJSONError(w, http.StatusBadRequest, "Invalid JSON body")
return return

View File

@@ -2,9 +2,7 @@ package handler
import ( import (
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"time" "time"
"vctp/db/queries" "vctp/db/queries"
@@ -24,6 +22,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/event/vm/delete [post] // @Router /api/event/vm/delete [post]
func (h *Handler) VmDeleteEvent(w http.ResponseWriter, r *http.Request) { func (h *Handler) VmDeleteEvent(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/event/vm/delete") { if h.denyLegacyAPI(w, "/api/event/vm/delete") {
return return
} }
@@ -34,18 +37,8 @@ func (h *Handler) VmDeleteEvent(w http.ResponseWriter, r *http.Request) {
deletedTimestamp int64 deletedTimestamp int64
) )
reqBody, err := io.ReadAll(r.Body)
if err != nil {
h.Logger.Error("Invalid data received", "error", err)
writeJSONError(w, http.StatusInternalServerError, "Invalid data received")
return
} else {
//h.Logger.Debug("received input data", "length", len(reqBody))
}
// Decode the JSON body into CloudEventReceived struct
var event models.CloudEventReceived var event models.CloudEventReceived
if err := json.Unmarshal(reqBody, &event); err != nil { if err := decodeJSONBody(w, r, &event); err != nil {
h.Logger.Error("unable to decode json", "error", err) h.Logger.Error("unable to decode json", "error", err)
prettyPrint(event) prettyPrint(event)
writeJSONError(w, http.StatusBadRequest, "Invalid JSON body") writeJSONError(w, http.StatusBadRequest, "Invalid JSON body")

View File

@@ -2,10 +2,8 @@ package handler
import ( import (
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strings" "strings"
"vctp/db" "vctp/db"
@@ -25,26 +23,19 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/import/vm [post] // @Router /api/import/vm [post]
func (h *Handler) VmImport(w http.ResponseWriter, r *http.Request) { func (h *Handler) VmImport(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/import/vm") { if h.denyLegacyAPI(w, "/api/import/vm") {
return return
} }
// Read request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
h.Logger.Error("Invalid data received", "length", len(reqBody), "error", err)
writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Invalid data received: '%s'", err))
return
} else {
h.Logger.Debug("received input data", "length", len(reqBody))
}
// Decode the JSON body into CloudEventReceived struct
var inData models.ImportReceived var inData models.ImportReceived
if err := json.Unmarshal(reqBody, &inData); err != nil { if err := decodeJSONBody(w, r, &inData); err != nil {
h.Logger.Error("Unable to decode json request body", "length", len(reqBody), "error", err) h.Logger.Error("Unable to decode json request body", "error", err)
writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Unable to decode json request body: '%s'", err)) writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("Unable to decode json request body: '%s'", err))
return return
} else { } else {
//h.Logger.Debug("successfully decoded JSON") //h.Logger.Debug("successfully decoded JSON")
@@ -66,7 +57,7 @@ func (h *Handler) VmImport(w http.ResponseWriter, r *http.Request) {
VmId: sql.NullString{String: inData.VmId, Valid: inData.VmId != ""}, VmId: sql.NullString{String: inData.VmId, Valid: inData.VmId != ""},
DatacenterName: sql.NullString{String: inData.Datacenter, Valid: inData.Datacenter != ""}, DatacenterName: sql.NullString{String: inData.Datacenter, Valid: inData.Datacenter != ""},
} }
_, err = h.Database.Queries().GetInventoryVmId(ctx, invParams) _, err := h.Database.Queries().GetInventoryVmId(ctx, invParams)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {

View File

@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
@@ -32,6 +31,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/event/vm/modify [post] // @Router /api/event/vm/modify [post]
func (h *Handler) VmModifyEvent(w http.ResponseWriter, r *http.Request) { func (h *Handler) VmModifyEvent(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/event/vm/modify") { if h.denyLegacyAPI(w, "/api/event/vm/modify") {
return return
} }
@@ -44,18 +48,10 @@ func (h *Handler) VmModifyEvent(w http.ResponseWriter, r *http.Request) {
ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) ctx, cancel := withRequestTimeout(r, defaultRequestTimeout)
defer cancel() defer cancel()
reqBody, err := io.ReadAll(r.Body)
if err != nil {
h.Logger.Error("Invalid data received", "length", len(reqBody), "error", err)
writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Invalid data received: '%s'", err))
return
}
// Decode the JSON body into CloudEventReceived struct
var event models.CloudEventReceived var event models.CloudEventReceived
if err := json.Unmarshal(reqBody, &event); err != nil { if err := decodeJSONBody(w, r, &event); err != nil {
h.Logger.Error("Unable to decode json request body", "length", len(reqBody), "error", err) h.Logger.Error("Unable to decode json request body", "error", err)
writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Unable to decode json request body: '%s'", err)) writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("Unable to decode json request body: '%s'", err))
return return
} else { } else {
//h.Logger.Debug("successfully decoded JSON") //h.Logger.Debug("successfully decoded JSON")
@@ -339,7 +335,15 @@ func (h *Handler) calculateNewDiskSize(ctx context.Context, event models.CloudEv
var totalDiskBytes int64 var totalDiskBytes int64
h.Logger.Debug("connecting to vcenter") h.Logger.Debug("connecting to vcenter")
vc := vcenter.New(h.Logger, h.VcCreds) vc := vcenter.New(h.Logger, h.VcCreds)
vc.Login(event.CloudEvent.Source) if err := vc.Login(event.CloudEvent.Source); err != nil {
h.Logger.Error("unable to connect to vcenter while calculating disk size", "source", event.CloudEvent.Source, "error", err)
return 0
}
defer func() {
logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
_ = vc.Logout(logoutCtx)
}()
vmObject, err := vc.FindVMByIDWithDatacenter(event.CloudEvent.Data.VM.VM.Value, event.CloudEvent.Data.Datacenter.Datacenter.Value) vmObject, err := vc.FindVMByIDWithDatacenter(event.CloudEvent.Data.VM.VM.Value, event.CloudEvent.Data.Datacenter.Datacenter.Value)
@@ -369,10 +373,6 @@ func (h *Handler) calculateNewDiskSize(ctx context.Context, event models.CloudEv
} }
} }
logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
_ = vc.Logout(logoutCtx)
h.Logger.Debug("Calculated new disk size", "value", diskSize) h.Logger.Debug("Calculated new disk size", "value", diskSize)
return diskSize return diskSize
@@ -394,7 +394,15 @@ func (h *Handler) AddVmToInventory(evt models.CloudEventReceived, ctx context.Co
) )
//c.Logger.Debug("connecting to vcenter") //c.Logger.Debug("connecting to vcenter")
vc := vcenter.New(h.Logger, h.VcCreds) vc := vcenter.New(h.Logger, h.VcCreds)
vc.Login(evt.CloudEvent.Source) if err := vc.Login(evt.CloudEvent.Source); err != nil {
h.Logger.Error("unable to connect to vcenter", "source", evt.CloudEvent.Source, "error", err)
return 0, err
}
defer func() {
logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
_ = vc.Logout(logoutCtx)
}()
//datacenter = evt.DatacenterName.String //datacenter = evt.DatacenterName.String
vmObject, err := vc.FindVMByIDWithDatacenter(evt.CloudEvent.Data.VM.VM.Value, evt.CloudEvent.Data.Datacenter.Datacenter.Value) vmObject, err := vc.FindVMByIDWithDatacenter(evt.CloudEvent.Data.VM.VM.Value, evt.CloudEvent.Data.Datacenter.Datacenter.Value)
@@ -410,7 +418,6 @@ func (h *Handler) AddVmToInventory(evt models.CloudEventReceived, ctx context.Co
if strings.HasPrefix(vmObject.Name, "vCLS-") { if strings.HasPrefix(vmObject.Name, "vCLS-") {
h.Logger.Info("Skipping internal vCLS VM", "vm_name", vmObject.Name) h.Logger.Info("Skipping internal vCLS VM", "vm_name", vmObject.Name)
_ = vc.Logout(ctx)
return 0, nil return 0, nil
} }
@@ -484,8 +491,6 @@ func (h *Handler) AddVmToInventory(evt models.CloudEventReceived, ctx context.Co
poweredOn = "TRUE" poweredOn = "TRUE"
} }
_ = vc.Logout(ctx)
if foundVm { if foundVm {
e := evt.CloudEvent e := evt.CloudEvent
h.Logger.Debug("Adding to Inventory table", "vm_name", e.Data.VM.Name, "vcpus", numVcpus, "ram", numRam, "dc", e.Data.Datacenter.Datacenter.Value) h.Logger.Debug("Adding to Inventory table", "vm_name", e.Data.VM.Name, "vcpus", numVcpus, "ram", numRam, "dc", e.Data.Datacenter.Datacenter.Value)

View File

@@ -2,10 +2,8 @@ package handler
import ( import (
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@@ -26,6 +24,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/event/vm/move [post] // @Router /api/event/vm/move [post]
func (h *Handler) VmMoveEvent(w http.ResponseWriter, r *http.Request) { func (h *Handler) VmMoveEvent(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/event/vm/move") { if h.denyLegacyAPI(w, "/api/event/vm/move") {
return return
} }
@@ -36,21 +39,10 @@ func (h *Handler) VmMoveEvent(w http.ResponseWriter, r *http.Request) {
ctx, cancel := withRequestTimeout(r, defaultRequestTimeout) ctx, cancel := withRequestTimeout(r, defaultRequestTimeout)
defer cancel() defer cancel()
reqBody, err := io.ReadAll(r.Body)
if err != nil {
h.Logger.Error("Invalid data received", "error", err)
writeJSONError(w, http.StatusInternalServerError, "Invalid data received")
return
} else {
//h.Logger.Debug("received input data", "length", len(reqBody))
}
// Decode the JSON body into CloudEventReceived struct
var event models.CloudEventReceived var event models.CloudEventReceived
if err := json.Unmarshal(reqBody, &event); err != nil { if err := decodeJSONBody(w, r, &event); err != nil {
h.Logger.Error("unable to unmarshal json", "error", err) h.Logger.Error("unable to unmarshal json", "error", err)
prettyPrint(reqBody) writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("Unable to unmarshal JSON in request body: '%s'", err))
writeJSONError(w, http.StatusInternalServerError, fmt.Sprintf("Unable to unmarshal JSON in request body: '%s'", err))
return return
} else { } else {
h.Logger.Debug("successfully decoded JSON") h.Logger.Debug("successfully decoded JSON")

View File

@@ -1,8 +1,10 @@
package handler package handler
import ( import (
"context"
"database/sql" "database/sql"
"net/http" "net/http"
"time"
"vctp/db/queries" "vctp/db/queries"
"vctp/internal/vcenter" "vctp/internal/vcenter"
) )
@@ -17,6 +19,11 @@ import (
// @Failure 500 {object} models.ErrorResponse "Server error" // @Failure 500 {object} models.ErrorResponse "Server error"
// @Router /api/inventory/vm/update [post] // @Router /api/inventory/vm/update [post]
func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) { func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if h.denyLegacyAPI(w, "/api/inventory/vm/update") { if h.denyLegacyAPI(w, "/api/inventory/vm/update") {
return return
} }
@@ -31,20 +38,42 @@ func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) {
defer cancel() defer cancel()
// reload settings in case vcenter list has changed // reload settings in case vcenter list has changed
h.Settings.ReadYMLSettings() if err := h.Settings.ReadYMLSettings(); err != nil {
h.Logger.Error("unable to reload settings", "error", err)
writeJSONError(w, http.StatusInternalServerError, "Unable to reload settings")
return
}
for _, url := range h.Settings.Values.Settings.VcenterAddresses { for _, url := range h.Settings.Values.Settings.VcenterAddresses {
h.Logger.Debug("connecting to vcenter", "url", url) h.Logger.Debug("connecting to vcenter", "url", url)
vc := vcenter.New(h.Logger, h.VcCreds) vc := vcenter.New(h.Logger, h.VcCreds)
vc.Login(url) if err := vc.Login(url); err != nil {
h.Logger.Error("unable to connect to vcenter", "url", url, "error", err)
writeJSONError(w, http.StatusInternalServerError, "Unable to connect to vcenter")
return
}
logout := func() {
logoutCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
if err := vc.Logout(logoutCtx); err != nil {
h.Logger.Warn("vcenter logout failed", "url", url, "error", err)
}
}
// Get list of VMs from vcenter // Get list of VMs from vcenter
vms, err := vc.GetAllVmReferences() vms, err := vc.GetAllVmReferences()
if err != nil {
logout()
h.Logger.Error("Unable to query vcenter VM references", "url", url, "error", err)
writeJSONError(w, http.StatusInternalServerError, "Unable to query vcenter VM references")
return
}
// Get list of VMs from inventory table // Get list of VMs from inventory table
h.Logger.Debug("Querying inventory table") h.Logger.Debug("Querying inventory table")
results, err := h.Database.Queries().GetInventoryByVcenter(ctx, url) results, err := h.Database.Queries().GetInventoryByVcenter(ctx, url)
if err != nil { if err != nil {
logout()
h.Logger.Error("Unable to query inventory table", "error", err) h.Logger.Error("Unable to query inventory table", "error", err)
writeJSONError(w, http.StatusInternalServerError, "Unable to query inventory table") writeJSONError(w, http.StatusInternalServerError, "Unable to query inventory table")
return return
@@ -116,6 +145,7 @@ func (h *Handler) VmUpdateDetails(w http.ResponseWriter, r *http.Request) {
} }
} }
logout()
} }
h.Logger.Debug("Processed vm update successfully") h.Logger.Debug("Processed vm update successfully")

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"syscall"
"time" "time"
"github.com/go-co-op/gocron/v2" "github.com/go-co-op/gocron/v2"
@@ -18,6 +19,7 @@ type Server struct {
logger *slog.Logger logger *slog.Logger
cron gocron.Scheduler cron gocron.Scheduler
cancel context.CancelFunc cancel context.CancelFunc
startErr chan error
disableTls bool disableTls bool
tlsCertFilename string tlsCertFilename string
tlsKeyFilename string tlsKeyFilename string
@@ -30,22 +32,14 @@ func New(logger *slog.Logger, cron gocron.Scheduler, cancel context.CancelFunc,
// Set some options for TLS // Set some options for TLS
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256},
PreferServerCipherSuites: true,
InsecureSkipVerify: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
},
} }
srv := &http.Server{ srv := &http.Server{
Addr: addr, Addr: addr,
//WriteTimeout: 120 * time.Second, WriteTimeout: 2 * time.Minute,
WriteTimeout: 0,
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,
ReadHeaderTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
} }
@@ -56,6 +50,7 @@ func New(logger *slog.Logger, cron gocron.Scheduler, cancel context.CancelFunc,
logger: logger, logger: logger,
cron: cron, cron: cron,
cancel: cancel, cancel: cancel,
startErr: make(chan error, 1),
} }
// Apply any options // Apply any options
@@ -120,60 +115,60 @@ func SetPrivateKey(tlsKeyFilename string) Option {
} }
// StartAndWait starts the server and waits for a signal to shut down. // StartAndWait starts the server and waits for a signal to shut down.
func (s *Server) StartAndWait() { func (s *Server) StartAndWait() error {
s.Start() s.Start()
s.GracefulShutdown() return s.GracefulShutdown()
} }
// Start starts the server. // Start starts the server.
func (s *Server) Start() { func (s *Server) Start() {
go func() { go func() {
if s.disableTls { if s.disableTls {
s.logger.Info("starting server", "port", s.srv.Addr) s.logger.Info("starting server", "port", s.srv.Addr)
if err := s.srv.ListenAndServe(); err != nil { if err := s.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
s.logger.Error("failed to start server", "error", err) s.startErr <- err
os.Exit(1)
} }
} else { } else {
s.logger.Info("starting TLS server", "port", s.srv.Addr, "cert", s.tlsCertFilename, "key", s.tlsKeyFilename) s.logger.Info("starting TLS server", "port", s.srv.Addr, "cert", s.tlsCertFilename, "key", s.tlsKeyFilename)
if err := s.srv.ListenAndServeTLS(s.tlsCertFilename, s.tlsKeyFilename); err != nil && err != http.ErrServerClosed { if err := s.srv.ListenAndServeTLS(s.tlsCertFilename, s.tlsKeyFilename); err != nil && err != http.ErrServerClosed {
s.logger.Error("failed to start server", "error", err) s.startErr <- err
os.Exit(1)
} }
} }
}() }()
} }
// GracefulShutdown shuts down the server gracefully. // GracefulShutdown shuts down the server gracefully.
func (s *Server) GracefulShutdown() { func (s *Server) GracefulShutdown() error {
c := make(chan os.Signal, 1) signals := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) signal.Notify(signals, os.Interrupt, syscall.SIGTERM)
// SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught. defer signal.Stop(signals)
signal.Notify(c, os.Interrupt)
// Block until we receive our signal. var startErr error
<-c select {
case sig := <-signals:
s.logger.Info("shutdown signal received", "signal", sig.String())
// Create a deadline to wait for. // Create a deadline to wait for.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
// Doesn't block if no connections, but will otherwise wait // Doesn't block if no connections, but will otherwise wait
// until the timeout deadline. // until the timeout deadline.
_ = s.srv.Shutdown(ctx) if err := s.srv.Shutdown(ctx); err != nil {
s.logger.Warn("http server shutdown returned error", "error", err)
}
case err := <-s.startErr:
s.logger.Error("http server exited with startup/runtime error", "error", err)
startErr = err
}
s.logger.Info("runing cron shutdown") s.logger.Info("running cron shutdown")
err := s.cron.Shutdown() if err := s.cron.Shutdown(); err != nil {
if err != nil {
s.logger.Error("error shutting cron", "error", err) s.logger.Error("error shutting cron", "error", err)
} }
s.logger.Info("runing cancel") s.logger.Info("running cancel")
s.cancel() s.cancel()
// Optionally, you could run srv.Shutdown in a goroutine and block on
// <-ctx.Done() if your application should wait for other services
// to finalize based on context cancellation.
s.logger.Info("shutting down") s.logger.Info("shutting down")
//os.Exit(0) return startErr
} }

View File

@@ -100,6 +100,15 @@ merge_missing_settings_from_rpmnew() {
merge_missing_settings_from_rpmnew "$TARGET_CFG" "$SOURCE_CFG" || : merge_missing_settings_from_rpmnew "$TARGET_CFG" "$SOURCE_CFG" || :
if [ -f "$TARGET_CFG" ]; then
chown root:dtms "$TARGET_CFG" || :
chmod 640 "$TARGET_CFG" || :
fi
if [ -f "$SOURCE_CFG" ]; then
chown root:dtms "$SOURCE_CFG" || :
chmod 640 "$SOURCE_CFG" || :
fi
if command -v systemctl >/dev/null 2>&1; then if command -v systemctl >/dev/null 2>&1; then
systemctl daemon-reload || : systemctl daemon-reload || :
if [ "${1:-0}" -eq 1 ]; then if [ "${1:-0}" -eq 1 ]; then

View File

@@ -15,10 +15,10 @@ getent passwd "$USER" >/dev/null || useradd -r -g "$GROUP" -m -s /bin/bash -c "v
[ -d /etc/dtms ] || mkdir -p /etc/dtms [ -d /etc/dtms ] || mkdir -p /etc/dtms
# set group ownership on vctp config directory if not already done # set group ownership on vctp config directory if not already done
[ "$(stat -c "%G" /etc/dtms)" = "$GROUP" ] || chgrp -R "$GROUP" /etc/dtms [ "$(stat -c "%G" /etc/dtms)" = "$GROUP" ] || chgrp "$GROUP" /etc/dtms
# set permissions on vctp config directory if not already done # set permissions on vctp config directory if not already done
[ "$(stat -c "%a" /etc/dtms)" = "774" ] || chmod -R 774 /etc/dtms [ "$(stat -c "%a" /etc/dtms)" = "750" ] || chmod 750 /etc/dtms
# create vctp data directory if it doesn't exist # create vctp data directory if it doesn't exist
[ -d /var/lib/vctp ] || mkdir -p /var/lib/vctp [ -d /var/lib/vctp ] || mkdir -p /var/lib/vctp