feat(dns): add nil-safe service method guards
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
b6f9d50393
commit
5fd82dd342
3 changed files with 115 additions and 2 deletions
43
serve.go
43
serve.go
|
|
@ -119,7 +119,10 @@ func (server *DNSServer) Close() error {
|
|||
// port := service.ResolveDNSPort()
|
||||
// server, err := service.Serve("127.0.0.1", port)
|
||||
func (service *Service) ResolveDNSPort() int {
|
||||
if service == nil || service.dnsPort <= 0 {
|
||||
if service == nil {
|
||||
return DefaultDNSPort
|
||||
}
|
||||
if service.dnsPort <= 0 {
|
||||
return DefaultDNSPort
|
||||
}
|
||||
return service.dnsPort
|
||||
|
|
@ -130,6 +133,9 @@ func (service *Service) ResolveDNSPort() int {
|
|||
// port := service.DNSListenPort()
|
||||
// server, err := service.Serve("127.0.0.1", port)
|
||||
func (service *Service) DNSListenPort() int {
|
||||
if service == nil {
|
||||
return DefaultDNSPort
|
||||
}
|
||||
return service.ResolveDNSPort()
|
||||
}
|
||||
|
||||
|
|
@ -138,16 +144,25 @@ func (service *Service) DNSListenPort() int {
|
|||
// port := service.DNSPort()
|
||||
// server, err := service.Serve("127.0.0.1", port)
|
||||
func (service *Service) DNSPort() int {
|
||||
if service == nil {
|
||||
return DefaultDNSPort
|
||||
}
|
||||
return service.ResolveDNSPort()
|
||||
}
|
||||
|
||||
// resolveDNSListenPort keeps internal callers aligned with explicit naming.
|
||||
func (service *Service) resolveDNSListenPort() int {
|
||||
if service == nil {
|
||||
return DefaultDNSPort
|
||||
}
|
||||
return service.DNSListenPort()
|
||||
}
|
||||
|
||||
// resolveServePort is a legacy compatibility helper.
|
||||
func (service *Service) resolveServePort() int {
|
||||
if service == nil {
|
||||
return DefaultDNSPort
|
||||
}
|
||||
return service.ResolveDNSPort()
|
||||
}
|
||||
|
||||
|
|
@ -156,7 +171,10 @@ func (service *Service) resolveServePort() int {
|
|||
// port := service.ResolveHTTPPort()
|
||||
// healthServer, err := service.ServeHTTPHealth("127.0.0.1", port)
|
||||
func (service *Service) ResolveHTTPPort() int {
|
||||
if service == nil || service.httpPort <= 0 {
|
||||
if service == nil {
|
||||
return DefaultHTTPPort
|
||||
}
|
||||
if service.httpPort <= 0 {
|
||||
return DefaultHTTPPort
|
||||
}
|
||||
return service.httpPort
|
||||
|
|
@ -167,6 +185,9 @@ func (service *Service) ResolveHTTPPort() int {
|
|||
// port := service.HTTPListenPort()
|
||||
// server, err := service.ServeHTTPHealth("127.0.0.1", port)
|
||||
func (service *Service) HTTPListenPort() int {
|
||||
if service == nil {
|
||||
return DefaultHTTPPort
|
||||
}
|
||||
return service.ResolveHTTPPort()
|
||||
}
|
||||
|
||||
|
|
@ -175,16 +196,25 @@ func (service *Service) HTTPListenPort() int {
|
|||
// port := service.HTTPPort()
|
||||
// healthServer, err := service.ServeHTTPHealth("127.0.0.1", port)
|
||||
func (service *Service) HTTPPort() int {
|
||||
if service == nil {
|
||||
return DefaultHTTPPort
|
||||
}
|
||||
return service.ResolveHTTPPort()
|
||||
}
|
||||
|
||||
// resolveHTTPListenPort keeps internal callers aligned with explicit naming.
|
||||
func (service *Service) resolveHTTPListenPort() int {
|
||||
if service == nil {
|
||||
return DefaultHTTPPort
|
||||
}
|
||||
return service.HTTPListenPort()
|
||||
}
|
||||
|
||||
// resolveHTTPPort is a legacy compatibility helper.
|
||||
func (service *Service) resolveHTTPPort() int {
|
||||
if service == nil {
|
||||
return DefaultHTTPPort
|
||||
}
|
||||
return service.ResolveHTTPPort()
|
||||
}
|
||||
|
||||
|
|
@ -195,6 +225,9 @@ func (service *Service) resolveHTTPPort() int {
|
|||
// lookup := new(dnsprotocol.Msg)
|
||||
// lookup.SetQuestion("gateway.charon.lthn.", dnsprotocol.TypeA)
|
||||
func (service *Service) Serve(bind string, port int) (*DNSServer, error) {
|
||||
if service == nil {
|
||||
return nil, fmt.Errorf("service is required")
|
||||
}
|
||||
if bind == "" {
|
||||
bind = "127.0.0.1"
|
||||
}
|
||||
|
|
@ -243,6 +276,9 @@ func (service *Service) Serve(bind string, port int) (*DNSServer, error) {
|
|||
// defer func() { _ = runtime.Close() }()
|
||||
// fmt.Println("dns:", runtime.DNSAddress(), "health:", runtime.HealthAddress())
|
||||
func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*ServiceRuntime, error) {
|
||||
if service == nil {
|
||||
return nil, fmt.Errorf("service is required")
|
||||
}
|
||||
if dnsPort <= 0 {
|
||||
dnsPort = service.resolveDNSListenPort()
|
||||
}
|
||||
|
|
@ -276,6 +312,9 @@ func (service *Service) ServeAll(bind string, dnsPort int, httpPort int) (*Servi
|
|||
// })
|
||||
// runtime, err := service.ServeConfigured("127.0.0.1")
|
||||
func (service *Service) ServeConfigured(bind string) (*ServiceRuntime, error) {
|
||||
if service == nil {
|
||||
return nil, fmt.Errorf("service is required")
|
||||
}
|
||||
return service.ServeAll(bind, service.dnsPort, service.httpPort)
|
||||
}
|
||||
|
||||
|
|
|
|||
35
service.go
35
service.go
|
|
@ -783,6 +783,9 @@ func (service *Service) RemoveRecord(name string) {
|
|||
//
|
||||
// result, ok := service.Resolve("gateway.charon.lthn")
|
||||
func (service *Service) Resolve(name string) (ResolveAllResult, bool) {
|
||||
if service == nil {
|
||||
return ResolveAllResult{}, false
|
||||
}
|
||||
record, ok := service.findRecord(name)
|
||||
if !ok {
|
||||
return ResolveAllResult{}, false
|
||||
|
|
@ -794,6 +797,9 @@ func (service *Service) Resolve(name string) (ResolveAllResult, bool) {
|
|||
//
|
||||
// result, ok, usedWildcard := service.ResolveWithWildcardMatch("node.charon.lthn")
|
||||
func (service *Service) ResolveWithWildcardMatch(name string) (ResolveAllResult, bool, bool) {
|
||||
if service == nil {
|
||||
return ResolveAllResult{}, false, false
|
||||
}
|
||||
record, ok, usedWildcard := service.findRecordWithMatch(name)
|
||||
if !ok {
|
||||
return ResolveAllResult{}, false, false
|
||||
|
|
@ -805,6 +811,9 @@ func (service *Service) ResolveWithWildcardMatch(name string) (ResolveAllResult,
|
|||
//
|
||||
// result, found, usedWildcard := service.ResolveWithMatch("node.charon.lthn")
|
||||
func (service *Service) ResolveWithMatch(name string) (ResolveAllResult, bool, bool) {
|
||||
if service == nil {
|
||||
return ResolveAllResult{}, false, false
|
||||
}
|
||||
record, ok, usedWildcard := service.findRecordWithMatch(name)
|
||||
if !ok {
|
||||
return ResolveAllResult{}, false, false
|
||||
|
|
@ -816,6 +825,9 @@ func (service *Service) ResolveWithMatch(name string) (ResolveAllResult, bool, b
|
|||
//
|
||||
// txt, ok := service.ResolveTXT("gateway.charon.lthn")
|
||||
func (service *Service) ResolveTXT(name string) ([]string, bool) {
|
||||
if service == nil {
|
||||
return nil, false
|
||||
}
|
||||
result, ok := service.ResolveTXTRecords(name)
|
||||
if !ok {
|
||||
return nil, false
|
||||
|
|
@ -827,6 +839,9 @@ func (service *Service) ResolveTXT(name string) ([]string, bool) {
|
|||
//
|
||||
// result, ok := service.ResolveTXTRecords("gateway.charon.lthn")
|
||||
func (service *Service) ResolveTXTRecords(name string) (ResolveTXTResult, bool) {
|
||||
if service == nil {
|
||||
return ResolveTXTResult{}, false
|
||||
}
|
||||
record, ok := service.findRecord(name)
|
||||
if !ok {
|
||||
return ResolveTXTResult{}, false
|
||||
|
|
@ -909,6 +924,9 @@ func (service *Service) refreshDerivedStateLocked() {
|
|||
//
|
||||
// addresses, ok := service.ResolveAddress("gateway.charon.lthn")
|
||||
func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) {
|
||||
if service == nil {
|
||||
return ResolveAddressResult{}, false
|
||||
}
|
||||
record, ok := service.findRecord(name)
|
||||
if !ok {
|
||||
return ResolveAddressResult{}, false
|
||||
|
|
@ -922,6 +940,9 @@ func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool)
|
|||
//
|
||||
// names, ok := service.ResolveReverse("10.10.10.10")
|
||||
func (service *Service) ResolveReverse(ip string) ([]string, bool) {
|
||||
if service == nil {
|
||||
return nil, false
|
||||
}
|
||||
service.pruneExpiredRecords()
|
||||
|
||||
service.mu.RLock()
|
||||
|
|
@ -943,6 +964,9 @@ func (service *Service) ResolveReverse(ip string) ([]string, bool) {
|
|||
//
|
||||
// Missing names still return empty arrays so the action payload stays stable.
|
||||
func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) {
|
||||
if service == nil {
|
||||
return ResolveAllResult{}, false
|
||||
}
|
||||
record, ok := service.findRecord(name)
|
||||
if !ok {
|
||||
if normalizeName(name) == service.ZoneApex() && service.ZoneApex() != "" {
|
||||
|
|
@ -972,6 +996,11 @@ func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) {
|
|||
// health := service.Health()
|
||||
// fmt.Println(health.Status, health.NamesCached, health.TreeRoot)
|
||||
func (service *Service) Health() HealthResult {
|
||||
if service == nil {
|
||||
return HealthResult{
|
||||
Status: "not_ready",
|
||||
}
|
||||
}
|
||||
service.pruneExpiredRecords()
|
||||
|
||||
service.mu.RLock()
|
||||
|
|
@ -994,6 +1023,9 @@ func (service *Service) Health() HealthResult {
|
|||
// apex := service.ZoneApex()
|
||||
// // "charon.lthn"
|
||||
func (service *Service) ZoneApex() string {
|
||||
if service == nil {
|
||||
return ""
|
||||
}
|
||||
service.pruneExpiredRecords()
|
||||
|
||||
service.mu.RLock()
|
||||
|
|
@ -1005,6 +1037,9 @@ func (service *Service) ZoneApex() string {
|
|||
//
|
||||
// result, ok := service.ResolveReverseNames("10.10.10.10")
|
||||
func (service *Service) ResolveReverseNames(ip string) (ReverseLookupResult, bool) {
|
||||
if service == nil {
|
||||
return ReverseLookupResult{}, false
|
||||
}
|
||||
names, ok := service.ResolveReverse(ip)
|
||||
if !ok {
|
||||
return ReverseLookupResult{}, false
|
||||
|
|
|
|||
|
|
@ -3436,6 +3436,45 @@ func TestIntActionValueAcceptsWholeFloat(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestServiceMethodsHandleNilReceiverWithoutPanicking(t *testing.T) {
|
||||
var service *Service
|
||||
|
||||
if _, ok := service.Resolve("gateway.charon.lthn"); ok {
|
||||
t.Fatal("expected nil service Resolve to return not found")
|
||||
}
|
||||
if _, _, ok := service.ResolveWithMatch("gateway.charon.lthn"); ok {
|
||||
t.Fatal("expected nil service ResolveWithMatch to return not found")
|
||||
}
|
||||
if _, ok := service.ResolveReverse("10.10.10.10"); ok {
|
||||
t.Fatal("expected nil service ResolveReverse to return not found")
|
||||
}
|
||||
if _, ok := service.ResolveReverseNames("10.10.10.10"); ok {
|
||||
t.Fatal("expected nil service ResolveReverseNames to return not found")
|
||||
}
|
||||
if got := service.ResolveDNSPort(); got != DefaultDNSPort {
|
||||
t.Fatalf("expected default DNS port from nil service, got %d", got)
|
||||
}
|
||||
if got := service.ResolveHTTPPort(); got != DefaultHTTPPort {
|
||||
t.Fatalf("expected default HTTP port from nil service, got %d", got)
|
||||
}
|
||||
if got := service.Health().Status; got != "not_ready" {
|
||||
t.Fatalf("expected nil service health status \"not_ready\", got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceServeReturnsErrorOnNilReceiver(t *testing.T) {
|
||||
var service *Service
|
||||
if _, err := service.Serve("127.0.0.1", 0); err == nil {
|
||||
t.Fatal("expected Serve to fail for nil service receiver")
|
||||
}
|
||||
if _, err := service.ServeAll("127.0.0.1", 0, 0); err == nil {
|
||||
t.Fatal("expected ServeAll to fail for nil service receiver")
|
||||
}
|
||||
if _, err := service.ServeConfigured("127.0.0.1"); err == nil {
|
||||
t.Fatal("expected ServeConfigured to fail for nil service receiver")
|
||||
}
|
||||
}
|
||||
|
||||
type actionRecorder struct {
|
||||
names []string
|
||||
handlers map[string]func(map[string]any) (any, bool, error)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue