398 lines
9.5 KiB
Go
398 lines
9.5 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
type NameRecords struct {
|
|
A []string `json:"a"`
|
|
AAAA []string `json:"aaaa"`
|
|
TXT []string `json:"txt"`
|
|
NS []string `json:"ns"`
|
|
}
|
|
|
|
type ResolveAllResult struct {
|
|
A []string `json:"a"`
|
|
AAAA []string `json:"aaaa"`
|
|
TXT []string `json:"txt"`
|
|
NS []string `json:"ns"`
|
|
}
|
|
|
|
type ResolveAddressResult struct {
|
|
Addresses []string `json:"addresses"`
|
|
}
|
|
|
|
type Service struct {
|
|
mu sync.RWMutex
|
|
records map[string]NameRecords
|
|
reverseIndex map[string][]string
|
|
treeRoot string
|
|
discoverer func() (map[string]NameRecords, error)
|
|
fallbackDiscoverer func() (map[string]NameRecords, error)
|
|
}
|
|
|
|
type ServiceOptions struct {
|
|
Records map[string]NameRecords
|
|
Discoverer func() (map[string]NameRecords, error)
|
|
FallbackDiscoverer func() (map[string]NameRecords, error)
|
|
}
|
|
|
|
func NewService(options ServiceOptions) *Service {
|
|
cached := make(map[string]NameRecords, len(options.Records))
|
|
for name, record := range options.Records {
|
|
cached[normalizeName(name)] = record
|
|
}
|
|
treeRoot := computeTreeRoot(cached)
|
|
return &Service{
|
|
records: cached,
|
|
reverseIndex: buildReverseIndex(cached),
|
|
treeRoot: treeRoot,
|
|
discoverer: options.Discoverer,
|
|
fallbackDiscoverer: options.FallbackDiscoverer,
|
|
}
|
|
}
|
|
|
|
func (service *Service) Discover() error {
|
|
discoverer := service.discoverer
|
|
fallback := service.fallbackDiscoverer
|
|
if discoverer == nil {
|
|
if fallback == nil {
|
|
return nil
|
|
}
|
|
discovered, err := fallback()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
service.replaceRecords(discovered)
|
|
return nil
|
|
}
|
|
|
|
discovered, err := discoverer()
|
|
if err != nil {
|
|
if fallback == nil {
|
|
return err
|
|
}
|
|
discovered, err = fallback()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
service.replaceRecords(discovered)
|
|
return err
|
|
}
|
|
service.replaceRecords(discovered)
|
|
return nil
|
|
}
|
|
|
|
func (service *Service) replaceRecords(discovered map[string]NameRecords) {
|
|
cached := make(map[string]NameRecords, len(discovered))
|
|
for name, record := range discovered {
|
|
normalizedName := normalizeName(name)
|
|
if normalizedName == "" {
|
|
continue
|
|
}
|
|
cached[normalizedName] = record
|
|
}
|
|
|
|
service.mu.Lock()
|
|
defer service.mu.Unlock()
|
|
service.records = cached
|
|
service.reverseIndex = buildReverseIndex(service.records)
|
|
service.treeRoot = computeTreeRoot(service.records)
|
|
}
|
|
|
|
func (service *Service) SetRecord(name string, record NameRecords) {
|
|
service.mu.Lock()
|
|
defer service.mu.Unlock()
|
|
service.records[normalizeName(name)] = record
|
|
service.reverseIndex = buildReverseIndex(service.records)
|
|
service.treeRoot = computeTreeRoot(service.records)
|
|
}
|
|
|
|
func (service *Service) RemoveRecord(name string) {
|
|
service.mu.Lock()
|
|
defer service.mu.Unlock()
|
|
delete(service.records, normalizeName(name))
|
|
service.reverseIndex = buildReverseIndex(service.records)
|
|
service.treeRoot = computeTreeRoot(service.records)
|
|
}
|
|
|
|
func (service *Service) Resolve(name string) (ResolveAllResult, bool) {
|
|
record, ok := service.findRecord(name)
|
|
if !ok {
|
|
return ResolveAllResult{}, false
|
|
}
|
|
return resolveResult(record), true
|
|
}
|
|
|
|
func (service *Service) ResolveTXT(name string) ([]string, bool) {
|
|
record, ok := service.findRecord(name)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return append([]string(nil), record.TXT...), true
|
|
}
|
|
|
|
// DiscoverWithHSD refreshes DNS records for each alias by calling HSD.
|
|
// Example:
|
|
//
|
|
// err := service.DiscoverWithHSD(context.Background(), []string{"gateway.lthn"}, dns.NewHSDClient(dns.HSDClientOptions{
|
|
// URL: "http://127.0.0.1:14037",
|
|
// Username: "user",
|
|
// Password: "pass",
|
|
// }))
|
|
func (service *Service) DiscoverWithHSD(ctx context.Context, aliases []string, client *HSDClient) error {
|
|
if client == nil {
|
|
return fmt.Errorf("hsd client is required")
|
|
}
|
|
|
|
resolved := make(map[string]NameRecords, len(aliases))
|
|
for _, alias := range aliases {
|
|
normalized := normalizeName(alias)
|
|
if normalized == "" {
|
|
continue
|
|
}
|
|
|
|
record, err := client.GetNameResource(ctx, normalized)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resolved[normalized] = record
|
|
}
|
|
|
|
service.replaceRecords(resolved)
|
|
return nil
|
|
}
|
|
|
|
func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) {
|
|
record, ok := service.findRecord(name)
|
|
if !ok {
|
|
return ResolveAddressResult{}, false
|
|
}
|
|
return ResolveAddressResult{
|
|
Addresses: MergeRecords(record.A, record.AAAA),
|
|
}, true
|
|
}
|
|
|
|
func (service *Service) ResolveReverse(ip string) ([]string, bool) {
|
|
service.mu.RLock()
|
|
defer service.mu.RUnlock()
|
|
|
|
normalizedIP := normalizeIP(ip)
|
|
if normalizedIP == "" {
|
|
return nil, false
|
|
}
|
|
|
|
names, ok := service.reverseIndex[normalizedIP]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return append([]string(nil), names...), true
|
|
}
|
|
|
|
func (service *Service) ResolveAll(name string) (ResolveAllResult, bool) {
|
|
record, ok := service.findRecord(name)
|
|
if !ok {
|
|
return ResolveAllResult{}, false
|
|
}
|
|
return resolveResult(record), true
|
|
}
|
|
|
|
func (service *Service) Health() map[string]any {
|
|
service.mu.RLock()
|
|
defer service.mu.RUnlock()
|
|
return map[string]any{
|
|
"status": "ready",
|
|
"names_cached": len(service.records),
|
|
"tree_root": service.treeRoot,
|
|
}
|
|
}
|
|
|
|
func (service *Service) findRecord(name string) (NameRecords, bool) {
|
|
service.mu.RLock()
|
|
defer service.mu.RUnlock()
|
|
|
|
normalized := normalizeName(name)
|
|
if record, ok := service.records[normalized]; ok {
|
|
return record, true
|
|
}
|
|
|
|
match, ok := findWildcardMatch(normalized, service.records)
|
|
return match, ok
|
|
}
|
|
|
|
func resolveResult(record NameRecords) ResolveAllResult {
|
|
return ResolveAllResult{
|
|
A: append([]string(nil), record.A...),
|
|
AAAA: append([]string(nil), record.AAAA...),
|
|
TXT: append([]string(nil), record.TXT...),
|
|
NS: append([]string(nil), record.NS...),
|
|
}
|
|
}
|
|
|
|
func buildReverseIndex(records map[string]NameRecords) map[string][]string {
|
|
raw := map[string]map[string]struct{}{}
|
|
for name, record := range records {
|
|
for _, ip := range record.A {
|
|
normalized := normalizeIP(ip)
|
|
if normalized == "" {
|
|
continue
|
|
}
|
|
index := raw[normalized]
|
|
if index == nil {
|
|
index = map[string]struct{}{}
|
|
raw[normalized] = index
|
|
}
|
|
index[name] = struct{}{}
|
|
}
|
|
for _, ip := range record.AAAA {
|
|
normalized := normalizeIP(ip)
|
|
if normalized == "" {
|
|
continue
|
|
}
|
|
index := raw[normalized]
|
|
if index == nil {
|
|
index = map[string]struct{}{}
|
|
raw[normalized] = index
|
|
}
|
|
index[name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
reverseIndex := make(map[string][]string, len(raw))
|
|
for ip, names := range raw {
|
|
unique := make([]string, 0, len(names))
|
|
for name := range names {
|
|
unique = append(unique, name)
|
|
}
|
|
slices.Sort(unique)
|
|
reverseIndex[ip] = unique
|
|
}
|
|
return reverseIndex
|
|
}
|
|
|
|
func normalizeIP(ip string) string {
|
|
parsed := net.ParseIP(strings.TrimSpace(ip))
|
|
if parsed == nil {
|
|
return ""
|
|
}
|
|
return parsed.String()
|
|
}
|
|
|
|
func computeTreeRoot(records map[string]NameRecords) string {
|
|
names := make([]string, 0, len(records))
|
|
for name := range records {
|
|
names = append(names, name)
|
|
}
|
|
slices.Sort(names)
|
|
|
|
var builder strings.Builder
|
|
for _, name := range names {
|
|
record := records[name]
|
|
builder.WriteString(name)
|
|
builder.WriteByte('\n')
|
|
builder.WriteString("A=")
|
|
builder.WriteString(serializeRecordValues(record.A))
|
|
builder.WriteByte('\n')
|
|
builder.WriteString("AAAA=")
|
|
builder.WriteString(serializeRecordValues(record.AAAA))
|
|
builder.WriteByte('\n')
|
|
builder.WriteString("TXT=")
|
|
builder.WriteString(serializeRecordValues(record.TXT))
|
|
builder.WriteByte('\n')
|
|
builder.WriteString("NS=")
|
|
builder.WriteString(serializeRecordValues(record.NS))
|
|
builder.WriteByte('\n')
|
|
}
|
|
|
|
sum := sha256.Sum256([]byte(builder.String()))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func serializeRecordValues(values []string) string {
|
|
copied := append([]string(nil), values...)
|
|
slices.Sort(copied)
|
|
return strings.Join(copied, ",")
|
|
}
|
|
|
|
func findWildcardMatch(name string, records map[string]NameRecords) (NameRecords, bool) {
|
|
bestMatch := ""
|
|
for candidate := range records {
|
|
if !strings.HasPrefix(candidate, "*.") {
|
|
continue
|
|
}
|
|
suffix := strings.TrimPrefix(candidate, "*.")
|
|
if wildcardMatches(suffix, name) {
|
|
if betterWildcardMatch(candidate, bestMatch) {
|
|
bestMatch = candidate
|
|
}
|
|
}
|
|
}
|
|
if bestMatch == "" {
|
|
return NameRecords{}, false
|
|
}
|
|
return records[bestMatch], true
|
|
}
|
|
|
|
func wildcardMatches(suffix, name string) bool {
|
|
parts := strings.Split(suffix, ".")
|
|
if len(parts) == 0 || len(name) <= len(suffix)+1 {
|
|
return false
|
|
}
|
|
if !strings.HasSuffix(name, "."+suffix) {
|
|
return false
|
|
}
|
|
return strings.Count(name[:len(name)-len(suffix)], ".") >= 1
|
|
}
|
|
|
|
func betterWildcardMatch(candidate, current string) bool {
|
|
if current == "" {
|
|
return true
|
|
}
|
|
remainingCandidate := strings.TrimPrefix(candidate, "*.")
|
|
remainingCurrent := strings.TrimPrefix(current, "*.")
|
|
if len(remainingCandidate) > len(remainingCurrent) {
|
|
return true
|
|
}
|
|
if len(remainingCandidate) == len(remainingCurrent) {
|
|
return candidate < current
|
|
}
|
|
return false
|
|
}
|
|
|
|
func normalizeName(name string) string {
|
|
trimmed := strings.TrimSpace(strings.ToLower(name))
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
if strings.HasSuffix(trimmed, ".") {
|
|
trimmed = strings.TrimSuffix(trimmed, ".")
|
|
}
|
|
return trimmed
|
|
}
|
|
|
|
func (service *Service) String() string {
|
|
return fmt.Sprintf("dns.Service{records=%d}", len(service.records))
|
|
}
|
|
|
|
func MergeRecords(values ...[]string) []string {
|
|
unique := []string{}
|
|
seen := map[string]bool{}
|
|
for _, batch := range values {
|
|
for _, value := range batch {
|
|
if seen[value] {
|
|
continue
|
|
}
|
|
seen[value] = true
|
|
unique = append(unique, value)
|
|
}
|
|
}
|
|
slices.Sort(unique)
|
|
return unique
|
|
}
|