diff --git a/service.go b/service.go index 8d9e941..9f5f30e 100644 --- a/service.go +++ b/service.go @@ -45,6 +45,7 @@ type Service struct { records map[string]NameRecords reverseIndex map[string][]string treeRoot string + hsdClient *HSDClient mainchainAliasClient *MainchainAliasClient discoverer func() (map[string]NameRecords, error) fallbackDiscoverer func() (map[string]NameRecords, error) @@ -60,6 +61,7 @@ type ServiceOptions struct { Discoverer func() (map[string]NameRecords, error) FallbackDiscoverer func() (map[string]NameRecords, error) MainchainAliasClient *MainchainAliasClient + HSDClient *HSDClient ChainAliasDiscoverer func(context.Context) ([]string, error) FallbackChainAliasDiscoverer func(context.Context) ([]string, error) TreeRootCheckInterval time.Duration @@ -80,6 +82,7 @@ func NewService(options ServiceOptions) *Service { records: cached, reverseIndex: buildReverseIndex(cached), treeRoot: treeRoot, + hsdClient: options.HSDClient, mainchainAliasClient: options.MainchainAliasClient, discoverer: options.Discoverer, fallbackDiscoverer: options.FallbackDiscoverer, @@ -89,9 +92,20 @@ func NewService(options ServiceOptions) *Service { } } +func (service *Service) resolveHSDClient(client *HSDClient) (*HSDClient, error) { + if client != nil { + return client, nil + } + if service.hsdClient == nil { + return nil, fmt.Errorf("hsd client is required") + } + return service.hsdClient, nil +} + func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HSDClient) error { - if client == nil { - return fmt.Errorf("hsd client is required") + resolved, err := service.resolveHSDClient(client) + if err != nil { + return err } aliases, err := discoverAliasesWithFallback( @@ -106,7 +120,7 @@ func (service *Service) DiscoverFromChainAliases(ctx context.Context, client *HS if aliases == nil { return nil } - return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, client) + return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, resolved) } func discoverAliasesWithFallback( @@ -161,6 +175,11 @@ func discoverAliasesWithFallback( // URL: "http://127.0.0.1:14037", // })) func (service *Service) DiscoverFromMainchainAliases(ctx context.Context, chainClient *MainchainAliasClient, hsdClient *HSDClient) error { + resolvedHSDClient, err := service.resolveHSDClient(hsdClient) + if err != nil { + return err + } + aliases, err := discoverAliases( ctx, func(ctx context.Context) ([]string, error) { @@ -188,7 +207,7 @@ func (service *Service) DiscoverFromMainchainAliases(ctx context.Context, chainC if len(aliases) == 0 { return nil } - return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, hsdClient) + return service.discoverFromChainAliasesUsingTreeRoot(ctx, aliases, resolvedHSDClient) } func (service *Service) discoverFromChainAliasesUsingTreeRoot(ctx context.Context, aliases []string, client *HSDClient) error { diff --git a/service_test.go b/service_test.go index d895ef1..a00b54b 100644 --- a/service_test.go +++ b/service_test.go @@ -458,6 +458,69 @@ func TestServiceDiscoverFromChainAliasesUsesFallbackWhenPrimaryFails(t *testing. } } +func TestServiceDiscoverFromChainAliasesUsesConfiguredHSDClient(t *testing.T) { + var treeRootCalls int32 + var nameResourceCalls int32 + + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + + switch payload.Method { + case "getblockchaininfo": + atomic.AddInt32(&treeRootCalls, 1) + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "tree_root": "root-1", + }, + }) + case "getnameresource": + atomic.AddInt32(&nameResourceCalls, 1) + switch payload.Params[0] { + case "gateway.charon.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + default: + t.Fatalf("unexpected alias lookup: %#v", payload.Params) + } + default: + t.Fatalf("unexpected method: %s", payload.Method) + } + })) + defer server.Close() + + service := NewService(ServiceOptions{ + ChainAliasDiscoverer: func(_ context.Context) ([]string, error) { + return []string{"gateway.charon.lthn"}, nil + }, + HSDClient: NewHSDClient(HSDClientOptions{URL: server.URL}), + }) + + if err := service.DiscoverFromChainAliases(context.Background(), nil); err != nil { + t.Fatalf("expected chain alias discovery to complete: %v", err) + } + + if atomic.LoadInt32(&treeRootCalls) != 1 { + t.Fatalf("expected one tree-root call, got %d", atomic.LoadInt32(&treeRootCalls)) + } + if atomic.LoadInt32(&nameResourceCalls) != 1 { + t.Fatalf("expected one name-resource call, got %d", atomic.LoadInt32(&nameResourceCalls)) + } + + gateway, ok := service.Resolve("gateway.charon.lthn") + if !ok || len(gateway.A) != 1 || gateway.A[0] != "10.10.10.10" { + t.Fatalf("expected gateway A record, got %#v (ok=%t)", gateway, ok) + } +} + func TestServiceDiscoverFromChainAliasesFallsBackToMainchainClientWhenDiscovererFails(t *testing.T) { var chainAliasCalls int32 var treeRootCalls int32