feat(dns): parse hsd record fields case-insensitively

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-04 00:09:50 +00:00
parent 08e0d201e1
commit 32543b2e12
2 changed files with 70 additions and 15 deletions

41
hsd.go
View file

@ -209,37 +209,49 @@ func parseHSDNameResource(raw json.RawMessage) (NameRecords, error) {
return NameRecords{}, errors.New("unable to parse getnameresource result")
}
func getCaseInsensitiveRecordField(fields map[string]json.RawMessage, key string) json.RawMessage {
if fields == nil {
return nil
}
for candidate, value := range fields {
if strings.EqualFold(candidate, key) {
return value
}
}
return nil
}
func parseHSDNameResourceRecords(raw json.RawMessage) (NameRecords, error) {
var fields map[string]json.RawMessage
if err := json.Unmarshal(raw, &fields); err != nil {
return NameRecords{}, err
}
a, err := parseHSDRecordValue(fields["a"])
a, err := parseHSDRecordValue(getCaseInsensitiveRecordField(fields, "a"))
if err != nil {
return NameRecords{}, err
}
aaaa, err := parseHSDRecordValue(fields["aaaa"])
aaaa, err := parseHSDRecordValue(getCaseInsensitiveRecordField(fields, "aaaa"))
if err != nil {
return NameRecords{}, err
}
txt, err := parseHSDRecordValue(fields["txt"])
txt, err := parseHSDRecordValue(getCaseInsensitiveRecordField(fields, "txt"))
if err != nil {
return NameRecords{}, err
}
ns, err := parseHSDRecordValue(fields["ns"])
ns, err := parseHSDRecordValue(getCaseInsensitiveRecordField(fields, "ns"))
if err != nil {
return NameRecords{}, err
}
ds, err := parseHSDRecordValue(fields["ds"])
ds, err := parseHSDRecordValue(getCaseInsensitiveRecordField(fields, "ds"))
if err != nil {
return NameRecords{}, err
}
dnsKey, err := parseHSDRecordValue(fields["dnskey"])
dnsKey, err := parseHSDRecordValue(getCaseInsensitiveRecordField(fields, "dnskey"))
if err != nil {
return NameRecords{}, err
}
rrSig, err := parseHSDRecordValue(fields["rrsig"])
rrSig, err := parseHSDRecordValue(getCaseInsensitiveRecordField(fields, "rrsig"))
if err != nil {
return NameRecords{}, err
}
@ -260,14 +272,13 @@ func hasHSDNameResourceField(fields map[string]json.RawMessage) bool {
return false
}
_, hasA := fields["a"]
_, hasAAAA := fields["aaaa"]
_, hasTXT := fields["txt"]
_, hasNS := fields["ns"]
_, hasDS := fields["ds"]
_, hasDNSKEY := fields["dnskey"]
_, hasRRSIG := fields["rrsig"]
return hasA || hasAAAA || hasTXT || hasNS || hasDS || hasDNSKEY || hasRRSIG
return getCaseInsensitiveRecordField(fields, "a") != nil ||
getCaseInsensitiveRecordField(fields, "aaaa") != nil ||
getCaseInsensitiveRecordField(fields, "txt") != nil ||
getCaseInsensitiveRecordField(fields, "ns") != nil ||
getCaseInsensitiveRecordField(fields, "ds") != nil ||
getCaseInsensitiveRecordField(fields, "dnskey") != nil ||
getCaseInsensitiveRecordField(fields, "rrsig") != nil
}
func parseHSDRecordValue(raw json.RawMessage) ([]string, error) {

View file

@ -92,6 +92,50 @@ func TestHSDClientGetNameResourceParsesWrappedRecords(t *testing.T) {
}
}
func TestHSDClientGetNameResourceParsesCaseInsensitiveFields(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) {
var payload struct {
Method string `json:"method"`
}
if err := json.NewDecoder(request.Body).Decode(&payload); err != nil {
t.Fatalf("unexpected request payload: %v", err)
}
if payload.Method != "getnameresource" {
t.Fatalf("expected method getnameresource, got %s", payload.Method)
}
responseWriter.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(responseWriter).Encode(map[string]any{
"result": map[string]any{
"records": map[string]any{
"A": []string{"10.10.10.10"},
"Txt": []string{"v=lthn1 type=case"},
"NS": []string{"ns.example.lthn"},
},
},
})
}))
defer server.Close()
client := NewHSDClient(HSDClientOptions{
URL: server.URL,
})
record, err := client.GetNameResource(context.Background(), "case.lthn")
if err != nil {
t.Fatalf("unexpected getnameresource error: %v", err)
}
if len(record.A) != 1 || record.A[0] != "10.10.10.10" {
t.Fatalf("unexpected wrapped A result with mixed-case key: %#v", record.A)
}
if len(record.TXT) != 1 || record.TXT[0] != "v=lthn1 type=case" {
t.Fatalf("unexpected wrapped TXT result with mixed-case key: %#v", record.TXT)
}
if len(record.NS) != 1 || record.NS[0] != "ns.example.lthn" {
t.Fatalf("unexpected wrapped NS result with mixed-case key: %#v", record.NS)
}
}
func TestHSDClientGetNameResourceParsesDSRecords(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) {
var payload struct {