diff --git a/hsd.go b/hsd.go new file mode 100644 index 0000000..7d8ac27 --- /dev/null +++ b/hsd.go @@ -0,0 +1,238 @@ +package dns + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" +) + +const defaultHSDJSONRPCVersion = "1.0" + +type HSDClientOptions struct { + URL string + Username string + Password string + HTTPClient *http.Client +} + +type HSDClient struct { + baseURL string + username string + password string + httpClient *http.Client +} + +type HSDNameResourceResult struct { + Address NameRecords +} + +type HSDBlockchainInfo struct { + TreeRoot string +} + +type HSDRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params []any `json:"params"` + ID int `json:"id"` +} + +type HSDRPCResponse struct { + Result json.RawMessage `json:"result"` + Error *HSDRPCError `json:"error"` +} + +type HSDRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (err *HSDRPCError) Error() string { + if err == nil { + return "" + } + return fmt.Sprintf("hsd rpc error (%d): %s", err.Code, err.Message) +} + +func NewHSDClient(options HSDClientOptions) *HSDClient { + client := options.HTTPClient + if client == nil { + client = &http.Client{} + } + + baseURL := strings.TrimSpace(options.URL) + if baseURL == "" { + baseURL = "http://127.0.0.1:14037" + } + + return &HSDClient{ + baseURL: baseURL, + username: options.Username, + password: options.Password, + httpClient: client, + } +} + +func (client *HSDClient) GetNameResource(ctx context.Context, name string) (NameRecords, error) { + normalized := strings.TrimSpace(name) + if normalized == "" { + return NameRecords{}, errors.New("name is required for getnameresource") + } + + request := HSDRPCRequest{ + JSONRPC: defaultHSDJSONRPCVersion, + Method: "getnameresource", + Params: []any{normalized}, + ID: 1, + } + + var result NameRecords + raw, err := client.call(ctx, request) + if err != nil { + return result, err + } + + result, err = parseHSDNameResource(raw) + if err != nil { + return NameRecords{}, err + } + + return result, nil +} + +func (client *HSDClient) GetBlockchainInfo(ctx context.Context) (HSDBlockchainInfo, error) { + var result HSDBlockchainInfo + request := HSDRPCRequest{ + JSONRPC: defaultHSDJSONRPCVersion, + Method: "getblockchaininfo", + Params: []any{}, + ID: 1, + } + + raw, err := client.call(ctx, request) + if err != nil { + return result, err + } + + result, err = parseHSDBlockchainInfo(raw) + if err != nil { + return HSDBlockchainInfo{}, err + } + + return result, nil +} + +func (client *HSDClient) call(ctx context.Context, request HSDRPCRequest) (json.RawMessage, error) { + body, err := json.Marshal(request) + if err != nil { + return nil, err + } + + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, client.baseURL, io.NopCloser(io.Reader(strings.NewReader(string(body))))) + if err != nil { + return nil, err + } + httpRequest.Header.Set("Content-Type", "application/json") + + if client.username != "" || client.password != "" { + httpRequest.Header.Set("Authorization", "Basic "+basicAuthToken(client.username, client.password)) + } + + response, err := client.httpClient.Do(httpRequest) + if err != nil { + return nil, err + } + defer func() { _ = response.Body.Close() }() + + responseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + if response.StatusCode < 200 || response.StatusCode >= 300 { + return nil, fmt.Errorf("hsd rpc request failed with status %d: %s", response.StatusCode, strings.TrimSpace(string(responseBody))) + } + + var decoded HSDRPCResponse + if err := json.Unmarshal(responseBody, &decoded); err != nil { + return nil, err + } + + if decoded.Error != nil { + return nil, decoded.Error + } + + return decoded.Result, nil +} + +func parseHSDNameResource(raw json.RawMessage) (NameRecords, error) { + var result NameRecords + var wrapper map[string]json.RawMessage + if err := json.Unmarshal(raw, &wrapper); err != nil { + return result, err + } + + if recordsRaw, ok := wrapper["records"]; ok { + if err := json.Unmarshal(recordsRaw, &result); err != nil { + return NameRecords{}, err + } + return result, nil + } + + if _, ok := wrapper["a"]; ok { + if err := json.Unmarshal(raw, &result); err != nil { + return NameRecords{}, err + } + return result, nil + } + + var wrapped struct { + A []string `json:"a"` + AAAA []string `json:"aaaa"` + TXT []string `json:"txt"` + NS []string `json:"ns"` + } + if err := json.Unmarshal(raw, &wrapped); err == nil { + result = NameRecords{ + A: wrapped.A, + AAAA: wrapped.AAAA, + TXT: wrapped.TXT, + NS: wrapped.NS, + } + return result, nil + } + + return NameRecords{}, errors.New("unable to parse getnameresource result") +} + +func parseHSDBlockchainInfo(raw json.RawMessage) (HSDBlockchainInfo, error) { + var info HSDBlockchainInfo + var wrapper map[string]json.RawMessage + if err := json.Unmarshal(raw, &wrapper); err != nil { + return info, err + } + + if rawTreeRoot, ok := wrapper["tree_root"]; ok { + if err := json.Unmarshal(rawTreeRoot, &info.TreeRoot); err != nil { + return HSDBlockchainInfo{}, err + } + return info, nil + } + + if rawTreeRoot, ok := wrapper["treeRoot"]; ok { + if err := json.Unmarshal(rawTreeRoot, &info.TreeRoot); err != nil { + return HSDBlockchainInfo{}, err + } + return info, nil + } + + return HSDBlockchainInfo{}, errors.New("unable to parse getblockchaininfo result") +} + +func basicAuthToken(username, password string) string { + return base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) +} diff --git a/hsd_test.go b/hsd_test.go new file mode 100644 index 0000000..0f6d2c1 --- /dev/null +++ b/hsd_test.go @@ -0,0 +1,173 @@ +package dns + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHSDClientGetNameResourceCallsRPCAndParsesResult(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + if payload.Method != "getnameresource" { + t.Fatalf("expected method getnameresource, got %s", payload.Method) + } + if len(payload.Params) != 1 || payload.Params[0] != "gateway.lthn" { + t.Fatalf("expected single name param gateway.lthn, got %#v", payload.Params) + } + + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + "aaaa": []string{"2600:1f1c:7f0:4f01::1"}, + "txt": []string{"v=lthn1 type=gateway"}, + "ns": []string{"ns.gateway.lthn"}, + }, + }) + })) + defer server.Close() + + client := NewHSDClient(HSDClientOptions{ + URL: server.URL, + }) + + record, err := client.GetNameResource(context.Background(), "gateway.lthn") + if err != nil { + t.Fatalf("unexpected getnameresource error: %v", err) + } + if len(record.A) != 1 || record.A[0] != "10.10.10.10" { + t.Fatalf("unexpected A result: %#v", record.A) + } + if len(record.AAAA) != 1 || record.AAAA[0] != "2600:1f1c:7f0:4f01::1" { + t.Fatalf("unexpected AAAA result: %#v", record.AAAA) + } +} + +func TestHSDClientGetNameResourceParsesWrappedRecords(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + if payload.Method != "getnameresource" { + t.Fatalf("expected method getnameresource, got %s", payload.Method) + } + + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "records": map[string]any{ + "a": []string{"10.11.11.11"}, + "txt": []string{"v=lthn1 type=node"}, + }, + }, + }) + })) + defer server.Close() + + client := NewHSDClient(HSDClientOptions{ + URL: server.URL, + }) + + record, err := client.GetNameResource(context.Background(), "node.lthn") + if err != nil { + t.Fatalf("unexpected getnameresource error: %v", err) + } + if len(record.A) != 1 || record.A[0] != "10.11.11.11" { + t.Fatalf("unexpected wrapped A result: %#v", record.A) + } + if len(record.TXT) != 1 || record.TXT[0] != "v=lthn1 type=node" { + t.Fatalf("unexpected wrapped TXT result: %#v", record.TXT) + } +} + +func TestHSDClientGetBlockchainInfo(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + if payload.Method != "getblockchaininfo" { + t.Fatalf("expected method getblockchaininfo, got %s", payload.Method) + } + + responseWriter.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "tree_root": "f00dc0de", + }, + }) + })) + defer server.Close() + + client := NewHSDClient(HSDClientOptions{ + URL: server.URL, + }) + + info, err := client.GetBlockchainInfo(context.Background()) + if err != nil { + t.Fatalf("unexpected getblockchaininfo error: %v", err) + } + if info.TreeRoot != "f00dc0de" { + t.Fatalf("unexpected tree root: %q", info.TreeRoot) + } +} + +func TestServiceDiscoverWithHSDRefreshesRecords(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + var payload struct { + Method string `json:"method"` + Params []any `json:"params"` + } + if err := json.NewDecoder(request.Body).Decode(&payload); err != nil { + t.Fatalf("unexpected request payload: %v", err) + } + switch payload.Params[0] { + case "gateway.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "a": []string{"10.10.10.10"}, + }, + }) + case "node.lthn": + _ = json.NewEncoder(responseWriter).Encode(map[string]any{ + "result": map[string]any{ + "aaaa": []string{"2600:1f1c:7f0:4f01::2"}, + }, + }) + default: + t.Fatalf("unexpected alias query: %#v", payload.Params) + responseWriter.WriteHeader(http.StatusBadRequest) + } + })) + defer server.Close() + + service := NewService(ServiceOptions{}) + client := NewHSDClient(HSDClientOptions{ + URL: server.URL, + }) + + if err := service.DiscoverWithHSD(context.Background(), []string{"gateway.lthn", "node.lthn"}, client); err != nil { + t.Fatalf("expected discovery via hsd to succeed: %v", err) + } + + if resolved, ok := service.Resolve("gateway.lthn"); !ok || len(resolved.A) != 1 || resolved.A[0] != "10.10.10.10" { + t.Fatalf("expected refreshed A record, got %#v (ok=%t)", resolved, ok) + } + if resolved, ok := service.Resolve("node.lthn"); !ok || len(resolved.AAAA) != 1 || resolved.AAAA[0] != "2600:1f1c:7f0:4f01::2" { + t.Fatalf("expected refreshed AAAA record, got %#v (ok=%t)", resolved, ok) + } +} diff --git a/service.go b/service.go index 0d41840..c5075b7 100644 --- a/service.go +++ b/service.go @@ -1,6 +1,7 @@ package dns import ( + "context" "crypto/sha256" "encoding/hex" "fmt" @@ -138,6 +139,37 @@ func (service *Service) ResolveTXT(name string) ([]string, bool) { return append([]string(nil), record.TXT...), true } +// DiscoverWithHSD refreshes DNS records for each alias by calling HSD. +// Example: +// +// err := service.DiscoverWithHSD(context.Background(), []string{"gateway.lthn"}, dns.NewHSDClient(dns.HSDClientOptions{ +// URL: "http://127.0.0.1:14037", +// Username: "user", +// Password: "pass", +// })) +func (service *Service) DiscoverWithHSD(ctx context.Context, aliases []string, client *HSDClient) error { + if client == nil { + return fmt.Errorf("hsd client is required") + } + + resolved := make(map[string]NameRecords, len(aliases)) + for _, alias := range aliases { + normalized := normalizeName(alias) + if normalized == "" { + continue + } + + record, err := client.GetNameResource(ctx, normalized) + if err != nil { + return err + } + resolved[normalized] = record + } + + service.replaceRecords(resolved) + return nil +} + func (service *Service) ResolveAddress(name string) (ResolveAddressResult, bool) { record, ok := service.findRecord(name) if !ok {