diff --git a/.gitignore b/.gitignore index 586fa3f..7936183 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ api\ tests.txt ccsecrets -ccsecrets.db +ccsecrets.* .env \ No newline at end of file diff --git a/controllers/auth.go b/controllers/auth.go index 592cce3..75d71a0 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -36,6 +36,19 @@ func Register(c *gin.Context) { u.UserName = input.Username u.Password = input.Password + //remove spaces in username + u.UserName = html.EscapeString(strings.TrimSpace(u.UserName)) + + // Check if user already exists + testUser, _ := models.GetUserByName(u.UserName) + fmt.Printf("Register checking if user already exists : '%v'\n", testUser) + if (models.User{} == testUser) { + fmt.Printf("Register confirmed no existing username\n") + } else { + c.JSON(http.StatusBadRequest, gin.H{"error": "Attempt to register conflicting username"}) + return + } + //turn password into hash hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost) if err != nil { @@ -46,9 +59,6 @@ func Register(c *gin.Context) { } u.Password = string(hashedPassword) - //remove spaces in username - u.UserName = html.EscapeString(strings.TrimSpace(u.UserName)) - _, err = u.SaveUser() if err != nil { diff --git a/controllers/retrieve_secrets.go b/controllers/retrieve_secrets.go index f76f7fc..506c51c 100644 --- a/controllers/retrieve_secrets.go +++ b/controllers/retrieve_secrets.go @@ -3,6 +3,7 @@ package controllers import ( "ccsecrets/models" "ccsecrets/utils/token" + "errors" "fmt" "net/http" @@ -17,6 +18,52 @@ type RetrieveInput struct { func RetrieveSecret(c *gin.Context) { var input RetrieveInput + // Validate the input matches our struct + if err := c.ShouldBindJSON(&input); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + fmt.Printf("RetrieveSecret received JSON input '%v'\n", input) + + // Get the user and role id of the requestor + user_id, err := token.ExtractTokenID(c) + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + u, err := models.GetUserRoleByID(user_id) + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Populate fields + s := models.Secret{} + s.RoleId = u.RoleId + s.DeviceName = input.DeviceName + s.DeviceCategory = input.DeviceCategory + + results, err := models.GetSecrets(&s) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if len(results) > 1 { + c.JSON(http.StatusBadRequest, gin.H{"error": errors.New("found multiple matching secrets, use retrieveMultiple instead")}) + return + } + + // output results as json + c.JSON(http.StatusOK, gin.H{"message": "success", "data": results}) +} + +func RetrieveMultpleSecrets(c *gin.Context) { + var input RetrieveInput + // Validate the input matches our struct if err := c.ShouldBindJSON(&input); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/controllers/store_secrets.go b/controllers/store_secrets.go index 808e2f0..23f983f 100644 --- a/controllers/store_secrets.go +++ b/controllers/store_secrets.go @@ -41,6 +41,19 @@ func StoreSecret(c *gin.Context) { s.RoleId = 1 } + // If this secret already exists in the database then generate an error + checkExists, err := models.GetSecrets(&s) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if len(checkExists) > 0 { + fmt.Printf("StoreSecret not storing secret with '%d' already matching secrets.\n", len(checkExists)) + c.JSON(http.StatusBadRequest, gin.H{"error": "StoreSecret attempting to store secret already defined. API calls for update/delete don't yet exist"}) + return + } + // Encrypt secret s.Secret = input.SecretValue _, err = s.EncryptSecret() @@ -49,14 +62,6 @@ func StoreSecret(c *gin.Context) { return } - // This is just here for testing to make sure that decryption works - /* - _, err = s.DecryptSecret() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"Error decrypting secret": err.Error()}) - return - } - */ _, err = s.SaveSecret() if err != nil { c.JSON(http.StatusBadRequest, gin.H{"Error saving secret": err.Error()}) diff --git a/main.go b/main.go index 8cfbcb3..5d230e0 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "log" "net/http" "os" @@ -18,7 +19,22 @@ import ( "github.com/gin-gonic/gin" ) +// For build numbers, from https://blog.kowalczyk.info/article/vEja/embedding-build-number-in-go-executable.html +var sha1ver string // sha1 revision used to build the program +var buildTime string // when the executable was built + func main() { + // Open connection to logfile + // From https://ispycode.com/GO/Logging/Logging-to-multiple-destinations + logFile := os.Getenv("LOG_FILE") + if logFile == "" { + logFile = "./ccsecrets.log" + } + logfileWriter, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + fmt.Println("Unable to write logfile", err) + os.Exit(1) + } // Initiate connection to sqlite and make sure our schema is up to date models.ConnectDatabase() @@ -27,10 +43,22 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - router := gin.Default() + // Creates a router without any middleware by default + router := gin.New() + + // Global middleware + // Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release. + // By default gin.DefaultWriter = os.Stdout + gin.DefaultWriter = io.MultiWriter(logfileWriter, os.Stdout) + router.Use(gin.Logger()) + + // Recovery middleware recovers from any panics and writes a 500 if there was one. + router.Use(gin.Recovery()) + + // TODO - think of a better default landing page router.GET("/", func(c *gin.Context) { //time.Sleep(10 * time.Second) - c.String(http.StatusOK, "Hello World.") + c.String(http.StatusOK, fmt.Sprintf("Built on %s from sha1 %s\n", buildTime, sha1ver)) }) // Set some options for TLS @@ -90,6 +118,7 @@ func main() { protected := router.Group("/api/secret") protected.Use(middlewares.JwtAuthMiddleware()) protected.GET("/retrieve", controllers.RetrieveSecret) + protected.GET("/retrieveMultiple", controllers.RetrieveMultpleSecrets) protected.POST("/store", controllers.StoreSecret) // Initializing the server in a goroutine so that @@ -118,26 +147,4 @@ func main() { } log.Println("Server exiting") - - /* - r := gin.Default() - - // Define our routes underneath /api - public := r.Group("/api") - public.POST("/register", controllers.Register) - public.POST("/login", controllers.Login) - - // This is just PoC really, we can get rid of it - //protected := r.Group("/api/admin") - //protected.Use(middlewares.JwtAuthMiddleware()) - //protected.GET("/user", controllers.CurrentUser) - - // Get secrets - protected := r.Group("/api/secret") - protected.Use(middlewares.JwtAuthMiddleware()) - protected.GET("/device", controllers.CurrentUser) - - r.Run(":8443") - */ - } diff --git a/models/secret.go b/models/secret.go index bc9b7f3..4132bfb 100644 --- a/models/secret.go +++ b/models/secret.go @@ -44,6 +44,7 @@ func (s *Secret) SaveSecret() (*Secret, error) { return s, nil } +// Returns all matching secrets, up to caller to determine how to deal with multiple results func GetSecrets(s *Secret) ([]Secret, error) { var err error var rows *sqlx.Rows @@ -52,6 +53,7 @@ func GetSecrets(s *Secret) ([]Secret, error) { fmt.Printf("GetSecret querying values '%v'\n", s) // Determine whether to query for a specific device or a category of devices + // Prefer querying device name than category if s.DeviceName != "" { rows, err = db.Queryx("SELECT * FROM secrets WHERE DeviceName LIKE ? AND RoleId = ?", s.DeviceName, s.RoleId) } else if s.DeviceCategory != "" { @@ -62,8 +64,6 @@ func GetSecrets(s *Secret) ([]Secret, error) { return secretResults, err } - // TODO - do we want to generate an error if the query returns more than one result? - if err != nil { fmt.Printf("GetSecret error executing sql record : '%s'\n", err) return secretResults, err diff --git a/models/setup.go b/models/setup.go index 64bdba5..aa12eb9 100644 --- a/models/setup.go +++ b/models/setup.go @@ -11,6 +11,7 @@ import ( "github.com/jmoiron/sqlx" "github.com/joho/godotenv" + "golang.org/x/crypto/bcrypt" _ "modernc.org/sqlite" ) @@ -124,7 +125,17 @@ func CreateTables() { } rowCount, _ = CheckCount("users") if rowCount == 0 { - if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', '$2a$10$k1qldm.bWqZsQWrKPdahR.Pfz5LxkMUka2.8INEeSD7euzkiznIR.');"); err != nil { + // Check if there was an initial password defined in the .env file + initialPassword := os.Getenv("INITIAL_PASSWORD") + if initialPassword == "" { + initialPassword = "password" + } else if initialPassword[:4] == "$2a$" { + fmt.Printf("CreateTables inital admin password is already a hash") + } else { + cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost) + initialPassword = string(cryptText) + } + if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', ?);", initialPassword); err != nil { fmt.Printf("Error adding initial admin role : '%s'", err) os.Exit(1) } @@ -223,7 +234,7 @@ func CheckColumnExists(table string, column string) (bool, error) { for rows.Next() { // cols is an []interface{} of all of the column results cols, _ := rows.SliceScan() - fmt.Printf("CheckColumnExists Value is '%v'\n", cols[0].(int64)) + fmt.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column) count = cols[0].(int64) if count == 1 { diff --git a/models/user.go b/models/user.go index 4eb8bac..c3693f0 100644 --- a/models/user.go +++ b/models/user.go @@ -12,7 +12,7 @@ type User struct { UserId int `db:"UserId"` RoleId int `db:"RoleId"` UserName string `db:"UserName"` - Password string `db:"Password"` + Password string `db:"Password" json:"-"` } type UserRole struct { @@ -87,7 +87,7 @@ func GetUserByID(uid uint) (User, error) { var u User // Query database for matching user object - err := db.QueryRowx("SELECT * FROM users INNER JOIN roles ON users.RoleId = roles.RoleId WHERE UserId=?", uid).StructScan(&u) + err := db.QueryRowx("SELECT * FROM users WHERE UserId=?", uid).StructScan(&u) if err != nil { return u, errors.New("user not found") } @@ -96,12 +96,25 @@ func GetUserByID(uid uint) (User, error) { return u, errors.New("User not found!") } */ - u.PrepareGive() + //u.PrepareGive() return u, nil } +func GetUserByName(username string) (User, error) { + + var u User + + // Query database for matching user object + err := db.QueryRowx("SELECT * FROM users WHERE UserName=?", username).StructScan(&u) + if err != nil { + return u, errors.New("user not found") + } + + return u, nil +} + func GetUserRoleByID(uid uint) (UserRole, error) { var ur UserRole @@ -118,6 +131,8 @@ func GetUserRoleByID(uid uint) (UserRole, error) { } +/* func (u *User) PrepareGive() { u.Password = "" } +*/ diff --git a/utils/token/token.go b/utils/token/token.go index 0dbf05b..3c36d03 100644 --- a/utils/token/token.go +++ b/utils/token/token.go @@ -62,6 +62,7 @@ func ExtractTokenID(c *gin.Context) (uint, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } + // Why return the secret?? return []byte(os.Getenv("API_SECRET")), nil }) if err != nil {