From 66983fa59d0ce993a1fe8e0e8bdcdbfc940f6be0 Mon Sep 17 00:00:00 2001 From: Nathan Coad Date: Wed, 29 Mar 2023 14:44:50 +1100 Subject: [PATCH] added login and test admin page --- .gitignore | 3 +- controllers/auth.go | 54 +++++++++++++++++++++++++- go.mod | 2 + go.sum | 4 ++ main.go | 8 ++++ middlewares/middlewares.go | 21 ++++++++++ models/setup.go | 46 +++++++++++++++++++--- models/user.go | 74 ++++++++++++++++++++++++++++++++++- utils/token/token.go | 79 ++++++++++++++++++++++++++++++++++++++ 9 files changed, 282 insertions(+), 9 deletions(-) create mode 100644 middlewares/middlewares.go create mode 100644 utils/token/token.go diff --git a/.gitignore b/.gitignore index 97d0c99..586fa3f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ api\ tests.txt ccsecrets -ccsecrets.db \ No newline at end of file +ccsecrets.db +.env \ No newline at end of file diff --git a/controllers/auth.go b/controllers/auth.go index f551cc6..592cce3 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -7,6 +7,7 @@ import ( "strings" "ccsecrets/models" + "ccsecrets/utils/token" "github.com/gin-gonic/gin" "golang.org/x/crypto/bcrypt" @@ -17,6 +18,11 @@ type RegisterInput struct { Password string `json:"password" binding:"required"` } +type LoginInput struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + func Register(c *gin.Context) { var input RegisterInput @@ -36,7 +42,7 @@ func Register(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"Error hashing password": err.Error()}) return } else { - fmt.Printf("Hashed password value is '%s'\n", string(hashedPassword)) + fmt.Printf("Register generated hashed password value '%s' from '%s'\n", string(hashedPassword), input.Password) } u.Password = string(hashedPassword) @@ -52,3 +58,49 @@ func Register(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "registration success"}) } + +func Login(c *gin.Context) { + + var input LoginInput + + if err := c.ShouldBindJSON(&input); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + u := models.User{} + + u.UserName = input.Username + u.Password = input.Password + + fmt.Printf("Login checking username '%s' and password '%s'\n", u.UserName, u.Password) + + token, err := models.LoginCheck(u.UserName, u.Password) + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "username or password is incorrect."}) + return + } + + c.JSON(http.StatusOK, gin.H{"token": token}) + +} + +func CurrentUser(c *gin.Context) { + + user_id, err := token.ExtractTokenID(c) + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + u, err := models.GetUserByID(user_id) + + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "success", "data": u}) +} diff --git a/go.mod b/go.mod index a24d015..5f8a8f5 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,10 @@ module ccsecrets go 1.19 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/gin-gonic/gin v1.9.0 github.com/jmoiron/sqlx v1.3.5 + github.com/joho/godotenv v1.5.1 golang.org/x/crypto v0.7.0 modernc.org/sqlite v1.21.0 ) diff --git a/go.sum b/go.sum index ba93119..0d71680 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583j github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -33,6 +35,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= diff --git a/main.go b/main.go index 660b947..55944ef 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "ccsecrets/controllers" + "ccsecrets/middlewares" "ccsecrets/models" "github.com/gin-gonic/gin" @@ -9,13 +10,20 @@ import ( func main() { + // Initiate connection to sqlite and make sure our schema is up to date models.ConnectDatabase() r := gin.Default() public := r.Group("/api") + // Define our routes underneath /api public.POST("/register", controllers.Register) + public.POST("/login", controllers.Login) + + protected := r.Group("/api/admin") + protected.Use(middlewares.JwtAuthMiddleware()) + protected.GET("/user", controllers.CurrentUser) r.Run(":8080") diff --git a/middlewares/middlewares.go b/middlewares/middlewares.go new file mode 100644 index 0000000..d199e64 --- /dev/null +++ b/middlewares/middlewares.go @@ -0,0 +1,21 @@ +package middlewares + +import ( + "net/http" + + "ccsecrets/utils/token" + + "github.com/gin-gonic/gin" +) + +func JwtAuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + err := token.TokenValid(c) + if err != nil { + c.String(http.StatusUnauthorized, "Unauthorized") + c.Abort() + return + } + c.Next() + } +} diff --git a/models/setup.go b/models/setup.go index 79f0204..bdc5400 100644 --- a/models/setup.go +++ b/models/setup.go @@ -3,12 +3,14 @@ package models import ( "errors" "fmt" + "log" "os" "reflect" "ccsecrets/utils" "github.com/jmoiron/sqlx" + "github.com/joho/godotenv" _ "modernc.org/sqlite" ) @@ -59,6 +61,13 @@ const createSchema string = ` func ConnectDatabase() { var err error + // Load data from environment file + err = godotenv.Load(".env") + + if err != nil { + log.Fatalf("Error loading .env file") + } + // Try using sqlite as our database sqlPath := utils.GetFilePath(sqlFile) db, err = sqlx.Open("sqlite", sqlPath) @@ -84,15 +93,19 @@ func DisconnectDatabase() { func CreateTables() { var err error + var rowCount int // Create database tables if it doesn't exist // Roles table should go first since other tables refer to it if _, err = db.Exec(createRoles); err != nil { fmt.Printf("Error checking roles table : '%s'", err) os.Exit(1) } - if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false, true);"); err != nil { - fmt.Printf("Error adding initial admin role : '%s'", err) - os.Exit(1) + rowCount, _ = CheckCount("roles") + if rowCount == 0 { + if _, err = db.Exec("INSERT INTO roles VALUES(1, 'Admin', false, true);"); err != nil { + fmt.Printf("Error adding initial admin role : '%s'", err) + os.Exit(1) + } } // Users table @@ -100,9 +113,12 @@ func CreateTables() { fmt.Printf("Error checking users table : '%s'", err) os.Exit(1) } - if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', 'password', 'token');"); err != nil { - fmt.Printf("Error adding initial admin role : '%s'", err) - os.Exit(1) + rowCount, _ = CheckCount("users") + if rowCount == 0 { + if _, err = db.Exec("INSERT INTO users VALUES(1, 1, 'Administrator', 'password', 'token');"); err != nil { + fmt.Printf("Error adding initial admin role : '%s'", err) + os.Exit(1) + } } // Secrets table if _, err = db.Exec(createSecrets); err != nil { @@ -123,6 +139,24 @@ func CreateTables() { } } +// Count the number of records in the sqlite database +// Borrowed from https://gist.github.com/trkrameshkumar/f4f1c00ef5d578561c96?permalink_comment_id=2687592#gistcomment-2687592 +func CheckCount(tablename string) (int, error) { + var count int + stmt, err := db.Prepare("SELECT COUNT(*) as count FROM " + tablename) + if err != nil { + fmt.Printf("CheckCount error preparing sqlite statement : '%s'\n", err) + return 0, err + } + err = stmt.QueryRow().Scan(&count) + if err != nil { + fmt.Printf("CheckCount error querying database record count : '%s'\n", err) + return 0, err + } + stmt.Close() // or use defer rows.Close(), idc + return count, nil +} + // From https://stackoverflow.com/a/60100045 func GenerateInsertMethod(q interface{}) (string, error) { if reflect.ValueOf(q).Kind() == reflect.Struct { diff --git a/models/user.go b/models/user.go index d953d02..fbdae6f 100644 --- a/models/user.go +++ b/models/user.go @@ -1,6 +1,12 @@ package models -import "fmt" +import ( + "ccsecrets/utils/token" + "errors" + "fmt" + + "golang.org/x/crypto/bcrypt" +) type User struct { UserId int `db:"UserId"` @@ -27,3 +33,69 @@ func (u *User) SaveUser() (*User, error) { return u, nil } + +func VerifyPassword(password, hashedPassword string) error { + fmt.Printf("VerifyPassword comparing password vs hashed:\n'%s'\n'%s'\n", password, hashedPassword) + return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) +} + +func LoginCheck(username string, password string) (string, error) { + + var err error + + u := User{} + + // Query database for matching user object + err = db.QueryRowx("SELECT * FROM Users WHERE Username=?", username).StructScan(&u) + + fmt.Printf("LoginCheck retrieved user '%v' from database\n", u) + + //err = DB.Model(User{}).Where("username = ?", username).Take(&u).Error + + if err != nil { + return "", err + } + + err = VerifyPassword(password, u.Password) + + if err != nil && err == bcrypt.ErrMismatchedHashAndPassword { + fmt.Printf("LoginCheck says password doesn't match stored hash.\n") + return "", err + } else { + fmt.Printf("LoginCheck verified password against stored hash.\n") + } + + token, err := token.GenerateToken(uint(u.UserId)) + + if err != nil { + fmt.Printf("LoginCheck error generating token : '%s'\n", err) + return "", err + } + + return token, nil + +} + +func GetUserByID(uid uint) (User, error) { + + var u User + + // Query database for matching user object + err := db.QueryRowx("SELECT * FROM Users WHERE UserId=?", uid).StructScan(&u) + if err != nil { + return u, errors.New("user not found") + } + /* + if err := DB.First(&u, uid).Error; err != nil { + return u, errors.New("User not found!") + } + */ + u.PrepareGive() + + return u, nil + +} + +func (u *User) PrepareGive() { + u.Password = "" +} diff --git a/utils/token/token.go b/utils/token/token.go new file mode 100644 index 0000000..0dbf05b --- /dev/null +++ b/utils/token/token.go @@ -0,0 +1,79 @@ +package token + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + jwt "github.com/dgrijalva/jwt-go" + "github.com/gin-gonic/gin" +) + +func GenerateToken(user_id uint) (string, error) { + + token_lifespan, err := strconv.Atoi(os.Getenv("TOKEN_HOUR_LIFESPAN")) + if err != nil { + fmt.Printf("GenerateToken Error getting env value TOKEN_HOUR_LIFESPAN\n") + return "", err + } + + claims := jwt.MapClaims{} + claims["authorized"] = true + claims["user_id"] = user_id + claims["exp"] = time.Now().Add(time.Hour * time.Duration(token_lifespan)).Unix() + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + return token.SignedString([]byte(os.Getenv("API_SECRET"))) + +} + +func TokenValid(c *gin.Context) error { + tokenString := ExtractToken(c) + _, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(os.Getenv("API_SECRET")), nil + }) + if err != nil { + return err + } + return nil +} + +func ExtractToken(c *gin.Context) string { + token := c.Query("token") + if token != "" { + return token + } + bearerToken := c.Request.Header.Get("Authorization") + if len(strings.Split(bearerToken, " ")) == 2 { + return strings.Split(bearerToken, " ")[1] + } + return "" +} + +func ExtractTokenID(c *gin.Context) (uint, error) { + + tokenString := ExtractToken(c) + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(os.Getenv("API_SECRET")), nil + }) + if err != nil { + return 0, err + } + claims, ok := token.Claims.(jwt.MapClaims) + if ok && token.Valid { + uid, err := strconv.ParseUint(fmt.Sprintf("%.0f", claims["user_id"]), 10, 32) + if err != nil { + return 0, err + } + return uint(uid), nil + } + return 0, nil +}