From ae3e2be89a514a064d1b1892afe2ac129114168c Mon Sep 17 00:00:00 2001 From: Nathan Coad Date: Fri, 17 Apr 2026 13:19:08 +1000 Subject: [PATCH] add auth support --- README.md | 51 +++ components/views/index.templ | 8 +- components/views/index_templ.go | 4 +- go.mod | 23 +- go.sum | 30 ++ internal/auth/jwt.go | 292 +++++++++++++++ internal/auth/jwt_test.go | 247 +++++++++++++ internal/auth/ldap.go | 354 +++++++++++++++++++ internal/auth/ldap_test.go | 39 ++ internal/settings/settings.go | 173 +++++++++ internal/settings/settings_redaction_test.go | 35 +- internal/settings/settings_strict_test.go | 128 +++++++ server/handler/auth.go | 146 ++++++++ server/handler/auth_test.go | 219 ++++++++++++ server/handler/method_guards_test.go | 7 + server/middleware/auth.go | 206 +++++++++++ server/middleware/auth_test.go | 201 +++++++++++ server/models/api_responses.go | 13 + server/router/router.go | 63 ++-- src/postinstall.sh | 109 ++++++ src/vctp.yml | 15 + todo.md | 156 ++++++++ 22 files changed, 2479 insertions(+), 40 deletions(-) create mode 100644 internal/auth/jwt.go create mode 100644 internal/auth/jwt_test.go create mode 100644 internal/auth/ldap.go create mode 100644 internal/auth/ldap_test.go create mode 100644 server/handler/auth.go create mode 100644 server/handler/auth_test.go create mode 100644 server/middleware/auth.go create mode 100644 server/middleware/auth_test.go create mode 100644 todo.md diff --git a/README.md b/README.md index ed7961b..098f0e8 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,39 @@ These optional flags are read from the process environment (for example via `/et - `DAILY_AGG_GO`: set to `1` (default in `src/vctp.default`) to use the Go daily aggregation path. - `MONTHLY_AGG_GO`: set to `1` (default in `src/vctp.default`) to use the Go monthly aggregation path. +## Authentication and Authorization +Authentication uses LDAP bind + JWT bearer tokens. + +Login flow: +1. Call `POST /api/auth/login` with JSON body: +```json +{ "username": "your-user", "password": "your-password" } +``` +2. On success, use returned `access_token` as: +```http +Authorization: Bearer +``` + +Auth modes: +- `settings.auth_mode: disabled`: middleware bypassed. +- `settings.auth_mode: optional`: protected endpoints accept missing token, but validate any provided token. +- `settings.auth_mode: required`: protected endpoints require a valid bearer token. + +Role policy: +- `viewer`: read/report APIs (for example `/api/report/*`, `/api/diagnostics/daily-creation`). +- `admin`: mutating/admin APIs (for example `/api/snapshots/*` mutating endpoints, `/api/event/*`, `/api/import/vm`, `/api/encrypt`, `/api/vcenters/cache/rebuild`). +- `admin` implies `viewer` access. + +Public endpoints: +- UI pages (`/`, `/vcenters`, `/snapshots/*`, `/vm/trace`) +- Swagger UI/docs (`/swagger`, `/swagger/`, `/swagger.json`) +- Metrics (`/metrics`) +- Login (`/api/auth/login`) + +Debug endpoints: +- `/debug/pprof/*` handlers are only registered when `settings.enable_pprof: true`. +- When enabled, they require an authenticated `admin` token. + ## Credential Encryption Lifecycle At startup, vCTP resolves `settings.vcenter_password` using this order: @@ -256,6 +289,24 @@ HTTP/TLS: - `settings.tls_cert_filename`: PEM certificate path (TLS mode) - `settings.tls_key_filename`: PEM private key path (TLS mode) +Authentication: +- `settings.auth_enabled`: enables LDAP/JWT auth components. +- `settings.auth_mode`: `disabled`, `optional`, or `required`. +- `settings.auth_jwt_signing_key`: base64 signing key for JWTs. + - RPM postinstall auto-generates and writes this key to `/etc/dtms/vctp.yml` if it is missing/empty. +- `settings.auth_token_lifespan_minutes`: JWT access token lifetime. +- `settings.auth_jwt_issuer`: expected JWT issuer. +- `settings.auth_jwt_audience`: expected JWT audience. +- `settings.auth_clock_skew_seconds`: allowed clock skew for token validation. +- `settings.auth_group_role_mappings`: map of LDAP group DN -> role (`viewer` or `admin`). +- `settings.ldap_groups`: optional allowlist of LDAP group DNs required for login. +- `settings.ldap_bind_address`: LDAP/LDAPS URL used for authentication. +- `settings.ldap_base_dn`: LDAP base DN for user/group lookups. +- `settings.ldap_trust_cert_file`: optional CA cert file for LDAP TLS. +- `settings.ldap_disable_validation`: disables LDAP TLS cert validation. +- `settings.ldap_insecure`: insecure LDAP TLS mode. +- `settings.enable_pprof`: enables `/debug/pprof/*` routes (still admin-gated). + vCenter: - `settings.encryption_key`: optional explicit key source for credential encryption/decryption. If unset, vCTP derives a host key from hardware/host identity. diff --git a/components/views/index.templ b/components/views/index.templ index 2b0a911..9dbf3a2 100644 --- a/components/views/index.templ +++ b/components/views/index.templ @@ -21,7 +21,7 @@ templ Index(info BuildInfo) {
vCTP Console

Chargeback Intelligence Dashboard

-

Point in time snapshots of consumption.

+

Point in time snapshots of consumption with LDAP/JWT protected API access.

Hourly Snapshots @@ -58,6 +58,12 @@ templ Index(info BuildInfo) {

Use fast vCenter totals views (Daily Aggregated and Hourly Detail 45d) and VM Trace views (Hourly Detail and Daily Aggregated) to move between long-range trends and granular timelines.

+

+ When authentication is enabled, obtain a token from POST /api/auth/login and send it as Authorization: Bearer <token>. +

+

+ Role policy: viewer role covers read/report APIs, and admin role covers mutating/admin APIs (admin also grants viewer access). UI pages and /metrics remain public. +

Snapshots and Reports

diff --git a/components/views/index_templ.go b/components/views/index_templ.go index b2d212b..6d1ab82 100644 --- a/components/views/index_templ.go +++ b/components/views/index_templ.go @@ -47,7 +47,7 @@ func Index(info BuildInfo) templ.Component { if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } - templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "
vCTP Console

Chargeback Intelligence Dashboard

Point in time snapshots of consumption.

Build Time

") + templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "

vCTP Console

Chargeback Intelligence Dashboard

Point in time snapshots of consumption with LDAP/JWT protected API access.

Build Time

") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } @@ -86,7 +86,7 @@ func Index(info BuildInfo) templ.Component { if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } - templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "

Overview

vCTP is a vSphere Chargeback Tracking Platform.

Use fast vCenter totals views (Daily Aggregated and Hourly Detail 45d) and VM Trace views (Hourly Detail and Daily Aggregated) to move between long-range trends and granular timelines.

Snapshots and Reports

Hourly snapshots capture inventory per vCenter (concurrency via hourly_snapshot_concurrency), then daily and monthly summaries are derived from those snapshots.

Hourly tracks: VM identity (InventoryId, Name, VmId, VmUuid, Vcenter, EventKey, CloudId), lifecycle (CreationTime, DeletionTime, SnapshotTime), placement (Datacenter, Cluster, Folder, ResourcePool), and sizing/state (VcpuCount, RamGB, ProvisionedDisk, PoweredOn, IsTemplate, SrmPlaceholder).

Daily tracks: SamplesPresent, TotalSamples, AvgIsPresent, AvgVcpuCount, AvgRamGB, AvgProvisionedDisk, PoolTinPct, PoolBronzePct, PoolSilverPct, PoolGoldPct, plus chargeback totals columns Tin, Bronze, Silver, Gold.

Monthly tracks: the same daily aggregate fields, with monthly values weighted by per-day sample volume so partial-day VMs and config changes stay proportional.

Snapshots are registered in snapshot_registry so regeneration via /api/snapshots/aggregate can locate the correct tables (fallback scanning is also supported).

vCenter totals pages are accelerated by compact cache tables: vcenter_latest_totals and vcenter_aggregate_totals.

VM Trace daily mode uses the vm_daily_rollup cache when available, and falls back to daily summary tables if needed.

Reports (XLSX with totals/charts) are generated automatically after hourly, daily, and monthly jobs and written to a reports directory.

Hourly totals are interval-based: each row represents [HH:00, HH+1:00) and uses the first snapshot at or after the hour end (including cross-day snapshots) to prorate VM presence.

Monthly aggregation reports include a Daily Totals sheet with full-day interval labels (YYYY-MM-DD to YYYY-MM-DD) and prorated totals.

Prorating and Aggregation

SamplesPresent is the count of snapshots in which the VM appears; TotalSamples is the count of unique snapshot times for that vCenter/day.

AvgIsPresent = SamplesPresent / TotalSamples (0 when TotalSamples is 0).

Daily AvgVcpuCount, AvgRamGB, and AvgProvisionedDisk are per-sample sums divided by TotalSamples (time-weighted).

Daily pool percentages (PoolTinPct/PoolBronzePct/PoolSilverPct/PoolGoldPct) use pool-hit counts divided by SamplesPresent.

Monthly aggregation converts each day into weighted sums using sample volume, then recomputes monthly averages and pool percentages from those weighted totals.

CreationTime is only set when vCenter provides it; otherwise it remains 0.

") + templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "

Overview

vCTP is a vSphere Chargeback Tracking Platform.

Use fast vCenter totals views (Daily Aggregated and Hourly Detail 45d) and VM Trace views (Hourly Detail and Daily Aggregated) to move between long-range trends and granular timelines.

When authentication is enabled, obtain a token from POST /api/auth/login and send it as Authorization: Bearer <token>.

Role policy: viewer role covers read/report APIs, and admin role covers mutating/admin APIs (admin also grants viewer access). UI pages and /metrics remain public.

Snapshots and Reports

Hourly snapshots capture inventory per vCenter (concurrency via hourly_snapshot_concurrency), then daily and monthly summaries are derived from those snapshots.

Hourly tracks: VM identity (InventoryId, Name, VmId, VmUuid, Vcenter, EventKey, CloudId), lifecycle (CreationTime, DeletionTime, SnapshotTime), placement (Datacenter, Cluster, Folder, ResourcePool), and sizing/state (VcpuCount, RamGB, ProvisionedDisk, PoweredOn, IsTemplate, SrmPlaceholder).

Daily tracks: SamplesPresent, TotalSamples, AvgIsPresent, AvgVcpuCount, AvgRamGB, AvgProvisionedDisk, PoolTinPct, PoolBronzePct, PoolSilverPct, PoolGoldPct, plus chargeback totals columns Tin, Bronze, Silver, Gold.

Monthly tracks: the same daily aggregate fields, with monthly values weighted by per-day sample volume so partial-day VMs and config changes stay proportional.

Snapshots are registered in snapshot_registry so regeneration via /api/snapshots/aggregate can locate the correct tables (fallback scanning is also supported).

vCenter totals pages are accelerated by compact cache tables: vcenter_latest_totals and vcenter_aggregate_totals.

VM Trace daily mode uses the vm_daily_rollup cache when available, and falls back to daily summary tables if needed.

Reports (XLSX with totals/charts) are generated automatically after hourly, daily, and monthly jobs and written to a reports directory.

Hourly totals are interval-based: each row represents [HH:00, HH+1:00) and uses the first snapshot at or after the hour end (including cross-day snapshots) to prorate VM presence.

Monthly aggregation reports include a Daily Totals sheet with full-day interval labels (YYYY-MM-DD to YYYY-MM-DD) and prorated totals.

Prorating and Aggregation

SamplesPresent is the count of snapshots in which the VM appears; TotalSamples is the count of unique snapshot times for that vCenter/day.

AvgIsPresent = SamplesPresent / TotalSamples (0 when TotalSamples is 0).

Daily AvgVcpuCount, AvgRamGB, and AvgProvisionedDisk are per-sample sums divided by TotalSamples (time-weighted).

Daily pool percentages (PoolTinPct/PoolBronzePct/PoolSilverPct/PoolGoldPct) use pool-hit counts divided by SamplesPresent.

Monthly aggregation converts each day into weighted sums using sample volume, then recomputes monthly averages and pool percentages from those weighted totals.

CreationTime is only set when vCenter provides it; otherwise it remains 0.

") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } diff --git a/go.mod b/go.mod index e2ca096..fcccbaf 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,12 @@ module vctp -go 1.26.1 +go 1.26.2 require ( github.com/a-h/templ v0.3.1001 - github.com/go-co-op/gocron/v2 v2.19.1 - github.com/jackc/pgx/v5 v5.8.0 + github.com/go-co-op/gocron/v2 v2.21.0 + github.com/go-ldap/ldap/v3 v3.4.13 + github.com/jackc/pgx/v5 v5.9.1 github.com/jmoiron/sqlx v1.4.0 github.com/pressly/goose/v3 v3.27.0 github.com/prometheus/client_golang v1.23.2 @@ -13,16 +14,18 @@ require ( github.com/vmware/govmomi v0.53.0 github.com/xuri/excelize/v2 v2.10.1 gopkg.in/yaml.v3 v3.0.1 - modernc.org/sqlite v1.47.0 + modernc.org/sqlite v1.48.2 ) require ( + github.com/Azure/go-ntlmssp v0.1.0 // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/PuerkitoBio/purell v1.2.1 // indirect github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-openapi/jsonpointer v0.22.5 // indirect github.com/go-openapi/jsonreference v0.21.5 // indirect github.com/go-openapi/spec v0.22.4 // indirect @@ -45,7 +48,7 @@ require ( github.com/jonboulle/clockwork v0.5.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.9.2 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-isatty v0.0.21 // indirect github.com/mfridman/interpolate v0.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect @@ -63,17 +66,17 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/crypto v0.49.0 // indirect + golang.org/x/crypto v0.50.0 // indirect golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect golang.org/x/mod v0.34.0 // indirect - golang.org/x/net v0.52.0 // indirect + golang.org/x/net v0.53.0 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.42.0 // indirect - golang.org/x/text v0.35.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect golang.org/x/tools v0.43.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - modernc.org/libc v1.70.0 // indirect + modernc.org/libc v1.72.0 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 1de9840..b69b5f9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,10 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= +github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A= +github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk= github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI= @@ -25,10 +29,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-co-op/gocron/v2 v2.19.0 h1:OKf2y6LXPs/BgBI2fl8PxUpNAI1DA9Mg+hSeGOS38OU= github.com/go-co-op/gocron/v2 v2.19.0/go.mod h1:5lEiCKk1oVJV39Zg7/YG10OnaVrDAV5GGR6O0663k6U= github.com/go-co-op/gocron/v2 v2.19.1 h1:B4iLeA0NB/2iO3EKQ7NfKn5KsQgZfjb2fkvoZJU3yBI= github.com/go-co-op/gocron/v2 v2.19.1/go.mod h1:5lEiCKk1oVJV39Zg7/YG10OnaVrDAV5GGR6O0663k6U= +github.com/go-co-op/gocron/v2 v2.21.0 h1:e1nt9AEFglarRH9/9y9q0V5sblwxlknpHPjttEajrwQ= +github.com/go-co-op/gocron/v2 v2.21.0/go.mod h1:5lEiCKk1oVJV39Zg7/YG10OnaVrDAV5GGR6O0663k6U= +github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= +github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= +github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ= +github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= @@ -116,6 +128,8 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= @@ -143,6 +157,8 @@ github.com/mailru/easyjson v0.9.2 h1:dX8U45hQsZpxd80nLvDGihsQ/OxlvTkVUXH2r/8cb2M github.com/mailru/easyjson v0.9.2/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= +github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= @@ -227,6 +243,8 @@ golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= @@ -248,6 +266,8 @@ golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= @@ -261,6 +281,8 @@ golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -270,6 +292,8 @@ golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= @@ -295,10 +319,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U= modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= modernc.org/ccgo/v4 v4.30.2 h1:4yPaaq9dXYXZ2V8s1UgrC3KIj580l2N4ClrLwnbv2so= modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= +modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU= modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= @@ -315,6 +341,8 @@ modernc.org/libc v1.68.0 h1:PJ5ikFOV5pwpW+VqCK1hKJuEWsonkIJhhIXyuF/91pQ= modernc.org/libc v1.68.0/go.mod h1:NnKCYeoYgsEqnY3PgvNgAeaJnso968ygU8Z0DxjoEc0= modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c= +modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= @@ -329,6 +357,8 @@ modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk= modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +modernc.org/sqlite v1.48.2 h1:5CnW4uP8joZtA0LedVqLbZV5GD7F/0x91AXeSyjoh5c= +modernc.org/sqlite v1.48.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go new file mode 100644 index 0000000..dba9bb3 --- /dev/null +++ b/internal/auth/jwt.go @@ -0,0 +1,292 @@ +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strings" + "time" +) + +const ( + jwtAlgHS256 = "HS256" + jwtTyp = "JWT" +) + +var ( + ErrInvalidJWTConfig = errors.New("invalid jwt config") + ErrInvalidJWTToken = errors.New("invalid jwt token") + ErrInvalidJWTClaims = errors.New("invalid jwt claims") + ErrExpiredJWTToken = errors.New("jwt token expired") + ErrNotYetValidJWTToken = errors.New("jwt token is not yet valid") +) + +type JWTConfig struct { + SigningKeyBase64 string + Issuer string + Audience string + TokenLifespan time.Duration + ClockSkew time.Duration +} + +type Claims struct { + Subject string `json:"sub"` + Roles []string `json:"roles,omitempty"` + Groups []string `json:"groups,omitempty"` + Issuer string `json:"iss"` + Audience string `json:"aud"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp"` + NotBefore int64 `json:"nbf"` + ID string `json:"jti"` +} + +type JWTService struct { + signingKey []byte + issuer string + audience string + tokenLifespan time.Duration + clockSkew time.Duration + now func() time.Time +} + +type jwtHeader struct { + Algorithm string `json:"alg"` + Type string `json:"typ"` +} + +func NewJWTService(cfg JWTConfig) (*JWTService, error) { + issuer := strings.TrimSpace(cfg.Issuer) + audience := strings.TrimSpace(cfg.Audience) + if issuer == "" { + return nil, fmt.Errorf("%w: issuer is required", ErrInvalidJWTConfig) + } + if audience == "" { + return nil, fmt.Errorf("%w: audience is required", ErrInvalidJWTConfig) + } + if cfg.TokenLifespan <= 0 { + return nil, fmt.Errorf("%w: token lifespan must be greater than zero", ErrInvalidJWTConfig) + } + if cfg.ClockSkew < 0 { + return nil, fmt.Errorf("%w: clock skew cannot be negative", ErrInvalidJWTConfig) + } + + signingKey, err := decodeBase64Key(strings.TrimSpace(cfg.SigningKeyBase64)) + if err != nil { + return nil, fmt.Errorf("%w: signing key must be valid base64", ErrInvalidJWTConfig) + } + if len(signingKey) == 0 { + return nil, fmt.Errorf("%w: signing key cannot be empty", ErrInvalidJWTConfig) + } + + return &JWTService{ + signingKey: signingKey, + issuer: issuer, + audience: audience, + tokenLifespan: cfg.TokenLifespan, + clockSkew: cfg.ClockSkew, + now: time.Now, + }, nil +} + +func (s *JWTService) IssueToken(subject string, roles []string, groups []string) (string, Claims, error) { + subject = strings.TrimSpace(subject) + if subject == "" { + return "", Claims{}, fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims) + } + + now := s.now().UTC() + claims := Claims{ + Subject: subject, + Roles: compactTrimmedStrings(roles), + Groups: compactTrimmedStrings(groups), + Issuer: s.issuer, + Audience: s.audience, + IssuedAt: now.Unix(), + ExpiresAt: now.Add(s.tokenLifespan).Unix(), + NotBefore: now.Unix(), + ID: newTokenID(), + } + if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil { + return "", Claims{}, err + } + + token, err := encodeSignedJWT(claims, s.signingKey) + if err != nil { + return "", Claims{}, err + } + return token, claims, nil +} + +func (s *JWTService) VerifyToken(token string) (Claims, error) { + header, claims, signingInput, signature, err := parseJWT(token) + if err != nil { + return Claims{}, err + } + if header.Algorithm != jwtAlgHS256 { + return Claims{}, fmt.Errorf("%w: unsupported algorithm", ErrInvalidJWTToken) + } + if header.Type != "" && header.Type != jwtTyp { + return Claims{}, fmt.Errorf("%w: invalid token type", ErrInvalidJWTToken) + } + + expected := signPayload(signingInput, s.signingKey) + if !hmac.Equal(signature, expected) { + return Claims{}, fmt.Errorf("%w: signature mismatch", ErrInvalidJWTToken) + } + + now := s.now().UTC() + if err := validateClaims(claims, now, s.issuer, s.audience, s.clockSkew); err != nil { + return Claims{}, err + } + return claims, nil +} + +func validateClaims(claims Claims, now time.Time, expectedIssuer string, expectedAudience string, clockSkew time.Duration) error { + if strings.TrimSpace(claims.Subject) == "" { + return fmt.Errorf("%w: subject is required", ErrInvalidJWTClaims) + } + if strings.TrimSpace(claims.ID) == "" { + return fmt.Errorf("%w: jti is required", ErrInvalidJWTClaims) + } + if claims.Issuer != expectedIssuer { + return fmt.Errorf("%w: issuer mismatch", ErrInvalidJWTClaims) + } + if claims.Audience != expectedAudience { + return fmt.Errorf("%w: audience mismatch", ErrInvalidJWTClaims) + } + if claims.IssuedAt <= 0 { + return fmt.Errorf("%w: iat is required", ErrInvalidJWTClaims) + } + if claims.NotBefore <= 0 { + return fmt.Errorf("%w: nbf is required", ErrInvalidJWTClaims) + } + if claims.ExpiresAt <= 0 { + return fmt.Errorf("%w: exp is required", ErrInvalidJWTClaims) + } + if claims.ExpiresAt <= claims.IssuedAt { + return fmt.Errorf("%w: exp must be greater than iat", ErrInvalidJWTClaims) + } + if claims.NotBefore > claims.ExpiresAt { + return fmt.Errorf("%w: nbf cannot be greater than exp", ErrInvalidJWTClaims) + } + + unixNow := now.Unix() + skewSeconds := int64(clockSkew / time.Second) + if claims.IssuedAt > unixNow+skewSeconds { + return fmt.Errorf("%w: iat is in the future", ErrInvalidJWTClaims) + } + if claims.NotBefore > unixNow+skewSeconds { + return ErrNotYetValidJWTToken + } + if unixNow > claims.ExpiresAt+skewSeconds { + return ErrExpiredJWTToken + } + return nil +} + +func encodeSignedJWT(claims Claims, signingKey []byte) (string, error) { + headerJSON, err := json.Marshal(jwtHeader{Algorithm: jwtAlgHS256, Type: jwtTyp}) + if err != nil { + return "", fmt.Errorf("marshal jwt header: %w", err) + } + claimsJSON, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshal jwt claims: %w", err) + } + + headerPart := base64.RawURLEncoding.EncodeToString(headerJSON) + payloadPart := base64.RawURLEncoding.EncodeToString(claimsJSON) + signingInput := headerPart + "." + payloadPart + signature := signPayload(signingInput, signingKey) + signaturePart := base64.RawURLEncoding.EncodeToString(signature) + + return signingInput + "." + signaturePart, nil +} + +func parseJWT(token string) (jwtHeader, Claims, string, []byte, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken) + } + if parts[0] == "" || parts[1] == "" || parts[2] == "" { + return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: malformed token", ErrInvalidJWTToken) + } + + headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header encoding", ErrInvalidJWTToken) + } + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid payload encoding", ErrInvalidJWTToken) + } + signature, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid signature encoding", ErrInvalidJWTToken) + } + + var header jwtHeader + if err := json.Unmarshal(headerBytes, &header); err != nil { + return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid header json", ErrInvalidJWTToken) + } + var claims Claims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return jwtHeader{}, Claims{}, "", nil, fmt.Errorf("%w: invalid claims json", ErrInvalidJWTToken) + } + + return header, claims, parts[0] + "." + parts[1], signature, nil +} + +func signPayload(payload string, signingKey []byte) []byte { + mac := hmac.New(sha256.New, signingKey) + mac.Write([]byte(payload)) + return mac.Sum(nil) +} + +func newTokenID() string { + raw := make([]byte, 16) + if _, err := rand.Read(raw); err != nil { + return fmt.Sprintf("fallback-%d", time.Now().UTC().UnixNano()) + } + return hex.EncodeToString(raw) +} + +func decodeBase64Key(value string) ([]byte, error) { + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + for _, encoding := range encodings { + decoded, err := encoding.DecodeString(value) + if err == nil { + return decoded, nil + } + } + return nil, errors.New("invalid base64 encoding") +} + +func compactTrimmedStrings(values []string) []string { + if len(values) == 0 { + return nil + } + out := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + out = append(out, trimmed) + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go new file mode 100644 index 0000000..6bfa906 --- /dev/null +++ b/internal/auth/jwt_test.go @@ -0,0 +1,247 @@ +package auth + +import ( + "encoding/base64" + "errors" + "strings" + "testing" + "time" +) + +func TestNewJWTServiceRejectsBadConfig(t *testing.T) { + _, err := NewJWTService(JWTConfig{ + SigningKeyBase64: "!!!", + Issuer: "vctp", + Audience: "vctp-api", + TokenLifespan: time.Hour, + ClockSkew: time.Minute, + }) + if err == nil { + t.Fatal("expected invalid base64 signing key to fail") + } + if !errors.Is(err, ErrInvalidJWTConfig) { + t.Fatalf("expected ErrInvalidJWTConfig, got: %v", err) + } +} + +func TestIssueAndVerifyTokenRoundTrip(t *testing.T) { + now := time.Unix(1_700_000_000, 0).UTC() + svc := mustJWTService(t) + svc.now = func() time.Time { return now } + + token, issuedClaims, err := svc.IssueToken("alice", []string{"admin", " viewer "}, []string{"cn=vctp-admins,dc=example,dc=com"}) + if err != nil { + t.Fatalf("IssueToken returned error: %v", err) + } + if token == "" { + t.Fatal("expected non-empty token") + } + if issuedClaims.Subject != "alice" { + t.Fatalf("expected subject alice, got %q", issuedClaims.Subject) + } + if issuedClaims.Issuer != "vctp" { + t.Fatalf("expected issuer vctp, got %q", issuedClaims.Issuer) + } + if issuedClaims.Audience != "vctp-api" { + t.Fatalf("expected audience vctp-api, got %q", issuedClaims.Audience) + } + if issuedClaims.IssuedAt != now.Unix() { + t.Fatalf("unexpected iat: %d", issuedClaims.IssuedAt) + } + if issuedClaims.NotBefore != now.Unix() { + t.Fatalf("unexpected nbf: %d", issuedClaims.NotBefore) + } + if issuedClaims.ExpiresAt != now.Add(2*time.Hour).Unix() { + t.Fatalf("unexpected exp: %d", issuedClaims.ExpiresAt) + } + if issuedClaims.ID == "" { + t.Fatal("expected jti to be populated") + } + + verifiedClaims, err := svc.VerifyToken(token) + if err != nil { + t.Fatalf("VerifyToken returned error: %v", err) + } + if verifiedClaims.Subject != issuedClaims.Subject { + t.Fatalf("subject mismatch: got %q want %q", verifiedClaims.Subject, issuedClaims.Subject) + } + if verifiedClaims.ID != issuedClaims.ID { + t.Fatalf("jti mismatch: got %q want %q", verifiedClaims.ID, issuedClaims.ID) + } +} + +func TestVerifyTokenRejectsInvalidSignature(t *testing.T) { + svc := mustJWTService(t) + svc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() } + + token, _, err := svc.IssueToken("alice", []string{"admin"}, nil) + if err != nil { + t.Fatalf("IssueToken returned error: %v", err) + } + + other := mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("a different secret key"))) + other.now = svc.now + + _, err = other.VerifyToken(token) + if err == nil { + t.Fatal("expected signature mismatch to fail") + } + if !errors.Is(err, ErrInvalidJWTToken) { + t.Fatalf("expected ErrInvalidJWTToken, got: %v", err) + } +} + +func TestVerifyTokenRejectsIssuerAndAudienceMismatch(t *testing.T) { + issuerSvc := mustJWTService(t) + issuerSvc.now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() } + token, _, err := issuerSvc.IssueToken("alice", nil, nil) + if err != nil { + t.Fatalf("IssueToken returned error: %v", err) + } + + wrongIssuer, err := NewJWTService(JWTConfig{ + SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")), + Issuer: "other-issuer", + Audience: "vctp-api", + TokenLifespan: 2 * time.Hour, + ClockSkew: time.Minute, + }) + if err != nil { + t.Fatalf("failed to create verifier with wrong issuer: %v", err) + } + wrongIssuer.now = issuerSvc.now + + _, err = wrongIssuer.VerifyToken(token) + if err == nil { + t.Fatal("expected issuer mismatch to fail") + } + if !errors.Is(err, ErrInvalidJWTClaims) { + t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err) + } + if !strings.Contains(strings.ToLower(err.Error()), "issuer") { + t.Fatalf("expected issuer mismatch error, got: %v", err) + } + + wrongAudience, err := NewJWTService(JWTConfig{ + SigningKeyBase64: base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key")), + Issuer: "vctp", + Audience: "other-audience", + TokenLifespan: 2 * time.Hour, + ClockSkew: time.Minute, + }) + if err != nil { + t.Fatalf("failed to create verifier with wrong audience: %v", err) + } + wrongAudience.now = issuerSvc.now + + _, err = wrongAudience.VerifyToken(token) + if err == nil { + t.Fatal("expected audience mismatch to fail") + } + if !errors.Is(err, ErrInvalidJWTClaims) { + t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err) + } + if !strings.Contains(strings.ToLower(err.Error()), "audience") { + t.Fatalf("expected audience mismatch error, got: %v", err) + } +} + +func TestVerifyTokenRejectsExpiredNotBeforeAndFutureIssuedAt(t *testing.T) { + base := time.Unix(1_700_000_000, 0).UTC() + svc := mustJWTService(t) + svc.now = func() time.Time { return base } + + token, claims, err := svc.IssueToken("alice", nil, nil) + if err != nil { + t.Fatalf("IssueToken returned error: %v", err) + } + + svc.now = func() time.Time { return base.Add(3 * time.Hour) } + _, err = svc.VerifyToken(token) + if !errors.Is(err, ErrExpiredJWTToken) { + t.Fatalf("expected ErrExpiredJWTToken, got: %v", err) + } + + notBeforeClaims := claims + notBeforeClaims.NotBefore = base.Add(10 * time.Minute).Unix() + notBeforeClaims.IssuedAt = base.Unix() + notBeforeClaims.ExpiresAt = base.Add(2 * time.Hour).Unix() + notBeforeClaims.ID = "forced-jti-1" + notBeforeToken, err := encodeSignedJWT(notBeforeClaims, svc.signingKey) + if err != nil { + t.Fatalf("failed to create token with future nbf: %v", err) + } + svc.now = func() time.Time { return base } + _, err = svc.VerifyToken(notBeforeToken) + if !errors.Is(err, ErrNotYetValidJWTToken) { + t.Fatalf("expected ErrNotYetValidJWTToken, got: %v", err) + } + + futureIatClaims := claims + futureIatClaims.IssuedAt = base.Add(20 * time.Minute).Unix() + futureIatClaims.NotBefore = base.Unix() + futureIatClaims.ExpiresAt = base.Add(3 * time.Hour).Unix() + futureIatClaims.ID = "forced-jti-2" + futureIatToken, err := encodeSignedJWT(futureIatClaims, svc.signingKey) + if err != nil { + t.Fatalf("failed to create token with future iat: %v", err) + } + _, err = svc.VerifyToken(futureIatToken) + if err == nil { + t.Fatal("expected future iat validation to fail") + } + if !errors.Is(err, ErrInvalidJWTClaims) { + t.Fatalf("expected ErrInvalidJWTClaims for future iat, got: %v", err) + } +} + +func TestVerifyTokenRejectsMissingJTI(t *testing.T) { + base := time.Unix(1_700_000_000, 0).UTC() + svc := mustJWTService(t) + svc.now = func() time.Time { return base } + + token, claims, err := svc.IssueToken("alice", nil, nil) + if err != nil { + t.Fatalf("IssueToken returned error: %v", err) + } + if token == "" { + t.Fatal("expected non-empty token") + } + + claims.ID = "" + customToken, err := encodeSignedJWT(claims, svc.signingKey) + if err != nil { + t.Fatalf("failed to create token without jti: %v", err) + } + + _, err = svc.VerifyToken(customToken) + if err == nil { + t.Fatal("expected missing jti token to fail") + } + if !errors.Is(err, ErrInvalidJWTClaims) { + t.Fatalf("expected ErrInvalidJWTClaims, got: %v", err) + } + if !strings.Contains(strings.ToLower(err.Error()), "jti") { + t.Fatalf("expected jti validation error, got: %v", err) + } +} + +func mustJWTService(t *testing.T) *JWTService { + t.Helper() + return mustJWTServiceWithKey(t, base64.StdEncoding.EncodeToString([]byte("super-secret-signing-key"))) +} + +func mustJWTServiceWithKey(t *testing.T, keyBase64 string) *JWTService { + t.Helper() + svc, err := NewJWTService(JWTConfig{ + SigningKeyBase64: keyBase64, + Issuer: "vctp", + Audience: "vctp-api", + TokenLifespan: 2 * time.Hour, + ClockSkew: time.Minute, + }) + if err != nil { + t.Fatalf("failed to create jwt service: %v", err) + } + return svc +} diff --git a/internal/auth/ldap.go b/internal/auth/ldap.go new file mode 100644 index 0000000..09b5aaf --- /dev/null +++ b/internal/auth/ldap.go @@ -0,0 +1,354 @@ +package auth + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "net/url" + "os" + "sort" + "strings" + "time" + + "github.com/go-ldap/ldap/v3" +) + +var ( + ErrInvalidLDAPConfig = errors.New("invalid ldap config") + ErrLDAPInvalidCredentials = errors.New("invalid ldap credentials") + ErrLDAPOperationFailed = errors.New("ldap operation failed") +) + +type LDAPConfig struct { + BindAddress string + BaseDN string + TrustCertFile string + DisableValidation bool + Insecure bool + DialTimeout time.Duration +} + +type LDAPIdentity struct { + Username string + UserDN string + Groups []string +} + +type LDAPAuthenticator struct { + bindAddress string + baseDN string + trustCertFile string + disableValidation bool + insecure bool + dialTimeout time.Duration +} + +func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) { + bindAddress := strings.TrimSpace(cfg.BindAddress) + baseDN := strings.TrimSpace(cfg.BaseDN) + trustCertFile := strings.TrimSpace(cfg.TrustCertFile) + + if bindAddress == "" { + return nil, fmt.Errorf("%w: bind address is required", ErrInvalidLDAPConfig) + } + if baseDN == "" { + return nil, fmt.Errorf("%w: base DN is required", ErrInvalidLDAPConfig) + } + if _, err := url.ParseRequestURI(bindAddress); err != nil { + return nil, fmt.Errorf("%w: bind address must be a valid URL: %v", ErrInvalidLDAPConfig, err) + } + + dialTimeout := cfg.DialTimeout + if dialTimeout <= 0 { + dialTimeout = 10 * time.Second + } + + return &LDAPAuthenticator{ + bindAddress: bindAddress, + baseDN: baseDN, + trustCertFile: trustCertFile, + disableValidation: cfg.DisableValidation, + insecure: cfg.Insecure, + dialTimeout: dialTimeout, + }, nil +} + +func (a *LDAPAuthenticator) AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (LDAPIdentity, error) { + username = strings.TrimSpace(username) + if username == "" || password == "" { + return LDAPIdentity{}, ErrLDAPInvalidCredentials + } + if err := ctxErr(ctx); err != nil { + return LDAPIdentity{}, err + } + + conn, err := a.connect() + if err != nil { + return LDAPIdentity{}, err + } + defer conn.Close() + + if err := conn.Bind(username, password); err != nil { + if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { + return LDAPIdentity{}, ErrLDAPInvalidCredentials + } + return LDAPIdentity{}, fmt.Errorf("%w: bind failed: %v", ErrLDAPOperationFailed, err) + } + if err := ctxErr(ctx); err != nil { + return LDAPIdentity{}, err + } + + identity := LDAPIdentity{ + Username: username, + UserDN: username, + } + + entry, err := a.lookupUserEntry(conn, username) + if err != nil { + return LDAPIdentity{}, err + } + if entry != nil { + if strings.TrimSpace(entry.DN) != "" { + identity.UserDN = entry.DN + } + if v := firstNonEmpty( + entry.GetAttributeValue("uid"), + entry.GetAttributeValue("sAMAccountName"), + entry.GetAttributeValue("userPrincipalName"), + entry.GetAttributeValue("cn"), + ); v != "" { + identity.Username = v + } + } + + groupSet := make(map[string]struct{}) + if entry != nil { + for _, groupDN := range entry.GetAttributeValues("memberOf") { + groupDN = strings.TrimSpace(groupDN) + if groupDN == "" { + continue + } + groupSet[groupDN] = struct{}{} + } + } + + groupEntries, err := conn.Search(ldap.NewSearchRequest( + a.baseDN, + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, + 0, + false, + fmt.Sprintf("(|(member=%s)(uniqueMember=%s)(memberUid=%s))", + ldap.EscapeFilter(identity.UserDN), + ldap.EscapeFilter(identity.UserDN), + ldap.EscapeFilter(username), + ), + []string{"dn"}, + nil, + )) + if err == nil { + for _, e := range groupEntries.Entries { + if dn := strings.TrimSpace(e.DN); dn != "" { + groupSet[dn] = struct{}{} + } + } + } + + identity.Groups = mapKeysSorted(groupSet) + return identity, nil +} + +func ResolveRoles(groupDNs []string, groupRoleMappings map[string]string) []string { + if len(groupDNs) == 0 || len(groupRoleMappings) == 0 { + return nil + } + + normalizedMappings := make(map[string]string, len(groupRoleMappings)) + for groupDN, role := range groupRoleMappings { + groupDN = normalizeDN(groupDN) + role = strings.ToLower(strings.TrimSpace(role)) + if groupDN == "" || role == "" { + continue + } + normalizedMappings[groupDN] = role + } + + roleSet := make(map[string]struct{}) + for _, groupDN := range groupDNs { + if role, ok := normalizedMappings[normalizeDN(groupDN)]; ok { + roleSet[role] = struct{}{} + } + } + return mapKeysSorted(roleSet) +} + +func HasAnyGroup(groupDNs []string, requiredGroupDNs []string) bool { + requiredGroupDNs = compactTrimmedStrings(requiredGroupDNs) + if len(requiredGroupDNs) == 0 { + return true + } + if len(groupDNs) == 0 { + return false + } + + required := make(map[string]struct{}, len(requiredGroupDNs)) + for _, groupDN := range requiredGroupDNs { + required[normalizeDN(groupDN)] = struct{}{} + } + for _, groupDN := range groupDNs { + if _, ok := required[normalizeDN(groupDN)]; ok { + return true + } + } + return false +} + +func (a *LDAPAuthenticator) connect() (*ldap.Conn, error) { + tlsConfig, err := a.buildTLSConfig() + if err != nil { + return nil, err + } + + parsedURL, err := url.Parse(a.bindAddress) + if err != nil { + return nil, fmt.Errorf("%w: invalid bind address: %v", ErrInvalidLDAPConfig, err) + } + + options := []ldap.DialOpt{ + ldap.DialWithDialer(&net.Dialer{Timeout: a.dialTimeout}), + ldap.DialWithTLSConfig(tlsConfig), + } + conn, err := ldap.DialURL(a.bindAddress, options...) + if err != nil { + return nil, fmt.Errorf("%w: unable to connect: %v", ErrLDAPOperationFailed, err) + } + conn.SetTimeout(a.dialTimeout) + + // For ldap://, opportunistically upgrade to TLS unless explicitly configured as insecure. + if parsedURL.Scheme == "ldap" && !a.insecure { + if err := conn.StartTLS(tlsConfig); err != nil { + conn.Close() + return nil, fmt.Errorf("%w: starttls failed: %v", ErrLDAPOperationFailed, err) + } + } + + return conn, nil +} + +func (a *LDAPAuthenticator) buildTLSConfig() (*tls.Config, error) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: a.insecure || a.disableValidation, //nolint:gosec // controlled by explicit config flags + } + + if a.trustCertFile == "" { + return tlsConfig, nil + } + + caPEM, err := os.ReadFile(a.trustCertFile) + if err != nil { + return nil, fmt.Errorf("%w: failed to read ldap trust cert file: %v", ErrInvalidLDAPConfig, err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(caPEM) { + return nil, fmt.Errorf("%w: ldap trust cert file contains no valid certificates", ErrInvalidLDAPConfig) + } + tlsConfig.RootCAs = roots + return tlsConfig, nil +} + +func (a *LDAPAuthenticator) lookupUserEntry(conn *ldap.Conn, username string) (*ldap.Entry, error) { + if looksLikeDN(username) { + searchRes, err := conn.Search(ldap.NewSearchRequest( + username, + ldap.ScopeBaseObject, + ldap.NeverDerefAliases, + 1, + 0, + false, + "(objectClass=*)", + []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, + nil, + )) + if err != nil { + return nil, fmt.Errorf("%w: unable to load user entry: %v", ErrLDAPOperationFailed, err) + } + if len(searchRes.Entries) == 0 { + return nil, nil + } + return searchRes.Entries[0], nil + } + + searchRes, err := conn.Search(ldap.NewSearchRequest( + a.baseDN, + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 2, + 0, + false, + fmt.Sprintf("(|(uid=%s)(cn=%s)(sAMAccountName=%s)(userPrincipalName=%s))", + ldap.EscapeFilter(username), + ldap.EscapeFilter(username), + ldap.EscapeFilter(username), + ldap.EscapeFilter(username), + ), + []string{"uid", "sAMAccountName", "userPrincipalName", "cn", "memberOf"}, + nil, + )) + if err != nil { + return nil, fmt.Errorf("%w: user lookup failed: %v", ErrLDAPOperationFailed, err) + } + if len(searchRes.Entries) == 0 { + return nil, nil + } + return searchRes.Entries[0], nil +} + +func normalizeDN(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func mapKeysSorted[K ~string, V any](m map[K]V) []K { + if len(m) == 0 { + return nil + } + out := make([]K, 0, len(m)) + for key := range m { + out = append(out, key) + } + sort.Slice(out, func(i, j int) bool { + return out[i] < out[j] + }) + return out +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} + +func looksLikeDN(value string) bool { + value = strings.TrimSpace(value) + return strings.Contains(value, "=") && strings.Contains(value, ",") +} + +func ctxErr(ctx context.Context) error { + if ctx == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} diff --git a/internal/auth/ldap_test.go b/internal/auth/ldap_test.go new file mode 100644 index 0000000..586c75b --- /dev/null +++ b/internal/auth/ldap_test.go @@ -0,0 +1,39 @@ +package auth + +import "testing" + +func TestResolveRoles(t *testing.T) { + roles := ResolveRoles( + []string{ + "cn=vctp-admins,ou=groups,dc=example,dc=com", + " CN=VCTP-VIEWERS,OU=GROUPS,DC=EXAMPLE,DC=COM ", + }, + map[string]string{ + "cn=vctp-admins,ou=groups,dc=example,dc=com": "admin", + "cn=vctp-viewers,ou=groups,dc=example,dc=com": "viewer", + }, + ) + + if len(roles) != 2 { + t.Fatalf("expected 2 roles, got %d (%#v)", len(roles), roles) + } + if roles[0] != "admin" || roles[1] != "viewer" { + t.Fatalf("unexpected resolved roles: %#v", roles) + } +} + +func TestHasAnyGroup(t *testing.T) { + groups := []string{ + "cn=vctp-admins,ou=groups,dc=example,dc=com", + } + + if !HasAnyGroup(groups, []string{" cn=vctp-admins,ou=groups,dc=example,dc=com "}) { + t.Fatal("expected group intersection to match") + } + if HasAnyGroup(groups, []string{"cn=vctp-operators,ou=groups,dc=example,dc=com"}) { + t.Fatal("expected no intersection") + } + if !HasAnyGroup(groups, nil) { + t.Fatal("expected empty required groups to allow") + } +} diff --git a/internal/settings/settings.go b/internal/settings/settings.go index 9e8dd1e..75ed68e 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -1,6 +1,7 @@ package settings import ( + "encoding/base64" "errors" "fmt" "log/slog" @@ -18,6 +19,20 @@ var ( postgresKVPasswordPattern = regexp.MustCompile(`(?i)(\bpassword\s*=\s*)(?:'[^']*'|"[^"]*"|[^\s]+)`) ) +const ( + authModeDisabled = "disabled" + authModeOptional = "optional" + authModeRequired = "required" + + authRoleAdmin = "admin" + authRoleViewer = "viewer" + + defaultAuthTokenLifespanMinutes = 120 + defaultAuthJWTIssuer = "vctp" + defaultAuthJWTAudience = "vctp-api" + defaultAuthClockSkewSeconds = 60 +) + type Settings struct { SettingsPath string Logger *slog.Logger @@ -50,6 +65,21 @@ type SettingsYML struct { VcenterPassword string `yaml:"vcenter_password"` VcenterInsecure bool `yaml:"vcenter_insecure"` EnableLegacyAPI bool `yaml:"enable_legacy_api"` + AuthEnabled bool `yaml:"auth_enabled"` + AuthMode string `yaml:"auth_mode"` + AuthJWTSigningKey string `yaml:"auth_jwt_signing_key"` + AuthTokenLifespanMinutes int `yaml:"auth_token_lifespan_minutes"` + AuthJWTIssuer string `yaml:"auth_jwt_issuer"` + AuthJWTAudience string `yaml:"auth_jwt_audience"` + AuthClockSkewSeconds int `yaml:"auth_clock_skew_seconds"` + AuthGroupRoleMappings map[string]string `yaml:"auth_group_role_mappings"` + LDAPGroups []string `yaml:"ldap_groups"` + LDAPBindAddress string `yaml:"ldap_bind_address"` + LDAPBaseDN string `yaml:"ldap_base_dn"` + LDAPTrustCertFile string `yaml:"ldap_trust_cert_file"` + LDAPDisableValidation bool `yaml:"ldap_disable_validation"` + LDAPInsecure bool `yaml:"ldap_insecure"` + EnablePprof bool `yaml:"enable_pprof"` VcenterEventPollingSeconds int `yaml:"vcenter_event_polling_seconds"` VcenterInventoryPollingSeconds int `yaml:"vcenter_inventory_polling_seconds"` VcenterInventorySnapshotSeconds int `yaml:"vcenter_inventory_snapshot_seconds"` @@ -112,6 +142,9 @@ func (s *Settings) ReadYMLSettings() error { if err := d.Decode(&settings); err != nil { return fmt.Errorf("unable to decode settings file : '%s'", err) } + if err := applyDefaultsAndValidateSettings(&settings); err != nil { + return fmt.Errorf("invalid settings file: %w", err) + } // Avoid logging sensitive fields (e.g., credentials). redacted := settings @@ -119,6 +152,9 @@ func (s *Settings) ReadYMLSettings() error { if redacted.Settings.EncryptionKey != "" { redacted.Settings.EncryptionKey = "REDACTED" } + if redacted.Settings.AuthJWTSigningKey != "" { + redacted.Settings.AuthJWTSigningKey = "REDACTED" + } if redacted.Settings.DatabaseURL != "" { redacted.Settings.DatabaseURL = redactDatabaseURL(redacted.Settings.DatabaseURL) } @@ -189,3 +225,140 @@ func secureSettingsFileMode(mode os.FileMode) os.FileMode { secured |= 0o600 return secured } + +func applyDefaultsAndValidateSettings(cfg *SettingsYML) error { + if cfg == nil { + return errors.New("settings config is nil") + } + s := &cfg.Settings + + s.AuthMode = strings.ToLower(strings.TrimSpace(s.AuthMode)) + if s.AuthMode == "" { + s.AuthMode = authModeDisabled + } + if s.AuthTokenLifespanMinutes == 0 { + s.AuthTokenLifespanMinutes = defaultAuthTokenLifespanMinutes + } + s.AuthJWTIssuer = strings.TrimSpace(s.AuthJWTIssuer) + if s.AuthJWTIssuer == "" { + s.AuthJWTIssuer = defaultAuthJWTIssuer + } + s.AuthJWTAudience = strings.TrimSpace(s.AuthJWTAudience) + if s.AuthJWTAudience == "" { + s.AuthJWTAudience = defaultAuthJWTAudience + } + if s.AuthClockSkewSeconds == 0 { + s.AuthClockSkewSeconds = defaultAuthClockSkewSeconds + } + s.AuthJWTSigningKey = strings.TrimSpace(s.AuthJWTSigningKey) + s.LDAPBindAddress = strings.TrimSpace(s.LDAPBindAddress) + s.LDAPBaseDN = strings.TrimSpace(s.LDAPBaseDN) + s.LDAPTrustCertFile = strings.TrimSpace(s.LDAPTrustCertFile) + s.LDAPGroups = compactTrimmedStrings(s.LDAPGroups) + + if !isValidAuthMode(s.AuthMode) { + return fmt.Errorf("settings.auth_mode must be one of %q, %q, %q", authModeDisabled, authModeOptional, authModeRequired) + } + if s.AuthTokenLifespanMinutes <= 0 { + return errors.New("settings.auth_token_lifespan_minutes must be greater than 0") + } + if s.AuthClockSkewSeconds < 0 { + return errors.New("settings.auth_clock_skew_seconds must be >= 0") + } + + if len(s.AuthGroupRoleMappings) > 0 { + normalized := make(map[string]string, len(s.AuthGroupRoleMappings)) + for groupDN, role := range s.AuthGroupRoleMappings { + groupDN = strings.TrimSpace(groupDN) + role = strings.ToLower(strings.TrimSpace(role)) + if groupDN == "" { + return errors.New("settings.auth_group_role_mappings contains an empty group DN key") + } + if !isValidAuthRole(role) { + return fmt.Errorf("settings.auth_group_role_mappings[%q] has unsupported role %q", groupDN, role) + } + normalized[groupDN] = role + } + s.AuthGroupRoleMappings = normalized + } + + if !s.AuthEnabled { + return nil + } + if s.AuthMode == authModeDisabled { + return errors.New("settings.auth_mode must be optional or required when settings.auth_enabled=true") + } + if s.AuthJWTSigningKey == "" { + return errors.New("settings.auth_jwt_signing_key is required when settings.auth_enabled=true") + } + decodedKey, err := decodeBase64(s.AuthJWTSigningKey) + if err != nil { + return errors.New("settings.auth_jwt_signing_key must be valid base64") + } + if len(decodedKey) == 0 { + return errors.New("settings.auth_jwt_signing_key cannot decode to an empty value") + } + if s.LDAPBindAddress == "" { + return errors.New("settings.ldap_bind_address is required when settings.auth_enabled=true") + } + if s.LDAPBaseDN == "" { + return errors.New("settings.ldap_base_dn is required when settings.auth_enabled=true") + } + if len(s.AuthGroupRoleMappings) == 0 { + return errors.New("settings.auth_group_role_mappings must define at least one mapping when settings.auth_enabled=true") + } + + return nil +} + +func isValidAuthMode(mode string) bool { + switch mode { + case authModeDisabled, authModeOptional, authModeRequired: + return true + default: + return false + } +} + +func isValidAuthRole(role string) bool { + switch role { + case authRoleAdmin, authRoleViewer: + return true + default: + return false + } +} + +func decodeBase64(value string) ([]byte, error) { + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + for _, encoding := range encodings { + decoded, err := encoding.DecodeString(value) + if err == nil { + return decoded, nil + } + } + return nil, errors.New("invalid base64 encoding") +} + +func compactTrimmedStrings(values []string) []string { + if len(values) == 0 { + return nil + } + out := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + out = append(out, trimmed) + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/settings/settings_redaction_test.go b/internal/settings/settings_redaction_test.go index 9661ab5..8800328 100644 --- a/internal/settings/settings_redaction_test.go +++ b/internal/settings/settings_redaction_test.go @@ -1,6 +1,13 @@ package settings -import "testing" +import ( + "bytes" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" +) func TestRedactDatabaseURL_PostgresURI(t *testing.T) { input := "postgres://vctp_user:Secr3tP%40ss@db-host:5432/vctp?sslmode=disable" @@ -27,3 +34,29 @@ func TestRedactDatabaseURL_UnchangedWhenNoPassword(t *testing.T) { t.Fatalf("expected input to remain unchanged\nwant: %s\ngot: %s", input, got) } } + +func TestReadYMLSettingsRedactsAuthJWTSigningKey(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "vctp.yml") + content := `settings: + auth_jwt_signing_key: "c2VjcmV0" +` + if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write settings file: %v", err) + } + + var output bytes.Buffer + logger := slog.New(slog.NewTextHandler(&output, &slog.HandlerOptions{Level: slog.LevelDebug})) + s := New(logger, settingsPath) + if err := s.ReadYMLSettings(); err != nil { + t.Fatalf("expected settings to load, got error: %v", err) + } + + logged := output.String() + if strings.Contains(logged, "c2VjcmV0") { + t.Fatalf("expected auth_jwt_signing_key to be redacted in logs, got log output: %s", logged) + } + if !strings.Contains(logged, "REDACTED") { + t.Fatalf("expected redacted marker in logs, got log output: %s", logged) + } +} diff --git a/internal/settings/settings_strict_test.go b/internal/settings/settings_strict_test.go index 4ee54cb..be58a16 100644 --- a/internal/settings/settings_strict_test.go +++ b/internal/settings/settings_strict_test.go @@ -31,6 +31,134 @@ func TestReadYMLSettingsRejectsUnknownField(t *testing.T) { } } +func TestReadYMLSettingsAppliesAuthDefaults(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "vctp.yml") + content := `settings: + log_level: "info" +` + if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write settings file: %v", err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + s := New(logger, settingsPath) + if err := s.ReadYMLSettings(); err != nil { + t.Fatalf("expected settings to load, got error: %v", err) + } + + got := s.Values.Settings + if got.AuthMode != authModeDisabled { + t.Fatalf("expected default auth_mode=%q, got %q", authModeDisabled, got.AuthMode) + } + if got.AuthTokenLifespanMinutes != defaultAuthTokenLifespanMinutes { + t.Fatalf("expected default auth_token_lifespan_minutes=%d, got %d", defaultAuthTokenLifespanMinutes, got.AuthTokenLifespanMinutes) + } + if got.AuthJWTIssuer != defaultAuthJWTIssuer { + t.Fatalf("expected default auth_jwt_issuer=%q, got %q", defaultAuthJWTIssuer, got.AuthJWTIssuer) + } + if got.AuthJWTAudience != defaultAuthJWTAudience { + t.Fatalf("expected default auth_jwt_audience=%q, got %q", defaultAuthJWTAudience, got.AuthJWTAudience) + } + if got.AuthClockSkewSeconds != defaultAuthClockSkewSeconds { + t.Fatalf("expected default auth_clock_skew_seconds=%d, got %d", defaultAuthClockSkewSeconds, got.AuthClockSkewSeconds) + } +} + +func TestReadYMLSettingsRejectsInvalidAuthMode(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "vctp.yml") + content := `settings: + auth_mode: "sometimes" +` + if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write settings file: %v", err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + s := New(logger, settingsPath) + err := s.ReadYMLSettings() + if err == nil { + t.Fatal("expected invalid auth_mode to fail") + } + if !strings.Contains(strings.ToLower(err.Error()), "auth_mode") { + t.Fatalf("expected error to mention auth_mode, got: %v", err) + } +} + +func TestReadYMLSettingsRejectsAuthEnabledWithoutSigningKey(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "vctp.yml") + content := `settings: + auth_enabled: true + auth_mode: "required" + ldap_bind_address: "ldaps://ldap.example.com:636" + ldap_base_dn: "dc=example,dc=com" + auth_group_role_mappings: + "cn=vctp-admin,ou=groups,dc=example,dc=com": "admin" +` + if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write settings file: %v", err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + s := New(logger, settingsPath) + err := s.ReadYMLSettings() + if err == nil { + t.Fatal("expected auth_enabled=true without signing key to fail") + } + if !strings.Contains(strings.ToLower(err.Error()), "auth_jwt_signing_key") { + t.Fatalf("expected error to mention auth_jwt_signing_key, got: %v", err) + } +} + +func TestReadYMLSettingsAcceptsValidAuthConfigAndNormalizesMappings(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "vctp.yml") + content := `settings: + auth_enabled: true + auth_mode: "REQUIRED" + auth_jwt_signing_key: "c2VjcmV0" + auth_token_lifespan_minutes: 90 + auth_jwt_issuer: " custom-issuer " + auth_jwt_audience: " custom-audience " + auth_clock_skew_seconds: 15 + ldap_bind_address: "ldaps://ldap.example.com:636" + ldap_base_dn: "dc=example,dc=com" + ldap_groups: + - " cn=vctp-viewers,ou=groups,dc=example,dc=com " + auth_group_role_mappings: + " cn=vctp-admins,ou=groups,dc=example,dc=com ": " ADMIN " + "cn=vctp-viewers,ou=groups,dc=example,dc=com": "viewer" +` + if err := os.WriteFile(settingsPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write settings file: %v", err) + } + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + s := New(logger, settingsPath) + if err := s.ReadYMLSettings(); err != nil { + t.Fatalf("expected valid auth config, got error: %v", err) + } + + got := s.Values.Settings + if got.AuthMode != authModeRequired { + t.Fatalf("expected normalized auth_mode=%q, got %q", authModeRequired, got.AuthMode) + } + if got.AuthJWTIssuer != "custom-issuer" { + t.Fatalf("expected trimmed auth_jwt_issuer, got %q", got.AuthJWTIssuer) + } + if got.AuthJWTAudience != "custom-audience" { + t.Fatalf("expected trimmed auth_jwt_audience, got %q", got.AuthJWTAudience) + } + if len(got.LDAPGroups) != 1 || got.LDAPGroups[0] != "cn=vctp-viewers,ou=groups,dc=example,dc=com" { + t.Fatalf("expected ldap_groups to be compacted+trimmed, got %#v", got.LDAPGroups) + } + if got.AuthGroupRoleMappings["cn=vctp-admins,ou=groups,dc=example,dc=com"] != authRoleAdmin { + t.Fatalf("expected admin mapping to normalize role to %q, got %#v", authRoleAdmin, got.AuthGroupRoleMappings) + } +} + func TestSecureSettingsFileMode(t *testing.T) { cases := []struct { name string diff --git a/server/handler/auth.go b/server/handler/auth.go new file mode 100644 index 0000000..3da4328 --- /dev/null +++ b/server/handler/auth.go @@ -0,0 +1,146 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + "vctp/internal/auth" + "vctp/server/models" +) + +const ( + authLoginFailureMessage = "invalid username or password" + authLoginRequestTimeout = 30 * time.Second +) + +type ldapAuthenticator interface { + AuthenticateAndFetchGroups(ctx context.Context, username string, password string) (auth.LDAPIdentity, error) +} + +type jwtService interface { + IssueToken(subject string, roles []string, groups []string) (string, auth.Claims, error) +} + +var newLDAPAuthenticator = func(cfg auth.LDAPConfig) (ldapAuthenticator, error) { + return auth.NewLDAPAuthenticator(cfg) +} + +var newJWTService = func(cfg auth.JWTConfig) (jwtService, error) { + return auth.NewJWTService(cfg) +} + +// AuthLogin authenticates a user against LDAP and returns a signed JWT. +// @Summary Login +// @Description Authenticates a username/password against LDAP and returns a signed access token. +// @Tags auth +// @Accept json +// @Produce json +// @Param payload body models.AuthLoginRequest true "Login credentials" +// @Success 200 {object} models.AuthLoginResponse "Login success" +// @Failure 400 {object} models.ErrorResponse "Invalid request" +// @Failure 401 {object} models.ErrorResponse "Invalid credentials" +// @Failure 500 {object} models.ErrorResponse "Server error" +// @Failure 503 {object} models.ErrorResponse "Authentication disabled" +// @Router /api/auth/login [post] +func (h *Handler) AuthLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if h == nil || h.Settings == nil || h.Settings.Values == nil { + writeJSONError(w, http.StatusInternalServerError, "authentication is not configured") + return + } + + cfg := h.Settings.Values.Settings + if !cfg.AuthEnabled { + writeJSONError(w, http.StatusServiceUnavailable, "authentication is disabled") + return + } + + var req models.AuthLoginRequest + if err := decodeJSONBody(w, r, &req); err != nil { + h.Logger.Error("unable to decode auth login request", "error", err) + writeJSONError(w, http.StatusBadRequest, "invalid JSON body") + return + } + username := strings.TrimSpace(req.Username) + password := req.Password + if username == "" || strings.TrimSpace(password) == "" { + writeJSONError(w, http.StatusBadRequest, "username and password are required") + return + } + + ldapAuth, err := newLDAPAuthenticator(auth.LDAPConfig{ + BindAddress: cfg.LDAPBindAddress, + BaseDN: cfg.LDAPBaseDN, + TrustCertFile: cfg.LDAPTrustCertFile, + DisableValidation: cfg.LDAPDisableValidation, + Insecure: cfg.LDAPInsecure, + DialTimeout: authLoginRequestTimeout, + }) + if err != nil { + h.Logger.Error("failed to initialize ldap authenticator", "error", err) + writeJSONError(w, http.StatusInternalServerError, "authentication service unavailable") + return + } + + ctx, cancel := withRequestTimeout(r, authLoginRequestTimeout) + defer cancel() + identity, err := ldapAuth.AuthenticateAndFetchGroups(ctx, username, password) + if err != nil { + if errors.Is(err, auth.ErrLDAPInvalidCredentials) { + h.Logger.Warn("auth login rejected", "username", username, "reason", "invalid_credentials") + writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage) + return + } + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + h.Logger.Warn("auth login ldap timeout", "username", username, "error", err) + writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage) + return + } + h.Logger.Warn("auth login ldap failure", "username", username, "error", err) + writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage) + return + } + + roles := auth.ResolveRoles(identity.Groups, cfg.AuthGroupRoleMappings) + if !auth.HasAnyGroup(identity.Groups, cfg.LDAPGroups) || len(roles) == 0 { + h.Logger.Warn("auth login rejected", "username", username, "reason", "group_or_role_denied") + writeJSONError(w, http.StatusUnauthorized, authLoginFailureMessage) + return + } + + jwtSvc, err := newJWTService(auth.JWTConfig{ + SigningKeyBase64: cfg.AuthJWTSigningKey, + Issuer: cfg.AuthJWTIssuer, + Audience: cfg.AuthJWTAudience, + TokenLifespan: time.Duration(cfg.AuthTokenLifespanMinutes) * time.Minute, + ClockSkew: time.Duration(cfg.AuthClockSkewSeconds) * time.Second, + }) + if err != nil { + h.Logger.Error("failed to initialize jwt service", "error", err) + writeJSONError(w, http.StatusInternalServerError, "authentication service unavailable") + return + } + + subject := strings.TrimSpace(identity.Username) + if subject == "" { + subject = username + } + token, claims, err := jwtSvc.IssueToken(subject, roles, identity.Groups) + if err != nil { + h.Logger.Error("failed to issue auth token", "username", username, "error", err) + writeJSONError(w, http.StatusInternalServerError, "failed to issue access token") + return + } + + h.Logger.Info("auth login successful", "username", subject, "roles", roles) + writeJSON(w, http.StatusOK, models.AuthLoginResponse{ + AccessToken: token, + ExpiresAt: claims.ExpiresAt, + TokenType: "Bearer", + }) +} diff --git a/server/handler/auth_test.go b/server/handler/auth_test.go new file mode 100644 index 0000000..eb6c903 --- /dev/null +++ b/server/handler/auth_test.go @@ -0,0 +1,219 @@ +package handler + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + "vctp/internal/auth" + "vctp/internal/settings" + "vctp/server/models" +) + +type stubLDAPAuthenticator struct { + identity auth.LDAPIdentity + err error +} + +func (s *stubLDAPAuthenticator) AuthenticateAndFetchGroups(_ context.Context, _ string, _ string) (auth.LDAPIdentity, error) { + return s.identity, s.err +} + +type stubJWTService struct { + token string + claims auth.Claims + err error +} + +func (s *stubJWTService) IssueToken(_ string, _ []string, _ []string) (string, auth.Claims, error) { + return s.token, s.claims, s.err +} + +func TestAuthLoginAuthDisabled(t *testing.T) { + h := &Handler{ + Logger: newTestLogger(), + Settings: &settings.Settings{Values: &settings.SettingsYML{}}, + } + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`)) + rr := httptest.NewRecorder() + h.AuthLogin(rr, req) + + if rr.Code != http.StatusServiceUnavailable { + t.Fatalf("expected status %d, got %d", http.StatusServiceUnavailable, rr.Code) + } +} + +func TestAuthLoginInvalidCredentials(t *testing.T) { + restoreFactories := swapAuthFactoriesForTest( + func(_ auth.LDAPConfig) (ldapAuthenticator, error) { + return &stubLDAPAuthenticator{err: auth.ErrLDAPInvalidCredentials}, nil + }, + func(_ auth.JWTConfig) (jwtService, error) { + return &stubJWTService{}, nil + }, + ) + defer restoreFactories() + + h := &Handler{ + Logger: newTestLogger(), + Settings: testAuthEnabledSettings(), + } + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`)) + rr := httptest.NewRecorder() + h.AuthLogin(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } + var payload models.ErrorResponse + if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if payload.Message != authLoginFailureMessage { + t.Fatalf("unexpected error message: %q", payload.Message) + } +} + +func TestAuthLoginRejectsUnmappedRoles(t *testing.T) { + restoreFactories := swapAuthFactoriesForTest( + func(_ auth.LDAPConfig) (ldapAuthenticator, error) { + return &stubLDAPAuthenticator{ + identity: auth.LDAPIdentity{ + Username: "alice", + Groups: []string{"cn=other-group,ou=groups,dc=example,dc=com"}, + }, + }, nil + }, + func(_ auth.JWTConfig) (jwtService, error) { + return &stubJWTService{}, nil + }, + ) + defer restoreFactories() + + h := &Handler{ + Logger: newTestLogger(), + Settings: testAuthEnabledSettings(), + } + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`)) + rr := httptest.NewRecorder() + h.AuthLogin(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } +} + +func TestAuthLoginSuccess(t *testing.T) { + restoreFactories := swapAuthFactoriesForTest( + func(_ auth.LDAPConfig) (ldapAuthenticator, error) { + return &stubLDAPAuthenticator{ + identity: auth.LDAPIdentity{ + Username: "alice", + UserDN: "cn=alice,ou=users,dc=example,dc=com", + Groups: []string{"cn=vctp-admins,ou=groups,dc=example,dc=com"}, + }, + }, nil + }, + func(_ auth.JWTConfig) (jwtService, error) { + return &stubJWTService{ + token: "issued-token", + claims: auth.Claims{ + ExpiresAt: time.Unix(1_700_000_000, 0).Unix(), + }, + }, nil + }, + ) + defer restoreFactories() + + h := &Handler{ + Logger: newTestLogger(), + Settings: testAuthEnabledSettings(), + } + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`)) + rr := httptest.NewRecorder() + h.AuthLogin(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, rr.Code, rr.Body.String()) + } + var payload models.AuthLoginResponse + if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if payload.AccessToken != "issued-token" { + t.Fatalf("unexpected token: %q", payload.AccessToken) + } + if payload.TokenType != "Bearer" { + t.Fatalf("unexpected token type: %q", payload.TokenType) + } +} + +func TestAuthLoginJWTFactoryFailure(t *testing.T) { + restoreFactories := swapAuthFactoriesForTest( + func(_ auth.LDAPConfig) (ldapAuthenticator, error) { + return &stubLDAPAuthenticator{ + identity: auth.LDAPIdentity{ + Username: "alice", + Groups: []string{"cn=vctp-admins,ou=groups,dc=example,dc=com"}, + }, + }, nil + }, + func(_ auth.JWTConfig) (jwtService, error) { + return nil, errors.New("jwt init failed") + }, + ) + defer restoreFactories() + + h := &Handler{ + Logger: newTestLogger(), + Settings: testAuthEnabledSettings(), + } + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"alice","password":"pw"}`)) + rr := httptest.NewRecorder() + h.AuthLogin(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, rr.Code) + } +} + +func testAuthEnabledSettings() *settings.Settings { + cfg := &settings.Settings{Values: &settings.SettingsYML{}} + cfg.Values.Settings.AuthEnabled = true + cfg.Values.Settings.AuthMode = "required" + cfg.Values.Settings.AuthJWTSigningKey = base64.StdEncoding.EncodeToString([]byte("test-signing-key")) + cfg.Values.Settings.AuthTokenLifespanMinutes = 120 + cfg.Values.Settings.AuthJWTIssuer = "vctp" + cfg.Values.Settings.AuthJWTAudience = "vctp-api" + cfg.Values.Settings.AuthClockSkewSeconds = 60 + cfg.Values.Settings.LDAPBindAddress = "ldaps://ldap.example.com:636" + cfg.Values.Settings.LDAPBaseDN = "dc=example,dc=com" + cfg.Values.Settings.AuthGroupRoleMappings = map[string]string{ + "cn=vctp-admins,ou=groups,dc=example,dc=com": "admin", + } + return cfg +} + +func swapAuthFactoriesForTest( + ldapFactory func(auth.LDAPConfig) (ldapAuthenticator, error), + jwtFactory func(auth.JWTConfig) (jwtService, error), +) func() { + origLDAPFactory := newLDAPAuthenticator + origJWTFactory := newJWTService + newLDAPAuthenticator = ldapFactory + newJWTService = jwtFactory + return func() { + newLDAPAuthenticator = origLDAPFactory + newJWTService = origJWTFactory + } +} diff --git a/server/handler/method_guards_test.go b/server/handler/method_guards_test.go index 2e98309..cab8d32 100644 --- a/server/handler/method_guards_test.go +++ b/server/handler/method_guards_test.go @@ -18,6 +18,13 @@ func TestMutatingHandlersRejectWrongMethod(t *testing.T) { path string call func(*Handler, *httptest.ResponseRecorder, *http.Request) }{ + { + name: "auth login", + path: "/api/auth/login", + call: func(h *Handler, rr *httptest.ResponseRecorder, req *http.Request) { + h.AuthLogin(rr, req) + }, + }, { name: "snapshot force hourly", path: "/api/snapshots/hourly/force", diff --git a/server/middleware/auth.go b/server/middleware/auth.go new file mode 100644 index 0000000..c760ad9 --- /dev/null +++ b/server/middleware/auth.go @@ -0,0 +1,206 @@ +package middleware + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "strings" + "time" + "vctp/internal/auth" + "vctp/internal/settings" +) + +const ( + authModeDisabled = "disabled" + authModeOptional = "optional" + authModeRequired = "required" + + RoleViewer = "viewer" + RoleAdmin = "admin" +) + +type authClaimsContextKey struct{} + +// ClaimsFromContext returns validated JWT claims injected by RequireAuth. +func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) { + if ctx == nil { + return auth.Claims{}, false + } + claims, ok := ctx.Value(authClaimsContextKey{}).(auth.Claims) + return claims, ok +} + +// RequireAuth validates Bearer tokens according to settings.auth_mode: +// - disabled: auth bypassed +// - optional: missing token allowed, provided token must be valid +// - required: token required and must be valid +func RequireAuth(logger *slog.Logger, cfg *settings.Settings) Handler { + if logger == nil { + logger = slog.Default() + } + if cfg == nil || cfg.Values == nil { + return defaultHandler + } + + values := cfg.Values.Settings + mode := strings.ToLower(strings.TrimSpace(values.AuthMode)) + if mode == "" { + mode = authModeDisabled + } + if !values.AuthEnabled || mode == authModeDisabled { + return defaultHandler + } + + jwtSvc, err := auth.NewJWTService(auth.JWTConfig{ + SigningKeyBase64: values.AuthJWTSigningKey, + Issuer: values.AuthJWTIssuer, + Audience: values.AuthJWTAudience, + TokenLifespan: time.Duration(values.AuthTokenLifespanMinutes) * time.Minute, + ClockSkew: time.Duration(values.AuthClockSkewSeconds) * time.Second, + }) + if err != nil { + logger.Error("auth middleware init failed", "error", err) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSONAuthError(w, http.StatusServiceUnavailable, "authentication service unavailable") + }) + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, hasHeader, parseOK := extractBearerToken(r.Header.Get("Authorization")) + if !hasHeader { + if mode == authModeRequired { + writeJSONAuthError(w, http.StatusUnauthorized, "missing bearer token") + return + } + next.ServeHTTP(w, r) + return + } + if !parseOK { + writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token") + return + } + + claims, err := jwtSvc.VerifyToken(token) + if err != nil { + logger.Warn("auth middleware token validation failed", "path", r.URL.Path, "error", err) + writeJSONAuthError(w, http.StatusUnauthorized, "invalid bearer token") + return + } + + ctx := context.WithValue(r.Context(), authClaimsContextKey{}, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// RequireRole checks JWT claims injected by RequireAuth and enforces role policy. +// Returns: +// - 401 when no validated auth claims are present +// - 403 when claims are present but missing required role(s) +func RequireRole(requiredRoles ...string) Handler { + normalizedRequired := normalizeRoles(requiredRoles) + if len(normalizedRequired) == 0 { + return defaultHandler + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, ok := ClaimsFromContext(r.Context()) + if !ok { + writeJSONAuthError(w, http.StatusUnauthorized, "missing authentication context") + return + } + if !hasAnyRequiredRole(claims.Roles, normalizedRequired) { + writeJSONAuthError(w, http.StatusForbidden, "insufficient role") + return + } + next.ServeHTTP(w, r) + }) + } +} + +func writeJSONAuthError(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "ERROR", + "message": message, + }) +} + +func extractBearerToken(headerValue string) (token string, hasHeader bool, ok bool) { + headerValue = strings.TrimSpace(headerValue) + if headerValue == "" { + return "", false, false + } + parts := strings.Fields(headerValue) + if len(parts) != 2 { + return "", true, false + } + if !strings.EqualFold(parts[0], "Bearer") { + return "", true, false + } + token = strings.TrimSpace(parts[1]) + if token == "" { + return "", true, false + } + return token, true, true +} + +func normalizeRoles(roles []string) []string { + if len(roles) == 0 { + return nil + } + seen := make(map[string]struct{}, len(roles)) + out := make([]string, 0, len(roles)) + for _, role := range roles { + role = strings.ToLower(strings.TrimSpace(role)) + if role == "" { + continue + } + if _, ok := seen[role]; ok { + continue + } + seen[role] = struct{}{} + out = append(out, role) + } + if len(out) == 0 { + return nil + } + return out +} + +func hasAnyRequiredRole(userRoles []string, requiredRoles []string) bool { + if len(requiredRoles) == 0 { + return true + } + userRoleSet := make(map[string]struct{}, len(userRoles)) + for _, role := range normalizeRoles(userRoles) { + userRoleSet[role] = struct{}{} + } + if len(userRoleSet) == 0 { + return false + } + for _, requiredRole := range requiredRoles { + if hasRoleWithHierarchy(userRoleSet, requiredRole) { + return true + } + } + return false +} + +func hasRoleWithHierarchy(userRoleSet map[string]struct{}, requiredRole string) bool { + if _, ok := userRoleSet[requiredRole]; ok { + return true + } + // Admin implies viewer access. + if requiredRole == RoleViewer { + _, ok := userRoleSet[RoleAdmin] + return ok + } + return false +} diff --git a/server/middleware/auth_test.go b/server/middleware/auth_test.go new file mode 100644 index 0000000..d69cd85 --- /dev/null +++ b/server/middleware/auth_test.go @@ -0,0 +1,201 @@ +package middleware + +import ( + "encoding/base64" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + "vctp/internal/auth" + "vctp/internal/settings" +) + +func TestRequireAuthRequiredRejectsMissingToken(t *testing.T) { + cfg := testAuthSettings(true, authModeRequired) + mw := RequireAuth(testLogger(), cfg) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } +} + +func TestRequireAuthRequiredAcceptsValidTokenAndInjectsClaims(t *testing.T) { + cfg := testAuthSettings(true, authModeRequired) + token := mustTokenForConfig(t, cfg, "alice") + mw := RequireAuth(testLogger(), cfg) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + req.Header.Set("Authorization", "Bearer "+token) + + var gotSubject string + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, ok := ClaimsFromContext(r.Context()) + if !ok { + t.Fatal("expected claims in request context") + } + gotSubject = claims.Subject + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) + } + if gotSubject != "alice" { + t.Fatalf("expected subject alice, got %q", gotSubject) + } +} + +func TestRequireAuthOptionalAllowsNoToken(t *testing.T) { + cfg := testAuthSettings(true, authModeOptional) + mw := RequireAuth(testLogger(), cfg) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })).ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("expected status %d, got %d", http.StatusNoContent, rr.Code) + } +} + +func TestRequireAuthOptionalRejectsInvalidProvidedToken(t *testing.T) { + cfg := testAuthSettings(true, authModeOptional) + mw := RequireAuth(testLogger(), cfg) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + req.Header.Set("Authorization", "Bearer not-a-jwt") + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })).ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } +} + +func TestRequireAuthDisabledBypassesMiddleware(t *testing.T) { + cfg := testAuthSettings(false, authModeDisabled) + mw := RequireAuth(testLogger(), cfg) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + })).ServeHTTP(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("expected status %d, got %d", http.StatusAccepted, rr.Code) + } +} + +func TestRequireRoleRejectsMissingAuthContext(t *testing.T) { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } +} + +func TestRequireRoleRejectsInsufficientRole(t *testing.T) { + cfg := testAuthSettings(true, authModeRequired) + token := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer}) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + req.Header.Set("Authorization", "Bearer "+token) + + protected := RequireAuth(testLogger(), cfg)( + RequireRole(RoleAdmin)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })), + ) + protected.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("expected status %d, got %d", http.StatusForbidden, rr.Code) + } +} + +func TestRequireRoleViewerAllowsViewerAndAdmin(t *testing.T) { + cfg := testAuthSettings(true, authModeRequired) + viewerToken := mustTokenForConfigWithRoles(t, cfg, "alice", []string{RoleViewer}) + adminToken := mustTokenForConfigWithRoles(t, cfg, "bob", []string{RoleAdmin}) + + for name, token := range map[string]string{ + "viewer": viewerToken, + "admin": adminToken, + } { + t.Run(name, func(t *testing.T) { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sensitive", nil) + req.Header.Set("Authorization", "Bearer "+token) + + protected := RequireAuth(testLogger(), cfg)( + RequireRole(RoleViewer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })), + ) + protected.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) + } + }) + } +} + +func mustTokenForConfig(t *testing.T, cfg *settings.Settings, subject string) string { + t.Helper() + return mustTokenForConfigWithRoles(t, cfg, subject, []string{"admin"}) +} + +func mustTokenForConfigWithRoles(t *testing.T, cfg *settings.Settings, subject string, roles []string) string { + t.Helper() + svc, err := auth.NewJWTService(auth.JWTConfig{ + SigningKeyBase64: cfg.Values.Settings.AuthJWTSigningKey, + Issuer: cfg.Values.Settings.AuthJWTIssuer, + Audience: cfg.Values.Settings.AuthJWTAudience, + TokenLifespan: time.Duration(cfg.Values.Settings.AuthTokenLifespanMinutes) * time.Minute, + ClockSkew: time.Duration(cfg.Values.Settings.AuthClockSkewSeconds) * time.Second, + }) + if err != nil { + t.Fatalf("failed to create jwt service: %v", err) + } + token, _, err := svc.IssueToken(subject, roles, nil) + if err != nil { + t.Fatalf("failed to issue token: %v", err) + } + return token +} + +func testAuthSettings(enabled bool, mode string) *settings.Settings { + cfg := &settings.Settings{Values: &settings.SettingsYML{}} + cfg.Values.Settings.AuthEnabled = enabled + cfg.Values.Settings.AuthMode = mode + cfg.Values.Settings.AuthJWTSigningKey = base64.StdEncoding.EncodeToString([]byte("middleware-test-signing-key")) + cfg.Values.Settings.AuthTokenLifespanMinutes = 120 + cfg.Values.Settings.AuthJWTIssuer = "vctp" + cfg.Values.Settings.AuthJWTAudience = "vctp-api" + cfg.Values.Settings.AuthClockSkewSeconds = 60 + return cfg +} + +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} diff --git a/server/models/api_responses.go b/server/models/api_responses.go index 0d18987..a783f09 100644 --- a/server/models/api_responses.go +++ b/server/models/api_responses.go @@ -17,6 +17,19 @@ type ErrorResponse struct { Message string `json:"message"` } +// AuthLoginRequest represents login payload for LDAP/JWT authentication. +type AuthLoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// AuthLoginResponse represents successful auth login response. +type AuthLoginResponse struct { + AccessToken string `json:"access_token"` + ExpiresAt int64 `json:"expires_at"` + TokenType string `json:"token_type"` +} + // SnapshotMigrationStats mirrors the snapshot registry migration stats payload. type SnapshotMigrationStats struct { HourlyRenamed int `json:"HourlyRenamed"` diff --git a/server/router/router.go b/server/router/router.go index 2d4fd6a..9dcf3c2 100644 --- a/server/router/router.go +++ b/server/router/router.go @@ -29,6 +29,14 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st } mux := http.NewServeMux() + requireAuth := middleware.RequireAuth(logger, settings) + withAuthRole := func(next http.HandlerFunc, roles ...string) http.Handler { + wrapped := http.Handler(http.HandlerFunc(next)) + if len(roles) > 0 { + wrapped = middleware.RequireRole(roles...)(wrapped) + } + return requireAuth(wrapped) + } reportsDir := settings.Values.Settings.ReportsDir if reportsDir == "" { @@ -44,37 +52,38 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st mux.Handle("/favicon-32x32.png", middleware.CacheMiddleware(http.FileServer(http.FS(dist.AssetsDir)))) mux.Handle("/reports/", http.StripPrefix("/reports/", http.FileServer(http.Dir(filepath.Clean(reportsDir))))) mux.HandleFunc("/", h.Home) - mux.HandleFunc("/api/event/vm/create", h.VmCreateEvent) - mux.HandleFunc("/api/event/vm/modify", h.VmModifyEvent) - mux.HandleFunc("/api/event/vm/move", h.VmMoveEvent) - mux.HandleFunc("/api/event/vm/delete", h.VmDeleteEvent) - mux.HandleFunc("/api/import/vm", h.VmImport) + mux.Handle("/api/event/vm/create", withAuthRole(h.VmCreateEvent, middleware.RoleAdmin)) + mux.Handle("/api/event/vm/modify", withAuthRole(h.VmModifyEvent, middleware.RoleAdmin)) + mux.Handle("/api/event/vm/move", withAuthRole(h.VmMoveEvent, middleware.RoleAdmin)) + mux.Handle("/api/event/vm/delete", withAuthRole(h.VmDeleteEvent, middleware.RoleAdmin)) + mux.Handle("/api/import/vm", withAuthRole(h.VmImport, middleware.RoleAdmin)) // Use this when we need to manually remove a VM from the database to clean up - mux.HandleFunc("/api/inventory/vm/delete", h.VmCleanup) + mux.Handle("/api/inventory/vm/delete", withAuthRole(h.VmCleanup, middleware.RoleAdmin)) // add missing data to VMs - mux.HandleFunc("/api/inventory/vm/update", h.VmUpdateDetails) + mux.Handle("/api/inventory/vm/update", withAuthRole(h.VmUpdateDetails, middleware.RoleAdmin)) // Legacy/maintenance endpoints are gated by settings.enable_legacy_api. - mux.HandleFunc("/api/cleanup/updates", h.UpdateCleanup) + mux.Handle("/api/cleanup/updates", withAuthRole(h.UpdateCleanup, middleware.RoleAdmin)) //mux.HandleFunc("/api/cleanup/vcenter", h.VcCleanup) - mux.HandleFunc("/api/report/inventory", h.InventoryReportDownload) - mux.HandleFunc("/api/report/updates", h.UpdateReportDownload) - mux.HandleFunc("/api/report/snapshot", h.SnapshotReportDownload) - mux.HandleFunc("/api/snapshots/aggregate", h.SnapshotAggregateForce) - mux.HandleFunc("/api/snapshots/hourly/force", h.SnapshotForceHourly) - mux.HandleFunc("/api/snapshots/migrate", h.SnapshotMigrate) - mux.HandleFunc("/api/snapshots/repair", h.SnapshotRepair) - mux.HandleFunc("/api/snapshots/repair/all", h.SnapshotRepairSuite) - mux.HandleFunc("/api/snapshots/regenerate-hourly-reports", h.SnapshotRegenerateHourlyReports) - mux.HandleFunc("/api/diagnostics/daily-creation", h.DailyCreationDiagnostics) + mux.Handle("/api/report/inventory", withAuthRole(h.InventoryReportDownload, middleware.RoleViewer)) + mux.Handle("/api/report/updates", withAuthRole(h.UpdateReportDownload, middleware.RoleViewer)) + mux.Handle("/api/report/snapshot", withAuthRole(h.SnapshotReportDownload, middleware.RoleViewer)) + mux.Handle("/api/snapshots/aggregate", withAuthRole(h.SnapshotAggregateForce, middleware.RoleAdmin)) + mux.Handle("/api/snapshots/hourly/force", withAuthRole(h.SnapshotForceHourly, middleware.RoleAdmin)) + mux.Handle("/api/snapshots/migrate", withAuthRole(h.SnapshotMigrate, middleware.RoleAdmin)) + mux.Handle("/api/snapshots/repair", withAuthRole(h.SnapshotRepair, middleware.RoleAdmin)) + mux.Handle("/api/snapshots/repair/all", withAuthRole(h.SnapshotRepairSuite, middleware.RoleAdmin)) + mux.Handle("/api/snapshots/regenerate-hourly-reports", withAuthRole(h.SnapshotRegenerateHourlyReports, middleware.RoleAdmin)) + mux.Handle("/api/diagnostics/daily-creation", withAuthRole(h.DailyCreationDiagnostics, middleware.RoleViewer)) + mux.HandleFunc("/api/auth/login", h.AuthLogin) mux.HandleFunc("/vm/trace", h.VmTrace) mux.HandleFunc("/vcenters", h.VcenterList) mux.HandleFunc("/vcenters/totals", h.VcenterTotals) mux.HandleFunc("/vcenters/totals/daily", h.VcenterTotalsDaily) mux.HandleFunc("/vcenters/totals/hourly", h.VcenterTotalsHourlyDetailed) - mux.HandleFunc("/api/vcenters/cache/rebuild", h.VcenterCacheRebuild) + mux.Handle("/api/vcenters/cache/rebuild", withAuthRole(h.VcenterCacheRebuild, middleware.RoleAdmin)) mux.HandleFunc("/metrics", h.Metrics) mux.HandleFunc("/snapshots/hourly", h.SnapshotHourlyList) @@ -82,7 +91,7 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st mux.HandleFunc("/snapshots/monthly", h.SnapshotMonthlyList) // endpoint for encrypting vcenter credential - mux.HandleFunc("/api/encrypt", h.EncryptData) + mux.Handle("/api/encrypt", withAuthRole(h.EncryptData, middleware.RoleAdmin)) // serve swagger related components from the embedded fs swaggerSub, err := fs.Sub(swaggerUI, "swagger-ui-dist") @@ -100,12 +109,14 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st w.Write(swaggerSpec) }))) - // Register pprof handlers - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + // Register pprof handlers only when enabled, and gate them behind admin auth. + if settings.Values.Settings.EnablePprof { + mux.Handle("/debug/pprof/", withAuthRole(pprof.Index, middleware.RoleAdmin)) + mux.Handle("/debug/pprof/cmdline", withAuthRole(pprof.Cmdline, middleware.RoleAdmin)) + mux.Handle("/debug/pprof/profile", withAuthRole(pprof.Profile, middleware.RoleAdmin)) + mux.Handle("/debug/pprof/symbol", withAuthRole(pprof.Symbol, middleware.RoleAdmin)) + mux.Handle("/debug/pprof/trace", withAuthRole(pprof.Trace, middleware.RoleAdmin)) + } return middleware.NewLoggingMiddleware(logger, mux) } diff --git a/src/postinstall.sh b/src/postinstall.sh index 45f0fec..2a8bc39 100644 --- a/src/postinstall.sh +++ b/src/postinstall.sh @@ -98,7 +98,116 @@ merge_missing_settings_from_rpmnew() { rm -f "$src_pairs" "$target_pairs" "$missing_lines" "$merged_file" } +generate_random_auth_jwt_key() { + if command -v openssl >/dev/null 2>&1; then + openssl rand -base64 32 2>/dev/null | tr -d '\n' + return 0 + fi + if command -v base64 >/dev/null 2>&1; then + head -c 32 /dev/urandom | base64 | tr -d '\n' + return 0 + fi + return 1 +} + +auth_jwt_key_is_set() { + local target="$1" + [ -f "$target" ] || return 1 + + local extracted + extracted="$(awk ' + /^settings:[[:space:]]*$/ { in_settings = 1; next } + in_settings && /^[^[:space:]]/ { in_settings = 0 } + in_settings && $0 ~ /^ auth_jwt_signing_key:[[:space:]]*/ { + value = $0 + sub(/^[[:space:]]*auth_jwt_signing_key:[[:space:]]*/, "", value) + sub(/[[:space:]]*#.*/, "", value) + gsub(/^[[:space:]]+|[[:space:]]+$/, "", value) + gsub(/^["'\'']|["'\'']$/, "", value) + print value + exit + } + ' "$target")" + + [ -n "$extracted" ] +} + +set_auth_jwt_key() { + local target="$1" + local jwt_key="$2" + local updated_file + + [ -f "$target" ] || return 1 + updated_file="$(mktemp /tmp/vctp-postinstall-authkey-XXXXXX)" || return 1 + + if awk -v new_key="$jwt_key" ' + BEGIN { in_settings = 0; replaced = 0; inserted = 0 } + { + if ($0 ~ /^settings:[[:space:]]*$/) { + in_settings = 1 + print + next + } + + if (in_settings && $0 ~ /^ auth_jwt_signing_key:[[:space:]]*/) { + print " auth_jwt_signing_key: \"" new_key "\"" + replaced = 1 + next + } + + if (in_settings && $0 ~ /^[^[:space:]]/) { + if (!replaced && !inserted) { + print " auth_jwt_signing_key: \"" new_key "\"" + inserted = 1 + } + in_settings = 0 + } + + print + } + END { + if (in_settings && !replaced && !inserted) { + print " auth_jwt_signing_key: \"" new_key "\"" + } + } + ' "$target" > "$updated_file"; then + cat "$updated_file" > "$target" + rm -f "$updated_file" + return 0 + fi + + rm -f "$updated_file" + return 1 +} + +ensure_auth_jwt_key_in_settings() { + local target="$1" + [ -f "$target" ] || return 0 + + if auth_jwt_key_is_set "$target"; then + return 0 + fi + + local generated + generated="$(generate_random_auth_jwt_key)" || { + echo "vCTP postinstall: unable to generate auth_jwt_signing_key (openssl/base64 unavailable)" + return 0 + } + + if [ -z "$generated" ]; then + echo "vCTP postinstall: unable to generate auth_jwt_signing_key (empty key)" + return 0 + fi + + if set_auth_jwt_key "$target" "$generated"; then + echo "vCTP postinstall: generated and set settings.auth_jwt_signing_key in ${target}" + else + echo "vCTP postinstall: failed to write settings.auth_jwt_signing_key in ${target}" + fi +} + merge_missing_settings_from_rpmnew "$TARGET_CFG" "$SOURCE_CFG" || : +ensure_auth_jwt_key_in_settings "$TARGET_CFG" || : if [ -f "$TARGET_CFG" ]; then chown root:dtms "$TARGET_CFG" || : diff --git a/src/vctp.yml b/src/vctp.yml index cd0f34c..af49ba1 100644 --- a/src/vctp.yml +++ b/src/vctp.yml @@ -19,6 +19,21 @@ settings: vcenter_insecure: false # Legacy API endpoints are disabled by default. enable_legacy_api: false + auth_enabled: false + auth_mode: "disabled" # disabled | optional | required + auth_jwt_signing_key: "" # base64-encoded key, required when auth_enabled=true + auth_token_lifespan_minutes: 120 + auth_jwt_issuer: "vctp" + auth_jwt_audience: "vctp-api" + auth_clock_skew_seconds: 60 + auth_group_role_mappings: {} + ldap_groups: [] + ldap_bind_address: "" + ldap_base_dn: "" + ldap_trust_cert_file: "" + ldap_disable_validation: false + ldap_insecure: false + enable_pprof: false # Deprecated (ignored): legacy event poller vcenter_event_polling_seconds: 0 # Deprecated (ignored): legacy inventory poller diff --git a/todo.md b/todo.md new file mode 100644 index 0000000..31fa32e --- /dev/null +++ b/todo.md @@ -0,0 +1,156 @@ +# VCTP Auth Design TODO (LDAP + JWT) + +## 1. Goal +Add authentication and authorization to VCTP for sensitive endpoints, using the LDAP bind + JWT pattern from `cbs2` as a reference, adapted to VCTP's `net/http` architecture. + +## 2. Reference Findings from `cbs2` + +### 2.1 Where auth lives in `cbs2` +- Login handler: `/tmp/cbs2/server/handler/auth.go` +- JWT middleware: `/tmp/cbs2/server/handler/middlewares.go` +- Token utilities: `/tmp/cbs2/utils/token/token.go` +- LDAP bind + group lookup: `/tmp/cbs2/internal/ldap/ldap.go` +- Route protection split (public vs protected): `/tmp/cbs2/server/router/router.go` +- Settings fields for LDAP/JWT: `/tmp/cbs2/internal/settings/settings.go` and `/tmp/cbs2/src/cbs.yml` + +### 2.2 Pattern to reuse +- LDAP username/password bind for authentication. +- LDAP group membership check for authorization at login. +- Signed JWT access token with expiry. +- Middleware that validates `Authorization: Bearer `. +- Router-level grouping for protected routes. + +### 2.3 Things to improve (do not copy as-is) +- Avoid hardcoded fallback JWT secret (present in `cbs2` middleware). +- Avoid logging sensitive token/key values. +- Avoid weak/ambiguous claim model; use explicit issuer/audience/subject/exp/iat. +- Keep strict method/endpoint policy in one place instead of ad-hoc checks. + +## 3. Proposed VCTP Auth Architecture + +### 3.1 New packages/modules +1. `internal/auth/ldap.go` +- LDAP setup/connection helpers. +- `AuthenticateAndFetchGroups(username, password) ([]string, error)`. + +2. `internal/auth/jwt.go` +- JWT issue and verify. +- Claims struct with: + - `sub` (username) + - `roles` (derived roles) + - `groups` (optional raw LDAP groups) + - `iss`, `aud`, `iat`, `exp`, `nbf`, `jti` + +3. `server/middleware/auth.go` +- `RequireAuth(...)` for token validation. +- `RequireRole(...)` for endpoint authorization. +- Context injection for user identity and roles. + +4. `server/handler/auth.go` +- `POST /api/auth/login` +- Optional `GET /api/auth/me` for debugging/whoami. + +### 3.2 Route protection model +Define a central policy map in router startup (single source of truth): + +- Public (no auth): + - `/assets/*`, `/favicon*`, `/swagger*` (optional decision), `/` (optional decision) +- Authenticated read-only (viewer role): + - `/vcenters*`, `/snapshots/*`, `/vm/trace`, `/api/report/*`, `/metrics` (optional decision) +- Privileged write/admin (admin role): + - `/api/snapshots/*` mutating endpoints + - `/api/vcenters/cache/rebuild` + - `/api/encrypt` + - legacy mutating endpoints (`/api/event/*`, `/api/import/vm`, `/api/inventory/vm/*`, `/api/cleanup/*`) +- Debug endpoints (`/debug/pprof/*`): + - disabled by default via config + - if enabled, require admin + +## 4. VCTP Settings Additions +Add under `settings:` in `internal/settings/settings.go` and `src/vctp.yml`: + +1. `auth_enabled: false` +2. `auth_mode: "disabled"` (`disabled|optional|required`) +3. `auth_jwt_signing_key: ""` (base64-encoded, required when auth enabled) +4. `auth_token_lifespan_minutes: 120` +5. `auth_jwt_issuer: "vctp"` +6. `auth_jwt_audience: "vctp-api"` +7. `auth_clock_skew_seconds: 60` +8. `ldap_groups: []` +9. `ldap_bind_address: ""` +10. `ldap_base_dn: ""` +11. `ldap_trust_cert_file: ""` +12. `ldap_disable_validation: false` +13. `ldap_insecure: false` +14. `enable_pprof: false` + +## 5. Role Mapping +Use LDAP group DN mapping to roles (config-driven): + +- `auth_group_role_mappings` map/list, e.g.: + - group DN -> `admin` + - group DN -> `viewer` + +Default behavior: +- No mapped group: deny login. +- Multiple matches: union roles. + +## 6. API Contract + +### 6.1 Login +`POST /api/auth/login` +- Request: `{ "username": "...", "password": "..." }` +- Success: `{ "access_token": "...", "expires_at": , "token_type": "Bearer" }` +- Failure: `401` with generic message (no user/group leakage) + +### 6.2 Auth header +- `Authorization: Bearer ` + +### 6.3 Error behavior +- Missing/invalid token: `401` +- Valid token but insufficient role: `403` + +## 7. Rollout Plan + +### Phase 1: Foundation +1. Implement settings fields and validation. +2. Implement LDAP and JWT services. +3. Add `/api/auth/login`. +4. Add unit tests for token generation/validation and LDAP auth abstraction. + +### Phase 2: Middleware + policy +1. Add auth middleware for `net/http`. +2. Protect sensitive routes via central policy map. +3. Keep `auth_mode=optional` initially for safe rollout. + +### Phase 3: Enforce + harden +1. Switch production to `auth_mode=required`. +2. Gate/disable pprof by config. +3. Add structured audit logs for auth events. +4. Update Swagger security docs and README. + +## 8. Validation and Tests +1. Unit tests +- JWT: expired token, wrong signature, wrong issuer/audience, clock skew. +- Role extraction and mapping. + +2. Integration tests (handler/middleware) +- Unauthenticated access blocked on protected endpoints. +- Viewer can read but cannot mutate. +- Admin can mutate. + +3. Regression checks +- Existing legacy endpoint gating (`enable_legacy_api`) still behaves correctly after auth layering. + +## 9. Open Decisions +1. Should `/metrics` require auth in your deployment? No there is no need for auth for this endpoint +2. Should UI pages (`/`, `/vcenters`, `/snapshots/*`, `/vm/trace`) require login or stay public? These should stay public +3. Should Swagger UI be public, authenticated, or disabled in production? These should stay public +4. Do you want short-lived access tokens only, or access + refresh token flow? Short lived access tokens please, 2 hours is good enough + +## 10. Implementation Notes for VCTP +1. Reuse `cbs2` LDAP flow shape, but avoid its fallback secret behavior. +2. Keep all secret material redacted in logs. +3. Validate required auth settings at startup when `auth_enabled=true`. +4. Prefer fail-closed: if auth is enabled and misconfigured, abort startup. +