From 5fd82dd34229125943b07ea7fd57ac7ff9f4f354 Mon Sep 17 00:00:00 2001 From: Virgil Date: Sat, 4 Apr 2026 00:04:47 +0000 Subject: [PATCH] feat(dns): add nil-safe service method guards Co-Authored-By: Virgil --- serve.go | 43 +++++++++++++++++++++++++++++++++++++++++-- service.go | 35 +++++++++++++++++++++++++++++++++++ service_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 2 deletions(-) diff --git a/serve.go b/serve.go index be60178..235be1f 100644 --- a/serve.go +++ b/serve.go @@ -119,7 +119,10 @@ func (server *DNSServer) Close() error { // port := service.ResolveDNSPort() // server, err := service.Serve("127.0.0.1", port) func (service *Service) ResolveDNSPort() int { - if service == nil || service.dnsPort <= 0 { + if service == nil { + return DefaultDNSPort + } + if service.dnsPort <= 0 { return DefaultDNSPort } return service.dnsPort @@ -130,6 +133,9 @@ func (service *Service) ResolveDNSPort() int { // port := service.DNSListenPort() // server, err := service.Serve("127.0.0.1", port) func (service *Service) DNSListenPort() int { + if service == nil { + return DefaultDNSPort + } return service.ResolveDNSPort() } @@ -138,16 +144,25 @@ func (service *Service) DNSListenPort() int { // port := service.DNSPort() // server, err := service.Serve("127.0.0.1", port) func (service *Service) DNSPort() int { + if service == nil { + return DefaultDNSPort + } return service.ResolveDNSPort() } // resolveDNSListenPort keeps internal callers aligned with explicit naming. func (service *Service) resolveDNSListenPort() int { + if service == nil { + return DefaultDNSPort + } return service.DNSListenPort() } // resolveServePort is a legacy compatibility helper. func (service *Service) resolveServePort() int { + if service == nil { + return DefaultDNSPort + } return service.ResolveDNSPort() } @@ -156,7 +171,10 @@ func (service *Service) resolveServePort() int { // port := service.ResolveHTTPPort() // healthServer, err := service.ServeHTTPHealth("127.0.0.1", port) func (service *Service) ResolveHTTPPort() int { - if service == nil || service.httpPort <= 0 { + if service == nil { + return DefaultHTTPPort + } + if service.httpPort <= 0 { return DefaultHTTPPort } return service.httpPort @@ -167,6 +185,9 @@ func (service *Service) ResolveHTTPPort() int { // port := service.HTTPListenPort() // server, err := service.ServeHTTPHealth("127.0.0.1", port) func (service *Service) HTTPListenPort() int { + if service == nil { + return DefaultHTTPPort + } return service.ResolveHTTPPort() } @@ -175,16 +196,25 @@ func (service *Service) HTTPListenPort() int { // port := service.HTTPPort() // healthServer, err := service.ServeHTTPHealth("127.0.0.1", port) func (service *Service) HTTPPort() int { + if service == nil { + return DefaultHTTPPort + } return service.ResolveHTTPPort() } // resolveHTTPListenPort keeps internal callers aligned with explicit naming. func (service *Service) resolveHTTPListenPort() int { + if service == nil { + return DefaultHTTPPort + } return service.HTTPListenPort() } // resolveHTTPPort is a legacy compatibility helper. func (service *Service) resolveHTTPPort() int { + if service == nil { + return DefaultHTTPPort + } return service.ResolveHTTPPort() } @@ -195,6 +225,9 @@ func (service *Service) resolveHTTPPort() int { // lookup := new(dnsprotocol.Msg) // lookup.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeA) func (service *Service) Serve(bind string, port int) (*DNSServer, error) { + if service == nil { + return nil, fmt.Errorf("service is required") + } if bind == "" { bind = "127.0.0.1" } @@ -243,6 +276,9 @@ func (service *Service) Serve(bind string, port int) (*DNSServer, error) { // defer func() { _ = runtime.Close() }() // fmt.Println("dns:", runtime.DNSAddress(), "health:", runtime.HealthAddress()) func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*ServiceRuntime, error) { + if service == nil { + return nil, fmt.Errorf("service is required") + } if dnsPort <= 0 { dnsPort = service.resolveDNSListenPort() } @@ -276,6 +312,9 @@ func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*Servi // }) // runtime, err := service.ServeConfigured("127.0.0.1") func (service *Service) ServeConfigured(bind string) (*ServiceRuntime, error) { + if service == nil { + return nil, fmt.Errorf("service is required") + } return service.ServeAll(bind, service.dnsPort, service.httpPort) } diff --git a/service.go b/service.go index bc159fb..7443be7 100644 --- a/service.go +++ b/service.go @@ -783,6 +783,9 @@ func (service *Service) RemoveRecord(name string) { // // result, ok := service.Resolve("gateway.charon.lthn") func (service *Service) Resolve(name string) (ResolveAllResult, bool) { + if service == nil { + return ResolveAllResult{}, false + } record, ok := service.findRecord(name) if !ok { return ResolveAllResult{}, false @@ -794,6 +797,9 @@ func (service *Service) Resolve(name string) (ResolveAllResult, bool) { // // result, ok, usedWildcard := service.ResolveWithWildcardMatch("node.charon.lthn") func (service *Service) ResolveWithWildcardMatch(name string) (ResolveAllResult, bool, bool) { + if service == nil { + return ResolveAllResult{}, false, false + } record, ok, usedWildcard := service.findRecordWithMatch(name) if !ok { return ResolveAllResult{}, false, false @@ -805,6 +811,9 @@ func (service *Service) ResolveWithWildcardMatch(name string) (ResolveAllResult, // // result, found, usedWildcard := service.ResolveWithMatch("node.charon.lthn") func (service *Service) ResolveWithMatch(name string) (ResolveAllResult, bool, bool) { + if service == nil { + return ResolveAllResult{}, false, false + } record, ok, usedWildcard := service.findRecordWithMatch(name) if !ok { return ResolveAllResult{}, false, false @@ -816,6 +825,9 @@ func (service *Service) ResolveWithMatch(name string) (ResolveAllResult, bool, b // // txt, ok := service.ResolveTXT("gateway.charon.lthn") func (service *Service) ResolveTXT(name string) ([]string, bool) { + if service == nil { + return nil, false + } result, ok := service.ResolveTXTRecords(name) if !ok { return nil, false @@ -827,6 +839,9 @@ func (service *Service) ResolveTXT(name string) ([]string, bool) { // // result, ok := service.ResolveTXTRecords("gateway.charon.lthn") func (service *Service) ResolveTXTRecords(name string) (ResolveTXTResult, bool) { + if service == nil { + return ResolveTXTResult{}, false + } record, ok := service.findRecord(name) if !ok { return ResolveTXTResult{}, false @@ -909,6 +924,9 @@ func (service *Service) refreshDerivedStateLocked() { // // addresses, ok := service.ResolveAddress("gateway.charon.lthn") func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) { + if service == nil { + return ResolveAddressResult{}, false + } record, ok := service.findRecord(name) if !ok { return ResolveAddressResult{}, false @@ -922,6 +940,9 @@ func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) // // names, ok := service.ResolveReverse("10.10.10.10") func (service *Service) ResolveReverse(ip string) ([]string, bool) { + if service == nil { + return nil, false + } service.pruneExpiredRecords() service.mu.RLock() @@ -943,6 +964,9 @@ func (service *Service) ResolveReverse(ip string) ([]string, bool) { // // Missing names still return empty arrays so the action payload stays stable. func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) { + if service == nil { + return ResolveAllResult{}, false + } record, ok := service.findRecord(name) if !ok { if normalizeName(name) == service.ZoneApex() && service.ZoneApex() != "" { @@ -972,6 +996,11 @@ func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) { // health := service.Health() // fmt.Println(health.Status, health.NamesCached, health.TreeRoot) func (service *Service) Health() HealthResult { + if service == nil { + return HealthResult{ + Status: "not_ready", + } + } service.pruneExpiredRecords() service.mu.RLock() @@ -994,6 +1023,9 @@ func (service *Service) Health() HealthResult { // apex := service.ZoneApex() // // "charon.lthn" func (service *Service) ZoneApex() string { + if service == nil { + return "" + } service.pruneExpiredRecords() service.mu.RLock() @@ -1005,6 +1037,9 @@ func (service *Service) ZoneApex() string { // // result, ok := service.ResolveReverseNames("10.10.10.10") func (service *Service) ResolveReverseNames(ip string) (ReverseLookupResult, bool) { + if service == nil { + return ReverseLookupResult{}, false + } names, ok := service.ResolveReverse(ip) if !ok { return ReverseLookupResult{}, false diff --git a/service_test.go b/service_test.go index cf51419..818e468 100644 --- a/service_test.go +++ b/service_test.go @@ -3436,6 +3436,45 @@ func TestIntActionValueAcceptsWholeFloat(t *testing.T) { } } +func TestServiceMethodsHandleNilReceiverWithoutPanicking(t *testing.T) { + var service *Service + + if _, ok := service.Resolve("gateway.charon.lthn"); ok { + t.Fatal("expected nil service Resolve to return not found") + } + if _, _, ok := service.ResolveWithMatch("gateway.charon.lthn"); ok { + t.Fatal("expected nil service ResolveWithMatch to return not found") + } + if _, ok := service.ResolveReverse("10.10.10.10"); ok { + t.Fatal("expected nil service ResolveReverse to return not found") + } + if _, ok := service.ResolveReverseNames("10.10.10.10"); ok { + t.Fatal("expected nil service ResolveReverseNames to return not found") + } + if got := service.ResolveDNSPort(); got != DefaultDNSPort { + t.Fatalf("expected default DNS port from nil service, got %d", got) + } + if got := service.ResolveHTTPPort(); got != DefaultHTTPPort { + t.Fatalf("expected default HTTP port from nil service, got %d", got) + } + if got := service.Health().Status; got != "not_ready" { + t.Fatalf("expected nil service health status \"not_ready\", got %q", got) + } +} + +func TestServiceServeReturnsErrorOnNilReceiver(t *testing.T) { + var service *Service + if _, err := service.Serve("127.0.0.1", 0); err == nil { + t.Fatal("expected Serve to fail for nil service receiver") + } + if _, err := service.ServeAll("127.0.0.1", 0, 0); err == nil { + t.Fatal("expected ServeAll to fail for nil service receiver") + } + if _, err := service.ServeConfigured("127.0.0.1"); err == nil { + t.Fatal("expected ServeConfigured to fail for nil service receiver") + } +} + type actionRecorder struct { names []string handlers map[string]func(map[string]any) (any, bool, error)