diff --git a/controllers/retrieve_secrets.go b/controllers/retrieve_secrets.go index e87842b..b0b639f 100644 --- a/controllers/retrieve_secrets.go +++ b/controllers/retrieve_secrets.go @@ -11,7 +11,7 @@ type RetrieveInput struct { DeviceCategory string `json:"deviceCategory"` } -func Retrieve(c *gin.Context) { +func RetrieveSecret(c *gin.Context) { var input RetrieveInput if err := c.ShouldBindJSON(&input); err != nil { diff --git a/controllers/store_secrets.go b/controllers/store_secrets.go index 5985a20..05cde36 100644 --- a/controllers/store_secrets.go +++ b/controllers/store_secrets.go @@ -43,6 +43,7 @@ func StoreSecret(c *gin.Context) { // Encrypt secret s.Secret = input.SecretValue + s.EncryptSecret() _, err = s.SaveSecret() diff --git a/main.go b/main.go index 6a6556d..c95b232 100644 --- a/main.go +++ b/main.go @@ -46,7 +46,7 @@ func main() { // Get secrets protected := router.Group("/api/secret") protected.Use(middlewares.JwtAuthMiddleware()) - protected.GET("/retrieve", controllers.Retrieve) + protected.GET("/retrieve", controllers.RetrieveSecret) protected.POST("/store", controllers.StoreSecret) // Initializing the server in a goroutine so that diff --git a/models/secret.go b/models/secret.go index 9ab7408..e5b7a52 100644 --- a/models/secret.go +++ b/models/secret.go @@ -1,6 +1,13 @@ package models -import "fmt" +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" + "os" +) type Secret struct { SecretId int `db:"SecretId"` @@ -29,3 +36,38 @@ func (s *Secret) SaveSecret() (*Secret, error) { return s, nil } + +func (s *Secret) EncryptSecret() (*Secret, error) { + + keyString := os.Getenv("SECRETS_KEY") + // The key argument should be the AES key, either 16 or 32 bytes + // to select AES-128 or AES-256. + key := []byte(keyString) + plaintext := []byte(s.Secret) + + fmt.Printf("EncryptSecret applying key '%v' to plaintext secret '%s'\n", keyString, s.Secret) + + block, err := aes.NewCipher(key) + if err != nil { + panic(err.Error()) + } + + // Never use more than 2^32 random nonces with a given key because of the risk of a repeat. + nonce := make([]byte, 12) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + panic(err.Error()) + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + panic(err.Error()) + } + + ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil) + fmt.Printf("EncryptSecret generated ciphertext '%x'\n", ciphertext) + + s.Secret = string(ciphertext) + return s, nil + + //return string(ciphertext[:]), nil +} diff --git a/utils/crypt.go b/utils/crypt.go deleted file mode 100644 index 037d1cd..0000000 --- a/utils/crypt.go +++ /dev/null @@ -1,349 +0,0 @@ -package utils - -import ( - "bytes" - "crypto/rand" - "crypto/sha512" - "errors" - "strconv" -) - -// code in this file taken from https://github.com/tredoe/osutil/blob/master/v2/userutil/crypt/sha512_crypt/sha512_crypt.go - -var ( - ErrSaltPrefix = errors.New("invalid magic prefix") - ErrSaltFormat = errors.New("invalid salt format") - ErrSaltRounds = errors.New("invalid rounds") -) - -// Salt represents a salt. -type Salt struct { - MagicPrefix []byte - - SaltLenMin int - SaltLenMax int - - RoundsMin int - RoundsMax int - RoundsDefault int -} - -type crypter struct{ Salt Salt } - -var _rounds = []byte("rounds=") - -const ( - MagicPrefix = "$6$" - SaltLenMin = 1 - SaltLenMax = 16 - RoundsMin = 1000 - RoundsMax = 999999999 - RoundsDefault = 5000 - alphabet = "./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" -) - -// Generate generates a random salt of a given length. -// -// The length is set thus: -// -// length > SaltLenMax: length = SaltLenMax -// length < SaltLenMin: length = SaltLenMin -func (s *Salt) Generate(length int) []byte { - if length > s.SaltLenMax { - length = s.SaltLenMax - } else if length < s.SaltLenMin { - length = s.SaltLenMin - } - - saltLen := (length * 6 / 8) - if (length*6)%8 != 0 { - saltLen++ - } - salt := make([]byte, saltLen) - rand.Read(salt) - - out := make([]byte, len(s.MagicPrefix)+length) - copy(out, s.MagicPrefix) - copy(out[len(s.MagicPrefix):], Base64_24Bit(salt)) - return out -} - -// GenerateWRounds creates a random salt with the random bytes being of the -// length provided, and the rounds parameter set as specified. -// -// The parameters are set thus: -// -// length > SaltLenMax: length = SaltLenMax -// length < SaltLenMin: length = SaltLenMin -// -// rounds < 0: rounds = RoundsDefault -// rounds < RoundsMin: rounds = RoundsMin -// rounds > RoundsMax: rounds = RoundsMax -// -// If rounds is equal to RoundsDefault, then the "rounds=" part of the salt is -// removed. -func (s *Salt) GenerateWRounds(length, rounds int) []byte { - if length > s.SaltLenMax { - length = s.SaltLenMax - } else if length < s.SaltLenMin { - length = s.SaltLenMin - } - if rounds < 0 { - rounds = s.RoundsDefault - } else if rounds < s.RoundsMin { - rounds = s.RoundsMin - } else if rounds > s.RoundsMax { - rounds = s.RoundsMax - } - - //fmt.Printf("GenerateWRounds length is %d and rounds is %d\n", length, rounds) - - saltLen := (length * 6 / 8) - if (length*6)%8 != 0 { - saltLen++ - } - salt := make([]byte, saltLen) - rand.Read(salt) - - roundsText := "" - if rounds != s.RoundsDefault { - roundsText = "rounds=" + strconv.Itoa(rounds) + "$" - } - - out := make([]byte, len(s.MagicPrefix)+len(roundsText)+length) - copy(out, s.MagicPrefix) - //fmt.Printf("GenerateWRounds copy 1 : '%v'\n", out) - copy(out[len(s.MagicPrefix):], []byte(roundsText)) - //fmt.Printf("GenerateWRounds copy 2 : '%v'\n", out) - copy(out[len(s.MagicPrefix)+len(roundsText):], Base64_24Bit(salt)) - //fmt.Printf("GenerateWRounds copy 3 : '%v'\n", out) - return out -} - -func Base64_24Bit(src []byte) (hash []byte) { - if len(src) == 0 { - return []byte{} // TODO: return nil - } - - hashSize := (len(src) * 8) / 6 - if (len(src) % 6) != 0 { - hashSize += 1 - } - hash = make([]byte, hashSize) - - dst := hash - for len(src) > 0 { - switch len(src) { - default: - dst[0] = alphabet[src[0]&0x3f] - dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f] - dst[2] = alphabet[((src[1]>>4)|(src[2]<<4))&0x3f] - dst[3] = alphabet[(src[2]>>2)&0x3f] - src = src[3:] - dst = dst[4:] - case 2: - dst[0] = alphabet[src[0]&0x3f] - dst[1] = alphabet[((src[0]>>6)|(src[1]<<2))&0x3f] - dst[2] = alphabet[(src[1]>>4)&0x3f] - src = src[2:] - dst = dst[3:] - case 1: - dst[0] = alphabet[src[0]&0x3f] - dst[1] = alphabet[(src[0]>>6)&0x3f] - src = src[1:] - dst = dst[2:] - } - } - - return -} - -func Generate(key, salt []byte) (string, error) { - var rounds int - var isRoundsDef bool - - var c crypter - c.Salt = GetSalt() - - if len(salt) == 0 { - salt = c.Salt.GenerateWRounds(SaltLenMax, RoundsDefault) - //fmt.Printf("Generate created salt with value '%v'\n", salt) - } - if !bytes.HasPrefix(salt, c.Salt.MagicPrefix) { - //fmt.Printf("Generate salt '%v' has no magic prefix\n", salt) - return "", ErrSaltPrefix - } - - saltToks := bytes.Split(salt, []byte{'$'}) - if len(saltToks) < 3 { - return "", ErrSaltFormat - } - - if bytes.HasPrefix(saltToks[2], _rounds) { - isRoundsDef = true - pr, err := strconv.ParseInt(string(saltToks[2][7:]), 10, 32) - if err != nil { - return "", ErrSaltRounds - } - rounds = int(pr) - if rounds < RoundsMin { - rounds = RoundsMin - } else if rounds > RoundsMax { - rounds = RoundsMax - } - salt = saltToks[3] - } else { - rounds = RoundsDefault - salt = saltToks[2] - } - - if len(salt) > SaltLenMax { - salt = salt[0:SaltLenMax] - } - - // Compute alternate SHA512 sum with input KEY, SALT, and KEY. - Alternate := sha512.New() - Alternate.Write(key) - Alternate.Write(salt) - Alternate.Write(key) - AlternateSum := Alternate.Sum(nil) // 64 bytes - - A := sha512.New() - A.Write(key) - A.Write(salt) - // Add for any character in the key one byte of the alternate sum. - i := len(key) - for ; i > 64; i -= 64 { - A.Write(AlternateSum) - } - A.Write(AlternateSum[0:i]) - - // Take the binary representation of the length of the key and for every add - // the alternate sum, for every 0 the key. - for i = len(key); i > 0; i >>= 1 { - if (i & 1) != 0 { - A.Write(AlternateSum) - } else { - A.Write(key) - } - } - Asum := A.Sum(nil) - - // Start computation of P byte sequence. - P := sha512.New() - // For every character in the password add the entire password. - for i = 0; i < len(key); i++ { - P.Write(key) - } - Psum := P.Sum(nil) - // Create byte sequence P. - Pseq := make([]byte, 0, len(key)) - for i = len(key); i > 64; i -= 64 { - Pseq = append(Pseq, Psum...) - } - Pseq = append(Pseq, Psum[0:i]...) - - // Start computation of S byte sequence. - S := sha512.New() - for i = 0; i < (16 + int(Asum[0])); i++ { - S.Write(salt) - } - Ssum := S.Sum(nil) - // Create byte sequence S. - Sseq := make([]byte, 0, len(salt)) - for i = len(salt); i > 64; i -= 64 { - Sseq = append(Sseq, Ssum...) - } - Sseq = append(Sseq, Ssum[0:i]...) - - Csum := Asum - - // Repeatedly run the collected hash value through SHA512 to burn CPU cycles. - for i = 0; i < rounds; i++ { - C := sha512.New() - - // Add key or last result. - if (i & 1) != 0 { - C.Write(Pseq) - } else { - C.Write(Csum) - } - // Add salt for numbers not divisible by 3. - if (i % 3) != 0 { - C.Write(Sseq) - } - // Add key for numbers not divisible by 7. - if (i % 7) != 0 { - C.Write(Pseq) - } - // Add key or last result. - if (i & 1) != 0 { - C.Write(Csum) - } else { - C.Write(Pseq) - } - - Csum = C.Sum(nil) - } - - out := make([]byte, 0, 123) - out = append(out, MagicPrefix...) - if isRoundsDef { - out = append(out, []byte("rounds="+strconv.Itoa(rounds)+"$")...) - } - out = append(out, salt...) - out = append(out, '$') - out = append(out, Base64_24Bit([]byte{ - Csum[42], Csum[21], Csum[0], - Csum[1], Csum[43], Csum[22], - Csum[23], Csum[2], Csum[44], - Csum[45], Csum[24], Csum[3], - Csum[4], Csum[46], Csum[25], - Csum[26], Csum[5], Csum[47], - Csum[48], Csum[27], Csum[6], - Csum[7], Csum[49], Csum[28], - Csum[29], Csum[8], Csum[50], - Csum[51], Csum[30], Csum[9], - Csum[10], Csum[52], Csum[31], - Csum[32], Csum[11], Csum[53], - Csum[54], Csum[33], Csum[12], - Csum[13], Csum[55], Csum[34], - Csum[35], Csum[14], Csum[56], - Csum[57], Csum[36], Csum[15], - Csum[16], Csum[58], Csum[37], - Csum[38], Csum[17], Csum[59], - Csum[60], Csum[39], Csum[18], - Csum[19], Csum[61], Csum[40], - Csum[41], Csum[20], Csum[62], - Csum[63], - })...) - - // Clean sensitive data. - A.Reset() - Alternate.Reset() - P.Reset() - for i = 0; i < len(Asum); i++ { - Asum[i] = 0 - } - for i = 0; i < len(AlternateSum); i++ { - AlternateSum[i] = 0 - } - for i = 0; i < len(Pseq); i++ { - Pseq[i] = 0 - } - - return string(out), nil -} - -func (c *crypter) SetSalt(salt Salt) { c.Salt = salt } - -func GetSalt() Salt { - return Salt{ - MagicPrefix: []byte(MagicPrefix), - SaltLenMin: SaltLenMin, - SaltLenMax: SaltLenMax, - RoundsDefault: RoundsDefault, - RoundsMin: RoundsMin, - RoundsMax: RoundsMax, - } -}