diff --git a/service.go b/service.go index 6047f56..8130ba3 100644 --- a/service.go +++ b/service.go @@ -9,8 +9,11 @@ import ( "slices" "strings" "sync" + "time" ) +const defaultTreeRootCheckInterval = 15 * time.Second + type NameRecords struct { A []string `json:"a"` AAAA []string `json:"aaaa"` @@ -38,6 +41,9 @@ type Service struct { fallbackDiscoverer func() (map[string]NameRecords, error) chainAliasDiscoverer func(context.Context) ([]string, error) fallbackChainAliasDiscoverer func(context.Context) ([]string, error) + lastTreeRootCheck time.Time + chainTreeRoot string + treeRootCheckInterval time.Duration } type ServiceOptions struct { @@ -46,9 +52,15 @@ type ServiceOptions struct { FallbackDiscoverer func() (map[string]NameRecords, error) ChainAliasDiscoverer func(context.Context) ([]string, error) FallbackChainAliasDiscoverer func(context.Context) ([]string, error) + TreeRootCheckInterval time.Duration } func NewService(options ServiceOptions) *Service { + checkInterval := options.TreeRootCheckInterval + if checkInterval <= 0 { + checkInterval = defaultTreeRootCheckInterval + } + cached := make(map[string]NameRecords, len(options.Records)) for name, record := range options.Records { cached[normalizeName(name)] = record @@ -62,6 +74,7 @@ func NewService(options ServiceOptions) *Service { fallbackDiscoverer: options.FallbackDiscoverer, chainAliasDiscoverer: options.ChainAliasDiscoverer, fallbackChainAliasDiscoverer: options.FallbackChainAliasDiscoverer, + treeRootCheckInterval: checkInterval, } } @@ -77,7 +90,67 @@ func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HS if aliases == nil { return nil } - return service.DiscoverWithHSD(ctx, aliases, client) + return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, client) +} + +func (service *Service) discoverFromChainAliasesUsingTreeRoot(ctx context.Context, aliases []string, client *HSDClient) error { + if len(aliases) == 0 { + return nil + } + + now := time.Now() + if service.shouldUseCachedTreeRoot(now) { + return nil + } + + info, err := client.GetBlockchainInfo(ctx) + if err != nil { + return err + } + + cachedRoot := service.getChainTreeRoot() + if cachedRoot != "" && cachedRoot == info.TreeRoot { + service.recordTreeRootCheck(now) + return nil + } + + if err := service.DiscoverWithHSD(ctx, aliases, client); err != nil { + return err + } + + service.recordTreeRootState(now, info.TreeRoot) + return nil +} + +func (service *Service) shouldUseCachedTreeRoot(now time.Time) bool { + service.mu.RLock() + defer service.mu.RUnlock() + if service.lastTreeRootCheck.IsZero() { + return false + } + if service.treeRootCheckInterval <= 0 { + return false + } + return now.Sub(service.lastTreeRootCheck) < service.treeRootCheckInterval +} + +func (service *Service) getChainTreeRoot() string { + service.mu.RLock() + defer service.mu.RUnlock() + return service.chainTreeRoot +} + +func (service *Service) recordTreeRootCheck(now time.Time) { + service.mu.Lock() + defer service.mu.Unlock() + service.lastTreeRootCheck = now +} + +func (service *Service) recordTreeRootState(now time.Time, treeRoot string) { + service.mu.Lock() + defer service.mu.Unlock() + service.lastTreeRootCheck = now + service.chainTreeRoot = treeRoot } func discoverAliases(ctx context.Context, discoverer func(context.Context) ([]string, error), fallback func(context.Context) ([]string, error)) ([]string, error) { diff --git a/service_test.go b/service_test.go index 1029ffe..7328195 100644 --- a/service_test.go +++ b/service_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" @@ -339,6 +340,8 @@ func TestServiceDiscoverUsesFallbackOnlyWhenPrimaryMissing(t *testing.T) { func TestServiceDiscoverFromChainAliasesUsesFallbackWhenPrimaryFails(t *testing.T) { primaryCalled := false fallbackCalled := false + var treeRootCalls int32 + var nameResourceCalls int32 server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { var payload struct { @@ -348,21 +351,35 @@ func TestServiceDiscoverFromChainAliasesUsesFallbackWhenPrimaryFails(t *testing. if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { t.Fatalf("unexpected request payload: %v", err) } - switch payload.Params[0] { - case "gateway.charon.lthn": + + switch payload.Method { + case "getblockchaininfo": + atomic.AddInt32(&treeRootCalls, 1) _ = json.NewEncoder(responseWriter).Encode(map[string]any{ "result": map[string]any{ - "a": []string{"10.10.10.10"}, - }, - }) - case "node.charon.lthn": - _ = json.NewEncoder(responseWriter).Encode(map[string]any{ - "result": map[string]any{ - "aaaa": []string{"2600:1f1c:7f0:4f01::2"}, + "tree_root": "root-1", }, }) + case "getnameresource": + atomic.AddInt32(&nameResourceCalls, 1) + switch payload.Params[0] { + case "gateway.charon.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + case "node.charon.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "aaaa": []string{"2600:1f1c:7f0:4f01::2"}, + }, + }) + default: + t.Fatalf("unexpected alias query: %#v", payload.Params) + } default: - t.Fatalf("unexpected alias query: %#v", payload.Params) + t.Fatalf("unexpected method: %s", payload.Method) } })) defer server.Close() @@ -399,6 +416,74 @@ func TestServiceDiscoverFromChainAliasesUsesFallbackWhenPrimaryFails(t *testing. if !ok || len(node.AAAA) != 1 || node.AAAA[0] != "2600:1f1c:7f0:4f01::2" { t.Fatalf("expected node AAAA record, got %#v (ok=%t)", node, ok) } + + if atomic.LoadInt32(&treeRootCalls) != 1 || atomic.LoadInt32(&nameResourceCalls) != 2 { + t.Fatalf("expected one tree-root and two name-resource RPC calls, got treeRoot=%d nameResource=%d", atomic.LoadInt32(&treeRootCalls), atomic.LoadInt32(&nameResourceCalls)) + } +} + +func TestServiceDiscoverFromChainAliasesSkipsRefreshWhenTreeRootUnchanged(t *testing.T) { + var treeRootCalls int32 + var nameResourceCalls int32 + + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + + switch payload.Method { + case "getblockchaininfo": + atomic.AddInt32(&treeRootCalls, 1) + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "tree_root": "same-root", + }, + }) + case "getnameresource": + atomic.AddInt32(&nameResourceCalls, 1) + switch payload.Params[0] { + case "gateway.charon.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + default: + t.Fatalf("unexpected alias query: %#v", payload.Params) + } + default: + t.Fatalf("unexpected method: %s", payload.Method) + } + })) + defer server.Close() + + service := NewService(ServiceOptions{ + TreeRootCheckInterval: 5 * time.Second, + ChainAliasDiscoverer: func(_ context.Context) ([]string, error) { + return []string{"gateway.charon.lthn"}, nil + }, + }) + + client := NewHSDClient(HSDClientOptions{ + URL: server.URL, + }) + if err := service.DiscoverFromChainAliases(context.Background(), client); err != nil { + t.Fatalf("expected first chain alias discovery to run: %v", err) + } + if err := service.DiscoverFromChainAliases(context.Background(), client); err != nil { + t.Fatalf("expected second chain alias discovery to skip refresh: %v", err) + } + + if atomic.LoadInt32(&treeRootCalls) != 1 { + t.Fatalf("expected one tree_root check in interval window, got %d", atomic.LoadInt32(&treeRootCalls)) + } + if atomic.LoadInt32(&nameResourceCalls) != 1 { + t.Fatalf("expected one name-resource query while refreshing, got %d", atomic.LoadInt32(&nameResourceCalls)) + } } func TestServiceDiscoverFromChainAliasesIgnoresMissingDiscoverers(t *testing.T) {