diff --git a/serve.go b/serve.go index eb73bdf..7afa025 100644 --- a/serve.go +++ b/serve.go @@ -142,6 +142,13 @@ func (service *Service) Serve(bind string, port int) (*DNSServer, error) { // defer func() { _ = runtime.Close() }() // fmt.Println("dns:", runtime.DNSAddress(), "health:", runtime.HTTPAddress()) func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*ServiceRuntime, error) { + if dnsPort == 0 { + dnsPort = service.dnsPort + } + if httpPort == 0 { + httpPort = service.httpPort + } + dnsServer, err := service.Serve(bind, dnsPort) if err != nil { return nil, err @@ -159,6 +166,17 @@ func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*Servi }, nil } +// ServeConfigured starts DNS and health using the ports stored on the service. +// +// service := dns.NewService(dns.ServiceOptions{ +// DNSPort: 1053, +// HTTPPort: 5554, +// }) +// runtime, err := service.ServeConfigured("127.0.0.1") +func (service *Service) ServeConfigured(bind string) (*ServiceRuntime, error) { + return service.ServeAll(bind, service.dnsPort, service.httpPort) +} + type dnsRequestHandler struct { service *Service } diff --git a/service.go b/service.go index 3067125..4175549 100644 --- a/service.go +++ b/service.go @@ -69,6 +69,8 @@ type Service struct { reverseIndex *cache.Cache treeRoot string zoneApex string + dnsPort int + httpPort int lastAliasFingerprint string hsdClient *HSDClient mainchainAliasClient *MainchainAliasClient @@ -96,6 +98,8 @@ type ServiceOptions struct { Records map[string]NameRecords RecordDiscoverer func() (map[string]NameRecords, error) FallbackRecordDiscoverer func() (map[string]NameRecords, error) + DNSPort int + HTTPPort int HSDURL string HSDUsername string HSDPassword string @@ -162,6 +166,8 @@ func NewService(options ServiceOptions) *Service { reverseIndex: buildReverseIndexCache(cached), treeRoot: treeRoot, zoneApex: computeZoneApex(cached), + dnsPort: options.DNSPort, + httpPort: options.HTTPPort, hsdClient: hsdClient, mainchainAliasClient: mainchainClient, chainAliasActionCaller: options.ChainAliasActionCaller, diff --git a/service_test.go b/service_test.go index 3ddc847..9d2851f 100644 --- a/service_test.go +++ b/service_test.go @@ -5,8 +5,10 @@ import ( "encoding/json" "errors" "io" + "net" "net/http" "net/http/httptest" + "strconv" "strings" "sync/atomic" "testing" @@ -33,6 +35,23 @@ func exchangeWithRetry(t *testing.T, client dnsprotocol.Client, request *dnsprot return nil } +func pickFreeTCPPort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("expected free TCP port: %v", err) + } + defer func() { _ = listener.Close() }() + + tcpAddress, ok := listener.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("expected TCP listener address, got %T", listener.Addr()) + } + + return tcpAddress.Port +} + func TestServiceResolveUsesExactNameBeforeWildcard(t *testing.T) { service := NewService(ServiceOptions{ Records: map[string]NameRecords{ @@ -441,6 +460,45 @@ func TestServiceServeAllStartsDNSAndHTTPTogether(t *testing.T) { } } +func TestServiceServeConfiguredUsesPortsFromServiceOptions(t *testing.T) { + dnsPort := pickFreeTCPPort(t) + httpPort := pickFreeTCPPort(t) + + service := NewService(ServiceOptions{ + DNSPort: dnsPort, + HTTPPort: httpPort, + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + }, + }, + }) + + runtime, err := service.ServeConfigured("127.0.0.1") + if err != nil { + t.Fatalf("expected configured runtime to start: %v", err) + } + defer func() { + _ = runtime.Close() + }() + + _, dnsActualPort, err := net.SplitHostPort(runtime.DNSAddress()) + if err != nil { + t.Fatalf("expected DNS address to parse: %v", err) + } + if dnsActualPort != strconv.Itoa(dnsPort) { + t.Fatalf("expected configured DNS port %d, got %s", dnsPort, dnsActualPort) + } + + _, httpActualPort, err := net.SplitHostPort(runtime.HTTPAddress()) + if err != nil { + t.Fatalf("expected HTTP address to parse: %v", err) + } + if httpActualPort != strconv.Itoa(httpPort) { + t.Fatalf("expected configured HTTP port %d, got %s", httpPort, httpActualPort) + } +} + func TestServiceDiscoverReplacesRecordsFromDiscoverer(t *testing.T) { records := []map[string]NameRecords{ {