+24
-1
@@ -1,6 +1,7 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -104,13 +105,14 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st
|
||||
} else {
|
||||
mux.Handle("/swagger/", middleware.CacheMiddleware(http.StripPrefix("/swagger/", http.FileServer(http.FS(swaggerSub)))))
|
||||
}
|
||||
swaggerRuntimeSpec := buildRuntimeSwaggerSpec(logger, swaggerSpec, settings.Values.Settings.BindDisableTLS)
|
||||
mux.HandleFunc("/swagger", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/swagger/", http.StatusPermanentRedirect)
|
||||
})
|
||||
mux.Handle("/swagger.json", middleware.CacheMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(swaggerSpec)
|
||||
_, _ = w.Write(swaggerRuntimeSpec)
|
||||
})))
|
||||
|
||||
// Register pprof handlers only when enabled, and gate them behind admin auth.
|
||||
@@ -124,3 +126,24 @@ func New(logger *slog.Logger, database db.Database, buildTime string, sha1ver st
|
||||
|
||||
return middleware.NewLoggingMiddleware(logger, mux)
|
||||
}
|
||||
|
||||
func buildRuntimeSwaggerSpec(logger *slog.Logger, baseSpec []byte, bindDisableTLS bool) []byte {
|
||||
scheme := "https"
|
||||
if bindDisableTLS {
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(baseSpec, &parsed); err != nil {
|
||||
logger.Warn("failed to parse embedded swagger spec; serving original", "error", err)
|
||||
return baseSpec
|
||||
}
|
||||
parsed["schemes"] = []string{scheme}
|
||||
|
||||
updated, err := json.Marshal(parsed)
|
||||
if err != nil {
|
||||
logger.Warn("failed to render runtime swagger spec; serving original", "error", err)
|
||||
return baseSpec
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"vctp/version"
|
||||
@@ -112,3 +114,51 @@ func TestStaticResourcesAreCacheableInReleaseMode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerJSONDefaultsToHTTPSWhenTLSEnabled(t *testing.T) {
|
||||
cfg := testRouterSettings(t, false)
|
||||
cfg.Values.Settings.BindDisableTLS = false
|
||||
app := testRouter(t, cfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/swagger.json", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
app.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
var spec struct {
|
||||
Schemes []string `json:"schemes"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &spec); err != nil {
|
||||
t.Fatalf("failed to decode swagger spec: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(spec.Schemes, []string{"https"}) {
|
||||
t.Fatalf("unexpected schemes: got %v want %v", spec.Schemes, []string{"https"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerJSONDefaultsToHTTPWhenTLSDisabled(t *testing.T) {
|
||||
cfg := testRouterSettings(t, false)
|
||||
cfg.Values.Settings.BindDisableTLS = true
|
||||
app := testRouter(t, cfg)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/swagger.json", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
app.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
var spec struct {
|
||||
Schemes []string `json:"schemes"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &spec); err != nil {
|
||||
t.Fatalf("failed to decode swagger spec: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(spec.Schemes, []string{"http"}) {
|
||||
t.Fatalf("unexpected schemes: got %v want %v", spec.Schemes, []string{"http"})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user