improve checks
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,4 @@
|
|||||||
api\ tests.txt
|
api\ tests.txt
|
||||||
ccsecrets
|
ccsecrets
|
||||||
ccsecrets.db
|
ccsecrets.*
|
||||||
.env
|
.env
|
@@ -36,6 +36,19 @@ func Register(c *gin.Context) {
|
|||||||
u.UserName = input.Username
|
u.UserName = input.Username
|
||||||
u.Password = input.Password
|
u.Password = input.Password
|
||||||
|
|
||||||
|
//remove spaces in username
|
||||||
|
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
|
||||||
|
|
||||||
|
// Check if user already exists
|
||||||
|
testUser, _ := models.GetUserByName(u.UserName)
|
||||||
|
fmt.Printf("Register checking if user already exists : '%v'\n", testUser)
|
||||||
|
if (models.User{} == testUser) {
|
||||||
|
fmt.Printf("Register confirmed no existing username\n")
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Attempt to register conflicting username"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
//turn password into hash
|
//turn password into hash
|
||||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -46,9 +59,6 @@ func Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
u.Password = string(hashedPassword)
|
u.Password = string(hashedPassword)
|
||||||
|
|
||||||
//remove spaces in username
|
|
||||||
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
|
|
||||||
|
|
||||||
_, err = u.SaveUser()
|
_, err = u.SaveUser()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -3,6 +3,7 @@ package controllers
|
|||||||
import (
|
import (
|
||||||
"ccsecrets/models"
|
"ccsecrets/models"
|
||||||
"ccsecrets/utils/token"
|
"ccsecrets/utils/token"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@@ -17,6 +18,52 @@ type RetrieveInput struct {
|
|||||||
func RetrieveSecret(c *gin.Context) {
|
func RetrieveSecret(c *gin.Context) {
|
||||||
var input RetrieveInput
|
var input RetrieveInput
|
||||||
|
|
||||||
|
// Validate the input matches our struct
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("RetrieveSecret received JSON input '%v'\n", input)
|
||||||
|
|
||||||
|
// Get the user and role id of the requestor
|
||||||
|
user_id, err := token.ExtractTokenID(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := models.GetUserRoleByID(user_id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate fields
|
||||||
|
s := models.Secret{}
|
||||||
|
s.RoleId = u.RoleId
|
||||||
|
s.DeviceName = input.DeviceName
|
||||||
|
s.DeviceCategory = input.DeviceCategory
|
||||||
|
|
||||||
|
results, err := models.GetSecrets(&s)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) > 1 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": errors.New("found multiple matching secrets, use retrieveMultiple instead")})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// output results as json
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "success", "data": results})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveMultpleSecrets(c *gin.Context) {
|
||||||
|
var input RetrieveInput
|
||||||
|
|
||||||
// Validate the input matches our struct
|
// Validate the input matches our struct
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
@@ -41,6 +41,19 @@ func StoreSecret(c *gin.Context) {
|
|||||||
s.RoleId = 1
|
s.RoleId = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If this secret already exists in the database then generate an error
|
||||||
|
checkExists, err := models.GetSecrets(&s)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(checkExists) > 0 {
|
||||||
|
fmt.Printf("StoreSecret not storing secret with '%d' already matching secrets.\n", len(checkExists))
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "StoreSecret attempting to store secret already defined. API calls for update/delete don't yet exist"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Encrypt secret
|
// Encrypt secret
|
||||||
s.Secret = input.SecretValue
|
s.Secret = input.SecretValue
|
||||||
_, err = s.EncryptSecret()
|
_, err = s.EncryptSecret()
|
||||||
@@ -49,14 +62,6 @@ func StoreSecret(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is just here for testing to make sure that decryption works
|
|
||||||
/*
|
|
||||||
_, err = s.DecryptSecret()
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"Error decrypting secret": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
_, err = s.SaveSecret()
|
_, err = s.SaveSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"Error saving secret": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"Error saving secret": err.Error()})
|
||||||
|
55
main.go
55
main.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -18,7 +19,22 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// For build numbers, from https://blog.kowalczyk.info/article/vEja/embedding-build-number-in-go-executable.html
|
||||||
|
var sha1ver string // sha1 revision used to build the program
|
||||||
|
var buildTime string // when the executable was built
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
// Open connection to logfile
|
||||||
|
// From https://ispycode.com/GO/Logging/Logging-to-multiple-destinations
|
||||||
|
logFile := os.Getenv("LOG_FILE")
|
||||||
|
if logFile == "" {
|
||||||
|
logFile = "./ccsecrets.log"
|
||||||
|
}
|
||||||
|
logfileWriter, err := os.OpenFile(logFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Unable to write logfile", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
// Initiate connection to sqlite and make sure our schema is up to date
|
// Initiate connection to sqlite and make sure our schema is up to date
|
||||||
models.ConnectDatabase()
|
models.ConnectDatabase()
|
||||||
@@ -27,10 +43,22 @@ func main() {
|
|||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
defer stop()
|
defer stop()
|
||||||
|
|
||||||
router := gin.Default()
|
// Creates a router without any middleware by default
|
||||||
|
router := gin.New()
|
||||||
|
|
||||||
|
// Global middleware
|
||||||
|
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
|
||||||
|
// By default gin.DefaultWriter = os.Stdout
|
||||||
|
gin.DefaultWriter = io.MultiWriter(logfileWriter, os.Stdout)
|
||||||
|
router.Use(gin.Logger())
|
||||||
|
|
||||||
|
// Recovery middleware recovers from any panics and writes a 500 if there was one.
|
||||||
|
router.Use(gin.Recovery())
|
||||||
|
|
||||||
|
// TODO - think of a better default landing page
|
||||||
router.GET("/", func(c *gin.Context) {
|
router.GET("/", func(c *gin.Context) {
|
||||||
//time.Sleep(10 * time.Second)
|
//time.Sleep(10 * time.Second)
|
||||||
c.String(http.StatusOK, "Hello World.")
|
c.String(http.StatusOK, fmt.Sprintf("Built on %s from sha1 %s\n", buildTime, sha1ver))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set some options for TLS
|
// Set some options for TLS
|
||||||
@@ -90,6 +118,7 @@ func main() {
|
|||||||
protected := router.Group("/api/secret")
|
protected := router.Group("/api/secret")
|
||||||
protected.Use(middlewares.JwtAuthMiddleware())
|
protected.Use(middlewares.JwtAuthMiddleware())
|
||||||
protected.GET("/retrieve", controllers.RetrieveSecret)
|
protected.GET("/retrieve", controllers.RetrieveSecret)
|
||||||
|
protected.GET("/retrieveMultiple", controllers.RetrieveMultpleSecrets)
|
||||||
protected.POST("/store", controllers.StoreSecret)
|
protected.POST("/store", controllers.StoreSecret)
|
||||||
|
|
||||||
// Initializing the server in a goroutine so that
|
// Initializing the server in a goroutine so that
|
||||||
@@ -118,26 +147,4 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Server exiting")
|
log.Println("Server exiting")
|
||||||
|
|
||||||
/*
|
|
||||||
r := gin.Default()
|
|
||||||
|
|
||||||
// Define our routes underneath /api
|
|
||||||
public := r.Group("/api")
|
|
||||||
public.POST("/register", controllers.Register)
|
|
||||||
public.POST("/login", controllers.Login)
|
|
||||||
|
|
||||||
// This is just PoC really, we can get rid of it
|
|
||||||
//protected := r.Group("/api/admin")
|
|
||||||
//protected.Use(middlewares.JwtAuthMiddleware())
|
|
||||||
//protected.GET("/user", controllers.CurrentUser)
|
|
||||||
|
|
||||||
// Get secrets
|
|
||||||
protected := r.Group("/api/secret")
|
|
||||||
protected.Use(middlewares.JwtAuthMiddleware())
|
|
||||||
protected.GET("/device", controllers.CurrentUser)
|
|
||||||
|
|
||||||
r.Run(":8443")
|
|
||||||
*/
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -44,6 +44,7 @@ func (s *Secret) SaveSecret() (*Secret, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns all matching secrets, up to caller to determine how to deal with multiple results
|
||||||
func GetSecrets(s *Secret) ([]Secret, error) {
|
func GetSecrets(s *Secret) ([]Secret, error) {
|
||||||
var err error
|
var err error
|
||||||
var rows *sqlx.Rows
|
var rows *sqlx.Rows
|
||||||
@@ -52,6 +53,7 @@ func GetSecrets(s *Secret) ([]Secret, error) {
|
|||||||
fmt.Printf("GetSecret querying values '%v'\n", s)
|
fmt.Printf("GetSecret querying values '%v'\n", s)
|
||||||
|
|
||||||
// Determine whether to query for a specific device or a category of devices
|
// Determine whether to query for a specific device or a category of devices
|
||||||
|
// Prefer querying device name than category
|
||||||
if s.DeviceName != "" {
|
if s.DeviceName != "" {
|
||||||
rows, err = db.Queryx("SELECT * FROM secrets WHERE DeviceName LIKE ? AND RoleId = ?", s.DeviceName, s.RoleId)
|
rows, err = db.Queryx("SELECT * FROM secrets WHERE DeviceName LIKE ? AND RoleId = ?", s.DeviceName, s.RoleId)
|
||||||
} else if s.DeviceCategory != "" {
|
} else if s.DeviceCategory != "" {
|
||||||
@@ -62,8 +64,6 @@ func GetSecrets(s *Secret) ([]Secret, error) {
|
|||||||
return secretResults, err
|
return secretResults, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - do we want to generate an error if the query returns more than one result?
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("GetSecret error executing sql record : '%s'\n", err)
|
fmt.Printf("GetSecret error executing sql record : '%s'\n", err)
|
||||||
return secretResults, err
|
return secretResults, err
|
||||||
|
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,7 +125,17 @@ func CreateTables() {
|
|||||||
}
|
}
|
||||||
rowCount, _ = CheckCount("users")
|
rowCount, _ = CheckCount("users")
|
||||||
if rowCount == 0 {
|
if rowCount == 0 {
|
||||||
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', '$2a$10$k1qldm.bWqZsQWrKPdahR.Pfz5LxkMUka2.8INEeSD7euzkiznIR.');"); err != nil {
|
// Check if there was an initial password defined in the .env file
|
||||||
|
initialPassword := os.Getenv("INITIAL_PASSWORD")
|
||||||
|
if initialPassword == "" {
|
||||||
|
initialPassword = "password"
|
||||||
|
} else if initialPassword[:4] == "$2a$" {
|
||||||
|
fmt.Printf("CreateTables inital admin password is already a hash")
|
||||||
|
} else {
|
||||||
|
cryptText, _ := bcrypt.GenerateFromPassword([]byte(initialPassword), bcrypt.DefaultCost)
|
||||||
|
initialPassword = string(cryptText)
|
||||||
|
}
|
||||||
|
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', ?);", initialPassword); err != nil {
|
||||||
fmt.Printf("Error adding initial admin role : '%s'", err)
|
fmt.Printf("Error adding initial admin role : '%s'", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@@ -223,7 +234,7 @@ func CheckColumnExists(table string, column string) (bool, error) {
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
// cols is an []interface{} of all of the column results
|
// cols is an []interface{} of all of the column results
|
||||||
cols, _ := rows.SliceScan()
|
cols, _ := rows.SliceScan()
|
||||||
fmt.Printf("CheckColumnExists Value is '%v'\n", cols[0].(int64))
|
fmt.Printf("CheckColumnExists Value is '%v' for table '%s' and column '%s'\n", cols[0].(int64), table, column)
|
||||||
count = cols[0].(int64)
|
count = cols[0].(int64)
|
||||||
|
|
||||||
if count == 1 {
|
if count == 1 {
|
||||||
|
@@ -12,7 +12,7 @@ type User struct {
|
|||||||
UserId int `db:"UserId"`
|
UserId int `db:"UserId"`
|
||||||
RoleId int `db:"RoleId"`
|
RoleId int `db:"RoleId"`
|
||||||
UserName string `db:"UserName"`
|
UserName string `db:"UserName"`
|
||||||
Password string `db:"Password"`
|
Password string `db:"Password" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRole struct {
|
type UserRole struct {
|
||||||
@@ -87,7 +87,7 @@ func GetUserByID(uid uint) (User, error) {
|
|||||||
var u User
|
var u User
|
||||||
|
|
||||||
// Query database for matching user object
|
// Query database for matching user object
|
||||||
err := db.QueryRowx("SELECT * FROM users INNER JOIN roles ON users.RoleId = roles.RoleId WHERE UserId=?", uid).StructScan(&u)
|
err := db.QueryRowx("SELECT * FROM users WHERE UserId=?", uid).StructScan(&u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return u, errors.New("user not found")
|
return u, errors.New("user not found")
|
||||||
}
|
}
|
||||||
@@ -96,12 +96,25 @@ func GetUserByID(uid uint) (User, error) {
|
|||||||
return u, errors.New("User not found!")
|
return u, errors.New("User not found!")
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
u.PrepareGive()
|
//u.PrepareGive()
|
||||||
|
|
||||||
return u, nil
|
return u, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUserByName(username string) (User, error) {
|
||||||
|
|
||||||
|
var u User
|
||||||
|
|
||||||
|
// Query database for matching user object
|
||||||
|
err := db.QueryRowx("SELECT * FROM users WHERE UserName=?", username).StructScan(&u)
|
||||||
|
if err != nil {
|
||||||
|
return u, errors.New("user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
func GetUserRoleByID(uid uint) (UserRole, error) {
|
func GetUserRoleByID(uid uint) (UserRole, error) {
|
||||||
|
|
||||||
var ur UserRole
|
var ur UserRole
|
||||||
@@ -118,6 +131,8 @@ func GetUserRoleByID(uid uint) (UserRole, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
func (u *User) PrepareGive() {
|
func (u *User) PrepareGive() {
|
||||||
u.Password = ""
|
u.Password = ""
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
@@ -62,6 +62,7 @@ func ExtractTokenID(c *gin.Context) (uint, error) {
|
|||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
}
|
}
|
||||||
|
// Why return the secret??
|
||||||
return []byte(os.Getenv("API_SECRET")), nil
|
return []byte(os.Getenv("API_SECRET")), nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Reference in New Issue
Block a user