improve checks
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,4 @@
|
||||
api\ tests.txt
|
||||
ccsecrets
|
||||
ccsecrets.db
|
||||
ccsecrets.*
|
||||
.env
|
@@ -36,6 +36,19 @@ func Register(c *gin.Context) {
|
||||
u.UserName = input.Username
|
||||
u.Password = input.Password
|
||||
|
||||
//remove spaces in username
|
||||
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
|
||||
|
||||
// Check if user already exists
|
||||
testUser, _ := models.GetUserByName(u.UserName)
|
||||
fmt.Printf("Register checking if user already exists : '%v'\n", testUser)
|
||||
if (models.User{} == testUser) {
|
||||
fmt.Printf("Register confirmed no existing username\n")
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Attempt to register conflicting username"})
|
||||
return
|
||||
}
|
||||
|
||||
//turn password into hash
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
@@ -46,9 +59,6 @@ func Register(c *gin.Context) {
|
||||
}
|
||||
u.Password = string(hashedPassword)
|
||||
|
||||
//remove spaces in username
|
||||
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
|
||||
|
||||
_, err = u.SaveUser()
|
||||
|
||||
if err != nil {
|
||||
|
@@ -3,6 +3,7 @@ package controllers
|
||||
import (
|
||||
"ccsecrets/models"
|
||||
"ccsecrets/utils/token"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -17,6 +18,52 @@ type RetrieveInput struct {
|
||||
func RetrieveSecret(c *gin.Context) {
|
||||
var input RetrieveInput
|
||||
|
||||
// Validate the input matches our struct
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
fmt.Printf("RetrieveSecret received JSON input '%v'\n", input)
|
||||
|
||||
// Get the user and role id of the requestor
|
||||
user_id, err := token.ExtractTokenID(c)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
u, err := models.GetUserRoleByID(user_id)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Populate fields
|
||||
s := models.Secret{}
|
||||
s.RoleId = u.RoleId
|
||||
s.DeviceName = input.DeviceName
|
||||
s.DeviceCategory = input.DeviceCategory
|
||||
|
||||
results, err := models.GetSecrets(&s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(results) > 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errors.New("found multiple matching secrets, use retrieveMultiple instead")})
|
||||
return
|
||||
}
|
||||
|
||||
// output results as json
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success", "data": results})
|
||||
}
|
||||
|
||||
func RetrieveMultpleSecrets(c *gin.Context) {
|
||||
var input RetrieveInput
|
||||
|
||||
// Validate the input matches our struct
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
@@ -41,6 +41,19 @@ func StoreSecret(c *gin.Context) {
|
||||
s.RoleId = 1
|
||||
}
|
||||
|
||||
// If this secret already exists in the database then generate an error
|
||||
checkExists, err := models.GetSecrets(&s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(checkExists) > 0 {
|
||||
fmt.Printf("StoreSecret not storing secret with '%d' already matching secrets.\n", len(checkExists))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "StoreSecret attempting to store secret already defined. API calls for update/delete don't yet exist"})
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypt secret
|
||||
s.Secret = input.SecretValue
|
||||
_, err = s.EncryptSecret()
|
||||
@@ -49,14 +62,6 @@ func StoreSecret(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// This is just here for testing to make sure that decryption works
|
||||
/*
|
||||
_, err = s.DecryptSecret()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"Error decrypting secret": err.Error()})
|
||||
return
|
||||
}
|
||||
*/
|
||||
_, err = s.SaveSecret()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"Error saving secret": err.Error()})
|
||||
|
55
main.go
55
main.go
@@ -8,6 +8,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -18,7 +19,22 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// For build numbers, from https://blog.kowalczyk.info/article/vEja/embedding-build-number-in-go-executable.html
|
||||
var sha1ver string // sha1 revision used to build the program
|
||||
var buildTime string // when the executable was built
|
||||
|
||||
func main() {
|
||||
// Open connection to logfile
|
||||
// From https://ispycode.com/GO/Logging/Logging-to-multiple-destinations
|
||||
logFile := os.Getenv("LOG_FILE")
|
||||
if logFile == "" {
|
||||
logFile = "./ccsecrets.log"
|
||||
}
|
||||
logfileWriter, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Println("Unable to write logfile", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Initiate connection to sqlite and make sure our schema is up to date
|
||||
models.ConnectDatabase()
|
||||
@@ -27,10 +43,22 @@ func main() {
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
router := gin.Default()
|
||||
// Creates a router without any middleware by default
|
||||
router := gin.New()
|
||||
|
||||
// Global middleware
|
||||
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
|
||||
// By default gin.DefaultWriter = os.Stdout
|
||||
gin.DefaultWriter = io.MultiWriter(logfileWriter, os.Stdout)
|
||||
router.Use(gin.Logger())
|
||||
|
||||
// Recovery middleware recovers from any panics and writes a 500 if there was one.
|
||||
router.Use(gin.Recovery())
|
||||
|
||||
// TODO - think of a better default landing page
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
//time.Sleep(10 * time.Second)
|
||||
c.String(http.StatusOK, "Hello World.")
|
||||
c.String(http.StatusOK, fmt.Sprintf("Built on %s from sha1 %s\n", buildTime, sha1ver))
|
||||
})
|
||||
|
||||
// Set some options for TLS
|
||||
@@ -90,6 +118,7 @@ func main() {
|
||||
protected := router.Group("/api/secret")
|
||||
protected.Use(middlewares.JwtAuthMiddleware())
|
||||
protected.GET("/retrieve", controllers.RetrieveSecret)
|
||||
protected.GET("/retrieveMultiple", controllers.RetrieveMultpleSecrets)
|
||||
protected.POST("/store", controllers.StoreSecret)
|
||||
|
||||
// Initializing the server in a goroutine so that
|
||||
@@ -118,26 +147,4 @@ func main() {
|
||||
}
|
||||
|
||||
log.Println("Server exiting")
|
||||
|
||||
/*
|
||||
r := gin.Default()
|
||||
|
||||
// Define our routes underneath /api
|
||||
public := r.Group("/api")
|
||||
public.POST("/register", controllers.Register)
|
||||
public.POST("/login", controllers.Login)
|
||||
|
||||
// This is just PoC really, we can get rid of it
|
||||
//protected := r.Group("/api/admin")
|
||||
//protected.Use(middlewares.JwtAuthMiddleware())
|
||||
//protected.GET("/user", controllers.CurrentUser)
|
||||
|
||||
// Get secrets
|
||||
protected := r.Group("/api/secret")
|
||||
protected.Use(middlewares.JwtAuthMiddleware())
|
||||
protected.GET("/device", controllers.CurrentUser)
|
||||
|
||||
r.Run(":8443")
|
||||
*/
|
||||
|
||||
}
|
||||
|
@@ -44,6 +44,7 @@ func (s *Secret) SaveSecret() (*Secret, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Returns all matching secrets, up to caller to determine how to deal with multiple results
|
||||
func GetSecrets(s *Secret) ([]Secret, error) {
|
||||
var err error
|
||||
var rows *sqlx.Rows
|
||||
@@ -52,6 +53,7 @@ func GetSecrets(s *Secret) ([]Secret, error) {
|
||||
fmt.Printf("GetSecret querying values '%v'\n", s)
|
||||
|
||||
// Determine whether to query for a specific device or a category of devices
|
||||
// Prefer querying device name than category
|
||||
if s.DeviceName != "" {
|
||||
rows, err = db.Queryx("SELECT * FROM secrets WHERE DeviceName LIKE ? AND RoleId = ?", s.DeviceName, s.RoleId)
|
||||
} else if s.DeviceCategory != "" {
|
||||
@@ -62,8 +64,6 @@ func GetSecrets(s *Secret) ([]Secret, error) {
|
||||
return secretResults, err
|
||||
}
|
||||
|
||||
// TODO - do we want to generate an error if the query returns more than one result?
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("GetSecret error executing sql record : '%s'\n", err)
|
||||
return secretResults, err
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/joho/godotenv"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
@@ -124,7 +125,17 @@ func CreateTables() {
|
||||
}
|
||||
rowCount, _ = CheckCount("users")
|
||||
if rowCount == 0 {
|
||||
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', '$2a$10$k1qldm.bWqZsQWrKPdahR.Pfz5LxkMUka2.8INEeSD7euzkiznIR.');"); err != nil {
|
||||
// Check if there was an initial password defined in the .env file
|
||||
initialPassword := os.Getenv("INITIAL_PASSWORD")
|
||||
if initialPassword == "" {
|
||||
initialPassword = "password"
|
||||
} else if initialPassword[:4] == "$2a$" {
|
||||
fmt.Printf("CreateTables inital admin password is already a hash")
|
||||
} else {
|
||||
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
|
||||
initialPassword = string(cryptText)
|
||||
}
|
||||
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', ?);", initialPassword); err != nil {
|
||||
fmt.Printf("Error adding initial admin role : '%s'", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -223,7 +234,7 @@ func CheckColumnExists(table string, column string) (bool, error) {
|
||||
for rows.Next() {
|
||||
// cols is an []interface{} of all of the column results
|
||||
cols, _ := rows.SliceScan()
|
||||
fmt.Printf("CheckColumnExists Value is '%v'\n", cols[0].(int64))
|
||||
fmt.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
|
||||
count = cols[0].(int64)
|
||||
|
||||
if count == 1 {
|
||||
|
@@ -12,7 +12,7 @@ type User struct {
|
||||
UserId int `db:"UserId"`
|
||||
RoleId int `db:"RoleId"`
|
||||
UserName string `db:"UserName"`
|
||||
Password string `db:"Password"`
|
||||
Password string `db:"Password" json:"-"`
|
||||
}
|
||||
|
||||
type UserRole struct {
|
||||
@@ -87,7 +87,7 @@ func GetUserByID(uid uint) (User, error) {
|
||||
var u User
|
||||
|
||||
// Query database for matching user object
|
||||
err := db.QueryRowx("SELECT * FROM users INNER JOIN roles ON users.RoleId = roles.RoleId WHERE UserId=?", uid).StructScan(&u)
|
||||
err := db.QueryRowx("SELECT * FROM users WHERE UserId=?", uid).StructScan(&u)
|
||||
if err != nil {
|
||||
return u, errors.New("user not found")
|
||||
}
|
||||
@@ -96,12 +96,25 @@ func GetUserByID(uid uint) (User, error) {
|
||||
return u, errors.New("User not found!")
|
||||
}
|
||||
*/
|
||||
u.PrepareGive()
|
||||
//u.PrepareGive()
|
||||
|
||||
return u, nil
|
||||
|
||||
}
|
||||
|
||||
func GetUserByName(username string) (User, error) {
|
||||
|
||||
var u User
|
||||
|
||||
// Query database for matching user object
|
||||
err := db.QueryRowx("SELECT * FROM users WHERE UserName=?", username).StructScan(&u)
|
||||
if err != nil {
|
||||
return u, errors.New("user not found")
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func GetUserRoleByID(uid uint) (UserRole, error) {
|
||||
|
||||
var ur UserRole
|
||||
@@ -118,6 +131,8 @@ func GetUserRoleByID(uid uint) (UserRole, error) {
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
func (u *User) PrepareGive() {
|
||||
u.Password = ""
|
||||
}
|
||||
*/
|
||||
|
@@ -62,6 +62,7 @@ func ExtractTokenID(c *gin.Context) (uint, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
// Why return the secret??
|
||||
return []byte(os.Getenv("API_SECRET")), nil
|
||||
})
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user