From 35e66a1ba82c1ff60c78294b13b1e2b0f7c58868 Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 21:06:45 +0000 Subject: [PATCH] Add chain alias action discovery hook --- service.go | 44 +++++++++++------------------- service_test.go | 72 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 28 deletions(-) diff --git a/service.go b/service.go index b496532..43a5546 100644 --- a/service.go +++ b/service.go @@ -48,6 +48,7 @@ type Service struct { zoneApex string hsdClient *HSDClient mainchainAliasClient *MainchainAliasClient + chainAliasAction func(context.Context) ([]string, error) discoverer func() (map[string]NameRecords, error) fallbackDiscoverer func() (map[string]NameRecords, error) chainAliasDiscoverer func(context.Context) ([]string, error) @@ -63,6 +64,7 @@ type ServiceOptions struct { FallbackDiscoverer func() (map[string]NameRecords, error) MainchainAliasClient *MainchainAliasClient HSDClient *HSDClient + ChainAliasAction func(context.Context) ([]string, error) ChainAliasDiscoverer func(context.Context) ([]string, error) FallbackChainAliasDiscoverer func(context.Context) ([]string, error) TreeRootCheckInterval time.Duration @@ -102,6 +104,7 @@ func NewService(options ServiceOptions) *Service { zoneApex: computeZoneApex(cached), hsdClient: options.HSDClient, mainchainAliasClient: options.MainchainAliasClient, + chainAliasAction: options.ChainAliasAction, discoverer: options.Discoverer, fallbackDiscoverer: options.FallbackDiscoverer, chainAliasDiscoverer: options.ChainAliasDiscoverer, @@ -126,8 +129,9 @@ func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HS return err } - aliases, err := discoverAliasesWithFallback( + aliases, err := service.discoverAliasesFromSources( ctx, + service.chainAliasAction, service.chainAliasDiscoverer, service.fallbackChainAliasDiscoverer, service.mainchainAliasClient, @@ -141,12 +145,20 @@ func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HS return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, resolved) } -func discoverAliasesWithFallback( +func (service *Service) discoverAliasesFromSources( ctx context.Context, + action func(context.Context) ([]string, error), discoverer func(context.Context) ([]string, error), fallback func(context.Context) ([]string, error), mainchainClient *MainchainAliasClient, ) ([]string, error) { + if action != nil { + aliases, err := action(ctx) + if err == nil { + return aliases, nil + } + } + if discoverer == nil { if fallback != nil { aliases, err := fallback(ctx) @@ -203,8 +215,9 @@ func (service *Service) DiscoverFromMainchainAliases(ctx context.Context, chainC effectiveChainClient = service.mainchainAliasClient } - aliases, err := discoverAliasesWithFallback( + aliases, err := service.discoverAliasesFromSources( ctx, + nil, func(ctx context.Context) ([]string, error) { if service.chainAliasDiscoverer != nil { return service.chainAliasDiscoverer(ctx) @@ -294,31 +307,6 @@ func (service *Service) recordTreeRootState(now time.Time, treeRoot string) { service.chainTreeRoot = treeRoot } -func discoverAliases(ctx context.Context, discoverer func(context.Context) ([]string, error), fallback func(context.Context) ([]string, error)) ([]string, error) { - if discoverer == nil { - if fallback == nil { - return nil, nil - } - aliases, err := fallback(ctx) - if err != nil { - return nil, err - } - return aliases, nil - } - - aliases, err := discoverer(ctx) - if err != nil { - if fallback == nil { - return nil, err - } - aliases, err = fallback(ctx) - if err != nil { - return nil, err - } - } - return aliases, nil -} - func (service *Service) Discover() error { discoverer := service.discoverer fallback := service.fallbackDiscoverer diff --git a/service_test.go b/service_test.go index 6a92599..4456dc9 100644 --- a/service_test.go +++ b/service_test.go @@ -516,6 +516,78 @@ func TestServiceDiscoverAliasesUsesConfiguredAliasDiscovery(t *testing.T) { } } +func TestServiceDiscoverAliasesUsesConfiguredChainAliasAction(t *testing.T) { + var treeRootCalls int32 + var nameResourceCalls int32 + actionCalled := false + discovererCalled := false + + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + + switch payload.Method { + case "getblockchaininfo": + atomic.AddInt32(&treeRootCalls, 1) + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "tree_root": "action-root-1", + }, + }) + case "getnameresource": + atomic.AddInt32(&nameResourceCalls, 1) + if len(payload.Params) != 1 || payload.Params[0] != "gateway.charon.lthn" { + t.Fatalf("unexpected alias lookup: %#v", payload.Params) + } + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + default: + t.Fatalf("unexpected method: %s", payload.Method) + } + })) + defer server.Close() + + service := NewService(ServiceOptions{ + ChainAliasAction: func(_ context.Context) ([]string, error) { + actionCalled = true + return []string{"gateway.charon.lthn"}, nil + }, + ChainAliasDiscoverer: func(_ context.Context) ([]string, error) { + discovererCalled = true + return nil, errors.New("discoverer should not be used when action succeeds") + }, + HSDClient: NewHSDClient(HSDClientOptions{URL: server.URL}), + }) + + if err := service.DiscoverAliases(context.Background()); err != nil { + t.Fatalf("expected DiscoverAliases to complete through chain alias action: %v", err) + } + if !actionCalled { + t.Fatal("expected chain alias action to be called") + } + if discovererCalled { + t.Fatal("expected chain alias discoverer to be skipped after action success") + } + + record, ok := service.Resolve("gateway.charon.lthn") + if !ok || len(record.A) != 1 || record.A[0] != "10.10.10.10" { + t.Fatalf("expected discovered gateway record, got %#v (ok=%t)", record, ok) + } + if atomic.LoadInt32(&treeRootCalls) != 1 || atomic.LoadInt32(&nameResourceCalls) != 1 { + t.Fatalf("expected one tree-root and one name-resource RPC call, got treeRoot=%d nameResource=%d", atomic.LoadInt32(&treeRootCalls), atomic.LoadInt32(&nameResourceCalls)) + } +} + func TestServiceDiscoverFallsBackWhenPrimaryDiscovererFails(t *testing.T) { primaryCalled := false fallbackCalled := false -- 2.45.3