diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..97d0c99 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +api\ tests.txt +ccsecrets +ccsecrets.db \ No newline at end of file diff --git a/controllers/auth.go b/controllers/auth.go index 22e7497..e842119 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -1,11 +1,51 @@ package controllers import ( + "html" "net/http" + "strings" + + "ccsecrets/models" "github.com/gin-gonic/gin" + "golang.org/x/crypto/bcrypt" ) -func Register(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"data": "hello from controller!"}) +type RegisterInput struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + +func Register(c *gin.Context) { + var input RegisterInput + + if err := c.ShouldBindJSON(&input); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + u := models.User{} + u.RoleId = 1 + u.UserName = input.Username + u.Password = input.Password + + //turn password into hash + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"Error hashing password": err.Error()}) + return + } + u.Password = string(hashedPassword) + + //remove spaces in username + u.UserName = html.EscapeString(strings.TrimSpace(u.UserName)) + + _, err = u.SaveUser() + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"Error saving user": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "registration success"}) } diff --git a/models/setup.go b/models/setup.go index ed94f67..5b8320a 100644 --- a/models/setup.go +++ b/models/setup.go @@ -1,8 +1,10 @@ package models import ( + "errors" "fmt" "os" + "reflect" "ccsecrets/utils" @@ -30,7 +32,7 @@ const createUsers string = ` UserId INTEGER PRIMARY KEY ASC, RoleId INTEGER, UserName VARCHAR, - UserPass VARCHAR, + Password VARCHAR, AccessToken varchar, FOREIGN KEY (RoleId) REFERENCES roles(RoleId) ); @@ -68,10 +70,12 @@ func ConnectDatabase() { fmt.Printf("Connected to sqlite database file '%s'\n", sqlPath) } + //sqlx.NameMapper = func(s string) string { return s } + // Make sure our tables exist CreateTables() - defer db.Close() + //defer db.Close() } func CreateTables() { @@ -82,11 +86,20 @@ func CreateTables() { fmt.Printf("Error checking roles table : '%s'", err) os.Exit(1) } + if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false, true);"); err != nil { + fmt.Printf("Error adding initial admin role : '%s'", err) + os.Exit(1) + } + // Users table if _, err = db.Exec(createUsers); err != nil { fmt.Printf("Error checking users table : '%s'", err) os.Exit(1) } + if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', 'password', 'token');"); err != nil { + fmt.Printf("Error adding initial admin role : '%s'", err) + os.Exit(1) + } // Secrets table if _, err = db.Exec(createSecrets); err != nil { fmt.Printf("Error checking secrets table : '%s'", err) @@ -106,6 +119,52 @@ func CreateTables() { } } +// From https://stackoverflow.com/a/60100045 +func GenerateInsertMethod(q interface{}) (string, error) { + if reflect.ValueOf(q).Kind() == reflect.Struct { + query := fmt.Sprintf("INSERT INTO %s", reflect.TypeOf(q).Name()) + fieldNames := "" + fieldValues := "" + v := reflect.ValueOf(q) + for i := 0; i < v.NumField(); i++ { + if i == 0 { + fieldNames = fmt.Sprintf("%s%s", fieldNames, v.Type().Field(i).Name) + } else { + fieldNames = fmt.Sprintf("%s, %s", fieldNames, v.Type().Field(i).Name) + } + switch v.Field(i).Kind() { + case reflect.Int: + if i == 0 { + fieldValues = fmt.Sprintf("%s%d", fieldValues, v.Field(i).Int()) + } else { + fieldValues = fmt.Sprintf("%s, %d", fieldValues, v.Field(i).Int()) + } + case reflect.String: + if i == 0 { + fieldValues = fmt.Sprintf("%s\"%s\"", fieldValues, v.Field(i).String()) + } else { + fieldValues = fmt.Sprintf("%s, \"%s\"", fieldValues, v.Field(i).String()) + } + case reflect.Bool: + var boolSet int8 + if v.Field(i).Bool() { + boolSet = 1 + } + if i == 0 { + fieldValues = fmt.Sprintf("%s%d", fieldValues, boolSet) + } else { + fieldValues = fmt.Sprintf("%s, %d", fieldValues, boolSet) + } + default: + fmt.Printf("Unsupported type '%s'\n", v.Field(i).Kind()) + } + } + query = fmt.Sprintf("%s(%s) VALUES (%s)", query, fieldNames, fieldValues) + return query, nil + } + return "", errors.New("SqlGenerationError") +} + func CheckColumnExists(table string, column string) (bool, error) { var count int64 rows, err := db.Queryx("SELECT COUNT(*) AS CNTREC FROM pragma_table_info('" + table + "') WHERE name='" + column + "';") diff --git a/models/user.go b/models/user.go index 7c1435c..9885231 100644 --- a/models/user.go +++ b/models/user.go @@ -1,9 +1,47 @@ package models +import "fmt" + type User struct { - UserId int `db:UserId` - RoleId int `db:RoleId` - UserName string `db:UserName` - UserPass string `db:UserPass` - AccessToken string `db:AccessToken` + UserId int `db:"UserId"` + RoleId int `db:"RoleId"` + UserName string `db:"UserName"` + Password string `db:"Password"` + AccessToken string `db:"AccessToken"` +} + +func (u *User) SaveUser() (*User, error) { + + var err error + + /* + sql, err := GenerateInsertMethod(&u) + if err != nil { + fmt.Printf("SaveUser error generating sql record : '%s'\n", err) + return &User{}, err + } else { + fmt.Println(sql) + } + result, err := db.Exec(sql) + */ + + fmt.Printf("SaveUser received object '%v'\n", u.RoleId) + + result, err := db.NamedExec((`INSERT INTO users (RoleId, UserName, Password, AccessToken) VALUES (:RoleId, :UserName, :Password, :AccessToken)`), u) + + if err != nil { + fmt.Printf("SaveUser error executing sql record : '%s'\n", err) + return &User{}, err + } else { + affected, _ := result.RowsAffected() + id, _ := result.LastInsertId() + fmt.Printf("SaveUser insert returned result id '%d' affecting %d row(s).\n", id, affected) + } + /* + err = CreateUser(&u).Error + if err != nil { + return &User{}, err + } + */ + return u, nil } diff --git a/utils/crypt.go b/utils/crypt.go new file mode 100644 index 0000000..037d1cd --- /dev/null +++ b/utils/crypt.go @@ -0,0 +1,349 @@ +package utils + +import ( + "bytes" + "crypto/rand" + "crypto/sha512" + "errors" + "strconv" +) + +// code in this file taken from https://github.com/tredoe/osutil/blob/master/v2/userutil/crypt/sha512_crypt/sha512_crypt.go + +var ( + ErrSaltPrefix = errors.New("invalid magic prefix") + ErrSaltFormat = errors.New("invalid salt format") + ErrSaltRounds = errors.New("invalid rounds") +) + +// Salt represents a salt. +type Salt struct { + MagicPrefix []byte + + SaltLenMin int + SaltLenMax int + + RoundsMin int + RoundsMax int + RoundsDefault int +} + +type crypter struct{ Salt Salt } + +var _rounds = []byte("rounds=") + +const ( + MagicPrefix = "$6$" + SaltLenMin = 1 + SaltLenMax = 16 + RoundsMin = 1000 + RoundsMax = 999999999 + RoundsDefault = 5000 + alphabet = "./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +) + +// Generate generates a random salt of a given length. +// +// The length is set thus: +// +// length > SaltLenMax: length = SaltLenMax +// length < SaltLenMin: length = SaltLenMin +func (s *Salt) Generate(length int) []byte { + if length > s.SaltLenMax { + length = s.SaltLenMax + } else if length < s.SaltLenMin { + length = s.SaltLenMin + } + + saltLen := (length * 6 / 8) + if (length*6)%8 != 0 { + saltLen++ + } + salt := make([]byte, saltLen) + rand.Read(salt) + + out := make([]byte, len(s.MagicPrefix)+length) + copy(out, s.MagicPrefix) + copy(out[len(s.MagicPrefix):], Base64_24Bit(salt)) + return out +} + +// GenerateWRounds creates a random salt with the random bytes being of the +// length provided, and the rounds parameter set as specified. +// +// The parameters are set thus: +// +// length > SaltLenMax: length = SaltLenMax +// length < SaltLenMin: length = SaltLenMin +// +// rounds < 0: rounds = RoundsDefault +// rounds < RoundsMin: rounds = RoundsMin +// rounds > RoundsMax: rounds = RoundsMax +// +// If rounds is equal to RoundsDefault, then the "rounds=" part of the salt is +// removed. +func (s *Salt) GenerateWRounds(length, rounds int) []byte { + if length > s.SaltLenMax { + length = s.SaltLenMax + } else if length < s.SaltLenMin { + length = s.SaltLenMin + } + if rounds < 0 { + rounds = s.RoundsDefault + } else if rounds < s.RoundsMin { + rounds = s.RoundsMin + } else if rounds > s.RoundsMax { + rounds = s.RoundsMax + } + + //fmt.Printf("GenerateWRounds length is %d and rounds is %d\n", length, rounds) + + saltLen := (length * 6 / 8) + if (length*6)%8 != 0 { + saltLen++ + } + salt := make([]byte, saltLen) + rand.Read(salt) + + roundsText := "" + if rounds != s.RoundsDefault { + roundsText = "rounds=" + strconv.Itoa(rounds) + "$" + } + + out := make([]byte, len(s.MagicPrefix)+len(roundsText)+length) + copy(out, s.MagicPrefix) + //fmt.Printf("GenerateWRounds copy 1 : '%v'\n", out) + copy(out[len(s.MagicPrefix):], []byte(roundsText)) + //fmt.Printf("GenerateWRounds copy 2 : '%v'\n", out) + copy(out[len(s.MagicPrefix)+len(roundsText):], Base64_24Bit(salt)) + //fmt.Printf("GenerateWRounds copy 3 : '%v'\n", out) + return out +} + +func Base64_24Bit(src []byte) (hash []byte) { + if len(src) == 0 { + return []byte{} // TODO: return nil + } + + hashSize := (len(src) * 8) / 6 + if (len(src) % 6) != 0 { + hashSize += 1 + } + hash = make([]byte, hashSize) + + dst := hash + for len(src) > 0 { + switch len(src) { + default: + dst[0] = alphabet[src[0]&0x3f] + dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f] + dst[2] = alphabet[((src[1]>>4)|(src[2]<<4))&0x3f] + dst[3] = alphabet[(src[2]>>2)&0x3f] + src = src[3:] + dst = dst[4:] + case 2: + dst[0] = alphabet[src[0]&0x3f] + dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f] + dst[2] = alphabet[(src[1]>>4)&0x3f] + src = src[2:] + dst = dst[3:] + case 1: + dst[0] = alphabet[src[0]&0x3f] + dst[1] = alphabet[(src[0]>>6)&0x3f] + src = src[1:] + dst = dst[2:] + } + } + + return +} + +func Generate(key, salt []byte) (string, error) { + var rounds int + var isRoundsDef bool + + var c crypter + c.Salt = GetSalt() + + if len(salt) == 0 { + salt = c.Salt.GenerateWRounds(SaltLenMax, RoundsDefault) + //fmt.Printf("Generate created salt with value '%v'\n", salt) + } + if !bytes.HasPrefix(salt, c.Salt.MagicPrefix) { + //fmt.Printf("Generate salt '%v' has no magic prefix\n", salt) + return "", ErrSaltPrefix + } + + saltToks := bytes.Split(salt, []byte{'$'}) + if len(saltToks) < 3 { + return "", ErrSaltFormat + } + + if bytes.HasPrefix(saltToks[2], _rounds) { + isRoundsDef = true + pr, err := strconv.ParseInt(string(saltToks[2][7:]), 10, 32) + if err != nil { + return "", ErrSaltRounds + } + rounds = int(pr) + if rounds < RoundsMin { + rounds = RoundsMin + } else if rounds > RoundsMax { + rounds = RoundsMax + } + salt = saltToks[3] + } else { + rounds = RoundsDefault + salt = saltToks[2] + } + + if len(salt) > SaltLenMax { + salt = salt[0:SaltLenMax] + } + + // Compute alternate SHA512 sum with input KEY, SALT, and KEY. + Alternate := sha512.New() + Alternate.Write(key) + Alternate.Write(salt) + Alternate.Write(key) + AlternateSum := Alternate.Sum(nil) // 64 bytes + + A := sha512.New() + A.Write(key) + A.Write(salt) + // Add for any character in the key one byte of the alternate sum. + i := len(key) + for ; i > 64; i -= 64 { + A.Write(AlternateSum) + } + A.Write(AlternateSum[0:i]) + + // Take the binary representation of the length of the key and for every add + // the alternate sum, for every 0 the key. + for i = len(key); i > 0; i >>= 1 { + if (i & 1) != 0 { + A.Write(AlternateSum) + } else { + A.Write(key) + } + } + Asum := A.Sum(nil) + + // Start computation of P byte sequence. + P := sha512.New() + // For every character in the password add the entire password. + for i = 0; i < len(key); i++ { + P.Write(key) + } + Psum := P.Sum(nil) + // Create byte sequence P. + Pseq := make([]byte, 0, len(key)) + for i = len(key); i > 64; i -= 64 { + Pseq = append(Pseq, Psum...) + } + Pseq = append(Pseq, Psum[0:i]...) + + // Start computation of S byte sequence. + S := sha512.New() + for i = 0; i < (16 + int(Asum[0])); i++ { + S.Write(salt) + } + Ssum := S.Sum(nil) + // Create byte sequence S. + Sseq := make([]byte, 0, len(salt)) + for i = len(salt); i > 64; i -= 64 { + Sseq = append(Sseq, Ssum...) + } + Sseq = append(Sseq, Ssum[0:i]...) + + Csum := Asum + + // Repeatedly run the collected hash value through SHA512 to burn CPU cycles. + for i = 0; i < rounds; i++ { + C := sha512.New() + + // Add key or last result. + if (i & 1) != 0 { + C.Write(Pseq) + } else { + C.Write(Csum) + } + // Add salt for numbers not divisible by 3. + if (i % 3) != 0 { + C.Write(Sseq) + } + // Add key for numbers not divisible by 7. + if (i % 7) != 0 { + C.Write(Pseq) + } + // Add key or last result. + if (i & 1) != 0 { + C.Write(Csum) + } else { + C.Write(Pseq) + } + + Csum = C.Sum(nil) + } + + out := make([]byte, 0, 123) + out = append(out, MagicPrefix...) + if isRoundsDef { + out = append(out, []byte("rounds="+strconv.Itoa(rounds)+"$")...) + } + out = append(out, salt...) + out = append(out, '$') + out = append(out, Base64_24Bit([]byte{ + Csum[42], Csum[21], Csum[0], + Csum[1], Csum[43], Csum[22], + Csum[23], Csum[2], Csum[44], + Csum[45], Csum[24], Csum[3], + Csum[4], Csum[46], Csum[25], + Csum[26], Csum[5], Csum[47], + Csum[48], Csum[27], Csum[6], + Csum[7], Csum[49], Csum[28], + Csum[29], Csum[8], Csum[50], + Csum[51], Csum[30], Csum[9], + Csum[10], Csum[52], Csum[31], + Csum[32], Csum[11], Csum[53], + Csum[54], Csum[33], Csum[12], + Csum[13], Csum[55], Csum[34], + Csum[35], Csum[14], Csum[56], + Csum[57], Csum[36], Csum[15], + Csum[16], Csum[58], Csum[37], + Csum[38], Csum[17], Csum[59], + Csum[60], Csum[39], Csum[18], + Csum[19], Csum[61], Csum[40], + Csum[41], Csum[20], Csum[62], + Csum[63], + })...) + + // Clean sensitive data. + A.Reset() + Alternate.Reset() + P.Reset() + for i = 0; i < len(Asum); i++ { + Asum[i] = 0 + } + for i = 0; i < len(AlternateSum); i++ { + AlternateSum[i] = 0 + } + for i = 0; i < len(Pseq); i++ { + Pseq[i] = 0 + } + + return string(out), nil +} + +func (c *crypter) SetSalt(salt Salt) { c.Salt = salt } + +func GetSalt() Salt { + return Salt{ + MagicPrefix: []byte(MagicPrefix), + SaltLenMin: SaltLenMin, + SaltLenMax: SaltLenMax, + RoundsDefault: RoundsDefault, + RoundsMin: RoundsMin, + RoundsMax: RoundsMax, + } +}