diff --git a/action.go b/action.go new file mode 100644 index 0000000..302ed61 --- /dev/null +++ b/action.go @@ -0,0 +1,143 @@ +package dns + +import ( + "context" + "errors" + "fmt" +) + +const ( + ActionResolve = "dns.resolve" + ActionResolveTXT = "dns.resolve.txt" + ActionResolveAll = "dns.resolve.all" + ActionReverse = "dns.reverse" + ActionServe = "dns.serve" + ActionHealth = "dns.health" + ActionDiscover = "dns.discover" +) + +var ( + errActionNotFound = errors.New("dns action not found") + errActionMissingValue = errors.New("dns action missing required value") +) + +// HandleAction executes a DNS action by name. +// +// service.HandleAction("dns.resolve", map[string]any{"name": "gateway.charon.lthn"}) +func (service *Service) HandleAction(name string, values map[string]any) (any, bool, error) { + switch name { + case ActionResolve: + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveAddress(host) + if !ok { + return nil, false, nil + } + return result, true, nil + case ActionResolveTXT: + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveTXTRecords(host) + if !ok { + return nil, false, nil + } + return result, true, nil + case ActionResolveAll: + host, err := stringActionValue(values, "name") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveAll(host) + if !ok { + return nil, false, nil + } + return result, true, nil + case ActionReverse: + ip, err := stringActionValue(values, "ip") + if err != nil { + return nil, false, err + } + result, ok := service.ResolveReverseNames(ip) + if !ok { + return nil, false, nil + } + return result, true, nil + case ActionServe: + bind, _ := stringActionValueOptional(values, "bind") + port, err := intActionValue(values, "port") + if err != nil { + return nil, false, err + } + result, err := service.Serve(bind, port) + if err != nil { + return nil, false, err + } + return result, true, nil + case ActionHealth: + return service.Health(), true, nil + case ActionDiscover: + if err := service.DiscoverAliases(context.Background()); err != nil { + return nil, false, err + } + return service.Health(), true, nil + default: + return nil, false, errActionNotFound + } +} + +func stringActionValue(values map[string]any, key string) (string, error) { + if values == nil { + return "", errActionMissingValue + } + raw, exists := values[key] + if !exists { + return "", errActionMissingValue + } + if value, ok := raw.(string); ok { + return value, nil + } + return "", errActionMissingValue +} + +func stringActionValueOptional(values map[string]any, key string) (string, error) { + if values == nil { + return "", nil + } + raw, exists := values[key] + if !exists { + return "", nil + } + value, ok := raw.(string) + if !ok { + return "", fmt.Errorf("%w: %s", errActionMissingValue, key) + } + return value, nil +} + +func intActionValue(values map[string]any, key string) (int, error) { + if values == nil { + return 0, errActionMissingValue + } + raw, exists := values[key] + if !exists { + return 0, errActionMissingValue + } + switch value := raw.(type) { + case int: + return value, nil + case int32: + return int(value), nil + case int64: + return int(value), nil + case float64: + return int(value), nil + case float32: + return int(value), nil + default: + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } +} diff --git a/service_test.go b/service_test.go index d25b8e0..bb4dc65 100644 --- a/service_test.go +++ b/service_test.go @@ -928,3 +928,179 @@ func TestServiceServeReturnsNXDOMAINWhenMissing(t *testing.T) { t.Fatalf("expected NXDOMAIN, got %d", response.Rcode) } } + +func TestServiceHandleActionResolveAndTXTAndAll(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + AAAA: []string{"2600:1f1c:7f0:4f01::1"}, + TXT: []string{"v=lthn1 type=gateway"}, + NS: []string{"ns.charon.lthn"}, + }, + }, + }) + + addresses, ok, err := service.HandleAction(ActionResolve, map[string]any{ + "name": "gateway.charon.lthn", + }) + if err != nil { + t.Fatalf("unexpected resolve action error: %v", err) + } + if !ok { + t.Fatal("expected resolve action to return a record") + } + payload, ok := addresses.(ResolveAddressResult) + if !ok { + t.Fatalf("expected ResolveAddressResult payload, got %T", addresses) + } + if len(payload.Addresses) != 2 || payload.Addresses[0] != "10.10.10.10" || payload.Addresses[1] != "2600:1f1c:7f0:4f01::1" { + t.Fatalf("unexpected resolve result: %#v", payload.Addresses) + } + + txtPayload, ok, err := service.HandleAction(ActionResolveTXT, map[string]any{ + "name": "gateway.charon.lthn", + }) + if err != nil { + t.Fatalf("unexpected txt action error: %v", err) + } + if !ok { + t.Fatal("expected txt action to return a record") + } + txts, ok := txtPayload.(ResolveTXTResult) + if !ok { + t.Fatalf("expected ResolveTXTResult payload, got %T", txtPayload) + } + if len(txts.TXT) != 1 || txts.TXT[0] != "v=lthn1 type=gateway" { + t.Fatalf("unexpected txt result: %#v", txts.TXT) + } + + allPayload, ok, err := service.HandleAction(ActionResolveAll, map[string]any{ + "name": "gateway.charon.lthn", + }) + if err != nil { + t.Fatalf("unexpected resolve.all action error: %v", err) + } + if !ok { + t.Fatal("expected resolve.all action to return a record") + } + all, ok := allPayload.(ResolveAllResult) + if !ok { + t.Fatalf("expected ResolveAllResult payload, got %T", allPayload) + } + if len(all.NS) != 1 || all.NS[0] != "ns.charon.lthn" { + t.Fatalf("unexpected resolve.all result: %#v", all) + } +} + +func TestServiceHandleActionReverseHealthServeAndDiscover(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + switch payload.Method { + case "getblockchaininfo": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "tree_root": "discover-root", + }, + }) + case "getnameresource": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + default: + t.Fatalf("unexpected method: %s", payload.Method) + } + })) + defer server.Close() + + service := NewService(ServiceOptions{ + ChainAliasDiscoverer: func(_ context.Context) ([]string, error) { + return []string{"gateway.charon.lthn"}, nil + }, + HSDClient: NewHSDClient(HSDClientOptions{ + URL: server.URL, + }), + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.20"}, + }, + }, + }) + + reversePayload, ok, err := service.HandleAction(ActionReverse, map[string]any{ + "ip": "10.10.10.20", + }) + if err != nil { + t.Fatalf("unexpected reverse action error: %v", err) + } + if !ok { + t.Fatal("expected reverse action to return a record") + } + reverse, ok := reversePayload.(ReverseLookupResult) + if !ok { + t.Fatalf("expected ReverseLookupResult payload, got %T", reversePayload) + } + if len(reverse.Names) != 1 || reverse.Names[0] != "gateway.charon.lthn" { + t.Fatalf("unexpected reverse result: %#v", reverse.Names) + } + + healthPayload, ok, err := service.HandleAction(ActionHealth, nil) + if err != nil { + t.Fatalf("unexpected health action error: %v", err) + } + if !ok { + t.Fatal("expected health action payload") + } + health, ok := healthPayload.(map[string]any) + if !ok { + t.Fatalf("expected health map payload, got %T", healthPayload) + } + if health["status"] != "ready" { + t.Fatalf("unexpected health payload: %#v", health) + } + + srvPayload, ok, err := service.HandleAction(ActionServe, map[string]any{ + "bind": "127.0.0.1", + "port": 0, + }) + if err != nil { + t.Fatalf("unexpected serve action error: %v", err) + } + if !ok { + t.Fatal("expected serve action to start server") + } + dnsServer, ok := srvPayload.(*DNSServer) + if !ok { + t.Fatalf("expected DNSServer payload, got %T", srvPayload) + } + if dnsServer.Address() == "" { + t.Fatal("expected server address from serve action") + } + _ = dnsServer.Close() + + discoverPayload, ok, err := service.HandleAction(ActionDiscover, nil) + if err != nil { + t.Fatalf("unexpected discover action error: %v", err) + } + if discoverPayload == nil || !ok { + t.Fatal("expected discover action payload") + } + if !ok { + t.Fatal("expected discover action to succeed") + } + discoverHealth, ok := discoverPayload.(map[string]any) + if !ok { + t.Fatalf("expected discover action payload map, got %T", discoverPayload) + } + if discoverHealth["tree_root"] != "discover-root" { + t.Fatalf("expected discover to refresh tree root, got %#v", discoverHealth["tree_root"]) + } +}