diff --git a/serve.go b/serve.go index 9d8670d..138f973 100644 --- a/serve.go +++ b/serve.go @@ -217,15 +217,23 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr }) } case dnsprotocol.TypeNS: - if !found { - goto noRecord + if found { + for _, value := range record.NS { + reply.Answer = append(reply.Answer, &dnsprotocol.NS{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ns: normalizeName(value) + ".", + }) + } } - for _, value := range record.NS { + if len(reply.Answer) == 0 && normalizeName(name) == handler.service.ZoneApex() && handler.service.ZoneApex() != "" { reply.Answer = append(reply.Answer, &dnsprotocol.NS{ Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, - Ns: normalizeName(value) + ".", + Ns: normalizeName("ns."+name) + ".", }) } + if len(reply.Answer) == 0 && !found { + goto noRecord + } case dnsprotocol.TypePTR: ip, ok := parsePTRIP(name) if !ok { diff --git a/service_test.go b/service_test.go index b5d88d7..9bb1a6a 100644 --- a/service_test.go +++ b/service_test.go @@ -1099,6 +1099,43 @@ func TestServiceServeAnswersSOAForDerivedZoneApexWithoutExactRecord(t *testing.T } } +func TestServiceServeAnswersNSForDerivedZoneApexWithoutExactRecord(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + }, + "node.charon.lthn": { + A: []string{"10.10.10.11"}, + }, + }, + }) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + request := new(dnsprotocol.Msg) + request.SetQuestion("charon.lthn.", dnsprotocol.TypeNS) + response := exchangeWithRetry(t, client, request, srv.Address()) + if response.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("expected NS query for derived apex to succeed, got %d", response.Rcode) + } + if len(response.Answer) != 1 { + t.Fatalf("expected one NS answer for derived apex, got %d", len(response.Answer)) + } + ns, ok := response.Answer[0].(*dnsprotocol.NS) + if !ok { + t.Fatalf("expected NS answer for derived apex, got %#v", response.Answer[0]) + } + if ns.Ns != "ns.charon.lthn." { + t.Fatalf("expected synthesized apex NS, got %q", ns.Ns) + } +} + func TestServiceServeReturnsNXDOMAINWhenMissing(t *testing.T) { service := NewService(ServiceOptions{})