diff --git a/action.go b/action.go index 302ed61..f7daee9 100644 --- a/action.go +++ b/action.go @@ -21,72 +21,127 @@ var ( errActionMissingValue = errors.New("dns action missing required value") ) +type ActionDefinition struct { + Name string + Invoke func(map[string]any) (any, bool, error) +} + +// ActionDefinitions returns the complete DNS action surface in registration order. +// +// service.ActionDefinitions() +func (service *Service) ActionDefinitions() []ActionDefinition { + return []ActionDefinition{ + { + Name: ActionResolve, + Invoke: func(values map[string]any) (any, bool, error) { + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveAddress(host) + if !ok { + return nil, false, nil + } + return result, true, nil + }, + }, + { + Name: ActionResolveTXT, + Invoke: func(values map[string]any) (any, bool, error) { + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveTXTRecords(host) + if !ok { + return nil, false, nil + } + return result, true, nil + }, + }, + { + Name: ActionResolveAll, + Invoke: func(values map[string]any) (any, bool, error) { + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveAll(host) + if !ok { + return nil, false, nil + } + return result, true, nil + }, + }, + { + Name: ActionReverse, + Invoke: func(values map[string]any) (any, bool, error) { + ip, err := stringActionValue(values, "ip") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveReverseNames(ip) + if !ok { + return nil, false, nil + } + return result, true, nil + }, + }, + { + Name: ActionServe, + Invoke: func(values map[string]any) (any, bool, error) { + bind, _ := stringActionValueOptional(values, "bind") + port, err := intActionValue(values, "port") + if err != nil { + return nil, false, err + } + result, err := service.Serve(bind, port) + if err != nil { + return nil, false, err + } + return result, true, nil + }, + }, + { + Name: ActionHealth, + Invoke: func(map[string]any) (any, bool, error) { + return service.Health(), true, nil + }, + }, + { + Name: ActionDiscover, + Invoke: func(map[string]any) (any, bool, error) { + if err := service.DiscoverAliases(context.Background()); err != nil { + return nil, false, err + } + return service.Health(), true, nil + }, + }, + } +} + +// ActionNames returns the names of the registered DNS actions. +// +// service.ActionNames() +func (service *Service) ActionNames() []string { + definitions := service.ActionDefinitions() + names := make([]string, 0, len(definitions)) + for _, definition := range definitions { + names = append(names, definition.Name) + } + return names +} + // HandleAction executes a DNS action by name. // // service.HandleAction("dns.resolve", map[string]any{"name": "gateway.charon.lthn"}) func (service *Service) HandleAction(name string, values map[string]any) (any, bool, error) { - switch name { - case ActionResolve: - host, err := stringActionValue(values, "name") - if err != nil { - return nil, false, err + for _, definition := range service.ActionDefinitions() { + if definition.Name == name { + return definition.Invoke(values) } - result, ok := service.ResolveAddress(host) - if !ok { - return nil, false, nil - } - return result, true, nil - case ActionResolveTXT: - host, err := stringActionValue(values, "name") - if err != nil { - return nil, false, err - } - result, ok := service.ResolveTXTRecords(host) - if !ok { - return nil, false, nil - } - return result, true, nil - case ActionResolveAll: - host, err := stringActionValue(values, "name") - if err != nil { - return nil, false, err - } - result, ok := service.ResolveAll(host) - if !ok { - return nil, false, nil - } - return result, true, nil - case ActionReverse: - ip, err := stringActionValue(values, "ip") - if err != nil { - return nil, false, err - } - result, ok := service.ResolveReverseNames(ip) - if !ok { - return nil, false, nil - } - return result, true, nil - case ActionServe: - bind, _ := stringActionValueOptional(values, "bind") - port, err := intActionValue(values, "port") - if err != nil { - return nil, false, err - } - result, err := service.Serve(bind, port) - if err != nil { - return nil, false, err - } - return result, true, nil - case ActionHealth: - return service.Health(), true, nil - case ActionDiscover: - if err := service.DiscoverAliases(context.Background()); err != nil { - return nil, false, err - } - return service.Health(), true, nil - default: - return nil, false, errActionNotFound } + return nil, false, errActionNotFound } func stringActionValue(values map[string]any, key string) (string, error) { diff --git a/hsd.go b/hsd.go index 7d8ac27..b741ef8 100644 --- a/hsd.go +++ b/hsd.go @@ -1,6 +1,7 @@ package dns import ( + "bytes" "context" "encoding/base64" "encoding/json" @@ -133,7 +134,7 @@ func (client *HSDClient) call(ctx context.Context, request HSDRPCRequest) (json. return nil, err } - httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, client.baseURL, io.NopCloser(io.Reader(strings.NewReader(string(body))))) + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, client.baseURL, bytes.NewReader(body)) if err != nil { return nil, err } diff --git a/mainchain.go b/mainchain.go index cb64b9d..11bf46d 100644 --- a/mainchain.go +++ b/mainchain.go @@ -1,6 +1,7 @@ package dns import ( + "bytes" "context" "encoding/json" "errors" @@ -89,7 +90,7 @@ func (client *MainchainAliasClient) call(ctx context.Context, request MainchainR return nil, err } - httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, client.baseURL, io.NopCloser(io.Reader(strings.NewReader(string(body))))) + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, client.baseURL, bytes.NewReader(body)) if err != nil { return nil, err } diff --git a/service_test.go b/service_test.go index d015c40..0593887 100644 --- a/service_test.go +++ b/service_test.go @@ -1039,6 +1039,82 @@ func TestServiceHandleActionResolveAndTXTAndAll(t *testing.T) { } } +func TestServiceActionNamesExposeAllRFCActions(t *testing.T) { + service := NewService(ServiceOptions{}) + + names := service.ActionNames() + expected := []string{ + ActionResolve, + ActionResolveTXT, + ActionResolveAll, + ActionReverse, + ActionServe, + ActionHealth, + ActionDiscover, + } + + if len(names) != len(expected) { + t.Fatalf("expected %d action names, got %d: %#v", len(expected), len(names), names) + } + for i, name := range expected { + if names[i] != name { + t.Fatalf("unexpected action name at %d: got %q want %q", i, names[i], name) + } + } +} + +func TestServiceActionDefinitionsHaveInvokers(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + }, + }, + }) + + definitions := service.ActionDefinitions() + if len(definitions) == 0 { + t.Fatal("expected action definitions") + } + + for _, definition := range definitions { + if definition.Name == "" { + t.Fatal("expected action definition name") + } + if definition.Invoke == nil { + t.Fatalf("expected action invoke for %s", definition.Name) + } + } + + resolveDefinition := definitions[0] + if resolveDefinition.Name != ActionResolve { + t.Fatalf("expected first action definition to be %s, got %s", ActionResolve, resolveDefinition.Name) + } + payload, ok, err := resolveDefinition.Invoke(map[string]any{ + "name": "gateway.charon.lthn", + }) + if err != nil { + t.Fatalf("unexpected action invoke error: %v", err) + } + if !ok { + t.Fatal("expected resolve action definition to return a record") + } + result, ok := payload.(ResolveAddressResult) + if !ok || len(result.Addresses) != 1 || result.Addresses[0] != "10.10.10.10" { + t.Fatalf("unexpected resolve payload: %#v", payload) + } + + handlePayload, handleOK, handleErr := service.HandleAction(ActionResolve, map[string]any{ + "name": "gateway.charon.lthn", + }) + if handleErr != nil || !handleOK { + t.Fatalf("unexpected handle action result: ok=%t err=%v", handleOK, handleErr) + } + if handleResult, ok := handlePayload.(ResolveAddressResult); !ok || len(handleResult.Addresses) != 1 || handleResult.Addresses[0] != "10.10.10.10" { + t.Fatalf("unexpected handle action payload: %#v", handlePayload) + } +} + func TestServiceHandleActionReverseHealthServeAndDiscover(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { var payload struct {