working registration

This commit is contained in:
2023-03-28 23:00:18 +11:00
parent 8dc02a98bd
commit 7495a341cd
5 changed files with 498 additions and 9 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
api\ tests.txt
ccsecrets
ccsecrets.db

View File

@@ -1,11 +1,51 @@
package controllers package controllers
import ( import (
"html"
"net/http" "net/http"
"strings"
"ccsecrets/models"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
) )
func Register(c *gin.Context) { type RegisterInput struct {
c.JSON(http.StatusOK, gin.H{"data": "hello from controller!"}) Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
func Register(c *gin.Context) {
var input RegisterInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
u := models.User{}
u.RoleId = 1
u.UserName = input.Username
u.Password = input.Password
//turn password into hash
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"Error hashing password": err.Error()})
return
}
u.Password = string(hashedPassword)
//remove spaces in username
u.UserName = html.EscapeString(strings.TrimSpace(u.UserName))
_, err = u.SaveUser()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"Error saving user": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "registration success"})
} }

View File

@@ -1,8 +1,10 @@
package models package models
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"reflect"
"ccsecrets/utils" "ccsecrets/utils"
@@ -30,7 +32,7 @@ const createUsers string = `
UserId INTEGER PRIMARY KEY ASC, UserId INTEGER PRIMARY KEY ASC,
RoleId INTEGER, RoleId INTEGER,
UserName VARCHAR, UserName VARCHAR,
UserPass VARCHAR, Password VARCHAR,
AccessToken varchar, AccessToken varchar,
FOREIGN KEY (RoleId) REFERENCES roles(RoleId) FOREIGN KEY (RoleId) REFERENCES roles(RoleId)
); );
@@ -68,10 +70,12 @@ func ConnectDatabase() {
fmt.Printf("Connected to sqlite database file '%s'\n", sqlPath) fmt.Printf("Connected to sqlite database file '%s'\n", sqlPath)
} }
//sqlx.NameMapper = func(s string) string { return s }
// Make sure our tables exist // Make sure our tables exist
CreateTables() CreateTables()
defer db.Close() //defer db.Close()
} }
func CreateTables() { func CreateTables() {
@@ -82,11 +86,20 @@ func CreateTables() {
fmt.Printf("Error checking roles table : '%s'", err) fmt.Printf("Error checking roles table : '%s'", err)
os.Exit(1) os.Exit(1)
} }
if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false, true);"); err != nil {
fmt.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
// Users table // Users table
if _, err = db.Exec(createUsers); err != nil { if _, err = db.Exec(createUsers); err != nil {
fmt.Printf("Error checking users table : '%s'", err) fmt.Printf("Error checking users table : '%s'", err)
os.Exit(1) os.Exit(1)
} }
if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', 'password', 'token');"); err != nil {
fmt.Printf("Error adding initial admin role : '%s'", err)
os.Exit(1)
}
// Secrets table // Secrets table
if _, err = db.Exec(createSecrets); err != nil { if _, err = db.Exec(createSecrets); err != nil {
fmt.Printf("Error checking secrets table : '%s'", err) fmt.Printf("Error checking secrets table : '%s'", err)
@@ -106,6 +119,52 @@ func CreateTables() {
} }
} }
// From https://stackoverflow.com/a/60100045
func GenerateInsertMethod(q interface{}) (string, error) {
if reflect.ValueOf(q).Kind() == reflect.Struct {
query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name())
fieldNames := ""
fieldValues := ""
v := reflect.ValueOf(q)
for i := 0; i < v.NumField(); i++ {
if i == 0 {
fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name)
} else {
fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name)
}
switch v.Field(i).Kind() {
case reflect.Int:
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int())
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int())
}
case reflect.String:
if i == 0 {
fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String())
} else {
fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String())
}
case reflect.Bool:
var boolSet int8
if v.Field(i).Bool() {
boolSet = 1
}
if i == 0 {
fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet)
} else {
fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet)
}
default:
fmt.Printf("Unsupported type '%s'\n", v.Field(i).Kind())
}
}
query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues)
return query, nil
}
return "", errors.New("SqlGenerationError")
}
func CheckColumnExists(table string, column string) (bool, error) { func CheckColumnExists(table string, column string) (bool, error) {
var count int64 var count int64
rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';") rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';")

View File

@@ -1,9 +1,47 @@
package models package models
import "fmt"
type User struct { type User struct {
UserId int `db:UserId` UserId int `db:"UserId"`
RoleId int `db:RoleId` RoleId int `db:"RoleId"`
UserName string `db:UserName` UserName string `db:"UserName"`
UserPass string `db:UserPass` Password string `db:"Password"`
AccessToken string `db:AccessToken` AccessToken string `db:"AccessToken"`
}
func (u *User) SaveUser() (*User, error) {
var err error
/*
sql, err := GenerateInsertMethod(&u)
if err != nil {
fmt.Printf("SaveUser error generating sql record : '%s'\n", err)
return &User{}, err
} else {
fmt.Println(sql)
}
result, err := db.Exec(sql)
*/
fmt.Printf("SaveUser received object '%v'\n", u.RoleId)
result, err := db.NamedExec((`INSERT INTO users (RoleId, UserName, Password, AccessToken) VALUES (:RoleId, :UserName, :Password, :AccessToken)`), u)
if err != nil {
fmt.Printf("SaveUser error executing sql record : '%s'\n", err)
return &User{}, err
} else {
affected, _ := result.RowsAffected()
id, _ := result.LastInsertId()
fmt.Printf("SaveUser insert returned result id '%d' affecting %d row(s).\n", id, affected)
}
/*
err = CreateUser(&u).Error
if err != nil {
return &User{}, err
}
*/
return u, nil
} }

349
utils/crypt.go Normal file
View File

@@ -0,0 +1,349 @@
package utils
import (
"bytes"
"crypto/rand"
"crypto/sha512"
"errors"
"strconv"
)
// code in this file taken from https://github.com/tredoe/osutil/blob/master/v2/userutil/crypt/sha512_crypt/sha512_crypt.go
var (
ErrSaltPrefix = errors.New("invalid magic prefix")
ErrSaltFormat = errors.New("invalid salt format")
ErrSaltRounds = errors.New("invalid rounds")
)
// Salt represents a salt.
type Salt struct {
MagicPrefix []byte
SaltLenMin int
SaltLenMax int
RoundsMin int
RoundsMax int
RoundsDefault int
}
type crypter struct{ Salt Salt }
var _rounds = []byte("rounds=")
const (
MagicPrefix = "$6$"
SaltLenMin = 1
SaltLenMax = 16
RoundsMin = 1000
RoundsMax = 999999999
RoundsDefault = 5000
alphabet = "./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
)
// Generate generates a random salt of a given length.
//
// The length is set thus:
//
// length > SaltLenMax: length = SaltLenMax
// length < SaltLenMin: length = SaltLenMin
func (s *Salt) Generate(length int) []byte {
if length > s.SaltLenMax {
length = s.SaltLenMax
} else if length < s.SaltLenMin {
length = s.SaltLenMin
}
saltLen := (length * 6 / 8)
if (length*6)%8 != 0 {
saltLen++
}
salt := make([]byte, saltLen)
rand.Read(salt)
out := make([]byte, len(s.MagicPrefix)+length)
copy(out, s.MagicPrefix)
copy(out[len(s.MagicPrefix):], Base64_24Bit(salt))
return out
}
// GenerateWRounds creates a random salt with the random bytes being of the
// length provided, and the rounds parameter set as specified.
//
// The parameters are set thus:
//
// length > SaltLenMax: length = SaltLenMax
// length < SaltLenMin: length = SaltLenMin
//
// rounds < 0: rounds = RoundsDefault
// rounds < RoundsMin: rounds = RoundsMin
// rounds > RoundsMax: rounds = RoundsMax
//
// If rounds is equal to RoundsDefault, then the "rounds=" part of the salt is
// removed.
func (s *Salt) GenerateWRounds(length, rounds int) []byte {
if length > s.SaltLenMax {
length = s.SaltLenMax
} else if length < s.SaltLenMin {
length = s.SaltLenMin
}
if rounds < 0 {
rounds = s.RoundsDefault
} else if rounds < s.RoundsMin {
rounds = s.RoundsMin
} else if rounds > s.RoundsMax {
rounds = s.RoundsMax
}
//fmt.Printf("GenerateWRounds length is %d and rounds is %d\n", length, rounds)
saltLen := (length * 6 / 8)
if (length*6)%8 != 0 {
saltLen++
}
salt := make([]byte, saltLen)
rand.Read(salt)
roundsText := ""
if rounds != s.RoundsDefault {
roundsText = "rounds=" + strconv.Itoa(rounds) + "$"
}
out := make([]byte, len(s.MagicPrefix)+len(roundsText)+length)
copy(out, s.MagicPrefix)
//fmt.Printf("GenerateWRounds copy 1 : '%v'\n", out)
copy(out[len(s.MagicPrefix):], []byte(roundsText))
//fmt.Printf("GenerateWRounds copy 2 : '%v'\n", out)
copy(out[len(s.MagicPrefix)+len(roundsText):], Base64_24Bit(salt))
//fmt.Printf("GenerateWRounds copy 3 : '%v'\n", out)
return out
}
func Base64_24Bit(src []byte) (hash []byte) {
if len(src) == 0 {
return []byte{} // TODO: return nil
}
hashSize := (len(src) * 8) / 6
if (len(src) % 6) != 0 {
hashSize += 1
}
hash = make([]byte, hashSize)
dst := hash
for len(src) > 0 {
switch len(src) {
default:
dst[0] = alphabet[src[0]&0x3f]
dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f]
dst[2] = alphabet[((src[1]>>4)|(src[2]<<4))&0x3f]
dst[3] = alphabet[(src[2]>>2)&0x3f]
src = src[3:]
dst = dst[4:]
case 2:
dst[0] = alphabet[src[0]&0x3f]
dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f]
dst[2] = alphabet[(src[1]>>4)&0x3f]
src = src[2:]
dst = dst[3:]
case 1:
dst[0] = alphabet[src[0]&0x3f]
dst[1] = alphabet[(src[0]>>6)&0x3f]
src = src[1:]
dst = dst[2:]
}
}
return
}
func Generate(key, salt []byte) (string, error) {
var rounds int
var isRoundsDef bool
var c crypter
c.Salt = GetSalt()
if len(salt) == 0 {
salt = c.Salt.GenerateWRounds(SaltLenMax, RoundsDefault)
//fmt.Printf("Generate created salt with value '%v'\n", salt)
}
if !bytes.HasPrefix(salt, c.Salt.MagicPrefix) {
//fmt.Printf("Generate salt '%v' has no magic prefix\n", salt)
return "", ErrSaltPrefix
}
saltToks := bytes.Split(salt, []byte{'$'})
if len(saltToks) < 3 {
return "", ErrSaltFormat
}
if bytes.HasPrefix(saltToks[2], _rounds) {
isRoundsDef = true
pr, err := strconv.ParseInt(string(saltToks[2][7:]), 10, 32)
if err != nil {
return "", ErrSaltRounds
}
rounds = int(pr)
if rounds < RoundsMin {
rounds = RoundsMin
} else if rounds > RoundsMax {
rounds = RoundsMax
}
salt = saltToks[3]
} else {
rounds = RoundsDefault
salt = saltToks[2]
}
if len(salt) > SaltLenMax {
salt = salt[0:SaltLenMax]
}
// Compute alternate SHA512 sum with input KEY, SALT, and KEY.
Alternate := sha512.New()
Alternate.Write(key)
Alternate.Write(salt)
Alternate.Write(key)
AlternateSum := Alternate.Sum(nil) // 64 bytes
A := sha512.New()
A.Write(key)
A.Write(salt)
// Add for any character in the key one byte of the alternate sum.
i := len(key)
for ; i > 64; i -= 64 {
A.Write(AlternateSum)
}
A.Write(AlternateSum[0:i])
// Take the binary representation of the length of the key and for every add
// the alternate sum, for every 0 the key.
for i = len(key); i > 0; i >>= 1 {
if (i & 1) != 0 {
A.Write(AlternateSum)
} else {
A.Write(key)
}
}
Asum := A.Sum(nil)
// Start computation of P byte sequence.
P := sha512.New()
// For every character in the password add the entire password.
for i = 0; i < len(key); i++ {
P.Write(key)
}
Psum := P.Sum(nil)
// Create byte sequence P.
Pseq := make([]byte, 0, len(key))
for i = len(key); i > 64; i -= 64 {
Pseq = append(Pseq, Psum...)
}
Pseq = append(Pseq, Psum[0:i]...)
// Start computation of S byte sequence.
S := sha512.New()
for i = 0; i < (16 + int(Asum[0])); i++ {
S.Write(salt)
}
Ssum := S.Sum(nil)
// Create byte sequence S.
Sseq := make([]byte, 0, len(salt))
for i = len(salt); i > 64; i -= 64 {
Sseq = append(Sseq, Ssum...)
}
Sseq = append(Sseq, Ssum[0:i]...)
Csum := Asum
// Repeatedly run the collected hash value through SHA512 to burn CPU cycles.
for i = 0; i < rounds; i++ {
C := sha512.New()
// Add key or last result.
if (i & 1) != 0 {
C.Write(Pseq)
} else {
C.Write(Csum)
}
// Add salt for numbers not divisible by 3.
if (i % 3) != 0 {
C.Write(Sseq)
}
// Add key for numbers not divisible by 7.
if (i % 7) != 0 {
C.Write(Pseq)
}
// Add key or last result.
if (i & 1) != 0 {
C.Write(Csum)
} else {
C.Write(Pseq)
}
Csum = C.Sum(nil)
}
out := make([]byte, 0, 123)
out = append(out, MagicPrefix...)
if isRoundsDef {
out = append(out, []byte("rounds="+strconv.Itoa(rounds)+"$")...)
}
out = append(out, salt...)
out = append(out, '$')
out = append(out, Base64_24Bit([]byte{
Csum[42], Csum[21], Csum[0],
Csum[1], Csum[43], Csum[22],
Csum[23], Csum[2], Csum[44],
Csum[45], Csum[24], Csum[3],
Csum[4], Csum[46], Csum[25],
Csum[26], Csum[5], Csum[47],
Csum[48], Csum[27], Csum[6],
Csum[7], Csum[49], Csum[28],
Csum[29], Csum[8], Csum[50],
Csum[51], Csum[30], Csum[9],
Csum[10], Csum[52], Csum[31],
Csum[32], Csum[11], Csum[53],
Csum[54], Csum[33], Csum[12],
Csum[13], Csum[55], Csum[34],
Csum[35], Csum[14], Csum[56],
Csum[57], Csum[36], Csum[15],
Csum[16], Csum[58], Csum[37],
Csum[38], Csum[17], Csum[59],
Csum[60], Csum[39], Csum[18],
Csum[19], Csum[61], Csum[40],
Csum[41], Csum[20], Csum[62],
Csum[63],
})...)
// Clean sensitive data.
A.Reset()
Alternate.Reset()
P.Reset()
for i = 0; i < len(Asum); i++ {
Asum[i] = 0
}
for i = 0; i < len(AlternateSum); i++ {
AlternateSum[i] = 0
}
for i = 0; i < len(Pseq); i++ {
Pseq[i] = 0
}
return string(out), nil
}
func (c *crypter) SetSalt(salt Salt) { c.Salt = salt }
func GetSalt() Salt {
return Salt{
MagicPrefix: []byte(MagicPrefix),
SaltLenMin: SaltLenMin,
SaltLenMax: SaltLenMax,
RoundsDefault: RoundsDefault,
RoundsMin: RoundsMin,
RoundsMax: RoundsMax,
}
}