diff --git a/serve.go b/serve.go index d46cb61..9d8670d 100644 --- a/serve.go +++ b/serve.go @@ -19,6 +19,45 @@ type DNSServer struct { tcpServer *dnsprotocol.Server } +// ServiceRuntime owns the DNS and HTTP listeners created by ServeAll. +type ServiceRuntime struct { + DNS *DNSServer + HTTP *HTTPServer +} + +func (runtime *ServiceRuntime) DNSAddress() string { + if runtime == nil || runtime.DNS == nil { + return "" + } + return runtime.DNS.Address() +} + +func (runtime *ServiceRuntime) HTTPAddress() string { + if runtime == nil || runtime.HTTP == nil { + return "" + } + return runtime.HTTP.Address() +} + +func (runtime *ServiceRuntime) Close() error { + if runtime == nil { + return nil + } + + var firstError error + if runtime.DNS != nil { + if err := runtime.DNS.Close(); err != nil && firstError == nil { + firstError = err + } + } + if runtime.HTTP != nil { + if err := runtime.HTTP.Close(); err != nil && firstError == nil { + firstError = err + } + } + return firstError +} + func (server *DNSServer) Address() string { if server.udpListener == nil { return "" @@ -86,6 +125,27 @@ func (service *Service) Serve(bind string, port int) (*DNSServer, error) { return run, nil } +// ServeAll starts the DNS endpoint and the HTTP health endpoint together. +// +// runtime, err := service.ServeAll("127.0.0.1", 53, 5554) +func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*ServiceRuntime, error) { + dnsServer, err := service.Serve(bind, dnsPort) + if err != nil { + return nil, err + } + + httpServer, err := service.ServeHTTPHealth(bind, httpPort) + if err != nil { + _ = dnsServer.Close() + return nil, err + } + + return &ServiceRuntime{ + DNS: dnsServer, + HTTP: httpServer, + }, nil +} + type dnsRequestHandler struct { service *Service } diff --git a/service_test.go b/service_test.go index 0ad844b..6d3d6ae 100644 --- a/service_test.go +++ b/service_test.go @@ -346,6 +346,54 @@ func TestServiceServeHTTPHealthReturnsJSON(t *testing.T) { } } +func TestServiceServeAllStartsDNSAndHTTPTogether(t *testing.T) { + service := NewService(ServiceOptions{ + Records: map[string]NameRecords{ + "gateway.charon.lthn": { + A: []string{"10.10.10.10"}, + }, + }, + }) + + runtime, err := service.ServeAll("127.0.0.1", 0, 0) + if err != nil { + t.Fatalf("expected combined runtime to start: %v", err) + } + defer func() { + _ = runtime.Close() + }() + + if runtime.DNSAddress() == "" { + t.Fatal("expected DNS address from combined runtime") + } + if runtime.HTTPAddress() == "" { + t.Fatal("expected HTTP address from combined runtime") + } + + response, err := http.Get("http://" + runtime.HTTPAddress() + "/health") + if err != nil { + t.Fatalf("expected combined HTTP health endpoint to respond: %v", err) + } + defer func() { + _ = response.Body.Close() + }() + + if response.StatusCode != http.StatusOK { + t.Fatalf("unexpected combined health status: %d", response.StatusCode) + } + + client := dnsprotocol.Client{} + request := new(dnsprotocol.Msg) + request.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeA) + dnsResponse := exchangeWithRetry(t, client, request, runtime.DNSAddress()) + if dnsResponse.Rcode != dnsprotocol.RcodeSuccess { + t.Fatalf("unexpected combined DNS rcode: %d", dnsResponse.Rcode) + } + if len(dnsResponse.Answer) != 1 { + t.Fatalf("expected one DNS answer from combined runtime, got %d", len(dnsResponse.Answer)) + } +} + func TestServiceDiscoverReplacesRecordsFromDiscoverer(t *testing.T) { records := []map[string]NameRecords{ {