improve checks

This commit is contained in:
2023-04-03 08:33:31 +10:00
parent b45e276df5
commit 748f4251e1
9 changed files with 139 additions and 43 deletions

2
.gitignore vendored
View File

@@ -1,4 +1,4 @@
api\ tests.txt api\ tests.txt
ccsecrets ccsecrets
ccsecrets.db ccsecrets.*
.env .env

View File

@@ -36,6 +36,19 @@ func Register(c *gin.Context) {
u.UserName = input.Username u.UserName = input.Username
u.Password = input.Password u.Password = input.Password
//remove spaces in username
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
// Check if user already exists
testUser, _ := models.GetUserByName(u.UserName)
fmt.Printf("Register checking if user already exists : '%v'\n", testUser)
if (models.User{} == testUser) {
fmt.Printf("Register confirmed no existing username\n")
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "Attempt to register conflicting username"})
return
}
//turn password into hash //turn password into hash
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
@@ -46,9 +59,6 @@ func Register(c *gin.Context) {
} }
u.Password = string(hashedPassword) u.Password = string(hashedPassword)
//remove spaces in username
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
_, err = u.SaveUser() _, err = u.SaveUser()
if err != nil { if err != nil {

View File

@@ -3,6 +3,7 @@ package controllers
import ( import (
"ccsecrets/models" "ccsecrets/models"
"ccsecrets/utils/token" "ccsecrets/utils/token"
"errors"
"fmt" "fmt"
"net/http" "net/http"
@@ -17,6 +18,52 @@ type RetrieveInput struct {
func RetrieveSecret(c *gin.Context) { func RetrieveSecret(c *gin.Context) {
var input RetrieveInput var input RetrieveInput
// Validate the input matches our struct
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
fmt.Printf("RetrieveSecret received JSON input '%v'\n", input)
// Get the user and role id of the requestor
user_id, err := token.ExtractTokenID(c)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
u, err := models.GetUserRoleByID(user_id)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Populate fields
s := models.Secret{}
s.RoleId = u.RoleId
s.DeviceName = input.DeviceName
s.DeviceCategory = input.DeviceCategory
results, err := models.GetSecrets(&s)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(results) > 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": errors.New("found multiple matching secrets, use retrieveMultiple instead")})
return
}
// output results as json
c.JSON(http.StatusOK, gin.H{"message": "success", "data": results})
}
func RetrieveMultpleSecrets(c *gin.Context) {
var input RetrieveInput
// Validate the input matches our struct // Validate the input matches our struct
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -41,6 +41,19 @@ func StoreSecret(c *gin.Context) {
s.RoleId = 1 s.RoleId = 1
} }
// If this secret already exists in the database then generate an error
checkExists, err := models.GetSecrets(&s)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(checkExists) > 0 {
fmt.Printf("StoreSecret not storing secret with '%d' already matching secrets.\n", len(checkExists))
c.JSON(http.StatusBadRequest, gin.H{"error": "StoreSecret attempting to store secret already defined. API calls for update/delete don't yet exist"})
return
}
// Encrypt secret // Encrypt secret
s.Secret = input.SecretValue s.Secret = input.SecretValue
_, err = s.EncryptSecret() _, err = s.EncryptSecret()
@@ -49,14 +62,6 @@ func StoreSecret(c *gin.Context) {
return return
} }
// This is just here for testing to make sure that decryption works
/*
_, err = s.DecryptSecret()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"Error decrypting secret": err.Error()})
return
}
*/
_, err = s.SaveSecret() _, err = s.SaveSecret()
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"Error saving secret": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"Error saving secret": err.Error()})

55
main.go
View File

@@ -8,6 +8,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"log" "log"
"net/http" "net/http"
"os" "os"
@@ -18,7 +19,22 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// For build numbers, from https://blog.kowalczyk.info/article/vEja/embedding-build-number-in-go-executable.html
var sha1ver string // sha1 revision used to build the program
var buildTime string // when the executable was built
func main() { func main() {
// Open connection to logfile
// From https://ispycode.com/GO/Logging/Logging-to-multiple-destinations
logFile := os.Getenv("LOG_FILE")
if logFile == "" {
logFile = "./ccsecrets.log"
}
logfileWriter, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Println("Unable to write logfile", err)
os.Exit(1)
}
// Initiate connection to sqlite and make sure our schema is up to date // Initiate connection to sqlite and make sure our schema is up to date
models.ConnectDatabase() models.ConnectDatabase()
@@ -27,10 +43,22 @@ func main() {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop() defer stop()
router := gin.Default() // Creates a router without any middleware by default
router := gin.New()
// Global middleware
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
// By default gin.DefaultWriter = os.Stdout
gin.DefaultWriter = io.MultiWriter(logfileWriter, os.Stdout)
router.Use(gin.Logger())
// Recovery middleware recovers from any panics and writes a 500 if there was one.
router.Use(gin.Recovery())
// TODO - think of a better default landing page
router.GET("/", func(c *gin.Context) { router.GET("/", func(c *gin.Context) {
//time.Sleep(10 * time.Second) //time.Sleep(10 * time.Second)
c.String(http.StatusOK, "Hello World.") c.String(http.StatusOK, fmt.Sprintf("Built on %s from sha1 %s\n", buildTime, sha1ver))
}) })
// Set some options for TLS // Set some options for TLS
@@ -90,6 +118,7 @@ func main() {
protected := router.Group("/api/secret") protected := router.Group("/api/secret")
protected.Use(middlewares.JwtAuthMiddleware()) protected.Use(middlewares.JwtAuthMiddleware())
protected.GET("/retrieve", controllers.RetrieveSecret) protected.GET("/retrieve", controllers.RetrieveSecret)
protected.GET("/retrieveMultiple", controllers.RetrieveMultpleSecrets)
protected.POST("/store", controllers.StoreSecret) protected.POST("/store", controllers.StoreSecret)
// Initializing the server in a goroutine so that // Initializing the server in a goroutine so that
@@ -118,26 +147,4 @@ func main() {
} }
log.Println("Server exiting") log.Println("Server exiting")
/*
r := gin.Default()
// Define our routes underneath /api
public := r.Group("/api")
public.POST("/register", controllers.Register)
public.POST("/login", controllers.Login)
// This is just PoC really, we can get rid of it
//protected := r.Group("/api/admin")
//protected.Use(middlewares.JwtAuthMiddleware())
//protected.GET("/user", controllers.CurrentUser)
// Get secrets
protected := r.Group("/api/secret")
protected.Use(middlewares.JwtAuthMiddleware())
protected.GET("/device", controllers.CurrentUser)
r.Run(":8443")
*/
} }

View File

@@ -44,6 +44,7 @@ func (s *Secret) SaveSecret() (*Secret, error) {
return s, nil return s, nil
} }
// Returns all matching secrets, up to caller to determine how to deal with multiple results
func GetSecrets(s *Secret) ([]Secret, error) { func GetSecrets(s *Secret) ([]Secret, error) {
var err error var err error
var rows *sqlx.Rows var rows *sqlx.Rows
@@ -52,6 +53,7 @@ func GetSecrets(s *Secret) ([]Secret, error) {
fmt.Printf("GetSecret querying values '%v'\n", s) fmt.Printf("GetSecret querying values '%v'\n", s)
// Determine whether to query for a specific device or a category of devices // Determine whether to query for a specific device or a category of devices
// Prefer querying device name than category
if s.DeviceName != "" { if s.DeviceName != "" {
rows, err = db.Queryx("SELECT * FROM secrets WHERE DeviceName LIKE ? AND RoleId = ?", s.DeviceName, s.RoleId) rows, err = db.Queryx("SELECT * FROM secrets WHERE DeviceName LIKE ? AND RoleId = ?", s.DeviceName, s.RoleId)
} else if s.DeviceCategory != "" { } else if s.DeviceCategory != "" {
@@ -62,8 +64,6 @@ func GetSecrets(s *Secret) ([]Secret, error) {
return secretResults, err return secretResults, err
} }
// TODO - do we want to generate an error if the query returns more than one result?
if err != nil { if err != nil {
fmt.Printf("GetSecret error executing sql record : '%s'\n", err) fmt.Printf("GetSecret error executing sql record : '%s'\n", err)
return secretResults, err return secretResults, err

View File

@@ -11,6 +11,7 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
@@ -124,7 +125,17 @@ func CreateTables() {
} }
rowCount, _ = CheckCount("users") rowCount, _ = CheckCount("users")
if rowCount == 0 { if rowCount == 0 {
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', '$2a$10$k1qldm.bWqZsQWrKPdahR.Pfz5LxkMUka2.8INEeSD7euzkiznIR.');"); err != nil { // Check if there was an initial password defined in the .env file
initialPassword := os.Getenv("INITIAL_PASSWORD")
if initialPassword == "" {
initialPassword = "password"
} else if initialPassword[:4] == "$2a$" {
fmt.Printf("CreateTables inital admin password is already a hash")
} else {
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
initialPassword = string(cryptText)
}
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', ?);", initialPassword); err != nil {
fmt.Printf("Error adding initial admin role : '%s'", err) fmt.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1) os.Exit(1)
} }
@@ -223,7 +234,7 @@ func CheckColumnExists(table string, column string) (bool, error) {
for rows.Next() { for rows.Next() {
// cols is an []interface{} of all of the column results // cols is an []interface{} of all of the column results
cols, _ := rows.SliceScan() cols, _ := rows.SliceScan()
fmt.Printf("CheckColumnExists Value is '%v'\n", cols[0].(int64)) fmt.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
count = cols[0].(int64) count = cols[0].(int64)
if count == 1 { if count == 1 {

View File

@@ -12,7 +12,7 @@ type User struct {
UserId int `db:"UserId"` UserId int `db:"UserId"`
RoleId int `db:"RoleId"` RoleId int `db:"RoleId"`
UserName string `db:"UserName"` UserName string `db:"UserName"`
Password string `db:"Password"` Password string `db:"Password" json:"-"`
} }
type UserRole struct { type UserRole struct {
@@ -87,7 +87,7 @@ func GetUserByID(uid uint) (User, error) {
var u User var u User
// Query database for matching user object // Query database for matching user object
err := db.QueryRowx("SELECT * FROM users INNER JOIN roles ON users.RoleId = roles.RoleId WHERE UserId=?", uid).StructScan(&u) err := db.QueryRowx("SELECT * FROM users WHERE UserId=?", uid).StructScan(&u)
if err != nil { if err != nil {
return u, errors.New("user not found") return u, errors.New("user not found")
} }
@@ -96,12 +96,25 @@ func GetUserByID(uid uint) (User, error) {
return u, errors.New("User not found!") return u, errors.New("User not found!")
} }
*/ */
u.PrepareGive() //u.PrepareGive()
return u, nil return u, nil
} }
func GetUserByName(username string) (User, error) {
var u User
// Query database for matching user object
err := db.QueryRowx("SELECT * FROM users WHERE UserName=?", username).StructScan(&u)
if err != nil {
return u, errors.New("user not found")
}
return u, nil
}
func GetUserRoleByID(uid uint) (UserRole, error) { func GetUserRoleByID(uid uint) (UserRole, error) {
var ur UserRole var ur UserRole
@@ -118,6 +131,8 @@ func GetUserRoleByID(uid uint) (UserRole, error) {
} }
/*
func (u *User) PrepareGive() { func (u *User) PrepareGive() {
u.Password = "" u.Password = ""
} }
*/

View File

@@ -62,6 +62,7 @@ func ExtractTokenID(c *gin.Context) (uint, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
} }
// Why return the secret??
return []byte(os.Getenv("API_SECRET")), nil return []byte(os.Getenv("API_SECRET")), nil
}) })
if err != nil { if err != nil {