go-dns/service.go
Virgil 8fb7816316 feat(dns): add HSD-sidechain discovery client
Co-Authored-By: Virgil <virgil@lethean.io>
2026-04-03 19:56:16 +00:00

398 lines
9.5 KiB
Go

package dns
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"slices"
"strings"
"sync"
)
type NameRecords struct {
A []string `json:"a"`
AAAA []string `json:"aaaa"`
TXT []string `json:"txt"`
NS []string `json:"ns"`
}
type ResolveAllResult struct {
A []string `json:"a"`
AAAA []string `json:"aaaa"`
TXT []string `json:"txt"`
NS []string `json:"ns"`
}
type ResolveAddressResult struct {
Addresses []string `json:"addresses"`
}
type Service struct {
mu sync.RWMutex
records map[string]NameRecords
reverseIndex map[string][]string
treeRoot string
discoverer func() (map[string]NameRecords, error)
fallbackDiscoverer func() (map[string]NameRecords, error)
}
type ServiceOptions struct {
Records map[string]NameRecords
Discoverer func() (map[string]NameRecords, error)
FallbackDiscoverer func() (map[string]NameRecords, error)
}
func NewService(options ServiceOptions) *Service {
cached := make(map[string]NameRecords, len(options.Records))
for name, record := range options.Records {
cached[normalizeName(name)] = record
}
treeRoot := computeTreeRoot(cached)
return &Service{
records: cached,
reverseIndex: buildReverseIndex(cached),
treeRoot: treeRoot,
discoverer: options.Discoverer,
fallbackDiscoverer: options.FallbackDiscoverer,
}
}
func (service *Service) Discover() error {
discoverer := service.discoverer
fallback := service.fallbackDiscoverer
if discoverer == nil {
if fallback == nil {
return nil
}
discovered, err := fallback()
if err != nil {
return err
}
service.replaceRecords(discovered)
return nil
}
discovered, err := discoverer()
if err != nil {
if fallback == nil {
return err
}
discovered, err = fallback()
if err != nil {
return err
}
service.replaceRecords(discovered)
return err
}
service.replaceRecords(discovered)
return nil
}
func (service *Service) replaceRecords(discovered map[string]NameRecords) {
cached := make(map[string]NameRecords, len(discovered))
for name, record := range discovered {
normalizedName := normalizeName(name)
if normalizedName == "" {
continue
}
cached[normalizedName] = record
}
service.mu.Lock()
defer service.mu.Unlock()
service.records = cached
service.reverseIndex = buildReverseIndex(service.records)
service.treeRoot = computeTreeRoot(service.records)
}
func (service *Service) SetRecord(name string, record NameRecords) {
service.mu.Lock()
defer service.mu.Unlock()
service.records[normalizeName(name)] = record
service.reverseIndex = buildReverseIndex(service.records)
service.treeRoot = computeTreeRoot(service.records)
}
func (service *Service) RemoveRecord(name string) {
service.mu.Lock()
defer service.mu.Unlock()
delete(service.records, normalizeName(name))
service.reverseIndex = buildReverseIndex(service.records)
service.treeRoot = computeTreeRoot(service.records)
}
func (service *Service) Resolve(name string) (ResolveAllResult, bool) {
record, ok := service.findRecord(name)
if !ok {
return ResolveAllResult{}, false
}
return resolveResult(record), true
}
func (service *Service) ResolveTXT(name string) ([]string, bool) {
record, ok := service.findRecord(name)
if !ok {
return nil, false
}
return append([]string(nil), record.TXT...), true
}
// DiscoverWithHSD refreshes DNS records for each alias by calling HSD.
// Example:
//
// err := service.DiscoverWithHSD(context.Background(), []string{"gateway.lthn"}, dns.NewHSDClient(dns.HSDClientOptions{
// URL: "http://127.0.0.1:14037",
// Username: "user",
// Password: "pass",
// }))
func (service *Service) DiscoverWithHSD(ctx context.Context, aliases []string, client *HSDClient) error {
if client == nil {
return fmt.Errorf("hsd client is required")
}
resolved := make(map[string]NameRecords, len(aliases))
for _, alias := range aliases {
normalized := normalizeName(alias)
if normalized == "" {
continue
}
record, err := client.GetNameResource(ctx, normalized)
if err != nil {
return err
}
resolved[normalized] = record
}
service.replaceRecords(resolved)
return nil
}
func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) {
record, ok := service.findRecord(name)
if !ok {
return ResolveAddressResult{}, false
}
return ResolveAddressResult{
Addresses: MergeRecords(record.A, record.AAAA),
}, true
}
func (service *Service) ResolveReverse(ip string) ([]string, bool) {
service.mu.RLock()
defer service.mu.RUnlock()
normalizedIP := normalizeIP(ip)
if normalizedIP == "" {
return nil, false
}
names, ok := service.reverseIndex[normalizedIP]
if !ok {
return nil, false
}
return append([]string(nil), names...), true
}
func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) {
record, ok := service.findRecord(name)
if !ok {
return ResolveAllResult{}, false
}
return resolveResult(record), true
}
func (service *Service) Health() map[string]any {
service.mu.RLock()
defer service.mu.RUnlock()
return map[string]any{
"status": "ready",
"names_cached": len(service.records),
"tree_root": service.treeRoot,
}
}
func (service *Service) findRecord(name string) (NameRecords, bool) {
service.mu.RLock()
defer service.mu.RUnlock()
normalized := normalizeName(name)
if record, ok := service.records[normalized]; ok {
return record, true
}
match, ok := findWildcardMatch(normalized, service.records)
return match, ok
}
func resolveResult(record NameRecords) ResolveAllResult {
return ResolveAllResult{
A: append([]string(nil), record.A...),
AAAA: append([]string(nil), record.AAAA...),
TXT: append([]string(nil), record.TXT...),
NS: append([]string(nil), record.NS...),
}
}
func buildReverseIndex(records map[string]NameRecords) map[string][]string {
raw := map[string]map[string]struct{}{}
for name, record := range records {
for _, ip := range record.A {
normalized := normalizeIP(ip)
if normalized == "" {
continue
}
index := raw[normalized]
if index == nil {
index = map[string]struct{}{}
raw[normalized] = index
}
index[name] = struct{}{}
}
for _, ip := range record.AAAA {
normalized := normalizeIP(ip)
if normalized == "" {
continue
}
index := raw[normalized]
if index == nil {
index = map[string]struct{}{}
raw[normalized] = index
}
index[name] = struct{}{}
}
}
reverseIndex := make(map[string][]string, len(raw))
for ip, names := range raw {
unique := make([]string, 0, len(names))
for name := range names {
unique = append(unique, name)
}
slices.Sort(unique)
reverseIndex[ip] = unique
}
return reverseIndex
}
func normalizeIP(ip string) string {
parsed := net.ParseIP(strings.TrimSpace(ip))
if parsed == nil {
return ""
}
return parsed.String()
}
func computeTreeRoot(records map[string]NameRecords) string {
names := make([]string, 0, len(records))
for name := range records {
names = append(names, name)
}
slices.Sort(names)
var builder strings.Builder
for _, name := range names {
record := records[name]
builder.WriteString(name)
builder.WriteByte('\n')
builder.WriteString("A=")
builder.WriteString(serializeRecordValues(record.A))
builder.WriteByte('\n')
builder.WriteString("AAAA=")
builder.WriteString(serializeRecordValues(record.AAAA))
builder.WriteByte('\n')
builder.WriteString("TXT=")
builder.WriteString(serializeRecordValues(record.TXT))
builder.WriteByte('\n')
builder.WriteString("NS=")
builder.WriteString(serializeRecordValues(record.NS))
builder.WriteByte('\n')
}
sum := sha256.Sum256([]byte(builder.String()))
return hex.EncodeToString(sum[:])
}
func serializeRecordValues(values []string) string {
copied := append([]string(nil), values...)
slices.Sort(copied)
return strings.Join(copied, ",")
}
func findWildcardMatch(name string, records map[string]NameRecords) (NameRecords, bool) {
bestMatch := ""
for candidate := range records {
if !strings.HasPrefix(candidate, "*.") {
continue
}
suffix := strings.TrimPrefix(candidate, "*.")
if wildcardMatches(suffix, name) {
if betterWildcardMatch(candidate, bestMatch) {
bestMatch = candidate
}
}
}
if bestMatch == "" {
return NameRecords{}, false
}
return records[bestMatch], true
}
func wildcardMatches(suffix, name string) bool {
parts := strings.Split(suffix, ".")
if len(parts) == 0 || len(name) <= len(suffix)+1 {
return false
}
if !strings.HasSuffix(name, "."+suffix) {
return false
}
return strings.Count(name[:len(name)-len(suffix)], ".") >= 1
}
func betterWildcardMatch(candidate, current string) bool {
if current == "" {
return true
}
remainingCandidate := strings.TrimPrefix(candidate, "*.")
remainingCurrent := strings.TrimPrefix(current, "*.")
if len(remainingCandidate) > len(remainingCurrent) {
return true
}
if len(remainingCandidate) == len(remainingCurrent) {
return candidate < current
}
return false
}
func normalizeName(name string) string {
trimmed := strings.TrimSpace(strings.ToLower(name))
if trimmed == "" {
return ""
}
if strings.HasSuffix(trimmed, ".") {
trimmed = strings.TrimSuffix(trimmed, ".")
}
return trimmed
}
func (service *Service) String() string {
return fmt.Sprintf("dns.Service{records=%d}", len(service.records))
}
func MergeRecords(values ...[]string) []string {
unique := []string{}
seen := map[string]bool{}
for _, batch := range values {
for _, value := range batch {
if seen[value] {
continue
}
seen[value] = true
unique = append(unique, value)
}
}
slices.Sort(unique)
return unique
}