Preserve context in action registration

This commit is contained in:
Virgil 2026-04-03 22:54:47 +00:00
parent 4e5bb7e398
commit 7dc86bb44c
2 changed files with 92 additions and 0 deletions

View file

@ -36,6 +36,18 @@ type ActionRegistrar interface {
RegisterAction(name string, invoke func(map[string]any) (any, bool, error))
}
// ActionContextRegistrar publishes DNS actions while preserving caller context.
//
// registrar.RegisterActionContext(
// ActionDiscover,
// func(ctx context.Context, values map[string]any) (any, bool, error) {
// return service.HandleActionContext(ctx, ActionDiscover, values)
// },
// )
type ActionContextRegistrar interface {
RegisterActionContext(name string, invoke func(context.Context, map[string]any) (any, bool, error))
}
// ActionCaller resolves named actions from another Core surface.
//
// aliases, ok, err := caller.CallAction(
@ -148,6 +160,20 @@ func (service *Service) RegisterActions(registrar ActionRegistrar) {
if registrar == nil {
return
}
if contextRegistrar, ok := registrar.(ActionContextRegistrar); ok {
for _, definition := range service.ActionDefinitions() {
invoke := definition.InvokeContext
if invoke == nil {
invoke = func(ctx context.Context, values map[string]any) (any, bool, error) {
return definition.Invoke(values)
}
}
contextRegistrar.RegisterActionContext(definition.Name, invoke)
}
return
}
for _, definition := range service.ActionDefinitions() {
registrar.RegisterAction(definition.Name, definition.Invoke)
}

View file

@ -2145,6 +2145,49 @@ func TestServiceRegisterActionsPublishesAllActionsInOrder(t *testing.T) {
}
}
func TestServiceRegisterActionsUsesContextAwareRegistrarWhenAvailable(t *testing.T) {
type ctxKey string
registrar := &actionContextRecorder{}
service := NewService(ServiceOptions{
ChainAliasDiscoverer: func(ctx context.Context) ([]string, error) {
value, ok := ctx.Value(ctxKey("discover-token")).(string)
if !ok {
t.Fatal("expected discover context to be preserved")
}
if value != "preserved" {
t.Fatalf("unexpected discover context value: %q", value)
}
return []string{"gateway.charon.lthn"}, nil
},
HSDClient: NewHSDClient(HSDClientOptions{
URL: "http://127.0.0.1:1",
}),
})
service.RegisterActions(registrar)
invoke, ok := registrar.contextHandlers[ActionDiscover]
if !ok {
t.Fatal("expected context-aware registrar to receive discover action")
}
ctx := context.WithValue(context.Background(), ctxKey("discover-token"), "preserved")
payload, succeeded, err := invoke(ctx, nil)
if err == nil {
t.Fatal("expected discover action to fail without an HSD endpoint")
}
if succeeded {
t.Fatal("expected discover action to report failure")
}
if payload != nil {
t.Fatalf("expected no payload on failure, got %#v", payload)
}
if !strings.Contains(err.Error(), "connection refused") && !strings.Contains(err.Error(), "hsd rpc request failed") {
t.Fatalf("expected discover action to propagate the HSD client error, got %v", err)
}
}
func TestNewServiceWithRegistrarBuildsAndRegistersInOneStep(t *testing.T) {
registrar := &actionRecorder{}
service := NewServiceWithRegistrar(ServiceOptions{
@ -2411,6 +2454,29 @@ func (recorder *actionRecorder) RegisterAction(name string, invoke func(map[stri
recorder.handlers[name] = invoke
}
type actionContextRecorder struct {
names []string
contextHandlers map[string]func(context.Context, map[string]any) (any, bool, error)
}
func (recorder *actionContextRecorder) RegisterAction(name string, invoke func(map[string]any) (any, bool, error)) {
if recorder.contextHandlers == nil {
recorder.contextHandlers = map[string]func(context.Context, map[string]any) (any, bool, error){}
}
recorder.names = append(recorder.names, name)
recorder.contextHandlers[name] = func(ctx context.Context, values map[string]any) (any, bool, error) {
return invoke(values)
}
}
func (recorder *actionContextRecorder) RegisterActionContext(name string, invoke func(context.Context, map[string]any) (any, bool, error)) {
if recorder.contextHandlers == nil {
recorder.contextHandlers = map[string]func(context.Context, map[string]any) (any, bool, error){}
}
recorder.names = append(recorder.names, name)
recorder.contextHandlers[name] = invoke
}
type actionCallerFunc func(context.Context, string, map[string]any) (any, bool, error)
func (caller actionCallerFunc) CallAction(ctx context.Context, name string, values map[string]any) (any, bool, error) {