From 2c5ca56bfb2e68569aac84a1f124d6a39fbab5e0 Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 21:26:13 +0000 Subject: [PATCH] feat(dns): support ANY DNS queries Co-Authored-By: Virgil --- serve.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++ service_test.go | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/serve.go b/serve.go index ad6da16..0be39f6 100644 --- a/serve.go +++ b/serve.go @@ -270,6 +270,14 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr if !found { goto noRecord } + case dnsprotocol.TypeANY: + if found { + appendAnyAnswers(reply, question.Name, name, record) + } else if normalizeName(name) == handler.service.ZoneApex() && handler.service.ZoneApex() != "" { + appendAnyAnswers(reply, question.Name, name, NameRecords{}) + } else { + goto noRecord + } default: reply.SetRcode(request, dnsprotocol.RcodeNotImplemented) _ = responseWriter.WriteMsg(reply) @@ -340,3 +348,62 @@ func parsePTRIP(name string) (string, bool) { } return "", false } + +func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, name string, record NameRecords) { + for _, value := range record.A { + parsedIP := net.ParseIP(value) + if parsedIP == nil || parsedIP.To4() == nil { + continue + } + reply.Answer = append(reply.Answer, &dnsprotocol.A{ + Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + A: parsedIP.To4(), + }) + } + + for _, value := range record.AAAA { + parsedIP := net.ParseIP(value) + if parsedIP == nil || parsedIP.To4() != nil { + continue + } + reply.Answer = append(reply.Answer, &dnsprotocol.AAAA{ + Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeAAAA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + AAAA: parsedIP.To16(), + }) + } + + for _, value := range record.TXT { + reply.Answer = append(reply.Answer, &dnsprotocol.TXT{ + Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeTXT, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Txt: []string{value}, + }) + } + + if len(record.NS) > 0 { + for _, value := range record.NS { + reply.Answer = append(reply.Answer, &dnsprotocol.NS{ + Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ns: normalizeName(value) + ".", + }) + } + } + + if normalizeName(name) == normalizeName(questionName) && normalizeName(name) != "" { + if len(record.NS) == 0 { + reply.Answer = append(reply.Answer, &dnsprotocol.NS{ + Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ns: normalizeName("ns."+name) + ".", + }) + } + reply.Answer = append(reply.Answer, &dnsprotocol.SOA{ + Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeSOA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ns: normalizeName("ns."+name) + ".", + Mbox: "hostmaster." + name + ".", + Serial: uint32(time.Now().Unix()), + Refresh: 86400, + Retry: 3600, + Expire: 3600, + Minttl: 300, + }) + } +} diff --git a/service_test.go b/service_test.go index 173c1c9..7f719b6 100644 --- a/service_test.go +++ b/service_test.go @@ -1143,6 +1143,51 @@ func TestServiceServeResolvesAAndAAAARecords(t *testing.T) { } } +func TestServiceServeAnswersANYWithAllRecordTypes(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + AAAA: []string{"2600:1f1c:7f0:4f01::1"}, + TXT: []string{"v=lthn1 type=gateway"}, + NS: []string{"ns.gateway.charon.lthn"}, + }, + }, + }) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + request := new(dnsprotocol.Msg) + request.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeANY) + response := exchangeWithRetry(t, client, request, srv.Address()) + if response.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("unexpected ANY rcode: %d", response.Rcode) + } + + var sawA, sawAAAA, sawTXT, sawNS bool + for _, answer := range response.Answer { + switch rr := answer.(type) { + case *dnsprotocol.A: + sawA = rr.A.String() == "10.10.10.10" + case *dnsprotocol.AAAA: + sawAAAA = rr.AAAA.String() == "2600:1f1c:7f0:4f01::1" + case *dnsprotocol.TXT: + sawTXT = len(rr.Txt) == 1 && rr.Txt[0] == "v=lthn1 type=gateway" + case *dnsprotocol.NS: + sawNS = rr.Ns == "ns.gateway.charon.lthn." + } + } + + if !sawA || !sawAAAA || !sawTXT || !sawNS { + t.Fatalf("expected ANY answer to include A, AAAA, TXT, and NS records, got %#v", response.Answer) + } +} + func TestServiceServeResolvesWildcardAndPTRRecords(t *testing.T) { service := NewService(ServiceOptions{ Records: map[string]NameRecords{