feat(dns): support ANY DNS queries
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
b27160536d
commit
2c5ca56bfb
2 changed files with 112 additions and 0 deletions
67
serve.go
67
serve.go
|
|
@ -270,6 +270,14 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr
|
|||
if !found {
|
||||
goto noRecord
|
||||
}
|
||||
case dnsprotocol.TypeANY:
|
||||
if found {
|
||||
appendAnyAnswers(reply, question.Name, name, record)
|
||||
} else if normalizeName(name) == handler.service.ZoneApex() && handler.service.ZoneApex() != "" {
|
||||
appendAnyAnswers(reply, question.Name, name, NameRecords{})
|
||||
} else {
|
||||
goto noRecord
|
||||
}
|
||||
default:
|
||||
reply.SetRcode(request, dnsprotocol.RcodeNotImplemented)
|
||||
_ = responseWriter.WriteMsg(reply)
|
||||
|
|
@ -340,3 +348,62 @@ func parsePTRIP(name string) (string, bool) {
|
|||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, name string, record NameRecords) {
|
||||
for _, value := range record.A {
|
||||
parsedIP := net.ParseIP(value)
|
||||
if parsedIP == nil || parsedIP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
reply.Answer = append(reply.Answer, &dnsprotocol.A{
|
||||
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL},
|
||||
A: parsedIP.To4(),
|
||||
})
|
||||
}
|
||||
|
||||
for _, value := range record.AAAA {
|
||||
parsedIP := net.ParseIP(value)
|
||||
if parsedIP == nil || parsedIP.To4() != nil {
|
||||
continue
|
||||
}
|
||||
reply.Answer = append(reply.Answer, &dnsprotocol.AAAA{
|
||||
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeAAAA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL},
|
||||
AAAA: parsedIP.To16(),
|
||||
})
|
||||
}
|
||||
|
||||
for _, value := range record.TXT {
|
||||
reply.Answer = append(reply.Answer, &dnsprotocol.TXT{
|
||||
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeTXT, Class: dnsprotocol.ClassINET, Ttl: dnsTTL},
|
||||
Txt: []string{value},
|
||||
})
|
||||
}
|
||||
|
||||
if len(record.NS) > 0 {
|
||||
for _, value := range record.NS {
|
||||
reply.Answer = append(reply.Answer, &dnsprotocol.NS{
|
||||
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: dnsTTL},
|
||||
Ns: normalizeName(value) + ".",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if normalizeName(name) == normalizeName(questionName) && normalizeName(name) != "" {
|
||||
if len(record.NS) == 0 {
|
||||
reply.Answer = append(reply.Answer, &dnsprotocol.NS{
|
||||
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: dnsTTL},
|
||||
Ns: normalizeName("ns."+name) + ".",
|
||||
})
|
||||
}
|
||||
reply.Answer = append(reply.Answer, &dnsprotocol.SOA{
|
||||
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeSOA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL},
|
||||
Ns: normalizeName("ns."+name) + ".",
|
||||
Mbox: "hostmaster." + name + ".",
|
||||
Serial: uint32(time.Now().Unix()),
|
||||
Refresh: 86400,
|
||||
Retry: 3600,
|
||||
Expire: 3600,
|
||||
Minttl: 300,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1143,6 +1143,51 @@ func TestServiceServeResolvesAAndAAAARecords(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestServiceServeAnswersANYWithAllRecordTypes(t *testing.T) {
|
||||
service := NewService(ServiceOptions{
|
||||
Records: map[string]NameRecords{
|
||||
"gateway.charon.lthn": {
|
||||
A: []string{"10.10.10.10"},
|
||||
AAAA: []string{"2600:1f1c:7f0:4f01::1"},
|
||||
TXT: []string{"v=lthn1 type=gateway"},
|
||||
NS: []string{"ns.gateway.charon.lthn"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
srv, err := service.Serve("127.0.0.1", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("expected server to start: %v", err)
|
||||
}
|
||||
defer func() { _ = srv.Close() }()
|
||||
|
||||
client := dnsprotocol.Client{}
|
||||
request := new(dnsprotocol.Msg)
|
||||
request.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeANY)
|
||||
response := exchangeWithRetry(t, client, request, srv.Address())
|
||||
if response.Rcode != dnsprotocol.RcodeSuccess {
|
||||
t.Fatalf("unexpected ANY rcode: %d", response.Rcode)
|
||||
}
|
||||
|
||||
var sawA, sawAAAA, sawTXT, sawNS bool
|
||||
for _, answer := range response.Answer {
|
||||
switch rr := answer.(type) {
|
||||
case *dnsprotocol.A:
|
||||
sawA = rr.A.String() == "10.10.10.10"
|
||||
case *dnsprotocol.AAAA:
|
||||
sawAAAA = rr.AAAA.String() == "2600:1f1c:7f0:4f01::1"
|
||||
case *dnsprotocol.TXT:
|
||||
sawTXT = len(rr.Txt) == 1 && rr.Txt[0] == "v=lthn1 type=gateway"
|
||||
case *dnsprotocol.NS:
|
||||
sawNS = rr.Ns == "ns.gateway.charon.lthn."
|
||||
}
|
||||
}
|
||||
|
||||
if !sawA || !sawAAAA || !sawTXT || !sawNS {
|
||||
t.Fatalf("expected ANY answer to include A, AAAA, TXT, and NS records, got %#v", response.Answer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceServeResolvesWildcardAndPTRRecords(t *testing.T) {
|
||||
service := NewService(ServiceOptions{
|
||||
Records: map[string]NameRecords{
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue