diff --git a/service.go b/service.go index 8e33c88..91059ce 100644 --- a/service.go +++ b/service.go @@ -53,6 +53,37 @@ type ReverseLookupResult struct { Names []string `json:"names"` } +// ReverseIndex stores IP-to-name lookups in a dedicated semantic wrapper. +// +// index := buildReverseIndex(records, 15*time.Second) +// names, ok := index.Lookup("10.10.10.10") +type ReverseIndex struct { + cache *cache.Cache +} + +func (index *ReverseIndex) Lookup(ip string) ([]string, bool) { + if index == nil || index.cache == nil { + return nil, false + } + + normalizedIP := normalizeIP(ip) + if normalizedIP == "" { + return nil, false + } + + rawNames, found := index.cache.Get(normalizedIP) + if !found { + return nil, false + } + + names, ok := rawNames.([]string) + if !ok || len(names) == 0 { + return nil, false + } + + return append([]string(nil), names...), true +} + // HealthResult is the typed payload returned by Health and dns.health. // // health := service.Health() @@ -67,7 +98,7 @@ type Service struct { mu sync.RWMutex records map[string]NameRecords recordExpiry map[string]time.Time - reverseIndex *cache.Cache + reverseIndex *ReverseIndex treeRoot string zoneApex string dnsPort int @@ -180,7 +211,7 @@ func NewService(options ServiceOptions) *Service { service := &Service{ records: cached, recordExpiry: make(map[string]time.Time, len(cached)), - reverseIndex: buildReverseIndexCache(cached, options.RecordTTL), + reverseIndex: buildReverseIndex(cached, options.RecordTTL), treeRoot: treeRoot, zoneApex: computeZoneApex(cached), dnsPort: options.DNSPort, @@ -780,7 +811,7 @@ func (service *Service) pruneExpiredRecords() { } func (service *Service) refreshDerivedStateLocked() { - service.reverseIndex = buildReverseIndexCache(service.records, service.recordTTL) + service.reverseIndex = buildReverseIndex(service.records, service.recordTTL) service.treeRoot = computeTreeRoot(service.records) service.zoneApex = computeZoneApex(service.records) } @@ -804,11 +835,6 @@ func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) func (service *Service) ResolveReverse(ip string) ([]string, bool) { service.pruneExpiredRecords() - normalizedIP := normalizeIP(ip) - if normalizedIP == "" { - return nil, false - } - service.mu.RLock() reverseIndex := service.reverseIndex service.mu.RUnlock() @@ -817,16 +843,7 @@ func (service *Service) ResolveReverse(ip string) ([]string, bool) { return nil, false } - rawNames, found := reverseIndex.Get(normalizedIP) - if !found { - return nil, false - } - - names, ok := rawNames.([]string) - if !ok || len(names) == 0 { - return nil, false - } - return append([]string(nil), names...), true + return reverseIndex.Lookup(ip) } // ResolveAll returns the full record set for a name, including synthesized apex NS data. @@ -930,7 +947,7 @@ func resolveResult(record NameRecords) ResolveAllResult { } } -func buildReverseIndexCache(records map[string]NameRecords, ttl time.Duration) *cache.Cache { +func buildReverseIndex(records map[string]NameRecords, ttl time.Duration) *ReverseIndex { raw := map[string]map[string]struct{}{} for name, record := range records { for _, ip := range record.A { @@ -976,7 +993,7 @@ func buildReverseIndexCache(records map[string]NameRecords, ttl time.Duration) * for ip, names := range reverseIndex { reverseIndexCache.Set(ip, names, cacheTTL) } - return reverseIndexCache + return &ReverseIndex{cache: reverseIndexCache} } func normalizeIP(ip string) string {