From f2f1d65c65e4345d3f3ef9a90d0975ffffb1ace8 Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 21:41:00 +0000 Subject: [PATCH] Refresh DNS cache when aliases change --- service.go | 48 +++++++++++++++++++++++----- service_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 7 deletions(-) diff --git a/service.go b/service.go index 3d30013..1ec7fdb 100644 --- a/service.go +++ b/service.go @@ -57,6 +57,7 @@ type Service struct { reverseIndex *cache.Cache treeRoot string zoneApex string + lastAliasFingerprint string hsdClient *HSDClient mainchainAliasClient *MainchainAliasClient chainAliasActionCaller ActionCaller @@ -160,6 +161,19 @@ func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HS return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, resolved) } +func aliasFingerprint(aliases []string) string { + normalized := normalizeAliasList(aliases) + builder := strings.Builder{} + for index, alias := range normalized { + if index > 0 { + builder.WriteByte('\n') + } + builder.WriteString(alias) + } + sum := sha256.Sum256([]byte(builder.String())) + return hex.EncodeToString(sum[:]) +} + func (service *Service) discoverAliasesFromSources( ctx context.Context, actionCaller ActionCaller, @@ -369,7 +383,8 @@ func (service *Service) discoverFromChainAliasesUsingTreeRoot(ctx context.Contex } now := time.Now() - if service.shouldUseCachedTreeRoot(now) { + fingerprint := aliasFingerprint(aliases) + if service.shouldUseCachedTreeRoot(now, fingerprint) { return nil } @@ -379,8 +394,8 @@ func (service *Service) discoverFromChainAliasesUsingTreeRoot(ctx context.Contex } cachedRoot := service.getChainTreeRoot() - if cachedRoot != "" && cachedRoot == info.TreeRoot { - service.recordTreeRootCheck(now) + if cachedRoot != "" && cachedRoot == info.TreeRoot && service.getLastAliasFingerprint() == fingerprint { + service.recordTreeRootCheck(now, fingerprint) return nil } @@ -388,13 +403,16 @@ func (service *Service) discoverFromChainAliasesUsingTreeRoot(ctx context.Contex return err } - service.recordTreeRootState(now, info.TreeRoot) + service.recordTreeRootState(now, info.TreeRoot, fingerprint) return nil } -func (service *Service) shouldUseCachedTreeRoot(now time.Time) bool { +func (service *Service) shouldUseCachedTreeRoot(now time.Time, aliasFingerprint string) bool { service.mu.RLock() defer service.mu.RUnlock() + if service.lastAliasFingerprint != aliasFingerprint { + return false + } if service.lastTreeRootCheck.IsZero() { return false } @@ -410,17 +428,31 @@ func (service *Service) getChainTreeRoot() string { return service.chainTreeRoot } -func (service *Service) recordTreeRootCheck(now time.Time) { +func (service *Service) getLastAliasFingerprint() string { + service.mu.RLock() + defer service.mu.RUnlock() + return service.lastAliasFingerprint +} + +func (service *Service) recordAliasFingerprint(aliasFingerprint string) { + service.mu.Lock() + defer service.mu.Unlock() + service.lastAliasFingerprint = aliasFingerprint +} + +func (service *Service) recordTreeRootCheck(now time.Time, aliasFingerprint string) { service.mu.Lock() defer service.mu.Unlock() service.lastTreeRootCheck = now + service.lastAliasFingerprint = aliasFingerprint } -func (service *Service) recordTreeRootState(now time.Time, treeRoot string) { +func (service *Service) recordTreeRootState(now time.Time, treeRoot string, aliasFingerprint string) { service.mu.Lock() defer service.mu.Unlock() service.lastTreeRootCheck = now service.chainTreeRoot = treeRoot + service.lastAliasFingerprint = aliasFingerprint } // Discover refreshes the cache from the configured discoverer or fallback. @@ -553,6 +585,7 @@ func (service *Service) DiscoverWithHSD(ctx context.Context, aliases []string, c return fmt.Errorf("hsd client is required") } + fingerprint := aliasFingerprint(aliases) resolved := make(map[string]NameRecords, len(aliases)) for _, alias := range aliases { normalized := normalizeName(alias) @@ -568,6 +601,7 @@ func (service *Service) DiscoverWithHSD(ctx context.Context, aliases []string, c } service.replaceRecords(resolved) + service.recordAliasFingerprint(fingerprint) return nil } diff --git a/service_test.go b/service_test.go index ecd6f5a..95e60b0 100644 --- a/service_test.go +++ b/service_test.go @@ -765,6 +765,90 @@ func TestServiceDiscoverAliasesParsesAliasDetailRecordsFromActionCaller(t *testi } } +func TestServiceDiscoverAliasesRefreshesWhenAliasListChangesBeforeTreeRootIntervalExpires(t *testing.T) { + var treeRootCalls int32 + var nameResourceCalls int32 + aliasListIndex := 0 + + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + + switch payload.Method { + case "getblockchaininfo": + atomic.AddInt32(&treeRootCalls, 1) + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "tree_root": "shared-tree-root", + }, + }) + case "getnameresource": + atomic.AddInt32(&nameResourceCalls, 1) + responseWriter.Header().Set("Content-Type", "application/json") + switch payload.Params[0] { + case "gateway.charon.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + case "node.charon.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "aaaa": []string{"2600:1f1c:7f0:4f01::2"}, + }, + }) + default: + t.Fatalf("unexpected alias lookup: %#v", payload.Params) + } + default: + t.Fatalf("unexpected method: %s", payload.Method) + } + })) + defer server.Close() + + service := NewService(ServiceOptions{ + TreeRootCheckInterval: time.Hour, + ChainAliasDiscoverer: func(_ context.Context) ([]string, error) { + defer func() { aliasListIndex++ }() + if aliasListIndex == 0 { + return []string{"gateway.charon.lthn"}, nil + } + return []string{"gateway.charon.lthn", "node.charon.lthn"}, nil + }, + HSDClient: NewHSDClient(HSDClientOptions{URL: server.URL}), + }) + + if err := service.DiscoverAliases(context.Background()); err != nil { + t.Fatalf("expected first DiscoverAliases call to succeed: %v", err) + } + if atomic.LoadInt32(&treeRootCalls) != 1 || atomic.LoadInt32(&nameResourceCalls) != 1 { + t.Fatalf("expected first discovery to query tree root and one alias, got treeRoot=%d nameResource=%d", atomic.LoadInt32(&treeRootCalls), atomic.LoadInt32(&nameResourceCalls)) + } + + if err := service.DiscoverAliases(context.Background()); err != nil { + t.Fatalf("expected second DiscoverAliases call to refresh changed aliases: %v", err) + } + if atomic.LoadInt32(&treeRootCalls) != 2 || atomic.LoadInt32(&nameResourceCalls) != 3 { + t.Fatalf("expected alias change to force refresh, got treeRoot=%d nameResource=%d", atomic.LoadInt32(&treeRootCalls), atomic.LoadInt32(&nameResourceCalls)) + } + + gateway, ok := service.Resolve("gateway.charon.lthn") + if !ok || len(gateway.A) != 1 || gateway.A[0] != "10.10.10.10" { + t.Fatalf("expected refreshed gateway record, got %#v (ok=%t)", gateway, ok) + } + node, ok := service.Resolve("node.charon.lthn") + if !ok || len(node.AAAA) != 1 || node.AAAA[0] != "2600:1f1c:7f0:4f01::2" { + t.Fatalf("expected refreshed node record, got %#v (ok=%t)", node, ok) + } +} + func TestServiceDiscoverFallsBackWhenPrimaryDiscovererFails(t *testing.T) { primaryCalled := false fallbackCalled := false -- 2.45.3