From 7dc86bb44cde6b24488e9e7a3eee402f63d3afcc Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 22:54:47 +0000 Subject: [PATCH] Preserve context in action registration --- action.go | 26 +++++++++++++++++++ service_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/action.go b/action.go index 1a281af..6eebd74 100644 --- a/action.go +++ b/action.go @@ -36,6 +36,18 @@ type ActionRegistrar interface { RegisterAction(name string, invoke func(map[string]any) (any, bool, error)) } +// ActionContextRegistrar publishes DNS actions while preserving caller context. +// +// registrar.RegisterActionContext( +// ActionDiscover, +// func(ctx context.Context, values map[string]any) (any, bool, error) { +// return service.HandleActionContext(ctx, ActionDiscover, values) +// }, +// ) +type ActionContextRegistrar interface { + RegisterActionContext(name string, invoke func(context.Context, map[string]any) (any, bool, error)) +} + // ActionCaller resolves named actions from another Core surface. // // aliases, ok, err := caller.CallAction( @@ -148,6 +160,20 @@ func (service *Service) RegisterActions(registrar ActionRegistrar) { if registrar == nil { return } + + if contextRegistrar, ok := registrar.(ActionContextRegistrar); ok { + for _, definition := range service.ActionDefinitions() { + invoke := definition.InvokeContext + if invoke == nil { + invoke = func(ctx context.Context, values map[string]any) (any, bool, error) { + return definition.Invoke(values) + } + } + contextRegistrar.RegisterActionContext(definition.Name, invoke) + } + return + } + for _, definition := range service.ActionDefinitions() { registrar.RegisterAction(definition.Name, definition.Invoke) } diff --git a/service_test.go b/service_test.go index a22cebf..71aa1e7 100644 --- a/service_test.go +++ b/service_test.go @@ -2145,6 +2145,49 @@ func TestServiceRegisterActionsPublishesAllActionsInOrder(t *testing.T) { } } +func TestServiceRegisterActionsUsesContextAwareRegistrarWhenAvailable(t *testing.T) { + type ctxKey string + + registrar := &actionContextRecorder{} + service := NewService(ServiceOptions{ + ChainAliasDiscoverer: func(ctx context.Context) ([]string, error) { + value, ok := ctx.Value(ctxKey("discover-token")).(string) + if !ok { + t.Fatal("expected discover context to be preserved") + } + if value != "preserved" { + t.Fatalf("unexpected discover context value: %q", value) + } + return []string{"gateway.charon.lthn"}, nil + }, + HSDClient: NewHSDClient(HSDClientOptions{ + URL: "http://127.0.0.1:1", + }), + }) + + service.RegisterActions(registrar) + + invoke, ok := registrar.contextHandlers[ActionDiscover] + if !ok { + t.Fatal("expected context-aware registrar to receive discover action") + } + + ctx := context.WithValue(context.Background(), ctxKey("discover-token"), "preserved") + payload, succeeded, err := invoke(ctx, nil) + if err == nil { + t.Fatal("expected discover action to fail without an HSD endpoint") + } + if succeeded { + t.Fatal("expected discover action to report failure") + } + if payload != nil { + t.Fatalf("expected no payload on failure, got %#v", payload) + } + if !strings.Contains(err.Error(), "connection refused") && !strings.Contains(err.Error(), "hsd rpc request failed") { + t.Fatalf("expected discover action to propagate the HSD client error, got %v", err) + } +} + func TestNewServiceWithRegistrarBuildsAndRegistersInOneStep(t *testing.T) { registrar := &actionRecorder{} service := NewServiceWithRegistrar(ServiceOptions{ @@ -2411,6 +2454,29 @@ func (recorder *actionRecorder) RegisterAction(name string, invoke func(map[stri recorder.handlers[name] = invoke } +type actionContextRecorder struct { + names []string + contextHandlers map[string]func(context.Context, map[string]any) (any, bool, error) +} + +func (recorder *actionContextRecorder) RegisterAction(name string, invoke func(map[string]any) (any, bool, error)) { + if recorder.contextHandlers == nil { + recorder.contextHandlers = map[string]func(context.Context, map[string]any) (any, bool, error){} + } + recorder.names = append(recorder.names, name) + recorder.contextHandlers[name] = func(ctx context.Context, values map[string]any) (any, bool, error) { + return invoke(values) + } +} + +func (recorder *actionContextRecorder) RegisterActionContext(name string, invoke func(context.Context, map[string]any) (any, bool, error)) { + if recorder.contextHandlers == nil { + recorder.contextHandlers = map[string]func(context.Context, map[string]any) (any, bool, error){} + } + recorder.names = append(recorder.names, name) + recorder.contextHandlers[name] = invoke +} + type actionCallerFunc func(context.Context, string, map[string]any) (any, bool, error) func (caller actionCallerFunc) CallAction(ctx context.Context, name string, values map[string]any) (any, bool, error) { -- 2.45.3