diff --git a/service.go b/service.go index 4175549..37c1b3e 100644 --- a/service.go +++ b/service.go @@ -66,11 +66,13 @@ type HealthResult struct { type Service struct { mu sync.RWMutex records map[string]NameRecords + recordExpiry map[string]time.Time reverseIndex *cache.Cache treeRoot string zoneApex string dnsPort int httpPort int + recordTTL time.Duration lastAliasFingerprint string hsdClient *HSDClient mainchainAliasClient *MainchainAliasClient @@ -91,13 +93,16 @@ type Service struct { // Records: map[string]dns.NameRecords{ // "gateway.charon.lthn": {A: []string{"10.10.10.10"}}, // }, +// RecordTTL: 15 * time.Second, // HSDURL: "http://127.0.0.1:14037", // MainchainURL: "http://127.0.0.1:14037", // }) type ServiceOptions struct { - Records map[string]NameRecords - RecordDiscoverer func() (map[string]NameRecords, error) - FallbackRecordDiscoverer func() (map[string]NameRecords, error) + Records map[string]NameRecords + RecordDiscoverer func() (map[string]NameRecords, error) + FallbackRecordDiscoverer func() (map[string]NameRecords, error) + // RecordTTL keeps forward records and the reverse index alive for the same duration. + RecordTTL time.Duration DNSPort int HTTPPort int HSDURL string @@ -163,11 +168,13 @@ func NewService(options ServiceOptions) *Service { treeRoot := computeTreeRoot(cached) service := &Service{ records: cached, - reverseIndex: buildReverseIndexCache(cached), + recordExpiry: make(map[string]time.Time, len(cached)), + reverseIndex: buildReverseIndexCache(cached, options.RecordTTL), treeRoot: treeRoot, zoneApex: computeZoneApex(cached), dnsPort: options.DNSPort, httpPort: options.HTTPPort, + recordTTL: options.RecordTTL, hsdClient: hsdClient, mainchainAliasClient: mainchainClient, chainAliasActionCaller: options.ChainAliasActionCaller, @@ -183,6 +190,13 @@ func NewService(options ServiceOptions) *Service { service.RegisterActions(options.ActionRegistrar) } + if options.RecordTTL > 0 { + expiresAt := time.Now().Add(options.RecordTTL) + for name := range cached { + service.recordExpiry[name] = expiresAt + } + } + return service } @@ -587,44 +601,64 @@ func (service *Service) DiscoverAliases(ctx context.Context) error { func (service *Service) replaceRecords(discovered map[string]NameRecords) { cached := make(map[string]NameRecords, len(discovered)) + expiry := make(map[string]time.Time, len(discovered)) + now := time.Now() for name, record := range discovered { normalizedName := normalizeName(name) if normalizedName == "" { continue } cached[normalizedName] = record + if service.recordTTL > 0 { + expiry[normalizedName] = now.Add(service.recordTTL) + } } service.mu.Lock() defer service.mu.Unlock() service.records = cached - service.reverseIndex = buildReverseIndexCache(service.records) - service.treeRoot = computeTreeRoot(service.records) - service.zoneApex = computeZoneApex(service.records) + service.recordExpiry = expiry + service.refreshDerivedStateLocked() } // SetRecord inserts or replaces one cached name. // // service.SetRecord("gateway.charon.lthn", dns.NameRecords{A: []string{"10.10.10.10"}}) func (service *Service) SetRecord(name string, record NameRecords) { + normalizedName := normalizeName(name) + now := time.Now() service.mu.Lock() defer service.mu.Unlock() - service.records[normalizeName(name)] = record - service.reverseIndex = buildReverseIndexCache(service.records) - service.treeRoot = computeTreeRoot(service.records) - service.zoneApex = computeZoneApex(service.records) + if normalizedName == "" { + return + } + service.records[normalizedName] = record + if service.recordTTL > 0 { + if service.recordExpiry == nil { + service.recordExpiry = make(map[string]time.Time) + } + service.recordExpiry[normalizedName] = now.Add(service.recordTTL) + } else if service.recordExpiry != nil { + delete(service.recordExpiry, normalizedName) + } + service.refreshDerivedStateLocked() } // RemoveRecord deletes one cached name. // // service.RemoveRecord("gateway.charon.lthn") func (service *Service) RemoveRecord(name string) { + normalizedName := normalizeName(name) service.mu.Lock() defer service.mu.Unlock() - delete(service.records, normalizeName(name)) - service.reverseIndex = buildReverseIndexCache(service.records) - service.treeRoot = computeTreeRoot(service.records) - service.zoneApex = computeZoneApex(service.records) + if normalizedName == "" { + return + } + delete(service.records, normalizedName) + if service.recordExpiry != nil { + delete(service.recordExpiry, normalizedName) + } + service.refreshDerivedStateLocked() } // Resolve returns all record types for a name when an exact or wildcard match exists. @@ -694,6 +728,40 @@ func (service *Service) DiscoverWithHSD(ctx context.Context, aliases []string, c return nil } +func (service *Service) pruneExpiredRecords() { + if service.recordTTL <= 0 { + return + } + + now := time.Now() + service.mu.Lock() + defer service.mu.Unlock() + + if len(service.recordExpiry) == 0 { + return + } + + changed := false + for name, expiresAt := range service.recordExpiry { + if expiresAt.IsZero() || now.Before(expiresAt) { + continue + } + delete(service.recordExpiry, name) + delete(service.records, name) + changed = true + } + + if changed { + service.refreshDerivedStateLocked() + } +} + +func (service *Service) refreshDerivedStateLocked() { + service.reverseIndex = buildReverseIndexCache(service.records, service.recordTTL) + service.treeRoot = computeTreeRoot(service.records) + service.zoneApex = computeZoneApex(service.records) +} + // ResolveAddress returns A and AAAA values merged into one address list. // // addresses, ok := service.ResolveAddress("gateway.charon.lthn") @@ -711,6 +779,8 @@ func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) // // names, ok := service.ResolveReverse("10.10.10.10") func (service *Service) ResolveReverse(ip string) ([]string, bool) { + service.pruneExpiredRecords() + normalizedIP := normalizeIP(ip) if normalizedIP == "" { return nil, false @@ -773,6 +843,8 @@ func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) { // health := service.Health() // fmt.Println(health.Status, health.NamesCached, health.TreeRoot) func (service *Service) Health() HealthResult { + service.pruneExpiredRecords() + service.mu.RLock() defer service.mu.RUnlock() @@ -793,6 +865,8 @@ func (service *Service) Health() HealthResult { // apex := service.ZoneApex() // // "charon.lthn" func (service *Service) ZoneApex() string { + service.pruneExpiredRecords() + service.mu.RLock() defer service.mu.RUnlock() return service.zoneApex @@ -810,6 +884,8 @@ func (service *Service) ResolveReverseNames(ip string) (ReverseLookupResult, boo } func (service *Service) findRecord(name string) (NameRecords, bool) { + service.pruneExpiredRecords() + service.mu.RLock() defer service.mu.RUnlock() @@ -831,7 +907,7 @@ func resolveResult(record NameRecords) ResolveAllResult { } } -func buildReverseIndexCache(records map[string]NameRecords) *cache.Cache { +func buildReverseIndexCache(records map[string]NameRecords, ttl time.Duration) *cache.Cache { raw := map[string]map[string]struct{}{} for name, record := range records { for _, ip := range record.A { @@ -869,9 +945,13 @@ func buildReverseIndexCache(records map[string]NameRecords) *cache.Cache { slices.Sort(unique) reverseIndex[ip] = unique } - reverseIndexCache := cache.New(cache.NoExpiration, cache.NoExpiration) + cacheTTL := cache.NoExpiration + if ttl > 0 { + cacheTTL = ttl + } + reverseIndexCache := cache.New(cacheTTL, cacheTTL) for ip, names := range reverseIndex { - reverseIndexCache.Set(ip, names, cache.NoExpiration) + reverseIndexCache.Set(ip, names, cacheTTL) } return reverseIndexCache } diff --git a/service_test.go b/service_test.go index 231b94c..54c1ca2 100644 --- a/service_test.go +++ b/service_test.go @@ -274,6 +274,35 @@ func TestServiceResolveReverseUsesSetAndRemove(t *testing.T) { } } +func TestServiceRecordTTLExpiresForwardAndReverseLookups(t *testing.T) { + service := NewService(ServiceOptions{ + RecordTTL: 25 * time.Millisecond, + }) + + service.SetRecord("gateway.charon.lthn", NameRecords{ + A: []string{"10.10.10.10"}, + }) + + if _, ok := service.Resolve("gateway.charon.lthn"); !ok { + t.Fatal("expected record to resolve before expiry") + } + if names, ok := service.ResolveReverse("10.10.10.10"); !ok || len(names) != 1 || names[0] != "gateway.charon.lthn" { + t.Fatalf("expected reverse record before expiry, got %#v (ok=%t)", names, ok) + } + + time.Sleep(100 * time.Millisecond) + + if _, ok := service.Resolve("gateway.charon.lthn"); ok { + t.Fatal("expected forward record to expire") + } + if _, ok := service.ResolveReverse("10.10.10.10"); ok { + t.Fatal("expected reverse record to expire with the forward record") + } + if health := service.Health(); health.NamesCached != 0 { + t.Fatalf("expected expired record to be pruned from health, got %#v", health) + } +} + func TestServiceHealthUsesDeterministicTreeRootAndUpdatesOnMutations(t *testing.T) { service := NewService(ServiceOptions{ Records: map[string]NameRecords{