From 0f3b6e2f81ba0da7a43eff5b5bd7e207262759eb Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 20:45:14 +0000 Subject: [PATCH] fix(dns): restrict soa answers to zone apex Co-Authored-By: Virgil --- serve.go | 22 ++++++++++++---------- service.go | 41 +++++++++++++++++++++++++++++++++++++++++ service_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 10 deletions(-) diff --git a/serve.go b/serve.go index 6a8b6f5..6002568 100644 --- a/serve.go +++ b/serve.go @@ -188,16 +188,18 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr if !found { goto noRecord } - reply.Answer = append(reply.Answer, &dnsprotocol.SOA{ - Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeSOA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, - Ns: normalizeName("ns."+name) + ".", - Mbox: "hostmaster." + name + ".", - Serial: uint32(time.Now().Unix()), - Refresh: 86400, - Retry: 3600, - Expire: 3600, - Minttl: 300, - }) + if normalizeName(name) == handler.service.ZoneApex() { + reply.Answer = append(reply.Answer, &dnsprotocol.SOA{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeSOA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ns: normalizeName("ns."+name) + ".", + Mbox: "hostmaster." + name + ".", + Serial: uint32(time.Now().Unix()), + Refresh: 86400, + Retry: 3600, + Expire: 3600, + Minttl: 300, + }) + } default: reply.SetRcode(request, dnsprotocol.RcodeNotImplemented) _ = responseWriter.WriteMsg(reply) diff --git a/service.go b/service.go index 3568463..ab10631 100644 --- a/service.go +++ b/service.go @@ -45,6 +45,7 @@ type Service struct { records map[string]NameRecords reverseIndex map[string][]string treeRoot string + zoneApex string hsdClient *HSDClient mainchainAliasClient *MainchainAliasClient discoverer func() (map[string]NameRecords, error) @@ -82,6 +83,7 @@ func NewService(options ServiceOptions) *Service { records: cached, reverseIndex: buildReverseIndex(cached), treeRoot: treeRoot, + zoneApex: computeZoneApex(cached), hsdClient: options.HSDClient, mainchainAliasClient: options.MainchainAliasClient, discoverer: options.Discoverer, @@ -353,6 +355,7 @@ func (service *Service) replaceRecords(discovered map[string]NameRecords) { service.records = cached service.reverseIndex = buildReverseIndex(service.records) service.treeRoot = computeTreeRoot(service.records) + service.zoneApex = computeZoneApex(service.records) } func (service *Service) SetRecord(name string, record NameRecords) { @@ -361,6 +364,7 @@ func (service *Service) SetRecord(name string, record NameRecords) { service.records[normalizeName(name)] = record service.reverseIndex = buildReverseIndex(service.records) service.treeRoot = computeTreeRoot(service.records) + service.zoneApex = computeZoneApex(service.records) } func (service *Service) RemoveRecord(name string) { @@ -369,6 +373,7 @@ func (service *Service) RemoveRecord(name string) { delete(service.records, normalizeName(name)) service.reverseIndex = buildReverseIndex(service.records) service.treeRoot = computeTreeRoot(service.records) + service.zoneApex = computeZoneApex(service.records) } func (service *Service) Resolve(name string) (ResolveAllResult, bool) { @@ -481,6 +486,12 @@ func (service *Service) Health() map[string]any { } } +func (service *Service) ZoneApex() string { + service.mu.RLock() + defer service.mu.RUnlock() + return service.zoneApex +} + func (service *Service) ResolveReverseNames(ip string) (ReverseLookupResult, bool) { names, ok := service.ResolveReverse(ip) if !ok { @@ -590,6 +601,36 @@ func computeTreeRoot(records map[string]NameRecords) string { return hex.EncodeToString(sum[:]) } +func computeZoneApex(records map[string]NameRecords) string { + names := make([]string, 0, len(records)) + for name := range records { + if strings.HasPrefix(name, "*.") { + continue + } + names = append(names, name) + } + if len(names) == 0 { + return "" + } + + commonLabels := strings.Split(names[0], ".") + for _, name := range names[1:] { + labels := strings.Split(name, ".") + commonSuffixLength := 0 + for commonSuffixLength < len(commonLabels) && commonSuffixLength < len(labels) { + if commonLabels[len(commonLabels)-1-commonSuffixLength] != labels[len(labels)-1-commonSuffixLength] { + break + } + commonSuffixLength++ + } + if commonSuffixLength == 0 { + return "" + } + commonLabels = commonLabels[len(commonLabels)-commonSuffixLength:] + } + return strings.Join(commonLabels, ".") +} + func serializeRecordValues(values []string) string { copied := append([]string(nil), values...) slices.Sort(copied) diff --git a/service_test.go b/service_test.go index 1ae0d18..83a7f19 100644 --- a/service_test.go +++ b/service_test.go @@ -957,6 +957,49 @@ func TestServiceServeResolvesWildcardAndPTRRecords(t *testing.T) { } } +func TestServiceServeAnswersSOAOnlyForZoneApex(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "charon.lthn": { + NS: []string{"ns1.charon.lthn"}, + }, + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + }, + }, + }) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + apexRequest := new(dnsprotocol.Msg) + apexRequest.SetQuestion("charon.lthn.", dnsprotocol.TypeSOA) + apexResponse := exchangeWithRetry(t, client, apexRequest, srv.Address()) + if apexResponse.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("expected SOA query for apex to succeed, got %d", apexResponse.Rcode) + } + if len(apexResponse.Answer) != 1 { + t.Fatalf("expected one SOA answer for apex, got %d", len(apexResponse.Answer)) + } + if _, ok := apexResponse.Answer[0].(*dnsprotocol.SOA); !ok { + t.Fatalf("expected SOA answer for apex, got %#v", apexResponse.Answer[0]) + } + + subdomainRequest := new(dnsprotocol.Msg) + subdomainRequest.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeSOA) + subdomainResponse := exchangeWithRetry(t, client, subdomainRequest, srv.Address()) + if subdomainResponse.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("expected SOA query for non-apex existing name to succeed, got %d", subdomainResponse.Rcode) + } + if len(subdomainResponse.Answer) != 0 { + t.Fatalf("expected no SOA answer for non-apex name, got %#v", subdomainResponse.Answer) + } +} + func TestServiceServeReturnsNXDOMAINWhenMissing(t *testing.T) { service := NewService(ServiceOptions{}) -- 2.45.3