diff --git a/service.go b/service.go index cff31cd..546869e 100644 --- a/service.go +++ b/service.go @@ -189,6 +189,13 @@ func NewService(options ServiceOptions) *Service { checkInterval = DefaultTreeRootCheckInterval } + chainAliasActionCaller := options.ChainAliasActionCaller + if chainAliasActionCaller == nil { + if actionCaller, ok := options.ActionRegistrar.(ActionCaller); ok { + chainAliasActionCaller = actionCaller + } + } + hsdClient := options.HSDClient if hsdClient == nil { hsdPassword := options.HSDPassword @@ -242,7 +249,7 @@ func NewService(options ServiceOptions) *Service { recordTTL: options.RecordTTL, hsdClient: hsdClient, mainchainAliasClient: mainchainClient, - chainAliasActionCaller: options.ChainAliasActionCaller, + chainAliasActionCaller: chainAliasActionCaller, chainAliasAction: options.ChainAliasAction, recordDiscoverer: options.RecordDiscoverer, fallbackRecordDiscoverer: options.FallbackRecordDiscoverer, diff --git a/service_test.go b/service_test.go index 0ea2ac4..95d12bc 100644 --- a/service_test.go +++ b/service_test.go @@ -899,6 +899,81 @@ func TestServiceDiscoverAliasesUsesConfiguredActionCaller(t *testing.T) { } } +func TestNewServiceInfersChainAliasActionCallerFromRegistrar(t *testing.T) { + var actionCalls int32 + var chainAliasCalls int32 + + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + + switch payload.Method { + case "getblockchaininfo": + atomic.AddInt32(&chainAliasCalls, 1) + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "tree_root": "action-registry-root", + }, + }) + case "getnameresource": + atomic.AddInt32(&chainAliasCalls, 1) + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + default: + t.Fatalf("unexpected method: %s", payload.Method) + } + })) + defer server.Close() + + registrar := &actionRegistrarAndCaller{ + actionRecorder: actionRecorder{}, + } + registrar.handlers = map[string]func(map[string]any) (any, bool, error){} + service := NewService(ServiceOptions{ + ActionRegistrar: registrar, + HSDClient: NewHSDClient(HSDClientOptions{ + URL: server.URL, + }), + RecordTTL: time.Minute, + }) + + registrar.actionCall = func(ctx context.Context, action string, values map[string]any) (any, bool, error) { + atomic.AddInt32(&actionCalls, 1) + if action != "blockchain.chain.aliases" { + t.Fatalf("unexpected action name: %s", action) + } + return map[string]any{ + "aliases": []any{"gateway.charon.lthn"}, + }, true, nil + } + + if err := service.DiscoverAliases(context.Background()); err != nil { + t.Fatalf("expected discover to use registrar action caller: %v", err) + } + + record, ok := service.Resolve("gateway.charon.lthn") + if !ok || len(record.A) != 1 || record.A[0] != "10.10.10.10" { + t.Fatalf("expected discovered gateway record, got %#v (ok=%t)", record, ok) + } + + if atomic.LoadInt32(&actionCalls) != 1 { + t.Fatalf("expected one action-caller invocation, got %d", atomic.LoadInt32(&actionCalls)) + } + if atomic.LoadInt32(&chainAliasCalls) != 2 { + t.Fatalf("expected two HSD calls (tree root + resource), got %d", atomic.LoadInt32(&chainAliasCalls)) + } +} + func TestServiceDiscoverAliasesTreatsNilActionResponseAsNoAliases(t *testing.T) { var actionCalled bool var fallbackCallCount int32 @@ -3302,3 +3377,15 @@ type actionCallerFunc func(context.Context, string, map[string]any) (any, bool, func (caller actionCallerFunc) CallAction(ctx context.Context, name string, values map[string]any) (any, bool, error) { return caller(ctx, name, values) } + +type actionRegistrarAndCaller struct { + actionRecorder + actionCall func(context.Context, string, map[string]any) (any, bool, error) +} + +func (registrar *actionRegistrarAndCaller) CallAction(ctx context.Context, name string, values map[string]any) (any, bool, error) { + if registrar.actionCall == nil { + return nil, false, nil + } + return registrar.actionCall(ctx, name, values) +}