78 lines
2.0 KiB
Go
78 lines
2.0 KiB
Go
package apis
|
|
|
|
import (
|
|
"clortho/lib/db"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestAdminMiddleware(t *testing.T) {
|
|
// Create an admin user
|
|
adminUser, err := InitUser("admin_test", "password")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
adminUser.Admin = true
|
|
db.Connection.Save(&adminUser)
|
|
|
|
// Create a non-admin user
|
|
regularUser, err := InitUser("regular_test", "password")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
regularUser.Admin = false
|
|
db.Connection.Save(®ularUser)
|
|
|
|
// Test with admin user
|
|
t.Run("Admin user should have access", func(t *testing.T) {
|
|
router := gin.New()
|
|
router.Use(MockAuthMiddleware(*adminUser))
|
|
router.Use(AdminMiddleware())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), "success")
|
|
})
|
|
|
|
// Test with non-admin user
|
|
t.Run("Non-admin user should be forbidden", func(t *testing.T) {
|
|
router := gin.New()
|
|
router.Use(MockAuthMiddleware(*regularUser))
|
|
router.Use(AdminMiddleware())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusForbidden, w.Code)
|
|
assert.Contains(t, w.Body.String(), "admin access required")
|
|
})
|
|
|
|
// Test with no session
|
|
t.Run("No session should be unauthorized", func(t *testing.T) {
|
|
router := gin.New()
|
|
router.Use(AdminMiddleware())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
|
assert.Contains(t, w.Body.String(), "unauthorized")
|
|
})
|
|
} |