diff --git a/action.go b/action.go index 42301df..b43c8d7 100644 --- a/action.go +++ b/action.go @@ -22,8 +22,9 @@ var ( ) type ActionDefinition struct { - Name string - Invoke func(map[string]any) (any, bool, error) + Name string + Invoke func(map[string]any) (any, bool, error) + InvokeContext func(context.Context, map[string]any) (any, bool, error) } // ActionRegistrar publishes DNS actions into another Core surface. @@ -53,72 +54,46 @@ func (service *Service) ActionDefinitions() []ActionDefinition { { Name: ActionResolve, Invoke: func(values map[string]any) (any, bool, error) { - host, err := stringActionValue(values, "name") - if err != nil { - return nil, false, err - } - result, ok := service.ResolveAddress(host) - if !ok { - return nil, false, nil - } - return result, true, nil + return service.handleResolveAddress(context.Background(), values) + }, + InvokeContext: func(ctx context.Context, values map[string]any) (any, bool, error) { + return service.handleResolveAddress(ctx, values) }, }, { Name: ActionResolveTXT, Invoke: func(values map[string]any) (any, bool, error) { - host, err := stringActionValue(values, "name") - if err != nil { - return nil, false, err - } - result, ok := service.ResolveTXTRecords(host) - if !ok { - return nil, false, nil - } - return result, true, nil + return service.handleResolveTXTRecords(context.Background(), values) + }, + InvokeContext: func(ctx context.Context, values map[string]any) (any, bool, error) { + return service.handleResolveTXTRecords(ctx, values) }, }, { Name: ActionResolveAll, Invoke: func(values map[string]any) (any, bool, error) { - host, err := stringActionValue(values, "name") - if err != nil { - return nil, false, err - } - result, ok := service.ResolveAll(host) - if !ok { - return nil, false, nil - } - return result, true, nil + return service.handleResolveAll(context.Background(), values) + }, + InvokeContext: func(ctx context.Context, values map[string]any) (any, bool, error) { + return service.handleResolveAll(ctx, values) }, }, { Name: ActionReverse, Invoke: func(values map[string]any) (any, bool, error) { - ip, err := stringActionValue(values, "ip") - if err != nil { - return nil, false, err - } - result, ok := service.ResolveReverseNames(ip) - if !ok { - return nil, false, nil - } - return result, true, nil + return service.handleReverseLookup(context.Background(), values) + }, + InvokeContext: func(ctx context.Context, values map[string]any) (any, bool, error) { + return service.handleReverseLookup(ctx, values) }, }, { Name: ActionServe, Invoke: func(values map[string]any) (any, bool, error) { - bind, _ := stringActionValueOptional(values, "bind") - port, err := intActionValue(values, "port") - if err != nil { - return nil, false, err - } - result, err := service.Serve(bind, port) - if err != nil { - return nil, false, err - } - return result, true, nil + return service.handleServe(context.Background(), values) + }, + InvokeContext: func(ctx context.Context, values map[string]any) (any, bool, error) { + return service.handleServe(ctx, values) }, }, { @@ -126,6 +101,9 @@ func (service *Service) ActionDefinitions() []ActionDefinition { Invoke: func(map[string]any) (any, bool, error) { return service.Health(), true, nil }, + InvokeContext: func(_ context.Context, _ map[string]any) (any, bool, error) { + return service.Health(), true, nil + }, }, { Name: ActionDiscover, @@ -135,6 +113,12 @@ func (service *Service) ActionDefinitions() []ActionDefinition { } return service.Health(), true, nil }, + InvokeContext: func(ctx context.Context, _ map[string]any) (any, bool, error) { + if err := service.DiscoverAliases(ctx); err != nil { + return nil, false, err + } + return service.Health(), true, nil + }, }, } } @@ -182,14 +166,95 @@ func NewServiceWithRegistrar(options ServiceOptions, registrar ActionRegistrar) // "name": "gateway.charon.lthn", // }) func (service *Service) HandleAction(name string, values map[string]any) (any, bool, error) { + return service.HandleActionContext(context.Background(), name, values) +} + +// HandleActionContext executes a DNS action with the supplied context. +// +// payload, ok, err := service.HandleActionContext(ctx, ActionResolve, map[string]any{ +// "name": "gateway.charon.lthn", +// }) +func (service *Service) HandleActionContext(ctx context.Context, name string, values map[string]any) (any, bool, error) { + if ctx == nil { + ctx = context.Background() + } for _, definition := range service.ActionDefinitions() { if definition.Name == name { + if definition.InvokeContext != nil { + return definition.InvokeContext(ctx, values) + } return definition.Invoke(values) } } return nil, false, errActionNotFound } +func (service *Service) handleResolveAddress(ctx context.Context, values map[string]any) (any, bool, error) { + _ = ctx + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveAddress(host) + if !ok { + return nil, false, nil + } + return result, true, nil +} + +func (service *Service) handleResolveTXTRecords(ctx context.Context, values map[string]any) (any, bool, error) { + _ = ctx + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveTXTRecords(host) + if !ok { + return nil, false, nil + } + return result, true, nil +} + +func (service *Service) handleResolveAll(ctx context.Context, values map[string]any) (any, bool, error) { + _ = ctx + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveAll(host) + if !ok { + return nil, false, nil + } + return result, true, nil +} + +func (service *Service) handleReverseLookup(ctx context.Context, values map[string]any) (any, bool, error) { + _ = ctx + ip, err := stringActionValue(values, "ip") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveReverseNames(ip) + if !ok { + return nil, false, nil + } + return result, true, nil +} + +func (service *Service) handleServe(ctx context.Context, values map[string]any) (any, bool, error) { + _ = ctx + bind, _ := stringActionValueOptional(values, "bind") + port, err := intActionValue(values, "port") + if err != nil { + return nil, false, err + } + result, err := service.Serve(bind, port) + if err != nil { + return nil, false, err + } + return result, true, nil +} + func stringActionValue(values map[string]any, key string) (string, error) { if values == nil { return "", errActionMissingValue diff --git a/service_test.go b/service_test.go index 47b9d0f..a22cebf 100644 --- a/service_test.go +++ b/service_test.go @@ -2369,6 +2369,35 @@ func TestServiceHandleActionReverseHealthServeAndDiscover(t *testing.T) { } } +func TestServiceHandleActionContextPassesThroughToDiscover(t *testing.T) { + service := NewService(ServiceOptions{ + ChainAliasDiscoverer: func(ctx context.Context) ([]string, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + HSDClient: NewHSDClient(HSDClientOptions{ + URL: "http://127.0.0.1:1", + }), + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + payload, ok, err := service.HandleActionContext(ctx, ActionDiscover, nil) + if err == nil { + t.Fatal("expected discover action to fail for a canceled context") + } + if ok { + t.Fatal("expected discover action to report failure") + } + if payload != nil { + t.Fatalf("expected no payload on context cancellation, got %#v", payload) + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation error, got %v", err) + } +} + type actionRecorder struct { names []string handlers map[string]func(map[string]any) (any, bool, error)