Files
bamort/backend/registry/registry_test.go

360 lines
9.2 KiB
Go

package registry
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// resetRegistry clears all registered functions so each test starts with a
// clean slate. Only callable from within the registry package (_test.go).
func resetRegistry() {
routeFuncs = nil
publicRouteFuncs = nil
baseRouteFuncs = nil
migrateFuncs = nil
initializerFuncs = nil
authProvider = nil
modelInstances = nil
}
func newTestEngine() *gin.Engine {
gin.SetMode(gin.TestMode)
return gin.New()
}
func newTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
return db
}
// ---------------------------------------------------------------------------
// RegisterRoutes / RunAllRoutes
// ---------------------------------------------------------------------------
func TestRegisterRoutes_SingleFunc_IsCalled(t *testing.T) {
resetRegistry()
called := false
RegisterRoutes(func(r *gin.RouterGroup) {
called = true
r.GET("/ping", func(c *gin.Context) { c.Status(http.StatusOK) })
})
engine := newTestEngine()
RunAllRoutes(engine.Group("/api"))
assert.True(t, called)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/ping", nil)
engine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRegisterRoutes_MultipleFuncs_AllCalled(t *testing.T) {
resetRegistry()
count := 0
for i := 0; i < 3; i++ {
RegisterRoutes(func(r *gin.RouterGroup) { count++ })
}
RunAllRoutes(newTestEngine().Group("/api"))
assert.Equal(t, 3, count)
}
func TestRunAllRoutes_NoFuncs_DoesNotPanic(t *testing.T) {
resetRegistry()
assert.NotPanics(t, func() { RunAllRoutes(newTestEngine().Group("/api")) })
}
// ---------------------------------------------------------------------------
// RegisterPublicRoutes / RunAllPublicRoutes
// ---------------------------------------------------------------------------
func TestRegisterPublicRoutes_SingleFunc_IsCalled(t *testing.T) {
resetRegistry()
called := false
RegisterPublicRoutes(func(r *gin.Engine) {
called = true
r.GET("/health", func(c *gin.Context) { c.Status(http.StatusOK) })
})
engine := newTestEngine()
RunAllPublicRoutes(engine)
assert.True(t, called)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/health", nil)
engine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRunAllPublicRoutes_NoFuncs_DoesNotPanic(t *testing.T) {
resetRegistry()
assert.NotPanics(t, func() { RunAllPublicRoutes(newTestEngine()) })
}
// ---------------------------------------------------------------------------
// RegisterBaseRoutes / RunAllBaseRoutes
// ---------------------------------------------------------------------------
func TestRegisterBaseRoutes_SingleFunc_IsCalled(t *testing.T) {
resetRegistry()
called := false
RegisterBaseRoutes(func(r *gin.Engine) {
called = true
r.POST("/login", func(c *gin.Context) { c.Status(http.StatusOK) })
})
engine := newTestEngine()
RunAllBaseRoutes(engine)
assert.True(t, called)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/login", nil)
engine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRunAllBaseRoutes_NoFuncs_DoesNotPanic(t *testing.T) {
resetRegistry()
assert.NotPanics(t, func() { RunAllBaseRoutes(newTestEngine()) })
}
// ---------------------------------------------------------------------------
// SetAuthMiddleware / GetAuthMiddleware
// ---------------------------------------------------------------------------
func TestGetAuthMiddleware_NoProvider_ReturnsPassThrough(t *testing.T) {
resetRegistry()
mw := GetAuthMiddleware()
require.NotNil(t, mw)
engine := newTestEngine()
engine.Use(mw)
engine.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) })
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
engine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code, "pass-through should not block requests")
}
func TestSetAuthMiddleware_ProviderIsUsed(t *testing.T) {
resetRegistry()
providerCalled := false
SetAuthMiddleware(func() gin.HandlerFunc {
providerCalled = true
return func(c *gin.Context) {
c.Header("X-Auth", "ok")
c.Next()
}
})
mw := GetAuthMiddleware()
assert.True(t, providerCalled)
engine := newTestEngine()
engine.Use(mw)
engine.GET("/secure", func(c *gin.Context) { c.Status(http.StatusOK) })
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/secure", nil)
engine.ServeHTTP(w, req)
assert.Equal(t, "ok", w.Header().Get("X-Auth"))
}
func TestSetAuthMiddleware_OverwritesPreviousProvider(t *testing.T) {
resetRegistry()
SetAuthMiddleware(func() gin.HandlerFunc {
return func(c *gin.Context) { c.Header("X-Version", "first"); c.Next() }
})
SetAuthMiddleware(func() gin.HandlerFunc {
return func(c *gin.Context) { c.Header("X-Version", "second"); c.Next() }
})
engine := newTestEngine()
engine.Use(GetAuthMiddleware())
engine.GET("/v", func(c *gin.Context) { c.Status(http.StatusOK) })
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/v", nil)
engine.ServeHTTP(w, req)
assert.Equal(t, "second", w.Header().Get("X-Version"), "last registration wins")
}
// ---------------------------------------------------------------------------
// RegisterMigration / RunAllMigrations
// ---------------------------------------------------------------------------
func TestRunAllMigrations_NoFuncs_ReturnsNil(t *testing.T) {
resetRegistry()
assert.NoError(t, RunAllMigrations(newTestDB(t)))
}
func TestRunAllMigrations_SingleFunc_IsCalled(t *testing.T) {
resetRegistry()
called := false
RegisterMigration(func(d ...*gorm.DB) error {
called = true
return nil
})
require.NoError(t, RunAllMigrations(newTestDB(t)))
assert.True(t, called)
}
func TestRunAllMigrations_ReceivesCorrectDB(t *testing.T) {
resetRegistry()
db := newTestDB(t)
var received *gorm.DB
RegisterMigration(func(d ...*gorm.DB) error {
if len(d) > 0 {
received = d[0]
}
return nil
})
require.NoError(t, RunAllMigrations(db))
assert.Same(t, db, received, "RunAllMigrations should pass the provided DB instance")
}
func TestRunAllMigrations_MultipleFuncs_AllCalled(t *testing.T) {
resetRegistry()
count := 0
for i := 0; i < 3; i++ {
RegisterMigration(func(d ...*gorm.DB) error { count++; return nil })
}
require.NoError(t, RunAllMigrations(newTestDB(t)))
assert.Equal(t, 3, count)
}
func TestRunAllMigrations_FirstErrorAbortsChain(t *testing.T) {
resetRegistry()
secondCalled := false
RegisterMigration(func(d ...*gorm.DB) error { return errors.New("migration failed") })
RegisterMigration(func(d ...*gorm.DB) error { secondCalled = true; return nil })
err := RunAllMigrations(newTestDB(t))
assert.Error(t, err)
assert.False(t, secondCalled, "subsequent migrations must not run after an error")
}
func TestRunAllMigrations_AutoMigratesModel(t *testing.T) {
resetRegistry()
db := newTestDB(t)
type SampleModel struct {
gorm.Model
Name string
}
RegisterMigration(func(d ...*gorm.DB) error {
return d[0].AutoMigrate(&SampleModel{})
})
require.NoError(t, RunAllMigrations(db))
assert.True(t, db.Migrator().HasTable(&SampleModel{}), "table should exist after migration")
}
// ---------------------------------------------------------------------------
// RegisterInitializer / RunAllInitializers
// ---------------------------------------------------------------------------
func TestRunAllInitializers_NoFuncs_DoesNotPanic(t *testing.T) {
resetRegistry()
assert.NotPanics(t, func() { RunAllInitializers(newTestDB(t)) })
}
func TestRegisterInitializer_SingleFunc_IsCalled(t *testing.T) {
resetRegistry()
called := false
RegisterInitializer(func(db *gorm.DB) {
called = true
})
RunAllInitializers(newTestDB(t))
assert.True(t, called)
}
func TestRunAllInitializers_ReceivesCorrectDB(t *testing.T) {
resetRegistry()
db := newTestDB(t)
var received *gorm.DB
RegisterInitializer(func(d *gorm.DB) {
received = d
})
RunAllInitializers(db)
assert.Same(t, db, received, "RunAllInitializers should pass the provided DB instance")
}
func TestRunAllInitializers_MultipleFuncs_AllCalled(t *testing.T) {
resetRegistry()
count := 0
for i := 0; i < 3; i++ {
RegisterInitializer(func(db *gorm.DB) { count++ })
}
RunAllInitializers(newTestDB(t))
assert.Equal(t, 3, count)
}
// ---------------------------------------------------------------------------
// RegisterModel / GetModels
// ---------------------------------------------------------------------------
func TestGetModels_NoModels_ReturnsNil(t *testing.T) {
resetRegistry()
assert.Nil(t, GetModels())
}
func TestRegisterModel_SingleModel_IsReturned(t *testing.T) {
resetRegistry()
type Foo struct{ Name string }
RegisterModel(Foo{})
models := GetModels()
require.Len(t, models, 1)
assert.IsType(t, Foo{}, models[0])
}
func TestRegisterModel_MultipleModels_AllReturned(t *testing.T) {
resetRegistry()
type A struct{ X int }
type B struct{ Y string }
RegisterModel(A{})
RegisterModel(B{})
models := GetModels()
require.Len(t, models, 2)
assert.IsType(t, A{}, models[0])
assert.IsType(t, B{}, models[1])
}