324 lines
9.7 KiB
Go
324 lines
9.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"smt/controllers"
|
|
"smt/middlewares"
|
|
"smt/models"
|
|
"smt/utils"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/joho/godotenv"
|
|
)
|
|
|
|
// For build numbers, from https://blog.kowalczyk.info/article/vEja/embedding-build-number-in-go-executable.html
|
|
var sha1ver string // sha1 revision used to build the program
|
|
var buildTime string // when the executable was built
|
|
|
|
type Replacements map[string]string
|
|
|
|
var replacements = make(Replacements)
|
|
|
|
//go:embed www/*
|
|
var staticContent embed.FS
|
|
|
|
func getAllFilenames(efs *embed.FS) (files []string, err error) {
|
|
if err := fs.WalkDir(efs, ".", func(path string, d fs.DirEntry, err error) error {
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
|
|
files = append(files, path)
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return files, nil
|
|
}
|
|
|
|
// staticFileServer serves files from the provided fs.FS
|
|
func staticFileServer(content embed.FS) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Get the requested file name
|
|
fileName := c.Request.URL.Path
|
|
|
|
// Root request should load index.htm
|
|
if len(fileName) == 1 {
|
|
//log.Printf("staticFileServer replacing root request with index.html")
|
|
fileName = "www/index.html"
|
|
} else {
|
|
//fileName = strings.TrimLeft(fileName, "/")
|
|
fileName = "www" + fileName
|
|
}
|
|
|
|
//log.Printf("staticFileServer attempting to load filename '%s'\n", fileName)
|
|
|
|
// Try to open the file from the embedded FS
|
|
file, err := content.Open(fileName)
|
|
if err != nil {
|
|
log.Printf("staticFileServer error opening '%s' : '%s'\n", fileName, err)
|
|
c.String(http.StatusNotFound, "File not found")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
// Serve the file contents
|
|
fileStat, _ := file.Stat()
|
|
data := make([]byte, fileStat.Size())
|
|
_, err = file.Read(data)
|
|
if err != nil {
|
|
c.String(http.StatusInternalServerError, "Error reading file")
|
|
return
|
|
}
|
|
|
|
// parse the file and perform text replacements as necessary
|
|
for key, element := range replacements {
|
|
//log.Printf("Searching for '%s' to replace\n", key)
|
|
data = bytes.Replace(data, []byte(key), []byte(element), -1)
|
|
}
|
|
|
|
// Set the proper Content-Type header based on file extension
|
|
c.Header("Content-Type", http.DetectContentType(data))
|
|
c.Data(http.StatusOK, "", data)
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
// These replacements are for the embedded html generated from README.md
|
|
replacements["{SHA1VER}"] = sha1ver
|
|
replacements["{BUILDTIME}"] = buildTime
|
|
|
|
// Load data from environment file
|
|
envFilename := utils.GetFilePath(".env")
|
|
err := godotenv.Load(envFilename)
|
|
if err != nil {
|
|
panic("Error loading .env file")
|
|
}
|
|
|
|
// Open connection to logfile
|
|
// From https://ispycode.com/GO/Logging/Logging-to-multiple-destinations
|
|
logFile := os.Getenv("LOG_FILE")
|
|
if logFile == "" {
|
|
logFile = "./smt.log"
|
|
}
|
|
logfileWriter, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
|
if err != nil {
|
|
fmt.Printf("Unable to write logfile '%s' : '%s'\n", logFile, err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
log.SetOutput(logfileWriter)
|
|
log.Printf("SMT starting execution. Built on %s from sha1 %s\n", buildTime, sha1ver)
|
|
|
|
/*
|
|
// for debugging, list all the files that we embedded at compile time
|
|
files, err := getAllFilenames(&staticContent)
|
|
if err != nil {
|
|
log.Printf("Unable to access embedded fs : '%s'\n", err)
|
|
}
|
|
|
|
for i := range files {
|
|
log.Printf("Embedded file : '%s'\n", files[i])
|
|
}
|
|
*/
|
|
|
|
// Initiate connection to sqlite and make sure our schema is up to date
|
|
models.ConnectDatabase()
|
|
|
|
// Set secrets key from .env file
|
|
keyString := os.Getenv("SECRETS_KEY")
|
|
|
|
if keyString != "" {
|
|
// Key was defined in environment variable, let the models package know our secrets key
|
|
log.Println("Found secret key in environment variable")
|
|
models.ReceiveKey(keyString)
|
|
}
|
|
|
|
// If LDAP configuration is defined, prepare connection
|
|
ldapServer := os.Getenv("LDAP_BIND_ADDRESS")
|
|
if ldapServer != "" {
|
|
models.LdapSetup()
|
|
}
|
|
|
|
// Create context that listens for the interrupt signal from the OS.
|
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
|
defer stop()
|
|
|
|
// Creates a router without any middleware by default
|
|
gin.SetMode(gin.ReleaseMode)
|
|
router := gin.New()
|
|
|
|
// Global middleware
|
|
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
|
|
// By default gin.DefaultWriter = os.Stdout
|
|
|
|
// log to file only
|
|
gin.DefaultWriter = logfileWriter
|
|
|
|
// log to file and stdout
|
|
//gin.DefaultWriter = io.MultiWriter(logfileWriter, os.Stdout)
|
|
router.Use(gin.Logger())
|
|
|
|
// Recovery middleware recovers from any panics and writes a 500 if there was one.
|
|
router.Use(gin.Recovery())
|
|
|
|
// Set some options for TLS
|
|
tlsConfig := &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256},
|
|
PreferServerCipherSuites: true,
|
|
InsecureSkipVerify: true,
|
|
CipherSuites: []uint16{
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
},
|
|
}
|
|
|
|
// Determine bind IP
|
|
bindIP := os.Getenv("BIND_IP")
|
|
if bindIP == "" {
|
|
bindIP = utils.GetOutboundIP().String()
|
|
}
|
|
// Determine bind port
|
|
bindPort := os.Getenv("BIND_PORT")
|
|
if bindPort == "" {
|
|
bindIP = "8443"
|
|
}
|
|
bindAddress := fmt.Sprint(bindIP, ":", bindPort)
|
|
log.Printf("Will listen on address 'https://%s'\n", bindAddress)
|
|
|
|
// Get file names for TLS cert/key
|
|
tlsCertFilename := os.Getenv("TLS_CERT_FILE")
|
|
if tlsCertFilename != "" {
|
|
tlsCertFilename = utils.GetFilePath(tlsCertFilename)
|
|
} else {
|
|
tlsCertFilename = "./cert.pem"
|
|
}
|
|
|
|
tlsKeyFilename := os.Getenv("TLS_KEY_FILE")
|
|
if tlsKeyFilename != "" {
|
|
tlsKeyFilename = utils.GetFilePath(tlsKeyFilename)
|
|
} else {
|
|
tlsKeyFilename = "./privkey.pem"
|
|
}
|
|
|
|
// Generate certificate if required
|
|
if !(utils.FileExists(tlsCertFilename) && utils.FileExists(tlsKeyFilename)) {
|
|
log.Printf("Specified TLS certificate (%s) or private key (%s) do not exist.\n", tlsCertFilename, tlsKeyFilename)
|
|
utils.GenerateCerts(tlsCertFilename, tlsKeyFilename)
|
|
}
|
|
|
|
srv := &http.Server{
|
|
Addr: bindAddress,
|
|
Handler: router,
|
|
TLSConfig: tlsConfig,
|
|
}
|
|
|
|
// Serve the embedded HTML file if no other routes match
|
|
router.NoRoute(staticFileServer(staticContent))
|
|
|
|
// Register our routes
|
|
public := router.Group("/api")
|
|
public.POST("/login", controllers.Login)
|
|
|
|
// API calls that only an administrator can make
|
|
adminOnly := router.Group("/api/admin")
|
|
adminOnly.Use(middlewares.JwtAuthAdminMiddleware())
|
|
|
|
// User functions for admin
|
|
adminOnly.POST("/user/delete", controllers.DeleteUser)
|
|
adminOnly.POST("/user/register", controllers.AddUser) // TODO deprecate
|
|
adminOnly.POST("/user/add", controllers.AddUser)
|
|
adminOnly.GET("/users", controllers.GetUsers)
|
|
// TODO
|
|
//adminOnly.POST("/user/update", controllers.UpdateUser)
|
|
|
|
// Group functions for admin
|
|
adminOnly.GET("/groups", controllers.GetGroupsHandler)
|
|
adminOnly.POST("/group/add", controllers.AddGroupHandler)
|
|
// TODO
|
|
//adminOnly.POST("/group/update", controllers.UpdateGroup)
|
|
adminOnly.POST("/group/delete", controllers.DeleteGroupHandler)
|
|
|
|
// Permission functions for admin
|
|
adminOnly.GET("/permissions", controllers.GetPermissionsHandler)
|
|
adminOnly.POST("/permission/add", controllers.AddPermissionHandler)
|
|
adminOnly.POST("/permission/delete", controllers.DeletePermissionHandler)
|
|
|
|
// Safe functions for admin
|
|
adminOnly.GET("/safe/listall", controllers.GetAllSafesHandler)
|
|
adminOnly.POST("/safe/add", controllers.AddSafeHandler)
|
|
adminOnly.POST("/safe/delete", controllers.DeleteSafeHandler)
|
|
|
|
// Other functions for admin
|
|
adminOnly.POST("/unlock", controllers.Unlock)
|
|
|
|
// Get secrets
|
|
secretRoutes := router.Group("/api/secret")
|
|
secretRoutes.Use(middlewares.JwtAuthMiddleware())
|
|
secretRoutes.POST("/retrieve", controllers.RetrieveSecret) // TODO deprecate, replace retrieve with get
|
|
secretRoutes.POST("/get", controllers.RetrieveSecret)
|
|
secretRoutes.GET("/list", controllers.ListSecrets)
|
|
//secretRoutes.POST("/retrieveMultiple", controllers.RetrieveMultpleSecrets) // TODO is this still required?
|
|
secretRoutes.POST("/store", controllers.StoreSecret) // TODO deprecate, replace store with add
|
|
secretRoutes.POST("/add", controllers.StoreSecret)
|
|
|
|
secretRoutes.POST("/update", controllers.UpdateSecret)
|
|
// TODO
|
|
secretRoutes.POST("/delete", controllers.DeleteSecret)
|
|
|
|
// Get Safes (only those user allowed to access)
|
|
safeRoutes := router.Group("/api/safe")
|
|
safeRoutes.Use(middlewares.JwtAuthMiddleware())
|
|
safeRoutes.GET("/list", controllers.GetSafesHandler)
|
|
|
|
// Support parameters in path
|
|
// See https://gin-gonic.com/docs/examples/param-in-path/
|
|
secretRoutes.GET("/retrieve/name/:devicename", controllers.RetrieveSecretByDevicename)
|
|
secretRoutes.GET("/retrieve/category/:devicecategory", controllers.RetrieveSecretByDevicecategory)
|
|
|
|
// Initializing the server in a goroutine so that
|
|
// it won't block the graceful shutdown handling below
|
|
go func() {
|
|
if err := srv.ListenAndServeTLS(tlsCertFilename, tlsKeyFilename); err != nil && err != http.ErrServerClosed {
|
|
log.Fatalf("listen: %s\n", err)
|
|
}
|
|
}()
|
|
|
|
// Listen for the interrupt signal.
|
|
<-ctx.Done()
|
|
|
|
// Restore default behavior on the interrupt signal and notify user of shutdown.
|
|
stop()
|
|
log.Println("shutting down gracefully, press Ctrl+C again to force")
|
|
|
|
models.DisconnectDatabase()
|
|
|
|
// The context is used to inform the server it has 5 seconds to finish
|
|
// the request it is currently handling
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := srv.Shutdown(ctx); err != nil {
|
|
log.Fatal("Server forced to shutdown: ", err)
|
|
}
|
|
|
|
log.Println("Server exiting")
|
|
}
|