diff --git a/http_server.go b/http_server.go index d9d8920..d4f17e7 100644 --- a/http_server.go +++ b/http_server.go @@ -22,13 +22,18 @@ type HealthServer struct { // HTTPServer is kept as a compatibility alias for HealthServer. type HTTPServer = HealthServer -func (server *HealthServer) Address() string { +func (server *HealthServer) HealthAddress() string { if server == nil || server.listener == nil { return "" } return server.listener.Addr().String() } +// Address is retained for compatibility with older call sites. +func (server *HealthServer) Address() string { + return server.HealthAddress() +} + func (server *HealthServer) Close() error { if server == nil { return nil @@ -54,7 +59,7 @@ func (server *HealthServer) Close() error { // // server, err := service.ServeHTTPHealth("127.0.0.1", 5554) // defer func() { _ = server.Close() }() -// resp, _ := http.Get("http://" + server.Address() + "/health") +// resp, _ := http.Get("http://" + server.HealthAddress() + "/health") func (service *Service) ServeHTTPHealth(bind string, port int) (*HealthServer, error) { if bind == "" { bind = "127.0.0.1" diff --git a/serve.go b/serve.go index c21d294..a85022e 100644 --- a/serve.go +++ b/serve.go @@ -15,7 +15,7 @@ const defaultDNSTTL = 300 // // srv, err := service.Serve("127.0.0.1", 53) // defer func() { _ = srv.Close() }() -// fmt.Println("dns at", srv.Address()) +// fmt.Println("dns at", srv.DNSAddress()) type DNSServer struct { udpListener net.PacketConn tcpListener net.Listener @@ -29,7 +29,9 @@ type DNSServer struct { // defer func() { _ = runtime.Close() }() // fmt.Println(runtime.DNSAddress(), runtime.HealthAddress()) type ServiceRuntime struct { - DNS *DNSServer + DNS *DNSServer + Health *HealthServer + // HTTP is retained for compatibility with older call sites. HTTP *HealthServer } @@ -37,14 +39,20 @@ func (runtime *ServiceRuntime) DNSAddress() string { if runtime == nil || runtime.DNS == nil { return "" } - return runtime.DNS.Address() + return runtime.DNS.DNSAddress() } func (runtime *ServiceRuntime) HealthAddress() string { - if runtime == nil || runtime.HTTP == nil { + if runtime == nil { return "" } - return runtime.HTTP.Address() + if runtime.Health != nil { + return runtime.Health.HealthAddress() + } + if runtime.HTTP != nil { + return runtime.HTTP.HealthAddress() + } + return "" } // HTTPAddress is retained for compatibility with older call sites. @@ -63,7 +71,12 @@ func (runtime *ServiceRuntime) Close() error { firstError = err } } - if runtime.HTTP != nil { + if runtime.Health != nil { + if err := runtime.Health.Close(); err != nil && firstError == nil { + firstError = err + } + } + if runtime.HTTP != nil && runtime.HTTP != runtime.Health { if err := runtime.HTTP.Close(); err != nil && firstError == nil { firstError = err } @@ -71,13 +84,18 @@ func (runtime *ServiceRuntime) Close() error { return firstError } -func (server *DNSServer) Address() string { +func (server *DNSServer) DNSAddress() string { if server.udpListener == nil { return "" } return server.udpListener.LocalAddr().String() } +// Address is retained for compatibility with older call sites. +func (server *DNSServer) Address() string { + return server.DNSAddress() +} + func (server *DNSServer) Close() error { if server.udpListener != nil { _ = server.udpListener.Close() @@ -166,8 +184,9 @@ func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*Servi } return &ServiceRuntime{ - DNS: dnsServer, - HTTP: httpServer, + DNS: dnsServer, + Health: httpServer, + HTTP: httpServer, }, nil } diff --git a/service_test.go b/service_test.go index ba55488..bbf0c8f 100644 --- a/service_test.go +++ b/service_test.go @@ -413,7 +413,14 @@ func TestServiceServeHTTPHealthReturnsJSON(t *testing.T) { _ = httpServer.Close() }() - response, err := http.Get("http://" + httpServer.Address() + "/health") + if httpServer.HealthAddress() == "" { + t.Fatal("expected health address from health server") + } + if httpServer.Address() != httpServer.HealthAddress() { + t.Fatalf("expected Address and HealthAddress to match, got %q and %q", httpServer.Address(), httpServer.HealthAddress()) + } + + response, err := http.Get("http://" + httpServer.HealthAddress() + "/health") if err != nil { t.Fatalf("expected health endpoint to respond: %v", err) } @@ -464,6 +471,9 @@ func TestServiceServeAllStartsDNSAndHTTPTogether(t *testing.T) { if runtime.HealthAddress() == "" { t.Fatal("expected health address from combined runtime") } + if runtime.DNS.Address() != runtime.DNSAddress() { + t.Fatalf("expected DNSAddress and Address to match, got %q and %q", runtime.DNS.DNSAddress(), runtime.DNS.Address()) + } if runtime.HTTPAddress() != runtime.HealthAddress() { t.Fatalf("expected HTTPAddress and HealthAddress to match, got %q and %q", runtime.HTTPAddress(), runtime.HealthAddress()) } @@ -2235,9 +2245,12 @@ func TestServiceHandleActionReverseHealthServeAndDiscover(t *testing.T) { if !ok { t.Fatalf("expected DNSServer payload, got %T", srvPayload) } - if dnsServer.Address() == "" { + if dnsServer.DNSAddress() == "" { t.Fatal("expected server address from serve action") } + if dnsServer.Address() != dnsServer.DNSAddress() { + t.Fatalf("expected Address and DNSAddress to match, got %q and %q", dnsServer.Address(), dnsServer.DNSAddress()) + } _ = dnsServer.Close() discoverPayload, ok, err := service.HandleAction(ActionDiscover, nil)