feat(dns): add nil-safe service method guards

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-04 00:04:47 +00:00
parent b6f9d50393
commit 5fd82dd342
3 changed files with 115 additions and 2 deletions

View file

@ -119,7 +119,10 @@ func (server *DNSServer) Close() error {
// port := service.ResolveDNSPort()
// server, err := service.Serve("127.0.0.1", port)
func (service *Service) ResolveDNSPort() int {
if service == nil || service.dnsPort <= 0 {
if service == nil {
return DefaultDNSPort
}
if service.dnsPort <= 0 {
return DefaultDNSPort
}
return service.dnsPort
@ -130,6 +133,9 @@ func (service *Service) ResolveDNSPort() int {
// port := service.DNSListenPort()
// server, err := service.Serve("127.0.0.1", port)
func (service *Service) DNSListenPort() int {
if service == nil {
return DefaultDNSPort
}
return service.ResolveDNSPort()
}
@ -138,16 +144,25 @@ func (service *Service) DNSListenPort() int {
// port := service.DNSPort()
// server, err := service.Serve("127.0.0.1", port)
func (service *Service) DNSPort() int {
if service == nil {
return DefaultDNSPort
}
return service.ResolveDNSPort()
}
// resolveDNSListenPort keeps internal callers aligned with explicit naming.
func (service *Service) resolveDNSListenPort() int {
if service == nil {
return DefaultDNSPort
}
return service.DNSListenPort()
}
// resolveServePort is a legacy compatibility helper.
func (service *Service) resolveServePort() int {
if service == nil {
return DefaultDNSPort
}
return service.ResolveDNSPort()
}
@ -156,7 +171,10 @@ func (service *Service) resolveServePort() int {
// port := service.ResolveHTTPPort()
// healthServer, err := service.ServeHTTPHealth("127.0.0.1", port)
func (service *Service) ResolveHTTPPort() int {
if service == nil || service.httpPort <= 0 {
if service == nil {
return DefaultHTTPPort
}
if service.httpPort <= 0 {
return DefaultHTTPPort
}
return service.httpPort
@ -167,6 +185,9 @@ func (service *Service) ResolveHTTPPort() int {
// port := service.HTTPListenPort()
// server, err := service.ServeHTTPHealth("127.0.0.1", port)
func (service *Service) HTTPListenPort() int {
if service == nil {
return DefaultHTTPPort
}
return service.ResolveHTTPPort()
}
@ -175,16 +196,25 @@ func (service *Service) HTTPListenPort() int {
// port := service.HTTPPort()
// healthServer, err := service.ServeHTTPHealth("127.0.0.1", port)
func (service *Service) HTTPPort() int {
if service == nil {
return DefaultHTTPPort
}
return service.ResolveHTTPPort()
}
// resolveHTTPListenPort keeps internal callers aligned with explicit naming.
func (service *Service) resolveHTTPListenPort() int {
if service == nil {
return DefaultHTTPPort
}
return service.HTTPListenPort()
}
// resolveHTTPPort is a legacy compatibility helper.
func (service *Service) resolveHTTPPort() int {
if service == nil {
return DefaultHTTPPort
}
return service.ResolveHTTPPort()
}
@ -195,6 +225,9 @@ func (service *Service) resolveHTTPPort() int {
// lookup := new(dnsprotocol.Msg)
// lookup.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeA)
func (service *Service) Serve(bind string, port int) (*DNSServer, error) {
if service == nil {
return nil, fmt.Errorf("service is required")
}
if bind == "" {
bind = "127.0.0.1"
}
@ -243,6 +276,9 @@ func (service *Service) Serve(bind string, port int) (*DNSServer, error) {
// defer func() { _ = runtime.Close() }()
// fmt.Println("dns:", runtime.DNSAddress(), "health:", runtime.HealthAddress())
func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*ServiceRuntime, error) {
if service == nil {
return nil, fmt.Errorf("service is required")
}
if dnsPort <= 0 {
dnsPort = service.resolveDNSListenPort()
}
@ -276,6 +312,9 @@ func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*Servi
// })
// runtime, err := service.ServeConfigured("127.0.0.1")
func (service *Service) ServeConfigured(bind string) (*ServiceRuntime, error) {
if service == nil {
return nil, fmt.Errorf("service is required")
}
return service.ServeAll(bind, service.dnsPort, service.httpPort)
}

View file

@ -783,6 +783,9 @@ func (service *Service) RemoveRecord(name string) {
//
// result, ok := service.Resolve("gateway.charon.lthn")
func (service *Service) Resolve(name string) (ResolveAllResult, bool) {
if service == nil {
return ResolveAllResult{}, false
}
record, ok := service.findRecord(name)
if !ok {
return ResolveAllResult{}, false
@ -794,6 +797,9 @@ func (service *Service) Resolve(name string) (ResolveAllResult, bool) {
//
// result, ok, usedWildcard := service.ResolveWithWildcardMatch("node.charon.lthn")
func (service *Service) ResolveWithWildcardMatch(name string) (ResolveAllResult, bool, bool) {
if service == nil {
return ResolveAllResult{}, false, false
}
record, ok, usedWildcard := service.findRecordWithMatch(name)
if !ok {
return ResolveAllResult{}, false, false
@ -805,6 +811,9 @@ func (service *Service) ResolveWithWildcardMatch(name string) (ResolveAllResult,
//
// result, found, usedWildcard := service.ResolveWithMatch("node.charon.lthn")
func (service *Service) ResolveWithMatch(name string) (ResolveAllResult, bool, bool) {
if service == nil {
return ResolveAllResult{}, false, false
}
record, ok, usedWildcard := service.findRecordWithMatch(name)
if !ok {
return ResolveAllResult{}, false, false
@ -816,6 +825,9 @@ func (service *Service) ResolveWithMatch(name string) (ResolveAllResult, bool, b
//
// txt, ok := service.ResolveTXT("gateway.charon.lthn")
func (service *Service) ResolveTXT(name string) ([]string, bool) {
if service == nil {
return nil, false
}
result, ok := service.ResolveTXTRecords(name)
if !ok {
return nil, false
@ -827,6 +839,9 @@ func (service *Service) ResolveTXT(name string) ([]string, bool) {
//
// result, ok := service.ResolveTXTRecords("gateway.charon.lthn")
func (service *Service) ResolveTXTRecords(name string) (ResolveTXTResult, bool) {
if service == nil {
return ResolveTXTResult{}, false
}
record, ok := service.findRecord(name)
if !ok {
return ResolveTXTResult{}, false
@ -909,6 +924,9 @@ func (service *Service) refreshDerivedStateLocked() {
//
// addresses, ok := service.ResolveAddress("gateway.charon.lthn")
func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) {
if service == nil {
return ResolveAddressResult{}, false
}
record, ok := service.findRecord(name)
if !ok {
return ResolveAddressResult{}, false
@ -922,6 +940,9 @@ func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool)
//
// names, ok := service.ResolveReverse("10.10.10.10")
func (service *Service) ResolveReverse(ip string) ([]string, bool) {
if service == nil {
return nil, false
}
service.pruneExpiredRecords()
service.mu.RLock()
@ -943,6 +964,9 @@ func (service *Service) ResolveReverse(ip string) ([]string, bool) {
//
// Missing names still return empty arrays so the action payload stays stable.
func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) {
if service == nil {
return ResolveAllResult{}, false
}
record, ok := service.findRecord(name)
if !ok {
if normalizeName(name) == service.ZoneApex() && service.ZoneApex() != "" {
@ -972,6 +996,11 @@ func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) {
// health := service.Health()
// fmt.Println(health.Status, health.NamesCached, health.TreeRoot)
func (service *Service) Health() HealthResult {
if service == nil {
return HealthResult{
Status: "not_ready",
}
}
service.pruneExpiredRecords()
service.mu.RLock()
@ -994,6 +1023,9 @@ func (service *Service) Health() HealthResult {
// apex := service.ZoneApex()
// // "charon.lthn"
func (service *Service) ZoneApex() string {
if service == nil {
return ""
}
service.pruneExpiredRecords()
service.mu.RLock()
@ -1005,6 +1037,9 @@ func (service *Service) ZoneApex() string {
//
// result, ok := service.ResolveReverseNames("10.10.10.10")
func (service *Service) ResolveReverseNames(ip string) (ReverseLookupResult, bool) {
if service == nil {
return ReverseLookupResult{}, false
}
names, ok := service.ResolveReverse(ip)
if !ok {
return ReverseLookupResult{}, false

View file

@ -3436,6 +3436,45 @@ func TestIntActionValueAcceptsWholeFloat(t *testing.T) {
}
}
func TestServiceMethodsHandleNilReceiverWithoutPanicking(t *testing.T) {
var service *Service
if _, ok := service.Resolve("gateway.charon.lthn"); ok {
t.Fatal("expected nil service Resolve to return not found")
}
if _, _, ok := service.ResolveWithMatch("gateway.charon.lthn"); ok {
t.Fatal("expected nil service ResolveWithMatch to return not found")
}
if _, ok := service.ResolveReverse("10.10.10.10"); ok {
t.Fatal("expected nil service ResolveReverse to return not found")
}
if _, ok := service.ResolveReverseNames("10.10.10.10"); ok {
t.Fatal("expected nil service ResolveReverseNames to return not found")
}
if got := service.ResolveDNSPort(); got != DefaultDNSPort {
t.Fatalf("expected default DNS port from nil service, got %d", got)
}
if got := service.ResolveHTTPPort(); got != DefaultHTTPPort {
t.Fatalf("expected default HTTP port from nil service, got %d", got)
}
if got := service.Health().Status; got != "not_ready" {
t.Fatalf("expected nil service health status \"not_ready\", got %q", got)
}
}
func TestServiceServeReturnsErrorOnNilReceiver(t *testing.T) {
var service *Service
if _, err := service.Serve("127.0.0.1", 0); err == nil {
t.Fatal("expected Serve to fail for nil service receiver")
}
if _, err := service.ServeAll("127.0.0.1", 0, 0); err == nil {
t.Fatal("expected ServeAll to fail for nil service receiver")
}
if _, err := service.ServeConfigured("127.0.0.1"); err == nil {
t.Fatal("expected ServeConfigured to fail for nil service receiver")
}
}
type actionRecorder struct {
names []string
handlers map[string]func(map[string]any) (any, bool, error)