improve checks

This commit is contained in:
2023-04-03 08:33:31 +10:00
parent b45e276df5
commit 748f4251e1
9 changed files with 139 additions and 43 deletions

2
.gitignore vendored
View File

@@ -1,4 +1,4 @@
api\ tests.txt
ccsecrets
ccsecrets.db
ccsecrets.*
.env

View File

@@ -36,6 +36,19 @@ func Register(c *gin.Context) {
u.UserName = input.Username
u.Password = input.Password
//remove spaces in username
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
// Check if user already exists
testUser, _ := models.GetUserByName(u.UserName)
fmt.Printf("Register checking if user already exists : '%v'\n", testUser)
if (models.User{} == testUser) {
fmt.Printf("Register confirmed no existing username\n")
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "Attempt to register conflicting username"})
return
}
//turn password into hash
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
if err != nil {
@@ -46,9 +59,6 @@ func Register(c *gin.Context) {
}
u.Password = string(hashedPassword)
//remove spaces in username
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
_, err = u.SaveUser()
if err != nil {

View File

@@ -3,6 +3,7 @@ package controllers
import (
"ccsecrets/models"
"ccsecrets/utils/token"
"errors"
"fmt"
"net/http"
@@ -17,6 +18,52 @@ type RetrieveInput struct {
func RetrieveSecret(c *gin.Context) {
var input RetrieveInput
// Validate the input matches our struct
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
fmt.Printf("RetrieveSecret received JSON input '%v'\n", input)
// Get the user and role id of the requestor
user_id, err := token.ExtractTokenID(c)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
u, err := models.GetUserRoleByID(user_id)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Populate fields
s := models.Secret{}
s.RoleId = u.RoleId
s.DeviceName = input.DeviceName
s.DeviceCategory = input.DeviceCategory
results, err := models.GetSecrets(&s)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(results) > 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": errors.New("found multiple matching secrets, use retrieveMultiple instead")})
return
}
// output results as json
c.JSON(http.StatusOK, gin.H{"message": "success", "data": results})
}
func RetrieveMultpleSecrets(c *gin.Context) {
var input RetrieveInput
// Validate the input matches our struct
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -41,6 +41,19 @@ func StoreSecret(c *gin.Context) {
s.RoleId = 1
}
// If this secret already exists in the database then generate an error
checkExists, err := models.GetSecrets(&s)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(checkExists) > 0 {
fmt.Printf("StoreSecret not storing secret with '%d' already matching secrets.\n", len(checkExists))
c.JSON(http.StatusBadRequest, gin.H{"error": "StoreSecret attempting to store secret already defined. API calls for update/delete don't yet exist"})
return
}
// Encrypt secret
s.Secret = input.SecretValue
_, err = s.EncryptSecret()
@@ -49,14 +62,6 @@ func StoreSecret(c *gin.Context) {
return
}
// This is just here for testing to make sure that decryption works
/*
_, err = s.DecryptSecret()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"Error decrypting secret": err.Error()})
return
}
*/
_, err = s.SaveSecret()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"Error saving secret": err.Error()})

55
main.go
View File

@@ -8,6 +8,7 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"log"
"net/http"
"os"
@@ -18,7 +19,22 @@ import (
"github.com/gin-gonic/gin"
)
// For build numbers, from https://blog.kowalczyk.info/article/vEja/embedding-build-number-in-go-executable.html
var sha1ver string // sha1 revision used to build the program
var buildTime string // when the executable was built
func main() {
// Open connection to logfile
// From https://ispycode.com/GO/Logging/Logging-to-multiple-destinations
logFile := os.Getenv("LOG_FILE")
if logFile == "" {
logFile = "./ccsecrets.log"
}
logfileWriter, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Println("Unable to write logfile", err)
os.Exit(1)
}
// Initiate connection to sqlite and make sure our schema is up to date
models.ConnectDatabase()
@@ -27,10 +43,22 @@ func main() {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
router := gin.Default()
// Creates a router without any middleware by default
router := gin.New()
// Global middleware
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
// By default gin.DefaultWriter = os.Stdout
gin.DefaultWriter = io.MultiWriter(logfileWriter, os.Stdout)
router.Use(gin.Logger())
// Recovery middleware recovers from any panics and writes a 500 if there was one.
router.Use(gin.Recovery())
// TODO - think of a better default landing page
router.GET("/", func(c *gin.Context) {
//time.Sleep(10 * time.Second)
c.String(http.StatusOK, "Hello World.")
c.String(http.StatusOK, fmt.Sprintf("Built on %s from sha1 %s\n", buildTime, sha1ver))
})
// Set some options for TLS
@@ -90,6 +118,7 @@ func main() {
protected := router.Group("/api/secret")
protected.Use(middlewares.JwtAuthMiddleware())
protected.GET("/retrieve", controllers.RetrieveSecret)
protected.GET("/retrieveMultiple", controllers.RetrieveMultpleSecrets)
protected.POST("/store", controllers.StoreSecret)
// Initializing the server in a goroutine so that
@@ -118,26 +147,4 @@ func main() {
}
log.Println("Server exiting")
/*
r := gin.Default()
// Define our routes underneath /api
public := r.Group("/api")
public.POST("/register", controllers.Register)
public.POST("/login", controllers.Login)
// This is just PoC really, we can get rid of it
//protected := r.Group("/api/admin")
//protected.Use(middlewares.JwtAuthMiddleware())
//protected.GET("/user", controllers.CurrentUser)
// Get secrets
protected := r.Group("/api/secret")
protected.Use(middlewares.JwtAuthMiddleware())
protected.GET("/device", controllers.CurrentUser)
r.Run(":8443")
*/
}

View File

@@ -44,6 +44,7 @@ func (s *Secret) SaveSecret() (*Secret, error) {
return s, nil
}
// Returns all matching secrets, up to caller to determine how to deal with multiple results
func GetSecrets(s *Secret) ([]Secret, error) {
var err error
var rows *sqlx.Rows
@@ -52,6 +53,7 @@ func GetSecrets(s *Secret) ([]Secret, error) {
fmt.Printf("GetSecret querying values '%v'\n", s)
// Determine whether to query for a specific device or a category of devices
// Prefer querying device name than category
if s.DeviceName != "" {
rows, err = db.Queryx("SELECT * FROM secrets WHERE DeviceName LIKE ? AND RoleId = ?", s.DeviceName, s.RoleId)
} else if s.DeviceCategory != "" {
@@ -62,8 +64,6 @@ func GetSecrets(s *Secret) ([]Secret, error) {
return secretResults, err
}
// TODO - do we want to generate an error if the query returns more than one result?
if err != nil {
fmt.Printf("GetSecret error executing sql record : '%s'\n", err)
return secretResults, err

View File

@@ -11,6 +11,7 @@ import (
"github.com/jmoiron/sqlx"
"github.com/joho/godotenv"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
)
@@ -124,7 +125,17 @@ func CreateTables() {
}
rowCount, _ = CheckCount("users")
if rowCount == 0 {
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', '$2a$10$k1qldm.bWqZsQWrKPdahR.Pfz5LxkMUka2.8INEeSD7euzkiznIR.');"); err != nil {
// Check if there was an initial password defined in the .env file
initialPassword := os.Getenv("INITIAL_PASSWORD")
if initialPassword == "" {
initialPassword = "password"
} else if initialPassword[:4] == "$2a$" {
fmt.Printf("CreateTables inital admin password is already a hash")
} else {
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
initialPassword = string(cryptText)
}
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', ?);", initialPassword); err != nil {
fmt.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
@@ -223,7 +234,7 @@ func CheckColumnExists(table string, column string) (bool, error) {
for rows.Next() {
// cols is an []interface{} of all of the column results
cols, _ := rows.SliceScan()
fmt.Printf("CheckColumnExists Value is '%v'\n", cols[0].(int64))
fmt.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
count = cols[0].(int64)
if count == 1 {

View File

@@ -12,7 +12,7 @@ type User struct {
UserId int `db:"UserId"`
RoleId int `db:"RoleId"`
UserName string `db:"UserName"`
Password string `db:"Password"`
Password string `db:"Password" json:"-"`
}
type UserRole struct {
@@ -87,7 +87,7 @@ func GetUserByID(uid uint) (User, error) {
var u User
// Query database for matching user object
err := db.QueryRowx("SELECT * FROM users INNER JOIN roles ON users.RoleId = roles.RoleId WHERE UserId=?", uid).StructScan(&u)
err := db.QueryRowx("SELECT * FROM users WHERE UserId=?", uid).StructScan(&u)
if err != nil {
return u, errors.New("user not found")
}
@@ -96,12 +96,25 @@ func GetUserByID(uid uint) (User, error) {
return u, errors.New("User not found!")
}
*/
u.PrepareGive()
//u.PrepareGive()
return u, nil
}
func GetUserByName(username string) (User, error) {
var u User
// Query database for matching user object
err := db.QueryRowx("SELECT * FROM users WHERE UserName=?", username).StructScan(&u)
if err != nil {
return u, errors.New("user not found")
}
return u, nil
}
func GetUserRoleByID(uid uint) (UserRole, error) {
var ur UserRole
@@ -118,6 +131,8 @@ func GetUserRoleByID(uid uint) (UserRole, error) {
}
/*
func (u *User) PrepareGive() {
u.Password = ""
}
*/

View File

@@ -62,6 +62,7 @@ func ExtractTokenID(c *gin.Context) (uint, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Why return the secret??
return []byte(os.Getenv("API_SECRET")), nil
})
if err != nil {