working registration
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
api\ tests.txt
|
||||
ccsecrets
|
||||
ccsecrets.db
|
@@ -1,11 +1,51 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"html"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ccsecrets/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func Register(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"data": "hello from controller!"})
|
||||
type RegisterInput struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func Register(c *gin.Context) {
|
||||
var input RegisterInput
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
u := models.User{}
|
||||
u.RoleId = 1
|
||||
u.UserName = input.Username
|
||||
u.Password = input.Password
|
||||
|
||||
//turn password into hash
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"Error hashing password": err.Error()})
|
||||
return
|
||||
}
|
||||
u.Password = string(hashedPassword)
|
||||
|
||||
//remove spaces in username
|
||||
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
|
||||
|
||||
_, err = u.SaveUser()
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"Error saving user": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "registration success"})
|
||||
}
|
||||
|
@@ -1,8 +1,10 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"ccsecrets/utils"
|
||||
|
||||
@@ -30,7 +32,7 @@ const createUsers string = `
|
||||
UserId INTEGER PRIMARY KEY ASC,
|
||||
RoleId INTEGER,
|
||||
UserName VARCHAR,
|
||||
UserPass VARCHAR,
|
||||
Password VARCHAR,
|
||||
AccessToken varchar,
|
||||
FOREIGN KEY (RoleId) REFERENCES roles(RoleId)
|
||||
);
|
||||
@@ -68,10 +70,12 @@ func ConnectDatabase() {
|
||||
fmt.Printf("Connected to sqlite database file '%s'\n", sqlPath)
|
||||
}
|
||||
|
||||
//sqlx.NameMapper = func(s string) string { return s }
|
||||
|
||||
// Make sure our tables exist
|
||||
CreateTables()
|
||||
|
||||
defer db.Close()
|
||||
//defer db.Close()
|
||||
}
|
||||
|
||||
func CreateTables() {
|
||||
@@ -82,11 +86,20 @@ func CreateTables() {
|
||||
fmt.Printf("Error checking roles table : '%s'", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false, true);"); err != nil {
|
||||
fmt.Printf("Error adding initial admin role : '%s'", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Users table
|
||||
if _, err = db.Exec(createUsers); err != nil {
|
||||
fmt.Printf("Error checking users table : '%s'", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', 'password', 'token');"); err != nil {
|
||||
fmt.Printf("Error adding initial admin role : '%s'", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Secrets table
|
||||
if _, err = db.Exec(createSecrets); err != nil {
|
||||
fmt.Printf("Error checking secrets table : '%s'", err)
|
||||
@@ -106,6 +119,52 @@ func CreateTables() {
|
||||
}
|
||||
}
|
||||
|
||||
// From https://stackoverflow.com/a/60100045
|
||||
func GenerateInsertMethod(q interface{}) (string, error) {
|
||||
if reflect.ValueOf(q).Kind() == reflect.Struct {
|
||||
query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name())
|
||||
fieldNames := ""
|
||||
fieldValues := ""
|
||||
v := reflect.ValueOf(q)
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if i == 0 {
|
||||
fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name)
|
||||
} else {
|
||||
fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name)
|
||||
}
|
||||
switch v.Field(i).Kind() {
|
||||
case reflect.Int:
|
||||
if i == 0 {
|
||||
fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int())
|
||||
} else {
|
||||
fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int())
|
||||
}
|
||||
case reflect.String:
|
||||
if i == 0 {
|
||||
fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String())
|
||||
} else {
|
||||
fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String())
|
||||
}
|
||||
case reflect.Bool:
|
||||
var boolSet int8
|
||||
if v.Field(i).Bool() {
|
||||
boolSet = 1
|
||||
}
|
||||
if i == 0 {
|
||||
fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet)
|
||||
} else {
|
||||
fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet)
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unsupported type '%s'\n", v.Field(i).Kind())
|
||||
}
|
||||
}
|
||||
query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues)
|
||||
return query, nil
|
||||
}
|
||||
return "", errors.New("SqlGenerationError")
|
||||
}
|
||||
|
||||
func CheckColumnExists(table string, column string) (bool, error) {
|
||||
var count int64
|
||||
rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';")
|
||||
|
@@ -1,9 +1,47 @@
|
||||
package models
|
||||
|
||||
import "fmt"
|
||||
|
||||
type User struct {
|
||||
UserId int `db:UserId`
|
||||
RoleId int `db:RoleId`
|
||||
UserName string `db:UserName`
|
||||
UserPass string `db:UserPass`
|
||||
AccessToken string `db:AccessToken`
|
||||
UserId int `db:"UserId"`
|
||||
RoleId int `db:"RoleId"`
|
||||
UserName string `db:"UserName"`
|
||||
Password string `db:"Password"`
|
||||
AccessToken string `db:"AccessToken"`
|
||||
}
|
||||
|
||||
func (u *User) SaveUser() (*User, error) {
|
||||
|
||||
var err error
|
||||
|
||||
/*
|
||||
sql, err := GenerateInsertMethod(&u)
|
||||
if err != nil {
|
||||
fmt.Printf("SaveUser error generating sql record : '%s'\n", err)
|
||||
return &User{}, err
|
||||
} else {
|
||||
fmt.Println(sql)
|
||||
}
|
||||
result, err := db.Exec(sql)
|
||||
*/
|
||||
|
||||
fmt.Printf("SaveUser received object '%v'\n", u.RoleId)
|
||||
|
||||
result, err := db.NamedExec((`INSERT INTO users (RoleId, UserName, Password, AccessToken) VALUES (:RoleId, :UserName, :Password, :AccessToken)`), u)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("SaveUser error executing sql record : '%s'\n", err)
|
||||
return &User{}, err
|
||||
} else {
|
||||
affected, _ := result.RowsAffected()
|
||||
id, _ := result.LastInsertId()
|
||||
fmt.Printf("SaveUser insert returned result id '%d' affecting %d row(s).\n", id, affected)
|
||||
}
|
||||
/*
|
||||
err = CreateUser(&u).Error
|
||||
if err != nil {
|
||||
return &User{}, err
|
||||
}
|
||||
*/
|
||||
return u, nil
|
||||
}
|
||||
|
349
utils/crypt.go
Normal file
349
utils/crypt.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"errors"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// code in this file taken from https://github.com/tredoe/osutil/blob/master/v2/userutil/crypt/sha512_crypt/sha512_crypt.go
|
||||
|
||||
var (
|
||||
ErrSaltPrefix = errors.New("invalid magic prefix")
|
||||
ErrSaltFormat = errors.New("invalid salt format")
|
||||
ErrSaltRounds = errors.New("invalid rounds")
|
||||
)
|
||||
|
||||
// Salt represents a salt.
|
||||
type Salt struct {
|
||||
MagicPrefix []byte
|
||||
|
||||
SaltLenMin int
|
||||
SaltLenMax int
|
||||
|
||||
RoundsMin int
|
||||
RoundsMax int
|
||||
RoundsDefault int
|
||||
}
|
||||
|
||||
type crypter struct{ Salt Salt }
|
||||
|
||||
var _rounds = []byte("rounds=")
|
||||
|
||||
const (
|
||||
MagicPrefix = "$6$"
|
||||
SaltLenMin = 1
|
||||
SaltLenMax = 16
|
||||
RoundsMin = 1000
|
||||
RoundsMax = 999999999
|
||||
RoundsDefault = 5000
|
||||
alphabet = "./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
)
|
||||
|
||||
// Generate generates a random salt of a given length.
|
||||
//
|
||||
// The length is set thus:
|
||||
//
|
||||
// length > SaltLenMax: length = SaltLenMax
|
||||
// length < SaltLenMin: length = SaltLenMin
|
||||
func (s *Salt) Generate(length int) []byte {
|
||||
if length > s.SaltLenMax {
|
||||
length = s.SaltLenMax
|
||||
} else if length < s.SaltLenMin {
|
||||
length = s.SaltLenMin
|
||||
}
|
||||
|
||||
saltLen := (length * 6 / 8)
|
||||
if (length*6)%8 != 0 {
|
||||
saltLen++
|
||||
}
|
||||
salt := make([]byte, saltLen)
|
||||
rand.Read(salt)
|
||||
|
||||
out := make([]byte, len(s.MagicPrefix)+length)
|
||||
copy(out, s.MagicPrefix)
|
||||
copy(out[len(s.MagicPrefix):], Base64_24Bit(salt))
|
||||
return out
|
||||
}
|
||||
|
||||
// GenerateWRounds creates a random salt with the random bytes being of the
|
||||
// length provided, and the rounds parameter set as specified.
|
||||
//
|
||||
// The parameters are set thus:
|
||||
//
|
||||
// length > SaltLenMax: length = SaltLenMax
|
||||
// length < SaltLenMin: length = SaltLenMin
|
||||
//
|
||||
// rounds < 0: rounds = RoundsDefault
|
||||
// rounds < RoundsMin: rounds = RoundsMin
|
||||
// rounds > RoundsMax: rounds = RoundsMax
|
||||
//
|
||||
// If rounds is equal to RoundsDefault, then the "rounds=" part of the salt is
|
||||
// removed.
|
||||
func (s *Salt) GenerateWRounds(length, rounds int) []byte {
|
||||
if length > s.SaltLenMax {
|
||||
length = s.SaltLenMax
|
||||
} else if length < s.SaltLenMin {
|
||||
length = s.SaltLenMin
|
||||
}
|
||||
if rounds < 0 {
|
||||
rounds = s.RoundsDefault
|
||||
} else if rounds < s.RoundsMin {
|
||||
rounds = s.RoundsMin
|
||||
} else if rounds > s.RoundsMax {
|
||||
rounds = s.RoundsMax
|
||||
}
|
||||
|
||||
//fmt.Printf("GenerateWRounds length is %d and rounds is %d\n", length, rounds)
|
||||
|
||||
saltLen := (length * 6 / 8)
|
||||
if (length*6)%8 != 0 {
|
||||
saltLen++
|
||||
}
|
||||
salt := make([]byte, saltLen)
|
||||
rand.Read(salt)
|
||||
|
||||
roundsText := ""
|
||||
if rounds != s.RoundsDefault {
|
||||
roundsText = "rounds=" + strconv.Itoa(rounds) + "$"
|
||||
}
|
||||
|
||||
out := make([]byte, len(s.MagicPrefix)+len(roundsText)+length)
|
||||
copy(out, s.MagicPrefix)
|
||||
//fmt.Printf("GenerateWRounds copy 1 : '%v'\n", out)
|
||||
copy(out[len(s.MagicPrefix):], []byte(roundsText))
|
||||
//fmt.Printf("GenerateWRounds copy 2 : '%v'\n", out)
|
||||
copy(out[len(s.MagicPrefix)+len(roundsText):], Base64_24Bit(salt))
|
||||
//fmt.Printf("GenerateWRounds copy 3 : '%v'\n", out)
|
||||
return out
|
||||
}
|
||||
|
||||
func Base64_24Bit(src []byte) (hash []byte) {
|
||||
if len(src) == 0 {
|
||||
return []byte{} // TODO: return nil
|
||||
}
|
||||
|
||||
hashSize := (len(src) * 8) / 6
|
||||
if (len(src) % 6) != 0 {
|
||||
hashSize += 1
|
||||
}
|
||||
hash = make([]byte, hashSize)
|
||||
|
||||
dst := hash
|
||||
for len(src) > 0 {
|
||||
switch len(src) {
|
||||
default:
|
||||
dst[0] = alphabet[src[0]&0x3f]
|
||||
dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f]
|
||||
dst[2] = alphabet[((src[1]>>4)|(src[2]<<4))&0x3f]
|
||||
dst[3] = alphabet[(src[2]>>2)&0x3f]
|
||||
src = src[3:]
|
||||
dst = dst[4:]
|
||||
case 2:
|
||||
dst[0] = alphabet[src[0]&0x3f]
|
||||
dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f]
|
||||
dst[2] = alphabet[(src[1]>>4)&0x3f]
|
||||
src = src[2:]
|
||||
dst = dst[3:]
|
||||
case 1:
|
||||
dst[0] = alphabet[src[0]&0x3f]
|
||||
dst[1] = alphabet[(src[0]>>6)&0x3f]
|
||||
src = src[1:]
|
||||
dst = dst[2:]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func Generate(key, salt []byte) (string, error) {
|
||||
var rounds int
|
||||
var isRoundsDef bool
|
||||
|
||||
var c crypter
|
||||
c.Salt = GetSalt()
|
||||
|
||||
if len(salt) == 0 {
|
||||
salt = c.Salt.GenerateWRounds(SaltLenMax, RoundsDefault)
|
||||
//fmt.Printf("Generate created salt with value '%v'\n", salt)
|
||||
}
|
||||
if !bytes.HasPrefix(salt, c.Salt.MagicPrefix) {
|
||||
//fmt.Printf("Generate salt '%v' has no magic prefix\n", salt)
|
||||
return "", ErrSaltPrefix
|
||||
}
|
||||
|
||||
saltToks := bytes.Split(salt, []byte{'$'})
|
||||
if len(saltToks) < 3 {
|
||||
return "", ErrSaltFormat
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(saltToks[2], _rounds) {
|
||||
isRoundsDef = true
|
||||
pr, err := strconv.ParseInt(string(saltToks[2][7:]), 10, 32)
|
||||
if err != nil {
|
||||
return "", ErrSaltRounds
|
||||
}
|
||||
rounds = int(pr)
|
||||
if rounds < RoundsMin {
|
||||
rounds = RoundsMin
|
||||
} else if rounds > RoundsMax {
|
||||
rounds = RoundsMax
|
||||
}
|
||||
salt = saltToks[3]
|
||||
} else {
|
||||
rounds = RoundsDefault
|
||||
salt = saltToks[2]
|
||||
}
|
||||
|
||||
if len(salt) > SaltLenMax {
|
||||
salt = salt[0:SaltLenMax]
|
||||
}
|
||||
|
||||
// Compute alternate SHA512 sum with input KEY, SALT, and KEY.
|
||||
Alternate := sha512.New()
|
||||
Alternate.Write(key)
|
||||
Alternate.Write(salt)
|
||||
Alternate.Write(key)
|
||||
AlternateSum := Alternate.Sum(nil) // 64 bytes
|
||||
|
||||
A := sha512.New()
|
||||
A.Write(key)
|
||||
A.Write(salt)
|
||||
// Add for any character in the key one byte of the alternate sum.
|
||||
i := len(key)
|
||||
for ; i > 64; i -= 64 {
|
||||
A.Write(AlternateSum)
|
||||
}
|
||||
A.Write(AlternateSum[0:i])
|
||||
|
||||
// Take the binary representation of the length of the key and for every add
|
||||
// the alternate sum, for every 0 the key.
|
||||
for i = len(key); i > 0; i >>= 1 {
|
||||
if (i & 1) != 0 {
|
||||
A.Write(AlternateSum)
|
||||
} else {
|
||||
A.Write(key)
|
||||
}
|
||||
}
|
||||
Asum := A.Sum(nil)
|
||||
|
||||
// Start computation of P byte sequence.
|
||||
P := sha512.New()
|
||||
// For every character in the password add the entire password.
|
||||
for i = 0; i < len(key); i++ {
|
||||
P.Write(key)
|
||||
}
|
||||
Psum := P.Sum(nil)
|
||||
// Create byte sequence P.
|
||||
Pseq := make([]byte, 0, len(key))
|
||||
for i = len(key); i > 64; i -= 64 {
|
||||
Pseq = append(Pseq, Psum...)
|
||||
}
|
||||
Pseq = append(Pseq, Psum[0:i]...)
|
||||
|
||||
// Start computation of S byte sequence.
|
||||
S := sha512.New()
|
||||
for i = 0; i < (16 + int(Asum[0])); i++ {
|
||||
S.Write(salt)
|
||||
}
|
||||
Ssum := S.Sum(nil)
|
||||
// Create byte sequence S.
|
||||
Sseq := make([]byte, 0, len(salt))
|
||||
for i = len(salt); i > 64; i -= 64 {
|
||||
Sseq = append(Sseq, Ssum...)
|
||||
}
|
||||
Sseq = append(Sseq, Ssum[0:i]...)
|
||||
|
||||
Csum := Asum
|
||||
|
||||
// Repeatedly run the collected hash value through SHA512 to burn CPU cycles.
|
||||
for i = 0; i < rounds; i++ {
|
||||
C := sha512.New()
|
||||
|
||||
// Add key or last result.
|
||||
if (i & 1) != 0 {
|
||||
C.Write(Pseq)
|
||||
} else {
|
||||
C.Write(Csum)
|
||||
}
|
||||
// Add salt for numbers not divisible by 3.
|
||||
if (i % 3) != 0 {
|
||||
C.Write(Sseq)
|
||||
}
|
||||
// Add key for numbers not divisible by 7.
|
||||
if (i % 7) != 0 {
|
||||
C.Write(Pseq)
|
||||
}
|
||||
// Add key or last result.
|
||||
if (i & 1) != 0 {
|
||||
C.Write(Csum)
|
||||
} else {
|
||||
C.Write(Pseq)
|
||||
}
|
||||
|
||||
Csum = C.Sum(nil)
|
||||
}
|
||||
|
||||
out := make([]byte, 0, 123)
|
||||
out = append(out, MagicPrefix...)
|
||||
if isRoundsDef {
|
||||
out = append(out, []byte("rounds="+strconv.Itoa(rounds)+"$")...)
|
||||
}
|
||||
out = append(out, salt...)
|
||||
out = append(out, '$')
|
||||
out = append(out, Base64_24Bit([]byte{
|
||||
Csum[42], Csum[21], Csum[0],
|
||||
Csum[1], Csum[43], Csum[22],
|
||||
Csum[23], Csum[2], Csum[44],
|
||||
Csum[45], Csum[24], Csum[3],
|
||||
Csum[4], Csum[46], Csum[25],
|
||||
Csum[26], Csum[5], Csum[47],
|
||||
Csum[48], Csum[27], Csum[6],
|
||||
Csum[7], Csum[49], Csum[28],
|
||||
Csum[29], Csum[8], Csum[50],
|
||||
Csum[51], Csum[30], Csum[9],
|
||||
Csum[10], Csum[52], Csum[31],
|
||||
Csum[32], Csum[11], Csum[53],
|
||||
Csum[54], Csum[33], Csum[12],
|
||||
Csum[13], Csum[55], Csum[34],
|
||||
Csum[35], Csum[14], Csum[56],
|
||||
Csum[57], Csum[36], Csum[15],
|
||||
Csum[16], Csum[58], Csum[37],
|
||||
Csum[38], Csum[17], Csum[59],
|
||||
Csum[60], Csum[39], Csum[18],
|
||||
Csum[19], Csum[61], Csum[40],
|
||||
Csum[41], Csum[20], Csum[62],
|
||||
Csum[63],
|
||||
})...)
|
||||
|
||||
// Clean sensitive data.
|
||||
A.Reset()
|
||||
Alternate.Reset()
|
||||
P.Reset()
|
||||
for i = 0; i < len(Asum); i++ {
|
||||
Asum[i] = 0
|
||||
}
|
||||
for i = 0; i < len(AlternateSum); i++ {
|
||||
AlternateSum[i] = 0
|
||||
}
|
||||
for i = 0; i < len(Pseq); i++ {
|
||||
Pseq[i] = 0
|
||||
}
|
||||
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func (c *crypter) SetSalt(salt Salt) { c.Salt = salt }
|
||||
|
||||
func GetSalt() Salt {
|
||||
return Salt{
|
||||
MagicPrefix: []byte(MagicPrefix),
|
||||
SaltLenMin: SaltLenMin,
|
||||
SaltLenMax: SaltLenMax,
|
||||
RoundsDefault: RoundsDefault,
|
||||
RoundsMin: RoundsMin,
|
||||
RoundsMax: RoundsMax,
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user