diff --git a/hooks.go b/hooks.go index a19a47c..7a8b0f4 100644 --- a/hooks.go +++ b/hooks.go @@ -132,13 +132,19 @@ func loadLocaleProvider(svc *Service, provider localeProviderRegistration) { // OnMissingKey registers a handler for missing translation keys. func OnMissingKey(h MissingKeyHandler) { + SetMissingKeyHandlers(h) +} + +// SetMissingKeyHandlers replaces the full missing-key handler chain. +func SetMissingKeyHandlers(handlers ...MissingKeyHandler) { missingKeyHandlerMu.Lock() defer missingKeyHandlerMu.Unlock() - if h == nil { + handlers = filterNilMissingKeyHandlers(handlers) + if len(handlers) == 0 { missingKeyHandler.Store(missingKeyHandlersState{}) return } - missingKeyHandler.Store(missingKeyHandlersState{handlers: []MissingKeyHandler{h}}) + missingKeyHandler.Store(missingKeyHandlersState{handlers: handlers}) } // ClearMissingKeyHandlers removes all registered missing-key handlers. @@ -161,6 +167,22 @@ func AddMissingKeyHandler(h MissingKeyHandler) { missingKeyHandler.Store(current) } +func filterNilMissingKeyHandlers(handlers []MissingKeyHandler) []MissingKeyHandler { + if len(handlers) == 0 { + return nil + } + filtered := make([]MissingKeyHandler, 0, len(handlers)) + for _, h := range handlers { + if h != nil { + filtered = append(filtered, h) + } + } + if len(filtered) == 0 { + return nil + } + return filtered +} + func missingKeyHandlers() missingKeyHandlersState { v := missingKeyHandler.Load() if v == nil { diff --git a/hooks_test.go b/hooks_test.go index fa829d3..d1beff1 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -457,6 +457,54 @@ func TestAddMissingKeyHandler_Good(t *testing.T) { assert.Equal(t, 1, second) } +func TestSetMissingKeyHandlers_Good(t *testing.T) { + svc, err := New() + require.NoError(t, err) + prev := Default() + SetDefault(svc) + prevHandlers := missingKeyHandlers() + t.Cleanup(func() { + missingKeyHandler.Store(prevHandlers) + SetDefault(prev) + }) + svc.SetMode(ModeCollect) + + var first, second int + SetMissingKeyHandlers( + nil, + func(MissingKey) { first++ }, + func(MissingKey) { second++ }, + ) + + _ = T("missing.set.handlers") + + assert.Equal(t, 1, first) + assert.Equal(t, 1, second) + assert.Len(t, missingKeyHandlers().handlers, 2) +} + +func TestSetMissingKeyHandlers_Good_Clear(t *testing.T) { + svc, err := New() + require.NoError(t, err) + prev := Default() + SetDefault(svc) + prevHandlers := missingKeyHandlers() + t.Cleanup(func() { + missingKeyHandler.Store(prevHandlers) + SetDefault(prev) + }) + svc.SetMode(ModeCollect) + + var called int + SetMissingKeyHandlers(func(MissingKey) { called++ }) + SetMissingKeyHandlers(nil) + + _ = T("missing.set.handlers.clear") + + assert.Equal(t, 0, called) + assert.Empty(t, missingKeyHandlers().handlers) +} + func TestAddMissingKeyHandler_Good_Concurrent(t *testing.T) { svc, err := New() require.NoError(t, err)