package main import ( "bytes" "context" "crypto/tls" "embed" "fmt" "io/fs" "log" "net/http" "os" "os/signal" "smt/controllers" "smt/middlewares" "smt/models" "smt/utils" "strings" "syscall" "time" "github.com/gin-gonic/gin" "github.com/joho/godotenv" ) // For build numbers, from https://blog.kowalczyk.info/article/vEja/embedding-build-number-in-go-executable.html var sha1ver string // sha1 revision used to build the program var buildTime string // when the executable was built type Replacements map[string]string var replacements = make(Replacements) //go:embed www/* var staticContent embed.FS func getAllFilenames(efs *embed.FS) (files []string, err error) { if err := fs.WalkDir(efs, ".", func(path string, d fs.DirEntry, err error) error { if d.IsDir() { return nil } files = append(files, path) return nil }); err != nil { return nil, err } return files, nil } // staticFileServer serves files from the provided fs.FS func staticFileServer(content embed.FS) gin.HandlerFunc { return func(c *gin.Context) { // Get the requested file name fileName := c.Request.URL.Path // Root request should load index.htm if len(fileName) == 1 { //log.Printf("staticFileServer replacing root request with index.html") fileName = "www/index.html" } else { fileName = strings.TrimLeft(fileName, "/") } //log.Printf("staticFileServer attempting to load filename '%s'\n", fileName) // Try to open the file from the embedded FS file, err := content.Open(fileName) if err != nil { log.Printf("staticFileServer error opening '%s' : '%s'\n", fileName, err) c.String(http.StatusNotFound, "File not found") return } defer file.Close() // Serve the file contents fileStat, _ := file.Stat() data := make([]byte, fileStat.Size()) _, err = file.Read(data) if err != nil { c.String(http.StatusInternalServerError, "Error reading file") return } // parse the file and perform text replacements as necessary for key, element := range replacements { //log.Printf("Searching for '%s' to replace\n", key) data = bytes.Replace(data, []byte(key), []byte(element), -1) } // Set the proper Content-Type header based on file extension c.Header("Content-Type", http.DetectContentType(data)) c.Data(http.StatusOK, "", data) } } func main() { // These replacements are for the embedded html generated from README.md replacements["{SHA1VER}"] = sha1ver replacements["{BUILDTIME}"] = buildTime // Load data from environment file envFilename := utils.GetFilePath(".env") err := godotenv.Load(envFilename) if err != nil { panic("Error loading .env file") } // Open connection to logfile // From https://ispycode.com/GO/Logging/Logging-to-multiple-destinations logFile := os.Getenv("LOG_FILE") if logFile == "" { logFile = "./smt.log" } logfileWriter, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) if err != nil { fmt.Printf("Unable to write logfile '%s' : '%s'\n", logFile, err) os.Exit(1) } log.SetOutput(logfileWriter) log.Printf("SMT starting execution. Built on %s from sha1 %s\n", buildTime, sha1ver) // for debugging, list all the files that we embedded at compile time /* files, err := getAllFilenames(&staticContent) if err != nil { log.Printf("Unable to access embedded fs : '%s'\n", err) } for i := range files { log.Printf("Embedded file : '%s'\n", files[i]) } */ // Initiate connection to sqlite and make sure our schema is up to date models.ConnectDatabase() // Set secrets key from .env file keyString := os.Getenv("SECRETS_KEY") if keyString != "" { // Key was defined in environment variable, let the models package know our secrets key log.Println("Found secret key in environment variable") models.ReceiveKey(keyString) } // If LDAP configuration is defined, prepare connection ldapServer := os.Getenv("LDAP_BIND_ADDRESS") if ldapServer != "" { models.LdapSetup() } // Create context that listens for the interrupt signal from the OS. ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() // Creates a router without any middleware by default gin.SetMode(gin.ReleaseMode) router := gin.New() // Global middleware // Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release. // By default gin.DefaultWriter = os.Stdout // log to file only gin.DefaultWriter = logfileWriter // log to file and stdout //gin.DefaultWriter = io.MultiWriter(logfileWriter, os.Stdout) router.Use(gin.Logger()) // Recovery middleware recovers from any panics and writes a 500 if there was one. router.Use(gin.Recovery()) /* // TODO - think of a better default landing page router.GET("/", func(c *gin.Context) { c.String(http.StatusOK, fmt.Sprintf("SMT Built on %s from sha1 %s\n", buildTime, sha1ver)) }) */ // Set some options for TLS tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, InsecureSkipVerify: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, }, } // Determine bind IP bindIP := os.Getenv("BIND_IP") if bindIP == "" { bindIP = utils.GetOutboundIP().String() } // Determine bind port bindPort := os.Getenv("BIND_PORT") if bindPort == "" { bindIP = "8443" } bindAddress := fmt.Sprint(bindIP, ":", bindPort) log.Printf("Will listen on address 'https://%s'\n", bindAddress) // Get file names for TLS cert/key tlsCertFilename := os.Getenv("TLS_CERT_FILE") if tlsCertFilename != "" { tlsCertFilename = utils.GetFilePath(tlsCertFilename) } else { tlsCertFilename = "./cert.pem" } tlsKeyFilename := os.Getenv("TLS_KEY_FILE") if tlsKeyFilename != "" { tlsKeyFilename = utils.GetFilePath(tlsKeyFilename) } else { tlsKeyFilename = "./privkey.pem" } // Generate certificate if required if !(utils.FileExists(tlsCertFilename) && utils.FileExists(tlsKeyFilename)) { log.Printf("Specified TLS certificate (%s) or private key (%s) do not exist.\n", tlsCertFilename, tlsKeyFilename) utils.GenerateCerts(tlsCertFilename, tlsKeyFilename) } srv := &http.Server{ Addr: bindAddress, Handler: router, TLSConfig: tlsConfig, } // Set the default readme page //router.Use(EmbedReact("/", "static_files", staticDir)) //router.Use(static.Serve("/", static.LocalFile("./static_files", true))) // Serve the embedded HTML file if no other routes match router.NoRoute(staticFileServer(staticContent)) // Register our routes public := router.Group("/api") public.POST("/login", controllers.Login) //public.POST("/unlock", controllers.Unlock) // API calls that only an administrator can make adminOnly := router.Group("/api/admin") adminOnly.Use(middlewares.JwtAuthAdminMiddleware()) adminOnly.POST("/user/delete", controllers.DeleteUser) adminOnly.POST("/user/register", controllers.RegisterUser) // TODO deprecate adminOnly.POST("/user/add", controllers.RegisterUser) adminOnly.GET("/roles", controllers.GetRoles) adminOnly.POST("/role/add", controllers.AddRole) adminOnly.GET("/users", controllers.GetUsers) // TODO Make unlock an admin only function adminOnly.POST("/unlock", controllers.Unlock) // Get secrets protected := router.Group("/api/secret") protected.Use(middlewares.JwtAuthMiddleware()) protected.POST("/retrieve", controllers.RetrieveSecret) protected.GET("/list", controllers.ListSecrets) protected.POST("/retrieveMultiple", controllers.RetrieveMultpleSecrets) protected.POST("/store", controllers.StoreSecret) protected.POST("/update", controllers.UpdateSecret) // TODO //protected.POST("/delete", controllers.DeleteSecret) // Support parameters in path // See https://gin-gonic.com/docs/examples/param-in-path/ protected.GET("/retrieve/name/:devicename", controllers.RetrieveSecretByDevicename) protected.GET("/retrieve/category/:devicecategory", controllers.RetrieveSecretByDevicecategory) // Initializing the server in a goroutine so that // it won't block the graceful shutdown handling below go func() { if err := srv.ListenAndServeTLS(tlsCertFilename, tlsKeyFilename); err != nil && err != http.ErrServerClosed { log.Fatalf("listen: %s\n", err) } }() // Listen for the interrupt signal. <-ctx.Done() // Restore default behavior on the interrupt signal and notify user of shutdown. stop() log.Println("shutting down gracefully, press Ctrl+C again to force") models.DisconnectDatabase() // The context is used to inform the server it has 5 seconds to finish // the request it is currently handling ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { log.Fatal("Server forced to shutdown: ", err) } log.Println("Server exiting") }