diff --git a/action.go b/action.go index 13de62d..42301df 100644 --- a/action.go +++ b/action.go @@ -170,9 +170,10 @@ func (service *Service) RegisterActions(registrar ActionRegistrar) { // service := dns.NewServiceWithRegistrar(dns.ServiceOptions{}, registrar) // // registrar now exposes dns.resolve, dns.resolve.txt, dns.resolve.all, dns.reverse, dns.serve, dns.health, dns.discover func NewServiceWithRegistrar(options ServiceOptions, registrar ActionRegistrar) *Service { - service := NewService(options) - service.RegisterActions(registrar) - return service + if registrar != nil { + options.ActionRegistrar = registrar + } + return NewService(options) } // HandleAction executes a DNS action by name. diff --git a/service.go b/service.go index 19c5f77..ad32a83 100644 --- a/service.go +++ b/service.go @@ -109,6 +109,7 @@ type ServiceOptions struct { ChainAliasDiscoverer func(context.Context) ([]string, error) FallbackChainAliasDiscoverer func(context.Context) ([]string, error) TreeRootCheckInterval time.Duration + ActionRegistrar ActionRegistrar } // Options is the documented constructor shape for NewService. @@ -156,7 +157,7 @@ func NewService(options ServiceOptions) *Service { cached[normalizeName(name)] = record } treeRoot := computeTreeRoot(cached) - return &Service{ + service := &Service{ records: cached, reverseIndex: buildReverseIndexCache(cached), treeRoot: treeRoot, @@ -171,6 +172,12 @@ func NewService(options ServiceOptions) *Service { fallbackChainAliasDiscoverer: options.FallbackChainAliasDiscoverer, treeRootCheckInterval: checkInterval, } + + if options.ActionRegistrar != nil { + service.RegisterActions(options.ActionRegistrar) + } + + return service } func (service *Service) resolveHSDClient(client *HSDClient) (*HSDClient, error) { diff --git a/service_test.go b/service_test.go index 47aeed6..41607c8 100644 --- a/service_test.go +++ b/service_test.go @@ -1942,6 +1942,32 @@ func TestNewServiceWithRegistrarBuildsAndRegistersInOneStep(t *testing.T) { } } +func TestNewServiceAutoRegistersActionsWhenRegistrarIsConfigured(t *testing.T) { + registrar := &actionRecorder{} + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + }, + }, + ActionRegistrar: registrar, + }) + + if service == nil { + t.Fatal("expected service to be built") + } + + expected := service.ActionNames() + if len(registrar.names) != len(expected) { + t.Fatalf("expected constructor to auto-register %d actions, got %d: %#v", len(expected), len(registrar.names), registrar.names) + } + for index, name := range expected { + if registrar.names[index] != name { + t.Fatalf("unexpected auto-registered action at %d: got %q want %q", index, registrar.names[index], name) + } + } +} + func TestServiceActionDefinitionsHaveInvokers(t *testing.T) { service := NewService(ServiceOptions{ Records: map[string]NameRecords{