From acd6d70ac25abe68bdec9883be94af8af8466036 Mon Sep 17 00:00:00 2001 From: Virgil Date: Fri, 3 Apr 2026 19:53:38 +0000 Subject: [PATCH] feat(dns): add DNS serve implementation Co-Authored-By: Virgil --- go.mod | 10 ++ go.sum | 12 +++ serve.go | 274 ++++++++++++++++++++++++++++++++++++++++++++++++ service_test.go | 129 ++++++++++++++++++++++- 4 files changed, 424 insertions(+), 1 deletion(-) create mode 100644 go.sum create mode 100644 serve.go diff --git a/go.mod b/go.mod index 493f497..92f2952 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,13 @@ module dappco.re/go/dns go 1.22 + +require github.com/miekg/dns v1.1.62 + +require ( + golang.org/x/mod v0.18.0 // indirect + golang.org/x/net v0.27.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/tools v0.22.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..95e8194 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= diff --git a/serve.go b/serve.go new file mode 100644 index 0000000..d95e791 --- /dev/null +++ b/serve.go @@ -0,0 +1,274 @@ +package dns + +import ( + "net" + "strconv" + "strings" + "time" + + dnsprotocol "github.com/miekg/dns" +) + +const dnsTTL = 300 + +// DNSServer handles a live UDP+TCP DNS endpoint and owns listener resources. +type DNSServer struct { + udpListener net.PacketConn + tcpListener net.Listener + udpServer *dnsprotocol.Server + tcpServer *dnsprotocol.Server +} + +func (server *DNSServer) Address() string { + if server.udpListener == nil { + return "" + } + return server.udpListener.LocalAddr().String() +} + +func (server *DNSServer) Close() error { + if server.udpListener != nil { + _ = server.udpListener.Close() + } + if server.tcpListener != nil { + _ = server.tcpListener.Close() + } + var err error + if server.udpServer != nil { + err = server.udpServer.Shutdown() + } + if server.tcpServer != nil { + err = server.tcpServer.Shutdown() + } + return err +} + +// Serve starts DNS over UDP and TCP at bind:port and returns a running server handle. +// Example: +// +// srv, err := service.Serve("127.0.0.1", 0) +func (service *Service) Serve(bind string, port int) (*DNSServer, error) { + if bind == "" { + bind = "127.0.0.1" + } + addr := net.JoinHostPort(bind, strconv.Itoa(port)) + + udpListener, err := net.ListenPacket("udp", addr) + if err != nil { + return nil, err + } + tcpListener, err := net.Listen("tcp", addr) + if err != nil { + _ = udpListener.Close() + return nil, err + } + + handler := &dnsRequestHandler{service: service} + mux := dnsprotocol.NewServeMux() + mux.HandleFunc(".", handler.ServeDNS) + + udpServer := &dnsprotocol.Server{Net: "udp", PacketConn: udpListener, Handler: mux} + tcpServer := &dnsprotocol.Server{Net: "tcp", Listener: tcpListener, Handler: mux} + + run := &DNSServer{ + udpListener: udpListener, + tcpListener: tcpListener, + udpServer: udpServer, + tcpServer: tcpServer, + } + + go func() { + _ = udpServer.ActivateAndServe() + }() + go func() { + _ = tcpServer.ActivateAndServe() + }() + + return run, nil +} + +type dnsRequestHandler struct { + service *Service +} + +func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWriter, request *dnsprotocol.Msg) { + reply := new(dnsprotocol.Msg) + reply.Compress = true + + if request == nil { + reply.SetRcode(reply, dnsprotocol.RcodeNotImplemented) + _ = responseWriter.WriteMsg(reply) + return + } + + reply.SetReply(request) + if request.Opcode != dnsprotocol.OpcodeQuery { + reply.SetRcode(request, dnsprotocol.RcodeNotImplemented) + _ = responseWriter.WriteMsg(reply) + return + } + + if len(request.Question) == 0 { + reply.SetRcode(request, dnsprotocol.RcodeFormatError) + _ = responseWriter.WriteMsg(reply) + return + } + + question := request.Question[0] + name := strings.TrimSuffix(strings.ToLower(question.Name), ".") + record, found := handler.service.findRecord(name) + + switch question.Qtype { + case dnsprotocol.TypeA: + if !found { + goto noRecord + } + for _, value := range record.A { + parsedIP := net.ParseIP(value) + if parsedIP == nil || parsedIP.To4() == nil { + continue + } + reply.Answer = append(reply.Answer, &dnsprotocol.A{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + A: parsedIP.To4(), + }) + } + case dnsprotocol.TypeAAAA: + if !found { + goto noRecord + } + for _, value := range record.AAAA { + parsedIP := net.ParseIP(value) + if parsedIP == nil || parsedIP.To4() != nil { + continue + } + reply.Answer = append(reply.Answer, &dnsprotocol.AAAA{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeAAAA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + AAAA: parsedIP.To16(), + }) + } + case dnsprotocol.TypeTXT: + if !found { + goto noRecord + } + for _, value := range record.TXT { + reply.Answer = append(reply.Answer, &dnsprotocol.TXT{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeTXT, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Txt: []string{value}, + }) + } + case dnsprotocol.TypeNS: + if !found { + goto noRecord + } + for _, value := range record.NS { + reply.Answer = append(reply.Answer, &dnsprotocol.NS{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ns: normalizeName(value) + ".", + }) + } + case dnsprotocol.TypePTR: + ip, ok := parsePTRIP(name) + if !ok { + reply.SetRcode(request, dnsprotocol.RcodeFormatError) + _ = responseWriter.WriteMsg(reply) + return + } + values, ok := handler.service.ResolveReverse(ip) + if !ok { + goto noRecord + } + for _, value := range values { + reply.Answer = append(reply.Answer, &dnsprotocol.PTR{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypePTR, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ptr: normalizeName(value) + ".", + }) + } + case dnsprotocol.TypeSOA: + if !found { + goto noRecord + } + reply.Answer = append(reply.Answer, &dnsprotocol.SOA{ + Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeSOA, Class: dnsprotocol.ClassINET, Ttl: dnsTTL}, + Ns: normalizeName("ns."+name) + ".", + Mbox: "hostmaster." + name + ".", + Serial: uint32(time.Now().Unix()), + Refresh: 86400, + Retry: 3600, + Expire: 3600, + Minttl: 300, + }) + default: + reply.SetRcode(request, dnsprotocol.RcodeNotImplemented) + _ = responseWriter.WriteMsg(reply) + return + } + + if len(reply.Answer) == 0 { + goto noRecord + } + + _ = responseWriter.WriteMsg(reply) + return + +noRecord: + reply.SetRcode(request, dnsprotocol.RcodeNameError) + _ = responseWriter.WriteMsg(reply) +} + +func parsePTRIP(name string) (string, bool) { + lookup := strings.TrimSuffix(strings.ToLower(name), ".") + if strings.HasSuffix(lookup, ".in-addr.arpa") { + base := strings.TrimSuffix(strings.TrimSuffix(lookup, ".in-addr.arpa"), ".") + labels := strings.Split(base, ".") + if len(labels) != 4 { + return "", false + } + for _, label := range labels { + if label == "" { + return "", false + } + } + reversed := make([]string, 4) + for i := 0; i < 4; i++ { + reversed[i] = labels[3-i] + } + candidate := strings.Join(reversed, ".") + parsed := net.ParseIP(candidate) + if parsed == nil || parsed.To4() == nil { + return "", false + } + return parsed.String(), true + } + if strings.HasSuffix(lookup, ".ip6.arpa") { + base := strings.TrimSuffix(strings.TrimSuffix(lookup, ".ip6.arpa"), ".") + labels := strings.Split(base, ".") + if len(labels) != 32 { + return "", false + } + for _, label := range labels { + if len(label) != 1 { + return "", false + } + } + reversed := make([]string, 32) + for i := 0; i < 32; i++ { + reversed[i] = labels[31-i] + } + blocks := make([]string, 8) + for i := 0; i < 8; i++ { + block := strings.Join(reversed[i*4:(i+1)*4], "") + if len(block) != 4 { + return "", false + } + blocks[i] = block + } + candidate := strings.Join(blocks, ":") + parsed := net.ParseIP(candidate) + if parsed == nil { + return "", false + } + return parsed.String(), true + } + return "", false +} diff --git a/service_test.go b/service_test.go index b775660..5660aaf 100644 --- a/service_test.go +++ b/service_test.go @@ -1,6 +1,30 @@ package dns -import "testing" +import ( + "strings" + "testing" + "time" + + dnsprotocol "github.com/miekg/dns" +) + +func exchangeWithRetry(t *testing.T, client dnsprotocol.Client, request *dnsprotocol.Msg, address string) *dnsprotocol.Msg { + t.Helper() + + for attempt := 0; attempt < 80; attempt++ { + response, _, err := client.Exchange(request, address) + if err == nil { + return response + } + if !strings.Contains(err.Error(), "connection refused") { + t.Fatalf("dns query failed: %v", err) + } + time.Sleep(25 * time.Millisecond) + } + + t.Fatalf("dns query failed after retrying due to startup timing") + return nil +} func TestServiceResolveUsesExactNameBeforeWildcard(t *testing.T) { service := NewService(ServiceOptions{ @@ -243,3 +267,106 @@ func TestServiceDiscoverReturnsNilWithoutDiscoverer(t *testing.T) { t.Fatalf("expected no error when discoverer is missing: %v", err) } } + +func TestServiceServeResolvesAAndAAAARecords(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + AAAA: []string{"2600:1f1c:7f0:4f01::1"}, + }, + }, + }) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + query := func(qtype uint16) *dnsprotocol.Msg { + request := new(dnsprotocol.Msg) + request.SetQuestion("gateway.charon.lthn.", qtype) + response := exchangeWithRetry(t, client, request, srv.Address()) + if response.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("unexpected rcode for qtype %d: %d", qtype, response.Rcode) + } + return response + } + + aResponse := query(dnsprotocol.TypeA) + if len(aResponse.Answer) != 1 { + t.Fatalf("expected one A answer, got %d", len(aResponse.Answer)) + } + if got, ok := aResponse.Answer[0].(*dnsprotocol.A); !ok || got.A.String() != "10.10.10.10" { + t.Fatalf("unexpected A answer: %#v", aResponse.Answer[0]) + } + + aaaaResponse := query(dnsprotocol.TypeAAAA) + if len(aaaaResponse.Answer) != 1 { + t.Fatalf("expected one AAAA answer, got %d", len(aaaaResponse.Answer)) + } + if got, ok := aaaaResponse.Answer[0].(*dnsprotocol.AAAA); !ok || got.AAAA.String() != "2600:1f1c:7f0:4f01::1" { + t.Fatalf("unexpected AAAA answer: %#v", aaaaResponse.Answer[0]) + } +} + +func TestServiceServeResolvesWildcardAndPTRRecords(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "*.charon.lthn": { + A: []string{"10.0.0.1"}, + }, + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + }, + }, + }) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + request := new(dnsprotocol.Msg) + request.SetQuestion("node1.charon.lthn.", dnsprotocol.TypeA) + response := exchangeWithRetry(t, client, request, srv.Address()) + if response.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("unexpected rcode: %d", response.Rcode) + } + if got, ok := response.Answer[0].(*dnsprotocol.A); !ok || got.A.String() != "10.0.0.1" { + t.Fatalf("unexpected wildcard A answer: %#v", response.Answer) + } + + ptrName := "10.10.10.10.in-addr.arpa." + ptrRequest := new(dnsprotocol.Msg) + ptrRequest.SetQuestion(ptrName, dnsprotocol.TypePTR) + ptrResponse := exchangeWithRetry(t, client, ptrRequest, srv.Address()) + if len(ptrResponse.Answer) == 0 { + t.Fatal("expected PTR answer") + } + if got, ok := ptrResponse.Answer[0].(*dnsprotocol.PTR); !ok || got.Ptr != "gateway.charon.lthn." { + t.Fatalf("unexpected PTR answer: %#v", ptrResponse.Answer) + } +} + +func TestServiceServeReturnsNXDOMAINWhenMissing(t *testing.T) { + service := NewService(ServiceOptions{}) + + srv, err := service.Serve("127.0.0.1", 0) + if err != nil { + t.Fatalf("expected server to start: %v", err) + } + defer func() { _ = srv.Close() }() + + client := dnsprotocol.Client{} + request := new(dnsprotocol.Msg) + request.SetQuestion("missing.charon.lthn.", dnsprotocol.TypeA) + response := exchangeWithRetry(t, client, request, srv.Address()) + if response.Rcode != dnsprotocol.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", response.Rcode) + } +} -- 2.45.3