diff --git a/apis/apis_utils_test.go b/apis/apis_utils_test.go index 9cf1987..a6bcf50 100644 --- a/apis/apis_utils_test.go +++ b/apis/apis_utils_test.go @@ -12,3 +12,17 @@ func MockAuthMiddleware(user db.User) gin.HandlerFunc { c.Set("session", session) } } + +func InitUser(username string, password string) (*db.User, error) { + user := users.GetUser(username) + if user == nil { + user, _ = users.CreateUser(username) + } + passHash, _ := users.HashPassword(password) + user.PasswordHash = &passHash + res := db.Connection.Save(&user) + if res.Error != nil { + return nil, res.Error + } + return user, nil +} diff --git a/apis/auth_endpoints.go b/apis/auth_endpoints.go index 037f69c..56e8e33 100644 --- a/apis/auth_endpoints.go +++ b/apis/auth_endpoints.go @@ -55,6 +55,6 @@ func getMe(c *gin.Context) { } c.JSON(200, gin.H{ "loggedIn": true, - "user": session.(db.UserSession).User, + "user": session.(*db.UserSession).User, }) } diff --git a/apis/auth_endpoints_test.go b/apis/auth_endpoints_test.go index d5309ca..bacdcfd 100644 --- a/apis/auth_endpoints_test.go +++ b/apis/auth_endpoints_test.go @@ -15,6 +15,8 @@ import ( ) func TestMain(m *testing.M) { + // Set up test database file + os.Setenv("CLORTHO_DB_FILE", "test_clortho.db") // Global setup fmt.Println("Setting up resources...") @@ -50,21 +52,18 @@ func TestInitAuthEndpoints_authLogin(t *testing.T) { } func TestInitAuthEndpoints_getMe(t *testing.T) { - _, err := users.InitAdminUser() + user, err := InitUser("admin", "password") if err != nil { t.Fatal(err) } - - admin := users.GetUser("admin") r := gin.Default() - SetupRouter(r, MockAuthMiddleware(*admin)) + SetupRouter(r, MockAuthMiddleware(*user)) w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/private/auth/me", nil) r.ServeHTTP(w, req) - user := admin strUser, _ := json.Marshal(user) assert.Equal(t, 200, w.Code) - assert.JSONEq(t, fmt.Sprintf(`{"valid": true, "user": %s}`, strUser), w.Body.String()) + assert.JSONEq(t, fmt.Sprintf(`{"loggedIn": true, "user": %s}`, strUser), w.Body.String()) } diff --git a/cmd/server/main.go b/cmd/server/main.go index 9735432..1b1d864 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -14,7 +14,7 @@ func main() { log.Fatal("Could not initialize connection to the DB", err) } - adminPass, err := users.InitAdminUser() + adminPass, err := users.InitAdminUser(nil) if err != nil { log.Println(err) } else { diff --git a/users/users.go b/users/users.go index e103962..91a5944 100644 --- a/users/users.go +++ b/users/users.go @@ -46,6 +46,15 @@ func GetUser(username string) *db.User { } } +func CreateUser(username string) (*db.User, error) { + user := db.User{Username: username} + result := db.Connection.Create(&user) + if result.Error != nil { + return nil, result.Error + } + return &user, nil +} + func GetUsers() []db.User { var users []db.User db.Connection.Find(&users) @@ -136,7 +145,7 @@ func GenerateJwt(sessionId uint) (string, error) { } func NewSession(user db.User) *db.UserSession { - session := db.UserSession{UserID: user.ID} + session := db.UserSession{User: user} db.Connection.Create(&session) return &session }