diff --git a/serve.go b/serve.go index 34d7180..be60178 100644 --- a/serve.go +++ b/serve.go @@ -315,7 +315,7 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr if !found { goto noRecord } - for _, value := range record.A { + for _, value := range normalizeRecordValues(record.A) { parsedIP := net.ParseIP(value) if parsedIP == nil || parsedIP.To4() == nil { continue @@ -329,7 +329,7 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr if !found { goto noRecord } - for _, value := range record.AAAA { + for _, value := range normalizeRecordValues(record.AAAA) { parsedIP := net.ParseIP(value) if parsedIP == nil || parsedIP.To4() != nil { continue @@ -343,7 +343,7 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr if !found { goto noRecord } - for _, value := range record.TXT { + for _, value := range normalizeRecordValues(record.TXT) { reply.Answer = append(reply.Answer, &dnsprotocol.TXT{ Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeTXT, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL}, Txt: []string{value}, @@ -351,7 +351,7 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr } case dnsprotocol.TypeNS: if found { - for _, value := range record.NS { + for _, value := range normalizeRecordValues(record.NS) { reply.Answer = append(reply.Answer, &dnsprotocol.NS{ Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL}, Ns: normalizeName(value) + ".", @@ -413,17 +413,17 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr if !found { goto noRecord } - appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDS, record.DS) + appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDS, normalizeRecordValues(record.DS)) case dnsprotocol.TypeDNSKEY: if !found { goto noRecord } - appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDNSKEY, record.DNSKEY) + appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDNSKEY, normalizeRecordValues(record.DNSKEY)) case dnsprotocol.TypeRRSIG: if !found { goto noRecord } - appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeRRSIG, record.RRSIG) + appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeRRSIG, normalizeRecordValues(record.RRSIG)) default: reply.SetRcode(request, dnsprotocol.RcodeNotImplemented) _ = responseWriter.WriteMsg(reply) @@ -496,7 +496,7 @@ func parsePTRIP(name string) (string, bool) { } func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, lookupName string, record NameRecords, zoneApex string) { - for _, value := range record.A { + for _, value := range normalizeRecordValues(record.A) { parsedIP := net.ParseIP(value) if parsedIP == nil || parsedIP.To4() == nil { continue @@ -507,7 +507,7 @@ func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, lookupName st }) } - for _, value := range record.AAAA { + for _, value := range normalizeRecordValues(record.AAAA) { parsedIP := net.ParseIP(value) if parsedIP == nil || parsedIP.To4() != nil { continue @@ -518,7 +518,7 @@ func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, lookupName st }) } - for _, value := range record.TXT { + for _, value := range normalizeRecordValues(record.TXT) { reply.Answer = append(reply.Answer, &dnsprotocol.TXT{ Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeTXT, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL}, Txt: []string{value}, @@ -526,7 +526,7 @@ func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, lookupName st } if len(record.NS) > 0 { - for _, value := range record.NS { + for _, value := range normalizeRecordValues(record.NS) { reply.Answer = append(reply.Answer, &dnsprotocol.NS{ Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL}, Ns: normalizeName(value) + ".", @@ -534,15 +534,15 @@ func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, lookupName st } } - for _, value := range record.DNSKEY { + for _, value := range normalizeRecordValues(record.DNSKEY) { appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeDNSKEY, []string{value}) } - for _, value := range record.DS { + for _, value := range normalizeRecordValues(record.DS) { appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeDS, []string{value}) } - for _, value := range record.RRSIG { + for _, value := range normalizeRecordValues(record.RRSIG) { appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeRRSIG, []string{value}) } diff --git a/service.go b/service.go index 4f1f3b3..406ce51 100644 --- a/service.go +++ b/service.go @@ -994,13 +994,13 @@ func (service *Service) findRecordWithMatch(name string) (NameRecords, bool, boo func resolveResult(record NameRecords) ResolveAllResult { return ResolveAllResult{ - A: cloneStrings(record.A), - AAAA: cloneStrings(record.AAAA), - TXT: cloneStrings(record.TXT), - NS: cloneStrings(record.NS), - DS: cloneStrings(record.DS), - DNSKEY: cloneStrings(record.DNSKEY), - RRSIG: cloneStrings(record.RRSIG), + A: normalizeRecordValues(record.A), + AAAA: normalizeRecordValues(record.AAAA), + TXT: normalizeRecordValues(record.TXT), + NS: normalizeRecordValues(record.NS), + DS: normalizeRecordValues(record.DS), + DNSKEY: normalizeRecordValues(record.DNSKEY), + RRSIG: normalizeRecordValues(record.RRSIG), } } @@ -1136,9 +1136,30 @@ func computeZoneApex(records map[string]NameRecords) string { } func serializeRecordValues(values []string) string { - copied := append([]string(nil), values...) - slices.Sort(copied) - return strings.Join(copied, ",") + normalized := normalizeRecordValues(values) + return strings.Join(normalized, ",") +} + +func normalizeRecordValues(values []string) []string { + if len(values) == 0 { + return []string{} + } + + seen := make(map[string]struct{}, len(values)) + normalized := make([]string, 0, len(values)) + for _, value := range values { + if value == "" { + continue + } + if _, exists := seen[value]; exists { + continue + } + seen[value] = struct{}{} + normalized = append(normalized, value) + } + + slices.Sort(normalized) + return normalized } func cloneStrings(values []string) []string { @@ -1231,17 +1252,9 @@ func (service *Service) String() string { // // values := MergeRecords([]string{"10.10.10.10"}, []string{"10.0.0.1", "10.10.10.10"}) func MergeRecords(values ...[]string) []string { - unique := []string{} - seen := map[string]bool{} + merged := []string{} for _, batch := range values { - for _, value := range batch { - if seen[value] { - continue - } - seen[value] = true - unique = append(unique, value) - } + merged = append(merged, batch...) } - slices.Sort(unique) - return unique + return normalizeRecordValues(merged) }