diff --git a/service.go b/service.go index 09e67c9..99f6d08 100644 --- a/service.go +++ b/service.go @@ -150,7 +150,7 @@ func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HS return err } - aliases, err := service.discoverAliasesFromSources( + aliases, found, err := service.discoverAliasesFromSources( ctx, service.chainAliasActionCaller, service.chainAliasAction, @@ -161,7 +161,14 @@ func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HS if err != nil { return err } - if aliases == nil { + if !found { + return nil + } + if len(aliases) == 0 { + now := time.Now() + fingerprint := aliasFingerprint(aliases) + service.replaceRecords(map[string]NameRecords{}) + service.recordTreeRootState(now, "", fingerprint) return nil } return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, resolved) @@ -187,15 +194,15 @@ func (service *Service) discoverAliasesFromSources( discoverer func(context.Context) ([]string, error), fallback func(context.Context) ([]string, error), mainchainClient *MainchainAliasClient, -) ([]string, error) { +) ([]string, bool, error) { if aliases, ok := service.discoverAliasesFromActionCaller(ctx, actionCaller); ok { - return aliases, nil + return aliases, true, nil } if action != nil { aliases, err := action(ctx) if err == nil { - return aliases, nil + return aliases, true, nil } } @@ -203,38 +210,42 @@ func (service *Service) discoverAliasesFromSources( if fallback != nil { aliases, err := fallback(ctx) if err == nil { - return aliases, nil + return aliases, true, nil } if mainchainClient == nil { - return nil, err + return nil, false, err } - return mainchainClient.GetAllAliasDetails(ctx) + aliases, err = mainchainClient.GetAllAliasDetails(ctx) + return aliases, true, err } if mainchainClient == nil { - return nil, nil + return nil, false, nil } - return mainchainClient.GetAllAliasDetails(ctx) + aliases, err := mainchainClient.GetAllAliasDetails(ctx) + return aliases, true, err } aliases, err := discoverer(ctx) if err == nil { - return aliases, nil + return aliases, true, nil } if fallback == nil { if mainchainClient == nil { - return nil, err + return nil, false, err } - return mainchainClient.GetAllAliasDetails(ctx) + aliases, err = mainchainClient.GetAllAliasDetails(ctx) + return aliases, true, err } fallbackAliases, fallbackErr := fallback(ctx) if fallbackErr == nil { - return fallbackAliases, nil + return fallbackAliases, true, nil } if mainchainClient == nil { - return nil, fallbackErr + return nil, false, fallbackErr } - return mainchainClient.GetAllAliasDetails(ctx) + aliases, err = mainchainClient.GetAllAliasDetails(ctx) + return aliases, true, err } func (service *Service) discoverAliasesFromActionCaller(ctx context.Context, actionCaller ActionCaller) ([]string, bool) { @@ -350,7 +361,7 @@ func (service *Service) DiscoverFromMainchainAliases(ctx context.Context, chainC effectiveChainClient = service.mainchainAliasClient } - aliases, err := service.discoverAliasesFromSources( + aliases, found, err := service.discoverAliasesFromSources( ctx, service.chainAliasActionCaller, nil, @@ -367,7 +378,7 @@ func (service *Service) DiscoverFromMainchainAliases(ctx context.Context, chainC if service.fallbackChainAliasDiscoverer != nil { return service.fallbackChainAliasDiscoverer(ctx) } - if effectiveChainClient != nil && service.chainAliasDiscoverer != nil { + if effectiveChainClient != nil { return effectiveChainClient.GetAllAliasDetails(ctx) } return nil, nil @@ -377,7 +388,14 @@ func (service *Service) DiscoverFromMainchainAliases(ctx context.Context, chainC if err != nil { return err } + if !found { + return nil + } if len(aliases) == 0 { + now := time.Now() + fingerprint := aliasFingerprint(aliases) + service.replaceRecords(map[string]NameRecords{}) + service.recordTreeRootState(now, "", fingerprint) return nil } return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, resolvedHSDClient) diff --git a/service_test.go b/service_test.go index 095e278..9ed3ea5 100644 --- a/service_test.go +++ b/service_test.go @@ -677,6 +677,45 @@ func TestServiceDiscoverAliasesUsesConfiguredChainAliasAction(t *testing.T) { } } +func TestServiceDiscoverAliasesClearsCacheWhenAliasListBecomesEmpty(t *testing.T) { + var hsdCalls int32 + + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + atomic.AddInt32(&hsdCalls, 1) + t.Fatalf("unexpected HSD request while clearing an empty alias list") + })) + defer server.Close() + + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "legacy.charon.lthn": { + A: []string{"10.11.11.11"}, + }, + }, + ChainAliasDiscoverer: func(_ context.Context) ([]string, error) { + return []string{}, nil + }, + HSDClient: NewHSDClient(HSDClientOptions{URL: server.URL}), + }) + + if err := service.DiscoverAliases(context.Background()); err != nil { + t.Fatalf("expected empty alias discovery to succeed: %v", err) + } + + if _, ok := service.Resolve("legacy.charon.lthn"); ok { + t.Fatal("expected stale records to be cleared when the alias list is empty") + } + + health := service.Health() + if health.NamesCached != 0 { + t.Fatalf("expected empty cache after clearing aliases, got %d", health.NamesCached) + } + + if atomic.LoadInt32(&hsdCalls) != 0 { + t.Fatalf("expected no HSD requests when alias discovery returns empty, got %d", atomic.LoadInt32(&hsdCalls)) + } +} + func TestServiceDiscoverAliasesParsesAliasDetailRecordsFromActionCaller(t *testing.T) { var treeRootCalls int32 var nameResourceCalls int32