diff --git a/hsd.go b/hsd.go index 9920271..b5d6069 100644 --- a/hsd.go +++ b/hsd.go @@ -206,19 +206,23 @@ func parseHSDNameResource(raw json.RawMessage) (NameRecords, error) { } var wrapped struct { - A []string `json:"a"` - AAAA []string `json:"aaaa"` - TXT []string `json:"txt"` - NS []string `json:"ns"` - DS []string `json:"ds"` + A []string `json:"a"` + AAAA []string `json:"aaaa"` + TXT []string `json:"txt"` + NS []string `json:"ns"` + DS []string `json:"ds"` + DNSKEY []string `json:"dnskey"` + RRSIG []string `json:"rrsig"` } if err := json.Unmarshal(raw, &wrapped); err == nil { result = NameRecords{ - A: wrapped.A, - AAAA: wrapped.AAAA, - TXT: wrapped.TXT, - NS: wrapped.NS, - DS: wrapped.DS, + A: wrapped.A, + AAAA: wrapped.AAAA, + TXT: wrapped.TXT, + NS: wrapped.NS, + DS: wrapped.DS, + DNSKEY: wrapped.DNSKEY, + RRSIG: wrapped.RRSIG, } return result, nil } diff --git a/hsd_test.go b/hsd_test.go index 26f30fa..d5323ca 100644 --- a/hsd_test.go +++ b/hsd_test.go @@ -128,6 +128,46 @@ func TestHSDClientGetNameResourceParsesDSRecords(t *testing.T) { } } +func TestHSDClientGetNameResourceParsesDNSSECRecords(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + if payload.Method != "getnameresource" { + t.Fatalf("expected method getnameresource, got %s", payload.Method) + } + + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "records": map[string]any{ + "dnskey": []string{"257 3 13 AA=="}, + "rrsig": []string{"A 8 2 3600 20260101000000 20250101000000 12345 gateway.lthn. AA=="}, + }, + }, + }) + })) + defer server.Close() + + client := NewHSDClient(HSDClientOptions{ + URL: server.URL, + }) + + record, err := client.GetNameResource(context.Background(), "gateway.lthn") + if err != nil { + t.Fatalf("unexpected getnameresource error: %v", err) + } + if len(record.DNSKEY) != 1 || record.DNSKEY[0] != "257 3 13 AA==" { + t.Fatalf("unexpected DNSKEY result: %#v", record.DNSKEY) + } + if len(record.RRSIG) != 1 || record.RRSIG[0] != "A 8 2 3600 20260101000000 20250101000000 12345 gateway.lthn. AA==" { + t.Fatalf("unexpected RRSIG result: %#v", record.RRSIG) + } +} + func TestHSDClientGetBlockchainInfo(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { var payload struct { diff --git a/serve.go b/serve.go index 82afc20..002ce63 100644 --- a/serve.go +++ b/serve.go @@ -344,6 +344,16 @@ func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWr goto noRecord } appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDS, record.DS) + case dnsprotocol.TypeDNSKEY: + if !found { + goto noRecord + } + appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDNSKEY, record.DNSKEY) + case dnsprotocol.TypeRRSIG: + if !found { + goto noRecord + } + appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeRRSIG, record.RRSIG) default: reply.SetRcode(request, dnsprotocol.RcodeNotImplemented) _ = responseWriter.WriteMsg(reply) @@ -454,6 +464,14 @@ func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, lookupName st } } + for _, value := range record.DNSKEY { + appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeDNSKEY, []string{value}) + } + + for _, value := range record.RRSIG { + appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeRRSIG, []string{value}) + } + if normalizeName(lookupName) == normalizeName(zoneApex) && normalizeName(zoneApex) != "" { if len(record.NS) == 0 { reply.Answer = append(reply.Answer, &dnsprotocol.NS{ diff --git a/service.go b/service.go index b64fac2..a0cc3c6 100644 --- a/service.go +++ b/service.go @@ -28,11 +28,13 @@ const DefaultHTTPPort = 5554 // TXT: []string{"v=lthn1 type=gateway"}, // } type NameRecords struct { - A []string `json:"a"` - AAAA []string `json:"aaaa"` - TXT []string `json:"txt"` - NS []string `json:"ns"` - DS []string `json:"ds"` + A []string `json:"a"` + AAAA []string `json:"aaaa"` + TXT []string `json:"txt"` + NS []string `json:"ns"` + DS []string `json:"ds"` + DNSKEY []string `json:"dnskey"` + RRSIG []string `json:"rrsig"` } type ResolveAllResult struct { @@ -1039,6 +1041,12 @@ func computeTreeRoot(records map[string]NameRecords) string { builder.WriteString("DS=") builder.WriteString(serializeRecordValues(record.DS)) builder.WriteByte('\n') + builder.WriteString("DNSKEY=") + builder.WriteString(serializeRecordValues(record.DNSKEY)) + builder.WriteByte('\n') + builder.WriteString("RRSIG=") + builder.WriteString(serializeRecordValues(record.RRSIG)) + builder.WriteByte('\n') } sum := sha256.Sum256([]byte(builder.String())) diff --git a/service_test.go b/service_test.go index 9212773..3a3f1c8 100644 --- a/service_test.go +++ b/service_test.go @@ -1710,6 +1710,66 @@ func TestServiceServeAnswersDSRecords(t *testing.T) { } } +func TestServiceServeAnswersDNSKEYRecords(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + DNSKEY: []string{"257 3 13 AA=="}, + }, + }, + }) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + request := new(dnsprotocol.Msg) + request.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeDNSKEY) + response := exchangeWithRetry(t, client, request, srv.Address()) + if response.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("unexpected DNSKEY rcode: %d", response.Rcode) + } + if len(response.Answer) != 1 { + t.Fatalf("expected one DNSKEY answer, got %d", len(response.Answer)) + } + if _, ok := response.Answer[0].(*dnsprotocol.DNSKEY); !ok { + t.Fatalf("expected DNSKEY answer, got %#v", response.Answer[0]) + } +} + +func TestServiceServeAnswersRRSIGRecords(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + RRSIG: []string{"A 8 2 3600 20260101000000 20250101000000 12345 gateway.charon.lthn. AA=="}, + }, + }, + }) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + request := new(dnsprotocol.Msg) + request.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeRRSIG) + response := exchangeWithRetry(t, client, request, srv.Address()) + if response.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("unexpected RRSIG rcode: %d", response.Rcode) + } + if len(response.Answer) != 1 { + t.Fatalf("expected one RRSIG answer, got %d", len(response.Answer)) + } + if _, ok := response.Answer[0].(*dnsprotocol.RRSIG); !ok { + t.Fatalf("expected RRSIG answer, got %#v", response.Answer[0]) + } +} + func TestServiceServeAnswersANYWithAllRecordTypes(t *testing.T) { service := NewService(ServiceOptions{ Records: map[string]NameRecords{