580 lines
16 KiB
Go
580 lines
16 KiB
Go
package dns
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
dnsprotocol "github.com/miekg/dns"
|
|
)
|
|
|
|
const defaultDNSTTL = 300
|
|
|
|
// DNSServer owns the UDP+TCP DNS listeners returned by Serve.
|
|
//
|
|
// srv, err := service.Serve("127.0.0.1", 53)
|
|
// defer func() { _ = srv.Close() }()
|
|
// fmt.Println("dns at", srv.DNSAddress())
|
|
type DNSServer struct {
|
|
udpListener net.PacketConn
|
|
tcpListener net.Listener
|
|
udpServer *dnsprotocol.Server
|
|
tcpServer *dnsprotocol.Server
|
|
}
|
|
|
|
// ServiceRuntime wraps the DNS and health listeners returned by ServeAll.
|
|
//
|
|
// runtime, err := service.ServeAll("127.0.0.1", 53, 5554)
|
|
// defer func() { _ = runtime.Close() }()
|
|
// fmt.Println(runtime.DNSAddress(), runtime.HealthAddress())
|
|
type ServiceRuntime struct {
|
|
DNS *DNSServer
|
|
Health *HealthServer
|
|
// HTTP is retained for compatibility with older call sites.
|
|
HTTP *HealthServer
|
|
}
|
|
|
|
func (runtime *ServiceRuntime) DNSAddress() string {
|
|
if runtime == nil || runtime.DNS == nil {
|
|
return ""
|
|
}
|
|
return runtime.DNS.DNSAddress()
|
|
}
|
|
|
|
func (runtime *ServiceRuntime) HealthAddress() string {
|
|
if runtime == nil {
|
|
return ""
|
|
}
|
|
if runtime.Health != nil {
|
|
return runtime.Health.HealthAddress()
|
|
}
|
|
if runtime.HTTP != nil {
|
|
return runtime.HTTP.HealthAddress()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// HTTPAddress is retained for compatibility with older call sites.
|
|
func (runtime *ServiceRuntime) HTTPAddress() string {
|
|
return runtime.HealthAddress()
|
|
}
|
|
|
|
func (runtime *ServiceRuntime) Close() error {
|
|
if runtime == nil {
|
|
return nil
|
|
}
|
|
|
|
var firstError error
|
|
if runtime.DNS != nil {
|
|
if err := runtime.DNS.Close(); err != nil && firstError == nil {
|
|
firstError = err
|
|
}
|
|
}
|
|
if runtime.Health != nil {
|
|
if err := runtime.Health.Close(); err != nil && firstError == nil {
|
|
firstError = err
|
|
}
|
|
}
|
|
if runtime.HTTP != nil && runtime.HTTP != runtime.Health {
|
|
if err := runtime.HTTP.Close(); err != nil && firstError == nil {
|
|
firstError = err
|
|
}
|
|
}
|
|
return firstError
|
|
}
|
|
|
|
func (server *DNSServer) DNSAddress() string {
|
|
if server.udpListener == nil {
|
|
return ""
|
|
}
|
|
return server.udpListener.LocalAddr().String()
|
|
}
|
|
|
|
// Address is retained for compatibility with older call sites.
|
|
func (server *DNSServer) Address() string {
|
|
return server.DNSAddress()
|
|
}
|
|
|
|
func (server *DNSServer) Close() error {
|
|
if server.udpListener != nil {
|
|
_ = server.udpListener.Close()
|
|
}
|
|
if server.tcpListener != nil {
|
|
_ = server.tcpListener.Close()
|
|
}
|
|
var err error
|
|
if server.udpServer != nil {
|
|
err = server.udpServer.Shutdown()
|
|
}
|
|
if server.tcpServer != nil {
|
|
err = server.tcpServer.Shutdown()
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ResolveDNSPort returns the DNS port used for `dns.serve` and `Serve`.
|
|
//
|
|
// port := service.ResolveDNSPort()
|
|
// server, err := service.Serve("127.0.0.1", port)
|
|
func (service *Service) ResolveDNSPort() int {
|
|
if service == nil || service.dnsPort <= 0 {
|
|
return DefaultDNSPort
|
|
}
|
|
return service.dnsPort
|
|
}
|
|
|
|
// DNSPort is an explicit alias for ResolveDNSPort.
|
|
//
|
|
// port := service.DNSPort()
|
|
// server, err := service.Serve("127.0.0.1", port)
|
|
func (service *Service) DNSPort() int {
|
|
return service.ResolveDNSPort()
|
|
}
|
|
|
|
// resolveServePort keeps internal callers aligned with existing behavior.
|
|
func (service *Service) resolveServePort() int {
|
|
return service.ResolveDNSPort()
|
|
}
|
|
|
|
// ResolveHTTPPort returns the HTTP health port used by `ServeHTTPHealth`.
|
|
//
|
|
// port := service.ResolveHTTPPort()
|
|
// healthServer, err := service.ServeHTTPHealth("127.0.0.1", port)
|
|
func (service *Service) ResolveHTTPPort() int {
|
|
if service == nil || service.httpPort <= 0 {
|
|
return DefaultHTTPPort
|
|
}
|
|
return service.httpPort
|
|
}
|
|
|
|
// HTTPPort is an explicit alias for ResolveHTTPPort.
|
|
//
|
|
// port := service.HTTPPort()
|
|
// healthServer, err := service.ServeHTTPHealth("127.0.0.1", port)
|
|
func (service *Service) HTTPPort() int {
|
|
return service.ResolveHTTPPort()
|
|
}
|
|
|
|
func (service *Service) resolveHTTPPort() int {
|
|
return service.ResolveHTTPPort()
|
|
}
|
|
|
|
// Serve starts DNS over UDP and TCP.
|
|
//
|
|
// srv, err := service.Serve("0.0.0.0", 53)
|
|
// defer func() { _ = srv.Close() }()
|
|
// lookup := new(dnsprotocol.Msg)
|
|
// lookup.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeA)
|
|
func (service *Service) Serve(bind string, port int) (*DNSServer, error) {
|
|
if bind == "" {
|
|
bind = "127.0.0.1"
|
|
}
|
|
addr := net.JoinHostPort(bind, strconv.Itoa(port))
|
|
|
|
udpListener, err := net.ListenPacket("udp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tcpListener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
_ = udpListener.Close()
|
|
return nil, err
|
|
}
|
|
|
|
requestHandler := &dnsRequestHandler{service: service}
|
|
serveMux := dnsprotocol.NewServeMux()
|
|
serveMux.HandleFunc(".", requestHandler.ServeDNS)
|
|
|
|
udpServer := &dnsprotocol.Server{Net: "udp", PacketConn: udpListener, Handler: serveMux}
|
|
tcpServer := &dnsprotocol.Server{Net: "tcp", Listener: tcpListener, Handler: serveMux}
|
|
|
|
dnsServer := &DNSServer{
|
|
udpListener: udpListener,
|
|
tcpListener: tcpListener,
|
|
udpServer: udpServer,
|
|
tcpServer: tcpServer,
|
|
}
|
|
|
|
go func() {
|
|
_ = udpServer.ActivateAndServe()
|
|
}()
|
|
go func() {
|
|
_ = tcpServer.ActivateAndServe()
|
|
}()
|
|
|
|
return dnsServer, nil
|
|
}
|
|
|
|
// ServeAll starts DNS and health together.
|
|
//
|
|
// runtime, err := service.ServeAll("127.0.0.1", 53, 5554)
|
|
// defer func() { _ = runtime.Close() }()
|
|
// fmt.Println("dns:", runtime.DNSAddress(), "health:", runtime.HealthAddress())
|
|
func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*ServiceRuntime, error) {
|
|
if dnsPort == 0 {
|
|
dnsPort = service.dnsPort
|
|
}
|
|
if httpPort <= 0 {
|
|
httpPort = service.resolveHTTPPort()
|
|
}
|
|
|
|
dnsServer, err := service.Serve(bind, dnsPort)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpServer, err := service.ServeHTTPHealth(bind, httpPort)
|
|
if err != nil {
|
|
_ = dnsServer.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &ServiceRuntime{
|
|
DNS: dnsServer,
|
|
Health: httpServer,
|
|
HTTP: httpServer,
|
|
}, nil
|
|
}
|
|
|
|
// ServeConfigured starts DNS and health using the ports stored on the service.
|
|
//
|
|
// service := dns.NewService(dns.ServiceOptions{
|
|
// DNSPort: 1053,
|
|
// HTTPPort: 5554,
|
|
// })
|
|
// runtime, err := service.ServeConfigured("127.0.0.1")
|
|
func (service *Service) ServeConfigured(bind string) (*ServiceRuntime, error) {
|
|
return service.ServeAll(bind, service.dnsPort, service.httpPort)
|
|
}
|
|
|
|
type dnsRequestHandler struct {
|
|
service *Service
|
|
}
|
|
|
|
func (handler *dnsRequestHandler) ServeDNS(responseWriter dnsprotocol.ResponseWriter, request *dnsprotocol.Msg) {
|
|
reply := new(dnsprotocol.Msg)
|
|
reply.Compress = true
|
|
|
|
if request == nil {
|
|
reply.SetRcode(reply, dnsprotocol.RcodeNotImplemented)
|
|
_ = responseWriter.WriteMsg(reply)
|
|
return
|
|
}
|
|
|
|
reply.SetReply(request)
|
|
if request.Opcode != dnsprotocol.OpcodeQuery {
|
|
reply.SetRcode(request, dnsprotocol.RcodeNotImplemented)
|
|
_ = responseWriter.WriteMsg(reply)
|
|
return
|
|
}
|
|
|
|
if len(request.Question) == 0 {
|
|
reply.SetRcode(request, dnsprotocol.RcodeFormatError)
|
|
_ = responseWriter.WriteMsg(reply)
|
|
return
|
|
}
|
|
|
|
question := request.Question[0]
|
|
lookupName := strings.TrimSuffix(strings.ToLower(question.Name), ".")
|
|
record, found := handler.service.findRecord(lookupName)
|
|
|
|
switch question.Qtype {
|
|
case dnsprotocol.TypeA:
|
|
if !found {
|
|
goto noRecord
|
|
}
|
|
for _, value := range record.A {
|
|
parsedIP := net.ParseIP(value)
|
|
if parsedIP == nil || parsedIP.To4() == nil {
|
|
continue
|
|
}
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.A{
|
|
Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeA, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
A: parsedIP.To4(),
|
|
})
|
|
}
|
|
case dnsprotocol.TypeAAAA:
|
|
if !found {
|
|
goto noRecord
|
|
}
|
|
for _, value := range record.AAAA {
|
|
parsedIP := net.ParseIP(value)
|
|
if parsedIP == nil || parsedIP.To4() != nil {
|
|
continue
|
|
}
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.AAAA{
|
|
Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeAAAA, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
AAAA: parsedIP.To16(),
|
|
})
|
|
}
|
|
case dnsprotocol.TypeTXT:
|
|
if !found {
|
|
goto noRecord
|
|
}
|
|
for _, value := range record.TXT {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.TXT{
|
|
Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeTXT, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Txt: []string{value},
|
|
})
|
|
}
|
|
case dnsprotocol.TypeNS:
|
|
if found {
|
|
for _, value := range record.NS {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.NS{
|
|
Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Ns: normalizeName(value) + ".",
|
|
})
|
|
}
|
|
}
|
|
if len(reply.Answer) == 0 && normalizeName(lookupName) == handler.service.ZoneApex() && handler.service.ZoneApex() != "" {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.NS{
|
|
Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Ns: normalizeName("ns."+lookupName) + ".",
|
|
})
|
|
}
|
|
if len(reply.Answer) == 0 && !found {
|
|
goto noRecord
|
|
}
|
|
case dnsprotocol.TypePTR:
|
|
ip, ok := parsePTRIP(lookupName)
|
|
if !ok {
|
|
reply.SetRcode(request, dnsprotocol.RcodeFormatError)
|
|
_ = responseWriter.WriteMsg(reply)
|
|
return
|
|
}
|
|
values, ok := handler.service.ResolveReverse(ip)
|
|
if !ok {
|
|
goto noRecord
|
|
}
|
|
for _, value := range values {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.PTR{
|
|
Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypePTR, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Ptr: normalizeName(value) + ".",
|
|
})
|
|
}
|
|
case dnsprotocol.TypeSOA:
|
|
if normalizeName(lookupName) == handler.service.ZoneApex() && handler.service.ZoneApex() != "" {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.SOA{
|
|
Hdr: dnsprotocol.RR_Header{Name: question.Name, Rrtype: dnsprotocol.TypeSOA, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Ns: normalizeName("ns."+lookupName) + ".",
|
|
Mbox: "hostmaster." + lookupName + ".",
|
|
Serial: uint32(time.Now().Unix()),
|
|
Refresh: 86400,
|
|
Retry: 3600,
|
|
Expire: 3600,
|
|
Minttl: 300,
|
|
})
|
|
break
|
|
}
|
|
if !found {
|
|
goto noRecord
|
|
}
|
|
case dnsprotocol.TypeANY:
|
|
if found {
|
|
appendAnyAnswers(reply, question.Name, lookupName, record, handler.service.ZoneApex())
|
|
} else if normalizeName(lookupName) == handler.service.ZoneApex() && handler.service.ZoneApex() != "" {
|
|
appendAnyAnswers(reply, question.Name, lookupName, NameRecords{}, handler.service.ZoneApex())
|
|
} else {
|
|
goto noRecord
|
|
}
|
|
case dnsprotocol.TypeDS:
|
|
if !found {
|
|
goto noRecord
|
|
}
|
|
appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDS, record.DS)
|
|
case dnsprotocol.TypeDNSKEY:
|
|
if !found {
|
|
goto noRecord
|
|
}
|
|
appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeDNSKEY, record.DNSKEY)
|
|
case dnsprotocol.TypeRRSIG:
|
|
if !found {
|
|
goto noRecord
|
|
}
|
|
appendDNSSECResourceRecords(reply, question.Name, dnsprotocol.TypeRRSIG, record.RRSIG)
|
|
default:
|
|
reply.SetRcode(request, dnsprotocol.RcodeNotImplemented)
|
|
_ = responseWriter.WriteMsg(reply)
|
|
return
|
|
}
|
|
|
|
_ = responseWriter.WriteMsg(reply)
|
|
return
|
|
|
|
noRecord:
|
|
reply.SetRcode(request, dnsprotocol.RcodeNameError)
|
|
_ = responseWriter.WriteMsg(reply)
|
|
}
|
|
|
|
func parsePTRIP(name string) (string, bool) {
|
|
lookup := strings.TrimSuffix(strings.ToLower(name), ".")
|
|
if strings.HasSuffix(lookup, ".in-addr.arpa") {
|
|
base := strings.TrimSuffix(strings.TrimSuffix(lookup, ".in-addr.arpa"), ".")
|
|
labels := strings.Split(base, ".")
|
|
if len(labels) != 4 {
|
|
return "", false
|
|
}
|
|
for _, label := range labels {
|
|
if label == "" {
|
|
return "", false
|
|
}
|
|
}
|
|
reversed := make([]string, 4)
|
|
for i := 0; i < 4; i++ {
|
|
reversed[i] = labels[3-i]
|
|
}
|
|
candidate := strings.Join(reversed, ".")
|
|
parsed := net.ParseIP(candidate)
|
|
if parsed == nil || parsed.To4() == nil {
|
|
return "", false
|
|
}
|
|
return parsed.String(), true
|
|
}
|
|
if strings.HasSuffix(lookup, ".ip6.arpa") {
|
|
base := strings.TrimSuffix(strings.TrimSuffix(lookup, ".ip6.arpa"), ".")
|
|
labels := strings.Split(base, ".")
|
|
if len(labels) != 32 {
|
|
return "", false
|
|
}
|
|
for _, label := range labels {
|
|
if len(label) != 1 {
|
|
return "", false
|
|
}
|
|
}
|
|
reversed := make([]string, 32)
|
|
for i := 0; i < 32; i++ {
|
|
reversed[i] = labels[31-i]
|
|
}
|
|
blocks := make([]string, 8)
|
|
for i := 0; i < 8; i++ {
|
|
block := strings.Join(reversed[i*4:(i+1)*4], "")
|
|
if len(block) != 4 {
|
|
return "", false
|
|
}
|
|
blocks[i] = block
|
|
}
|
|
candidate := strings.Join(blocks, ":")
|
|
parsed := net.ParseIP(candidate)
|
|
if parsed == nil {
|
|
return "", false
|
|
}
|
|
return parsed.String(), true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func appendAnyAnswers(reply *dnsprotocol.Msg, questionName string, lookupName string, record NameRecords, zoneApex string) {
|
|
for _, value := range record.A {
|
|
parsedIP := net.ParseIP(value)
|
|
if parsedIP == nil || parsedIP.To4() == nil {
|
|
continue
|
|
}
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.A{
|
|
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeA, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
A: parsedIP.To4(),
|
|
})
|
|
}
|
|
|
|
for _, value := range record.AAAA {
|
|
parsedIP := net.ParseIP(value)
|
|
if parsedIP == nil || parsedIP.To4() != nil {
|
|
continue
|
|
}
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.AAAA{
|
|
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeAAAA, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
AAAA: parsedIP.To16(),
|
|
})
|
|
}
|
|
|
|
for _, value := range record.TXT {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.TXT{
|
|
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeTXT, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Txt: []string{value},
|
|
})
|
|
}
|
|
|
|
if len(record.NS) > 0 {
|
|
for _, value := range record.NS {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.NS{
|
|
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Ns: normalizeName(value) + ".",
|
|
})
|
|
}
|
|
}
|
|
|
|
for _, value := range record.DNSKEY {
|
|
appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeDNSKEY, []string{value})
|
|
}
|
|
|
|
for _, value := range record.DS {
|
|
appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeDS, []string{value})
|
|
}
|
|
|
|
for _, value := range record.RRSIG {
|
|
appendDNSSECResourceRecords(reply, questionName, dnsprotocol.TypeRRSIG, []string{value})
|
|
}
|
|
|
|
if normalizeName(lookupName) == normalizeName(zoneApex) && normalizeName(zoneApex) != "" {
|
|
if len(record.NS) == 0 {
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.NS{
|
|
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeNS, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Ns: normalizeName("ns."+lookupName) + ".",
|
|
})
|
|
}
|
|
reply.Answer = append(reply.Answer, &dnsprotocol.SOA{
|
|
Hdr: dnsprotocol.RR_Header{Name: questionName, Rrtype: dnsprotocol.TypeSOA, Class: dnsprotocol.ClassINET, Ttl: defaultDNSTTL},
|
|
Ns: normalizeName("ns."+lookupName) + ".",
|
|
Mbox: "hostmaster." + lookupName + ".",
|
|
Serial: uint32(time.Now().Unix()),
|
|
Refresh: 86400,
|
|
Retry: 3600,
|
|
Expire: 3600,
|
|
Minttl: 300,
|
|
})
|
|
}
|
|
}
|
|
|
|
func appendDNSSECResourceRecords(reply *dnsprotocol.Msg, questionName string, recordType uint16, values []string) {
|
|
for _, value := range values {
|
|
rr, err := parseDNSSECResourceRecord(questionName, recordType, value)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
reply.Answer = append(reply.Answer, rr)
|
|
}
|
|
}
|
|
|
|
func parseDNSSECResourceRecord(questionName string, recordType uint16, raw string) (dnsprotocol.RR, error) {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
return nil, fmt.Errorf("empty dnssec resource value")
|
|
}
|
|
|
|
fallback := fmt.Sprintf("%s %d IN %s %s", questionName, defaultDNSTTL, dnsprotocol.TypeToString[recordType], trimmed)
|
|
rr, err := dnsprotocol.NewRR(fallback)
|
|
if err == nil {
|
|
header := rr.Header()
|
|
if header.Rrtype != recordType {
|
|
return nil, fmt.Errorf("dnsprotocol record type mismatch")
|
|
}
|
|
header.Name = questionName
|
|
header.Class = dnsprotocol.ClassINET
|
|
header.Ttl = defaultDNSTTL
|
|
return rr, nil
|
|
}
|
|
|
|
if rr, err := dnsprotocol.NewRR(trimmed); err == nil {
|
|
header := rr.Header()
|
|
if header.Rrtype != recordType {
|
|
return nil, fmt.Errorf("dnsprotocol record type mismatch")
|
|
}
|
|
header.Name = questionName
|
|
header.Class = dnsprotocol.ClassINET
|
|
header.Ttl = defaultDNSTTL
|
|
return rr, nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|