diff --git a/CLAUDE.md b/CLAUDE.md index 5f35b3c..b5fd999 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Project -`go-p2p` is the P2P networking layer for the Lethean network. Module path: `forge.lthn.ai/core/go-p2p` +`go-p2p` is the P2P networking layer for the Lethean network. Module path: `dappco.re/go/core/p2p` ## Prerequisites @@ -40,7 +40,7 @@ logging/ — Structured levelled logger with component scoping (stdlib only) ### Data flow -1. **Identity** (`identity.go`) — Ed25519 keypair via Borg STMF. Shared secrets derived via X25519 ECDH + SHA-256. +1. **Identity** (`identity.go`) — X25519 keypair via Borg STMF. Shared secrets derived via X25519 ECDH + SHA-256. 2. **Transport** (`transport.go`) — WebSocket server/client (gorilla/websocket). Handshake exchanges `NodeIdentity` + HMAC-SHA256 challenge-response. Post-handshake messages are Borg SMSG-encrypted. Includes deduplication (5-min TTL), rate limiting (token bucket: 100 burst/50 per sec), and MaxConns enforcement. 3. **Dispatcher** (`dispatcher.go`) — Routes verified UEPS packets to intent handlers. Threat circuit breaker drops packets with `ThreatScore > 50,000` before routing. 4. **Controller** (`controller.go`) — Issues requests to remote peers using a pending-map pattern (`map[string]chan *Message`). Auto-connects to peers on demand. @@ -75,13 +75,13 @@ type ProfileManager interface { - UK English (colour, organisation, centre, behaviour, recognise) - All parameters and return types explicitly annotated -- Tests use `testify` assert/require; table-driven subtests with `t.Run()` +- Tests use `testify` assert/require; prefer table-driven subtests with `t.Run()` when multiple related cases share one shape - Test name suffixes: `_Good` (happy path), `_Bad` (expected errors), `_Ugly` (panic/edge cases) - Licence: EUPL-1.2 — new files need `// SPDX-License-Identifier: EUPL-1.2` - Security-first: do not weaken HMAC, challenge-response, Zip Slip defence, or rate limiting - Use `logging` package only — no `fmt.Println` or `log.Printf` in library code -- Error handling: use `coreerr.E()` from `go-log` — never `fmt.Errorf` or `errors.New` in library code -- File I/O: use `coreio.Local` from `go-io` — never `os.ReadFile`/`os.WriteFile` in library code (exception: `os.OpenFile` for streaming writes where `coreio` lacks support) +- Error handling: use `core.E()` from `dappco.re/go/core` — never `fmt.Errorf` or `errors.New` in library code +- File I/O: use `dappco.re/go/core` filesystem helpers (package-level adapters in `node/` backed by `core.Fs`) — never `os.ReadFile`/`os.WriteFile` in library code (exception: `os.OpenFile` for streaming writes where filesystem helpers cannot preserve tar header mode bits) - Hot-path debug logging uses sampling pattern: `if counter.Add(1)%interval == 0` ### Transport test helper diff --git a/CODEX.md b/CODEX.md new file mode 100644 index 0000000..3ba5218 --- /dev/null +++ b/CODEX.md @@ -0,0 +1,11 @@ + + +# CODEX.md + +Codex-compatible entrypoint for this repository. + +- Treat `CLAUDE.md` as the authoritative local conventions file for commands, architecture notes, coding standards, and commit format. +- Current module path: `dappco.re/go/core/p2p`. +- Verification baseline: `go build ./...`, `go vet ./...`, and `go test ./...`. +- Use conventional commits with `Co-Authored-By: Virgil `. +- If `.core/reference/docs/RFC.md` is absent in the checkout, report that gap explicitly and use the local docs under `docs/` plus the code as the available reference set. diff --git a/README.md b/README.md index a11c72f..c7da1ca 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,52 @@ -[![Go Reference](https://pkg.go.dev/badge/forge.lthn.ai/core/go-p2p.svg)](https://pkg.go.dev/forge.lthn.ai/core/go-p2p) -[![License: EUPL-1.2](https://img.shields.io/badge/License-EUPL--1.2-blue.svg)](LICENSE.md) +[![Go Reference](https://pkg.go.dev/badge/dappco.re/go/core/p2p.svg)](https://pkg.go.dev/dappco.re/go/core/p2p) +[![License: EUPL-1.2](https://img.shields.io/badge/License-EUPL--1.2-blue.svg)](CONTRIBUTING.md#license) [![Go Version](https://img.shields.io/badge/Go-1.26-00ADD8?style=flat&logo=go)](go.mod) # go-p2p -P2P mesh networking layer for the Lethean network. Provides Ed25519 node identity, an encrypted WebSocket transport with HMAC-SHA256 challenge-response handshake, KD-tree peer selection across four dimensions (latency, hops, geography, reliability score), UEPS wire protocol (RFC-021) TLV packet builder and reader, UEPS intent routing with a threat circuit breaker, and TIM deployment bundle encryption with Zip Slip and decompression-bomb defences. +P2P mesh networking layer for the Lethean network. Provides X25519 node identity, an encrypted WebSocket transport with HMAC-SHA256 challenge-response handshake, KD-tree peer selection across four dimensions (latency, hops, geography, reliability score), UEPS wire protocol (RFC-021) TLV packet builder and reader, UEPS intent routing with a threat circuit breaker, and TIM deployment bundle encryption with Zip Slip and decompression-bomb defences. -**Module**: `forge.lthn.ai/core/go-p2p` +**Module**: `dappco.re/go/core/p2p` **Licence**: EUPL-1.2 -**Language**: Go 1.25 +**Language**: Go 1.26 ## Quick Start ```go import ( - "forge.lthn.ai/core/go-p2p/node" - "forge.lthn.ai/core/go-p2p/ueps" + "log" + + "dappco.re/go/core/p2p/node" + "dappco.re/go/core/p2p/ueps" ) -// Start a P2P node -identity, _ := node.LoadOrCreateIdentity() -transport := node.NewTransport(identity, node.TransportConfig{ListenAddr: ":9091"}) -transport.Start(ctx) +nm, err := node.NewNodeManager() +if err != nil { + log.Fatal(err) +} +if !nm.HasIdentity() { + if err := nm.GenerateIdentity("worker-1", node.RoleWorker); err != nil { + log.Fatal(err) + } +} -// Build a UEPS packet -pkt, _ := ueps.NewBuilder(ueps.IntentCompute, payload).MarshalAndSign(sharedSecret) +registry, err := node.NewPeerRegistry() +if err != nil { + log.Fatal(err) +} + +transport := node.NewTransport(nm, registry, node.DefaultTransportConfig()) +if err := transport.Start(); err != nil { + log.Fatal(err) +} + +payload := []byte(`{"job":"hashrate"}`) +sharedSecret := make([]byte, 32) +pkt, err := ueps.NewBuilder(node.IntentCompute, payload).MarshalAndSign(sharedSecret) +if err != nil { + log.Fatal(err) +} +_ = pkt ``` ## Documentation @@ -44,4 +66,4 @@ go build ./... ## Licence -European Union Public Licence 1.2 — see [LICENCE](LICENCE) for details. +European Union Public Licence 1.2 — see [CONTRIBUTING](CONTRIBUTING.md#license) for details. diff --git a/SESSION-BRIEF.md b/SESSION-BRIEF.md index 03d6ffb..0029f62 100644 --- a/SESSION-BRIEF.md +++ b/SESSION-BRIEF.md @@ -1,129 +1,57 @@ # Session Brief: core/go-p2p -**Repo**: `forge.lthn.ai/core/go-p2p` (clone at `/tmp/core-go-p2p`) -**Module**: `forge.lthn.ai/core/go-p2p` -**Status**: 16 Go files, ~2,500 LOC, node tests PASS (42% coverage), ueps has NO TESTS -**Wiki**: https://forge.lthn.ai/core/go-p2p/wiki (6 pages) +**Repo**: `forge.lthn.ai/core/go-p2p` +**Module**: `dappco.re/go/core/p2p` +**Status**: `go build ./...`, `go vet ./...`, and `go test ./...` pass on 2026-03-27. +**Primary references**: `CLAUDE.md`, `docs/architecture.md`, `docs/development.md` ## What This Is -P2P networking layer for the Lethean network. Three packages: +P2P networking layer for the Lethean network. The repository currently consists of four Go packages: -### node/ — P2P Mesh (14 files) -- **Identity**: Ed25519 keypair generation, PEM serialisation, challenge-response auth -- **Transport**: Encrypted WebSocket connections via gorilla/websocket + Borg (encrypted blob storage) -- **Peers**: Registry with scoring, persistence, auth modes (open/allowlist), name validation -- **Messages**: Typed protocol messages (handshake, ping, stats, miner control, deploy, logs) -- **Protocol**: Response handler with validation and typed parsing -- **Worker**: Command handler (ping, stats, miner start/stop, deploy profiles, get logs) -- **Dispatcher**: UEPS packet routing skeleton with threat circuit breaker -- **Controller**: Remote node operations (connect, command, disconnect) -- **Bundle**: Service factory for Core framework DI registration - -### ueps/ — Wire Protocol (2 files, NO TESTS) -- **PacketBuilder**: Constructs signed UEPS frames with TLV encoding -- **ReadAndVerify**: Parses and verifies HMAC-SHA256 integrity -- TLV tags: 0x01-0x05 (header fields), 0x06 (HMAC), 0xFF (payload marker) -- Header: Version, CurrentLayer, TargetLayer, IntentID, ThreatScore - -### logging/ — Structured Logger (1 file) -- Simple levelled logger (INFO/WARN/ERROR/DEBUG) with key-value pairs +- `node/` — P2P mesh: identity, transport, peer registry, messages, protocol helpers, worker/controller logic, dispatcher, and deployment bundles +- `node/levin/` — standalone CryptoNote Levin binary protocol support +- `ueps/` — UEPS TLV wire protocol with HMAC-SHA256 integrity verification +- `logging/` — structured levelled logger with component scoping ## Current State | Area | Status | |------|--------| -| node/ tests | PASS — 42% statement coverage | -| ueps/ tests | NONE — zero test files | -| logging/ tests | NONE | -| go vet | Clean | -| TODOs/FIXMEs | None found | -| Identity (Ed25519) | Well tested — keypair, challenge-response, deterministic sigs | -| PeerRegistry | Well tested — add/remove, scoring, persistence, auth modes, name validation | -| Messages | Well tested — all 15 message types, serialisation, error codes | -| Worker | Well tested — ping, stats, miner, deploy, logs handlers | -| Transport | NOT tested — WebSocket + Borg encryption | -| Controller | NOT tested — remote node operations | -| Dispatcher | NOT tested — UEPS routing skeleton | +| Build | PASS | +| Vet | PASS | +| Tests | PASS | +| `logging/` | Has direct unit coverage | +| `ueps/` | Has round-trip, malformed packet, and coverage-path tests | +| `node/transport` | Has real WebSocket handshake and integration tests | +| `node/controller` | Has request/response, auto-connect, ping, and miner-control tests | +| `node/dispatcher` | Has routing, threshold, and concurrency tests | +| `node/levin` | Has protocol encode/decode coverage | + +## Key Behaviours + +- **Identity** — X25519 keypair generation via Borg STMF, persisted through XDG paths +- **Transport** — WebSocket mesh with challenge-response authentication, SMSG encryption, deduplication, rate limiting, and keepalive handling +- **Peer registry** — KD-tree selection across latency, hops, geography, and reliability score +- **Controller/worker** — request/response messaging for stats, miner control, logs, and deployment +- **Dispatcher** — UEPS intent routing with a threat circuit breaker at `ThreatScore > 50000` +- **Bundles** — TIM-based profile and miner bundle handling with defensive tar extraction ## Dependencies -- `github.com/Snider/Borg` v0.2.0 (encrypted blob storage) -- `github.com/Snider/Enchantrix` v0.0.2 (secure environment) -- `github.com/Snider/Poindexter` (secure pointer) -- `github.com/gorilla/websocket` v1.5.3 -- `github.com/google/uuid` v1.6.0 -- `github.com/ProtonMail/go-crypto` v1.3.0 +- `dappco.re/go/core` v0.8.0-alpha.1 +- `forge.lthn.ai/Snider/Borg` v0.3.1 +- `forge.lthn.ai/Snider/Poindexter` v0.0.3 - `github.com/adrg/xdg` v0.5.3 +- `github.com/google/uuid` v1.6.0 +- `github.com/gorilla/websocket` v1.5.3 - `github.com/stretchr/testify` v1.11.1 -- `golang.org/x/crypto` v0.45.0 - -## Priority Work - -### High (coverage gaps) -1. **UEPS tests** — Zero tests for the wire protocol. This is the consent-gated TLV protocol from RFC-021. Need: builder round-trip, HMAC verification, malformed packet rejection, boundary conditions (max ThreatScore, empty payload, oversized payload). -2. **Transport tests** — WebSocket connection, Borg encryption handshake, reconnection logic. -3. **Controller tests** — Connect/command/disconnect flow. -4. **Dispatcher tests** — UEPS routing, threat circuit breaker (ThreatScore > 50000 drops). - -### Medium (hardening) -5. **Increase node/ coverage** from 42% to 70%+ — focus on transport.go, controller.go, dispatcher.go -6. **Benchmarks** — Peer scoring, UEPS marshal/unmarshal, identity key generation -7. **Integration test** — Full node-to-node handshake over localhost WebSocket - -### Low (completeness) -8. **Logging tests** — Simple but should have coverage -9. **Peer discovery** — Currently manual. Add mDNS or DHT discovery -10. **Connection pooling** — Transport creates fresh connections; add pool for controller - -## File Map - -``` -/tmp/core-go-p2p/ -├── node/ -│ ├── bundle.go + bundle_test.go — Core DI factory -│ ├── identity.go + identity_test.go — Ed25519 keypair, PEM, challenge-response -│ ├── message.go + message_test.go — Protocol message types -│ ├── peer.go + peer_test.go — Registry, scoring, auth -│ ├── protocol.go + protocol_test.go — Response validation, typed parsing -│ ├── worker.go + worker_test.go — Command handlers -│ ├── transport.go (NO TEST) — WebSocket + Borg encryption -│ ├── controller.go (NO TEST) — Remote node operations -│ ├── dispatcher.go (NO TEST) — UEPS routing skeleton -│ └── logging.go — Package-level logger setup -├── ueps/ -│ ├── ueps.go (NO TEST) — PacketBuilder, ReadAndVerify, TLV -│ └── types.go (NO TEST) — UEPSHeader, ParsedPacket, intent IDs -├── logging/ -│ └── logger.go (NO TEST) — Levelled structured logger -├── go.mod -└── go.sum -``` - -## Key Interfaces - -```go -// node/message.go — 15 message types -const ( - MsgHandshake MsgHandshakeAck MsgPing MsgPong - MsgDisconnect MsgGetStats MsgStats MsgStartMiner - MsgStopMiner MsgMinerAck MsgDeploy MsgDeployAck - MsgGetLogs MsgLogs MsgError -) - -// ueps/types.go — UEPS header -type UEPSHeader struct { - Version uint8 // 0x09 - CurrentLayer uint8 - TargetLayer uint8 - IntentID uint8 // 0x01=Handshake, 0x20=Compute, 0x30=Rehab, 0xFF=Extended - ThreatScore uint16 -} -``` ## Conventions -- UK English -- Tests: testify assert/require -- Licence: EUPL-1.2 -- Lethean codenames: Borg (Secure/Blob), Poindexter (Secure/Pointer), Enchantrix (Secure/Environment) +- UK English in comments, logs, and docs +- `core.E()` for library error wrapping and sentinel definitions +- `core.Fs` adapters for library file I/O in `node/` +- `testify` in tests; prefer `t.Run()` tables for related cases +- EUPL-1.2 SPDX identifiers on new files +- Conventional commits with `Co-Authored-By: Virgil ` diff --git a/docs/RFC-CORE-008-AGENT-EXPERIENCE.md b/docs/RFC-CORE-008-AGENT-EXPERIENCE.md new file mode 100644 index 0000000..3763521 --- /dev/null +++ b/docs/RFC-CORE-008-AGENT-EXPERIENCE.md @@ -0,0 +1,440 @@ +# RFC-025: Agent Experience (AX) Design Principles + +- **Status:** Draft +- **Authors:** Snider, Cladius +- **Date:** 2026-03-19 +- **Applies to:** All Core ecosystem packages (CoreGO, CorePHP, CoreTS, core-agent) + +## Abstract + +Agent Experience (AX) is a design paradigm for software systems where the primary code consumer is an AI agent, not a human developer. AX sits alongside User Experience (UX) and Developer Experience (DX) as the third era of interface design. + +This RFC establishes AX as a formal design principle for the Core ecosystem and defines the conventions that follow from it. + +## Motivation + +As of early 2026, AI agents write, review, and maintain the majority of code in the Core ecosystem. The original author has not manually edited code (outside of Core struct design) since October 2025. Code is processed semantically — agents reason about intent, not characters. + +Design patterns inherited from the human-developer era optimise for the wrong consumer: + +- **Short names** save keystrokes but increase semantic ambiguity +- **Functional option chains** are fluent for humans but opaque for agents tracing configuration +- **Error-at-every-call-site** produces 50% boilerplate that obscures intent +- **Generic type parameters** force agents to carry type context that the runtime already has +- **Panic-hiding conventions** (`Must*`) create implicit control flow that agents must special-case + +AX acknowledges this shift and provides principles for designing code, APIs, file structures, and conventions that serve AI agents as first-class consumers. + +## The Three Eras + +| Era | Primary Consumer | Optimises For | Key Metric | +|-----|-----------------|---------------|------------| +| UX | End users | Discoverability, forgiveness, visual clarity | Task completion time | +| DX | Developers | Typing speed, IDE support, convention familiarity | Time to first commit | +| AX | AI agents | Predictability, composability, semantic navigation | Correct-on-first-pass rate | + +AX does not replace UX or DX. End users still need good UX. Developers still need good DX. But when the primary code author and maintainer is an AI agent, the codebase should be designed for that consumer first. + +## Principles + +### 1. Predictable Names Over Short Names + +Names are tokens that agents pattern-match across languages and contexts. Abbreviations introduce mapping overhead. + +``` +Config not Cfg +Service not Srv +Embed not Emb +Error not Err (as a subsystem name; err for local variables is fine) +Options not Opts +``` + +**Rule:** If a name would require a comment to explain, it is too short. + +**Exception:** Industry-standard abbreviations that are universally understood (`HTTP`, `URL`, `ID`, `IPC`, `I18n`) are acceptable. The test: would an agent trained on any mainstream language recognise it without context? + +### 2. Comments as Usage Examples + +The function signature tells WHAT. The comment shows HOW with real values. + +```go +// Detect the project type from files present +setup.Detect("/path/to/project") + +// Set up a workspace with auto-detected template +setup.Run(setup.Options{Path: ".", Template: "auto"}) + +// Scaffold a PHP module workspace +setup.Run(setup.Options{Path: "./my-module", Template: "php"}) +``` + +**Rule:** If a comment restates what the type signature already says, delete it. If a comment shows a concrete usage with realistic values, keep it. + +**Rationale:** Agents learn from examples more effectively than from descriptions. A comment like "Run executes the setup process" adds zero information. A comment like `setup.Run(setup.Options{Path: ".", Template: "auto"})` teaches an agent exactly how to call the function. + +### 3. Path Is Documentation + +File and directory paths should be self-describing. An agent navigating the filesystem should understand what it is looking at without reading a README. + +``` +flow/deploy/to/homelab.yaml — deploy TO the homelab +flow/deploy/from/github.yaml — deploy FROM GitHub +flow/code/review.yaml — code review flow +template/file/go/struct.go.tmpl — Go struct file template +template/dir/workspace/php/ — PHP workspace scaffold +``` + +**Rule:** If an agent needs to read a file to understand what a directory contains, the directory naming has failed. + +**Corollary:** The unified path convention (folder structure = HTTP route = CLI command = test path) is AX-native. One path, every surface. + +### 4. Templates Over Freeform + +When an agent generates code from a template, the output is constrained to known-good shapes. When an agent writes freeform, the output varies. + +```go +// Template-driven — consistent output +lib.RenderFile("php/action", data) +lib.ExtractDir("php", targetDir, data) + +// Freeform — variance in output +"write a PHP action class that..." +``` + +**Rule:** For any code pattern that recurs, provide a template. Templates are guardrails for agents. + +**Scope:** Templates apply to file generation, workspace scaffolding, config generation, and commit messages. They do NOT apply to novel logic — agents should write business logic freeform with the domain knowledge available. + +### 5. Declarative Over Imperative + +Agents reason better about declarations of intent than sequences of operations. + +```yaml +# Declarative — agent sees what should happen +steps: + - name: build + flow: tools/docker-build + with: + context: "{{ .app_dir }}" + image_name: "{{ .image_name }}" + + - name: deploy + flow: deploy/with/docker + with: + host: "{{ .host }}" +``` + +```go +// Imperative — agent must trace execution +cmd := exec.Command("docker", "build", "--platform", "linux/amd64", "-t", imageName, ".") +cmd.Dir = appDir +if err := cmd.Run(); err != nil { + return fmt.Errorf("docker build: %w", err) +} +``` + +**Rule:** Orchestration, configuration, and pipeline logic should be declarative (YAML/JSON). Implementation logic should be imperative (Go/PHP/TS). The boundary is: if an agent needs to compose or modify the logic, make it declarative. + +### 6. Universal Types (Core Primitives) + +Every component in the ecosystem accepts and returns the same primitive types. An agent processing any level of the tree sees identical shapes. + +```go +// Universal contract +setup.Run(core.Options{Path: ".", Template: "auto"}) +brain.New(core.Options{Name: "openbrain"}) +deploy.Run(core.Options{Flow: "deploy/to/homelab"}) + +// Fractal — Core itself is a Service +core.New(core.Options{ + Services: []core.Service{ + process.New(core.Options{Name: "process"}), + brain.New(core.Options{Name: "brain"}), + }, +}) +``` + +**Core primitive types:** + +| Type | Purpose | +|------|---------| +| `core.Options` | Input configuration (what you want) | +| `core.Config` | Runtime settings (what is active) | +| `core.Data` | Embedded or stored content | +| `core.Service` | A managed component with lifecycle | +| `core.Result[T]` | Return value with OK/fail state | + +**What this replaces:** + +| Go Convention | Core AX | Why | +|--------------|---------|-----| +| `func With*(v) Option` | `core.Options{Field: v}` | Struct literal is parseable; option chain requires tracing | +| `func Must*(v) T` | `core.Result[T]` | No hidden panics; errors flow through Core | +| `func *For[T](c) T` | `c.Service("name")` | String lookup is greppable; generics require type context | +| `val, err :=` everywhere | Single return via `core.Result` | Intent not obscured by error handling | +| `_ = err` | Never needed | Core handles all errors internally | + +### 7. Directory as Semantics + +The directory structure tells an agent the intent before it reads a word. Top-level directories are semantic categories, not organisational bins. + +``` +plans/ +├── code/ # Pure primitives — read for WHAT exists +├── project/ # Products — read for WHAT we're building and WHY +└── rfc/ # Contracts — read for constraints and rules +``` + +**Rule:** An agent should know what kind of document it's reading from the path alone. `code/core/go/io/RFC.md` = a lib primitive spec. `project/ofm/RFC.md` = a product spec that cross-references code/. `rfc/snider/borg/RFC-BORG-006-SMSG-FORMAT.md` = an immutable contract for the Borg SMSG protocol. + +**Corollary:** The three-way split (code/project/rfc) extends principle 3 (Path Is Documentation) from files to entire subtrees. The path IS the metadata. + +### 8. Lib Never Imports Consumer + +Dependency flows one direction. Libraries define primitives. Consumers compose from them. A new feature in a consumer can never break a library. + +``` +code/core/go/* → lib tier (stable foundation) +code/core/agent/ → consumer tier (composes from go/*) +code/core/cli/ → consumer tier (composes from go/*) +code/core/gui/ → consumer tier (composes from go/*) +``` + +**Rule:** If package A is in `go/` and package B is in the consumer tier, B may import A but A must never import B. The repo naming convention enforces this: `go-{name}` = lib, bare `{name}` = consumer. + +**Why this matters for agents:** When an agent is dispatched to implement a feature in `core/agent`, it can freely import from `go-io`, `go-scm`, `go-process`. But if an agent is dispatched to `go-io`, it knows its changes are foundational — every consumer depends on it, so the contract must not break. + +### 9. Issues Are N+(rounds) Deep + +Problems in code and specs are layered. Surface issues mask deeper issues. Fixing the surface reveals the next layer. This is not a failure mode — it is the discovery process. + +``` +Pass 1: Find 16 issues (surface — naming, imports, obvious errors) +Pass 2: Find 11 issues (structural — contradictions, missing types) +Pass 3: Find 5 issues (architectural — signature mismatches, registration gaps) +Pass 4: Find 4 issues (contract — cross-spec API mismatches) +Pass 5: Find 2 issues (mechanical — path format, nil safety) +Pass N: Findings are trivial → spec/code is complete +``` + +**Rule:** Iteration is required, not a failure. Each pass sees what the previous pass could not, because the context changed. An agent dispatched with the same task on the same repo will find different things each time — this is correct behaviour. + +**Corollary:** The cheapest model should do the most passes (surface work). The frontier model should arrive last, when only deep issues remain. Tiered iteration: grunt model grinds → mid model pre-warms → frontier model polishes. + +**Anti-pattern:** One-shot generation expecting valid output. No model, no human, produces correct-on-first-pass for non-trivial work. Expecting it wastes the first pass on surface issues that a cheaper pass would have caught. + +### 10. CLI Tests as Artifact Validation + +Unit tests verify the code. CLI tests verify the binary. The directory structure IS the command structure — path maps to command, Taskfile runs the test. + +``` +tests/cli/ +├── core/ +│ └── lint/ +│ ├── Taskfile.yaml ← test `core-lint` (root) +│ ├── run/ +│ │ ├── Taskfile.yaml ← test `core-lint run` +│ │ └── fixtures/ +│ ├── go/ +│ │ ├── Taskfile.yaml ← test `core-lint go` +│ │ └── fixtures/ +│ └── security/ +│ ├── Taskfile.yaml ← test `core-lint security` +│ └── fixtures/ +``` + +**Rule:** Every CLI command has a matching `tests/cli/{path}/Taskfile.yaml`. The Taskfile runs the compiled binary against fixtures with known inputs and validates the output. If the CLI test passes, the underlying actions work — because CLI commands call actions, MCP tools call actions, API endpoints call actions. Test the CLI, trust the rest. + +**Pattern:** + +```yaml +# tests/cli/core/lint/go/Taskfile.yaml +version: '3' +tasks: + test: + cmds: + - core-lint go --output json fixtures/ > /tmp/result.json + - jq -e '.findings | length > 0' /tmp/result.json + - jq -e '.summary.passed == false' /tmp/result.json +``` + +**Why this matters for agents:** An agent can validate its own work by running `task test` in the matching `tests/cli/` directory. No test framework, no mocking, no setup — just the binary, fixtures, and `jq` assertions. The agent builds the binary, runs the test, sees the result. If it fails, the agent can read the fixture, read the output, and fix the code. + +**Corollary:** Fixtures are planted bugs. Each fixture file has a known issue that the linter must find. If the linter doesn't find it, the test fails. Fixtures are the spec for what the tool must detect — they ARE the test cases, not descriptions of test cases. + +## Applying AX to Existing Patterns + +### File Structure + +``` +# AX-native: path describes content +core/agent/ +├── go/ # Go source +├── php/ # PHP source +├── ui/ # Frontend source +├── claude/ # Claude Code plugin +└── codex/ # Codex plugin + +# Not AX: generic names requiring README +src/ +├── lib/ +├── utils/ +└── helpers/ +``` + +### Error Handling + +```go +// AX-native: errors are infrastructure, not application logic +svc := c.Service("brain") +cfg := c.Config().Get("database.host") +// Errors logged by Core. Code reads like a spec. + +// Not AX: errors dominate the code +svc, err := c.ServiceFor[brain.Service]() +if err != nil { + return fmt.Errorf("get brain service: %w", err) +} +cfg, err := c.Config().Get("database.host") +if err != nil { + _ = err // silenced because "it'll be fine" +} +``` + +### API Design + +```go +// AX-native: one shape, every surface +core.New(core.Options{ + Name: "my-app", + Services: []core.Service{...}, + Config: core.Config{...}, +}) + +// Not AX: multiple patterns for the same thing +core.New( + core.WithName("my-app"), + core.WithService(factory1), + core.WithService(factory2), + core.WithConfig(cfg), +) +``` + +## The Plans Convention — AX Development Lifecycle + +The `plans/` directory structure encodes a development methodology designed for how generative AI actually works: iterative refinement across structured phases, not one-shot generation. + +### The Three-Way Split + +``` +plans/ +├── project/ # 1. WHAT and WHY — start here +├── rfc/ # 2. CONSTRAINTS — immutable contracts +└── code/ # 3. HOW — implementation specs +``` + +Each directory is a phase. Work flows from project → rfc → code. Each transition forces a refinement pass — you cannot write a code spec without discovering gaps in the project spec, and you cannot write an RFC without discovering assumptions in both. + +**Three places for data that can't be written simultaneously = three guaranteed iterations of "actually, this needs changing."** Refinement is baked into the structure, not bolted on as a review step. + +### Phase 1: Project (Vision) + +Start with `project/`. No code exists yet. Define: +- What the product IS and who it serves +- What existing primitives it consumes (cross-ref to `code/`) +- What constraints it operates under (cross-ref to `rfc/`) + +This is where creativity lives. Map features to building blocks. Connect systems. The project spec is integrative — it references everything else. + +### Phase 2: RFC (Contracts) + +Extract the immutable rules into `rfc/`. These are constraints that don't change with implementation: +- Wire formats, protocols, hash algorithms +- Security properties that must hold +- Compatibility guarantees + +RFCs are numbered per component (`RFC-BORG-006-SMSG-FORMAT.md`) and never modified after acceptance. If the contract changes, write a new RFC. + +### Phase 3: Code (Implementation Specs) + +Define the implementation in `code/`. Each component gets an RFC.md that an agent can implement from: +- Struct definitions (the DTOs — see principle 6) +- Method signatures and behaviour +- Error conditions and edge cases +- Cross-references to other code/ specs + +The code spec IS the product. Write the spec → dispatch to an agent → review output → iterate. + +### Pre-Launch: Alignment Protocol + +Before dispatching for implementation, verify spec-model alignment: + +``` +1. REVIEW — The implementation model (Codex/Jules) reads the spec + and reports missing elements. This surfaces the delta between + the model's training and the spec's assumptions. + + "I need X, Y, Z to implement this" is the model saying + "I hear you but I'm missing context" — without asking. + +2. ADJUST — Update the spec to close the gaps. Add examples, + clarify ambiguities, provide the context the model needs. + This is shared alignment, not compromise. + +3. VERIFY — A different model (or sub-agent) reviews the adjusted + spec without the planner's bias. Fresh eyes on the contract. + "Does this make sense to someone who wasn't in the room?" + +4. READY — When the review findings are trivial or deployment- + related (not architectural), the spec is ready to dispatch. +``` + +### Implementation: Iterative Dispatch + +Same prompt, multiple runs. Each pass sees deeper because the context evolved: + +``` +Round 1: Build features (the obvious gaps) +Round 2: Write tests (verify what was built) +Round 3: Harden security (what can go wrong?) +Round 4: Next RFC section (what's still missing?) +Round N: Findings are trivial → implementation is complete +``` + +Re-running is not failure. It is the process. Each pass changes the codebase, which changes what the next pass can see. The iteration IS the refinement. + +### Post-Implementation: Auto-Documentation + +The QA/verify chain produces artefacts that feed forward: +- Test results document the contract (what works, what doesn't) +- Coverage reports surface untested paths +- Diff summaries prep the changelog for the next release +- Doc site updates from the spec (the spec IS the documentation) + +The output of one cycle is the input to the next. The plans repo stays current because the specs drive the code, not the other way round. + +## Compatibility + +AX conventions are valid, idiomatic Go/PHP/TS. They do not require language extensions, code generation, or non-standard tooling. An AX-designed codebase compiles, tests, and deploys with standard toolchains. + +The conventions diverge from community patterns (functional options, Must/For, etc.) but do not violate language specifications. This is a style choice, not a fork. + +## Adoption + +AX applies to all new code in the Core ecosystem. Existing code migrates incrementally as it is touched — no big-bang rewrite. + +Priority order: +1. **Public APIs** (package-level functions, struct constructors) +2. **File structure** (path naming, template locations) +3. **Internal fields** (struct field names, local variables) + +## References + +- dAppServer unified path convention (2024) +- CoreGO DTO pattern refactor (2026-03-18) +- Core primitives design (2026-03-19) +- Go Proverbs, Rob Pike (2015) — AX provides an updated lens + +## Changelog + +- 2026-03-19: Initial draft diff --git a/docs/architecture.md b/docs/architecture.md index 2915608..dc204a1 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -1,6 +1,6 @@ # Architecture — go-p2p -`go-p2p` is the P2P networking layer for the Lethean network. Module path: `forge.lthn.ai/core/go-p2p`. +`go-p2p` is the P2P networking layer for the Lethean network. Module path: `dappco.re/go/core/p2p`. ## Package Structure @@ -17,7 +17,7 @@ go-p2p/ ### identity.go — Node Identity -Each node holds an Ed25519 keypair generated via Borg STMF (X25519 curve). The private key is stored at `~/.local/share/lethean-desktop/node/private.key` (mode 0600) and the public identity JSON at `~/.config/lethean-desktop/node.json`. +Each node holds an X25519 keypair generated via Borg STMF. The private key is stored at `~/.local/share/lethean-desktop/node/private.key` (mode 0600) and the public identity JSON at `~/.config/lethean-desktop/node.json`. `NodeIdentity` carries: - `ID` — 32-character hex string derived from SHA-256 of the public key (first 16 bytes) @@ -36,9 +36,9 @@ The `Transport` manages a WebSocket server (gorilla/websocket) and outbound conn | Field | Default | Purpose | |-------|---------|---------| -| `ListenAddr` | `:9091` | HTTP bind address | -| `WSPath` | `/ws` | WebSocket endpoint | -| `MaxConns` | 100 | Maximum concurrent connections | +| `ListenAddress` | `:9091` | HTTP bind address | +| `WebSocketPath` | `/ws` | WebSocket endpoint | +| `MaxConnections` | 100 | Maximum concurrent connections | | `MaxMessageSize` | 1 MB | Read limit per message | | `PingInterval` | 30 s | Keepalive ping period | | `PongTimeout` | 10 s | Maximum time to wait for pong | @@ -56,11 +56,11 @@ The `Transport` manages a WebSocket server (gorilla/websocket) and outbound conn **Rate limiting**: Each `PeerConnection` holds a `PeerRateLimiter` (token bucket: 100 burst, 50 tokens/second refill). Messages from rate-limited peers are dropped in the read loop. -**MaxConns enforcement**: The handler tracks `pendingConns` (atomic counter) during the handshake phase in addition to established connections, preventing races where a surge of simultaneous inbounds could exceed the limit. +**MaxConnections enforcement**: The handler tracks `pendingHandshakeCount` (atomic counter) during the handshake phase in addition to established connections, preventing races where a surge of simultaneous inbounds could exceed the limit. **Keepalive**: A goroutine per connection ticks at `PingInterval`. If `LastActivity` has not been updated within `PingInterval + PongTimeout`, the connection is removed. -**Graceful close**: `GracefulClose` sends `MsgDisconnect` before closing the underlying WebSocket. Write deadlines are managed exclusively inside `Send()` under `writeMu` to prevent the race (P2P-RACE-1) where a bare `SetWriteDeadline` call could race with concurrent sends. +**Graceful close**: `GracefulClose` sends `MsgDisconnect` before closing the underlying WebSocket. Write deadlines are managed exclusively inside `Send()` under `writeMutex` to prevent the race (P2P-RACE-1) where a bare `SetWriteDeadline` call could race with concurrent sends. **Buffer pool**: `MarshalJSON` uses a `sync.Pool` of `bytes.Buffer` (initial capacity 1 KB, maximum pooled size 64 KB) to reduce allocation pressure in the message serialisation hot path. HTML escaping is disabled to match `json.Marshal` semantics. @@ -70,20 +70,20 @@ The `Transport` manages a WebSocket server (gorilla/websocket) and outbound conn **Peer fields persisted**: - `ID`, `Name`, `PublicKey`, `Address`, `Role`, `AddedAt`, `LastSeen` -- `PingMS`, `Hops`, `GeoKM`, `Score` (float64, 0–100) +- `PingMilliseconds`, `Hops`, `GeographicKilometres`, `Score` (float64, 0–100) **KD-tree dimensions** (lower is better in all axes): | Dimension | Weight | Rationale | |-----------|--------|-----------| -| `PingMS` | 1.0 | Latency dominates interactive performance | +| `PingMilliseconds` | 1.0 | Latency dominates interactive performance | | `Hops` | 0.7 | Network hop count (routing cost) | -| `GeoKM` | 0.2 | Geographic distance (minor factor) | +| `GeographicKilometres` | 0.2 | Geographic distance (minor factor) | | `100 - Score` | 1.2 | Reliability (inverted so lower = better peer) | `SelectOptimalPeer()` queries the tree for the point nearest to the origin (ideal: zero latency, zero hops, zero distance, maximum score). `SelectNearestPeers(n)` returns the n best. -**Persistence**: Writes are debounced with a 5-second coalesce window (`scheduleSave`). The actual write uses an atomic rename pattern (write to `.tmp`, then `os.Rename`) to prevent partial file corruption. `Close()` flushes any pending dirty state synchronously. +**Persistence**: Writes are debounced with a 5-second coalesce window (`scheduleSave`). The actual write uses an atomic rename pattern (write to `.tmp`, then rename) to prevent partial file corruption. `Close()` flushes any pending dirty state synchronously. **Auth modes**: - `PeerAuthOpen` — any connecting peer is accepted (default). @@ -179,7 +179,7 @@ Auto-connect: if the target peer is not yet connected, `sendRequest` calls `tran |----------|-------|---------| | `IntentHandshake` | `0x01` | Connection establishment | | `IntentCompute` | `0x20` | Compute job request | -| `IntentRehab` | `0x30` | Benevolent intervention (pause execution) | +| `IntentPauseExecution` | `0x30` | Benevolent intervention (pause execution) | | `IntentCustom` | `0xFF` | Application-level sub-protocols | **Sentinel errors**: @@ -209,10 +209,10 @@ The Unified Encrypted Packet Structure defines a TLV-encoded binary frame authen [0x04][len][IntentID] Header: Semantic routing token [0x05][0x02][ThreatScore] Header: uint16, big-endian [0x06][0x20][HMAC-SHA256] Signature: 32 bytes, covers header TLVs + payload data -[0xFF][...payload...] Data: no length prefix (relies on external framing) +[0xFF][len][...payload...] Data: length-prefixed payload ``` -**HMAC coverage**: The signature is computed over the serialised header TLVs (tags 0x01–0x05) concatenated with the raw payload bytes. The HMAC TLV itself (tag 0x06) and the payload tag byte (0xFF) are excluded from the signed data. +**HMAC coverage**: The signature is computed over the serialised header TLVs (tags 0x01–0x05) concatenated with the raw payload bytes. The HMAC TLV itself (tag 0x06) and the payload TLV header (tag `0xFF` plus the 2-byte length) are excluded from the signed data. ### PacketBuilder @@ -220,9 +220,7 @@ The Unified Encrypted Packet Structure defines a TLV-encoded binary frame authen ### ReadAndVerify -`ReadAndVerify(r *bufio.Reader, sharedSecret)` reads a stream, decodes the TLV fields in order, reconstructs the signed data buffer, and verifies the HMAC with `hmac.Equal`. Unknown TLV tags are accumulated into the signed data buffer (forward-compatible extension mechanism) but their semantics are ignored. - -**Known limitation**: Tag 0xFF carries no length prefix. The reader calls `io.ReadAll` on the remaining stream, which requires external TCP framing (e.g. a 4-byte length prefix on the enclosing connection) to delimit the packet boundary. The packet is not self-delimiting. +`ReadAndVerify(r *bufio.Reader, sharedSecret)` reads a stream, decodes the TLV fields in order, reconstructs the signed data buffer, and verifies the HMAC with `hmac.Equal`. Unknown TLV tags are accumulated into the signed data buffer (forward-compatible extension mechanism) but their semantics are ignored. The payload TLV is length-prefixed like every other field, so UEPS frames are self-delimiting. ## logging/ — Structured Logger @@ -238,7 +236,7 @@ A global logger instance is available via `logging.Debug(...)`, `logging.Info(.. |----------|------------| | `Transport.conns` | `sync.RWMutex` | | `Transport.handler` | `sync.RWMutex` | -| `PeerConnection` writes | `sync.Mutex` (`writeMu`) | +| `PeerConnection` writes | `sync.Mutex` (`writeMutex`) | | `PeerConnection` close | `sync.Once` (`closeOnce`) | | `PeerRegistry.peers` + KD-tree | `sync.RWMutex` | | `PeerRegistry.allowedPublicKeys` | separate `sync.RWMutex` | @@ -246,7 +244,7 @@ A global logger instance is available via `logging.Debug(...)`, `logging.Info(.. | `Controller.pending` | `sync.RWMutex` | | `MessageDeduplicator.seen` | `sync.RWMutex` | | `Dispatcher.handlers` | `sync.RWMutex` | -| `Transport.pendingConns` | `atomic.Int32` | +| `Transport.pendingHandshakeCount` | `atomic.Int32` | The codebase is verified race-free under `go test -race`. @@ -255,8 +253,8 @@ The codebase is verified race-free under `go test -race`. ``` node/ ──► ueps/ node/ ──► logging/ -node/ ──► github.com/Snider/Borg (STMF crypto, SMSG encryption, TIM) -node/ ──► github.com/Snider/Poindexter (KD-tree peer selection) +node/ ──► forge.lthn.ai/Snider/Borg (STMF crypto, SMSG encryption, TIM) +node/ ──► forge.lthn.ai/Snider/Poindexter (KD-tree peer selection) node/ ──► github.com/gorilla/websocket node/ ──► github.com/google/uuid ueps/ ──► (stdlib only) diff --git a/docs/development.md b/docs/development.md index 42045f9..196c370 100644 --- a/docs/development.md +++ b/docs/development.md @@ -2,7 +2,7 @@ ## Prerequisites -- Go 1.25 or later (the module declares `go 1.25.5`) +- Go 1.26 or later (the module declares `go 1.26.0`) - Network access to `forge.lthn.ai` for private dependencies (Borg, Poindexter, Enchantrix) - SSH key configured for `git@forge.lthn.ai:2223` (HTTPS auth is not supported on Forge) @@ -43,7 +43,7 @@ go vet ./... ### Table-Driven Subtests -All tests use table-driven subtests with `t.Run()`. A test that does not follow this pattern should be refactored before merging. +Prefer table-driven subtests with `t.Run()` when multiple related cases share the same structure. Use clear case names and keep setup and verification consistent across the table. ```go func TestFoo(t *testing.T) { @@ -177,12 +177,12 @@ All parameters and return types must carry explicit type annotations. Avoid `int ### Error Handling - Never discard errors silently. -- Wrap errors with context using `fmt.Errorf("context: %w", err)`. +- Wrap library errors with context using `core.E("operation", "context", err)`. - Return typed sentinel errors for conditions callers need to inspect programmatically. ### Licence Header -Every new file must carry the EUPL-1.2 licence identifier. The module's `LICENSE` file governs the package. Do not include the full licence text in each file; a short SPDX identifier comment at the top is sufficient for new files: +Every new file must carry the EUPL-1.2 licence identifier. The project is licensed under EUPL-1.2; do not include the full licence text in each file. A short SPDX identifier comment at the top is sufficient for new files: ```go // SPDX-License-Identifier: EUPL-1.2 @@ -233,7 +233,7 @@ Examples: ``` feat(dispatcher): implement UEPS threat circuit breaker -test(transport): add keepalive timeout and MaxConns enforcement tests +test(transport): add keepalive timeout and MaxConnections enforcement tests fix(peer): prevent data race in GracefulClose (P2P-RACE-1) ``` diff --git a/docs/discovery.md b/docs/discovery.md index cc92c0d..452bb94 100644 --- a/docs/discovery.md +++ b/docs/discovery.md @@ -20,9 +20,9 @@ type Peer struct { LastSeen time.Time `json:"lastSeen"` // Poindexter metrics (updated dynamically) - PingMS float64 `json:"pingMs"` // Latency in milliseconds + PingMilliseconds float64 `json:"pingMs"` // Latency in milliseconds Hops int `json:"hops"` // Network hop count - GeoKM float64 `json:"geoKm"` // Geographic distance in kilometres + GeographicKilometres float64 `json:"geoKm"` // Geographic distance in kilometres Score float64 `json:"score"` // Reliability score 0--100 Connected bool `json:"-"` // Not persisted @@ -83,9 +83,9 @@ The registry maintains a 4-dimensional KD-tree for optimal peer selection. Each | Dimension | Source | Weight | Direction | |-----------|--------|--------|-----------| -| Latency | `PingMS` | 1.0 | Lower is better | +| Latency | `PingMilliseconds` | 1.0 | Lower is better | | Hops | `Hops` | 0.7 | Lower is better | -| Geographic distance | `GeoKM` | 0.2 | Lower is better | +| Geographic distance | `GeographicKilometres` | 0.2 | Lower is better | | Reliability | `100 - Score` | 1.2 | Inverted so lower is better | The score dimension is inverted so that the "ideal peer" target point `[0, 0, 0, 0]` represents zero latency, zero hops, zero distance, and maximum reliability (score 100). @@ -146,7 +146,7 @@ This also updates `LastSeen` and triggers a KD-tree rebuild. ```go // Create registry, err := node.NewPeerRegistry() // XDG paths -registry, err := node.NewPeerRegistryWithPath(path) // Custom path (testing) +registry, err := node.NewPeerRegistryFromPath(path) // Custom path (testing) // CRUD err := registry.AddPeer(peer) @@ -177,7 +177,7 @@ Peers are persisted to `~/.config/lethean-desktop/peers.json` as a JSON array. ### Debounced Writes -To avoid excessive disk I/O, saves are debounced with a 5-second coalesce interval. Multiple mutations within that window produce a single disk write. The write uses an atomic rename pattern (write to `.tmp`, then `os.Rename`) to prevent corruption on crash. +To avoid excessive disk I/O, saves are debounced with a 5-second coalesce interval. Multiple mutations within that window produce a single disk write. The write uses an atomic rename pattern (write to `.tmp`, then rename) to prevent corruption on crash. ```go // Flush pending changes on shutdown diff --git a/docs/history.md b/docs/history.md index 02f5819..986727f 100644 --- a/docs/history.md +++ b/docs/history.md @@ -10,10 +10,10 @@ Implemented the complete test suite for the UEPS binary framing layer. Tests cov - PacketBuilder round-trip: basic, binary payload, elevated threat score, large payload - HMAC verification: payload tampering detected, header tampering detected, wrong shared secret detected -- Boundary conditions: nil payload, empty slice payload, `uint16` max ThreatScore (65,535), TLV value exceeding 255 bytes (`writeTLV` error path) +- Boundary conditions: nil payload, empty slice payload, `uint16` max ThreatScore (65,535), TLV value exceeding 65,535 bytes (`writeTLV` error path) - Stream robustness: truncated packets detected at multiple cut points (EOF mid-tag, mid-length, mid-value), missing HMAC tag, unknown TLV tags skipped and included in signed data -The 11.5% gap from 100% coverage is the reader's `io.ReadAll` error path, which requires a contrived broken `io.Reader` to exercise. +The remaining gap from 100% coverage at the time was the payload read-error path, which required a contrived broken reader to exercise. ### Phase 2 — Transport Tests @@ -28,10 +28,10 @@ Tests covered: - Encrypted message round-trip: SMSG encrypt on one side, decrypt on other - Message deduplication: duplicate UUID dropped silently - Rate limiting: burst of more than 100 messages, subsequent drops after token bucket empties -- MaxConns enforcement: 503 HTTP rejection when limit is reached +- MaxConnections enforcement: 503 HTTP rejection when limit is reached - Keepalive timeout: connection cleaned up after `PingInterval + PongTimeout` elapses - Graceful close: `MsgDisconnect` sent before underlying WebSocket close -- Concurrent sends: no data races under `go test -race` (`writeMu` protects all writes) +- Concurrent sends: no data races under `go test -race` (`writeMutex` protects all writes) ### Phase 3 — Controller Tests @@ -86,15 +86,13 @@ Three integration tests (`TestIntegration_*`) exercise the full stack end-to-end ## Known Limitations -### UEPS 0xFF Payload Not Self-Delimiting +### UEPS Payload Framing (Resolved) -The `TagPayload` (0xFF) field carries no length prefix. `ReadAndVerify` calls `io.ReadAll` on the remaining stream, which means the packet format relies on external TCP framing to delimit the packet boundary. The enclosing transport must provide a length-prefixed frame before calling `ReadAndVerify`. This is noted in comments in both `packet.go` and `reader.go` but no solution is implemented. - -Consequence: UEPS packets cannot be chained in a raw stream without an outer framing protocol. The current WebSocket transport encapsulates each UEPS frame in a single WebSocket message, which provides the necessary boundary implicitly. +The `TagPayload` (0xFF) field now uses the same 2-byte length prefix as the other TLVs. `ReadAndVerify` reads that explicit length, so UEPS packets are self-delimiting and can be chained in a stream without relying on an outer framing layer. ### No Resource Cleanup on Some Error Paths -`transport.handleWSUpgrade` does not clean up on handshake timeout (the `pendingConns` counter is decremented correctly via `defer`, but the underlying WebSocket connection may linger briefly before the read deadline fires). `transport.Connect` does not clean up the temporary connection object on handshake failure (the raw WebSocket `conn` is closed, but there is no registry or metrics cleanup for the partially constructed `PeerConnection`). +`transport.handleWebSocketUpgrade` does not clean up on handshake timeout (the `pendingHandshakeCount` counter is decremented correctly via `defer`, but the underlying WebSocket connection may linger briefly before the read deadline fires). `transport.Connect` does not clean up the temporary connection object on handshake failure (the raw WebSocket `conn` is closed, but there is no registry or metrics cleanup for the partially constructed `PeerConnection`). These are low-severity gaps. They do not cause goroutine leaks under the current implementation because the connection's read loop is not started until after a successful handshake. @@ -106,9 +104,9 @@ The originally identified risk — that `transport.OnMessage(c.handleResponse)` ### P2P-RACE-1 — GracefulClose Data Race (Phase 3) -`GracefulClose` previously called `pc.Conn.SetWriteDeadline()` outside of `writeMu`, racing with concurrent `Send()` calls that also set the write deadline. Detected by `go test -race`. +`GracefulClose` previously called `pc.WebSocketConnection.SetWriteDeadline()` outside of `writeMutex`, racing with concurrent `Send()` calls that also set the write deadline. Detected by `go test -race`. -Fix: removed the bare `SetWriteDeadline` call from `GracefulClose`. The method now relies entirely on `Send()`, which manages write deadlines under `writeMu`. This is documented in a comment in `transport.go` to prevent the pattern from being reintroduced. +Fix: removed the bare `SetWriteDeadline` call from `GracefulClose`. The method now relies entirely on `Send()`, which manages write deadlines under `writeMutex`. This is documented in a comment in `transport.go` to prevent the pattern from being reintroduced. ## Wiki Corrections (19 February 2026) diff --git a/docs/identity.md b/docs/identity.md index c16a559..05c1a58 100644 --- a/docs/identity.md +++ b/docs/identity.md @@ -39,13 +39,13 @@ Paths follow XDG base directories via `github.com/adrg/xdg`. The private key is ### Creating an Identity ```go -nm, err := node.NewNodeManager() +nodeManager, err := node.NewNodeManager() if err != nil { log.Fatal(err) } // Generate a new identity (persists key and config to disk) -err = nm.GenerateIdentity("eu-controller-01", node.RoleController) +err = nodeManager.GenerateIdentity("eu-controller-01", node.RoleController) ``` Internally this calls `stmf.GenerateKeyPair()` from the Borg library to produce the X25519 keypair. @@ -53,7 +53,7 @@ Internally this calls `stmf.GenerateKeyPair()` from the Borg library to produce ### Custom Paths (Testing) ```go -nm, err := node.NewNodeManagerWithPaths( +nodeManager, err := node.NewNodeManagerFromPaths( "/tmp/test/private.key", "/tmp/test/node.json", ) @@ -62,8 +62,8 @@ nm, err := node.NewNodeManagerWithPaths( ### Checking and Retrieving Identity ```go -if nm.HasIdentity() { - identity := nm.GetIdentity() // Returns a copy +if nodeManager.HasIdentity() { + identity := nodeManager.GetIdentity() // Returns a copy fmt.Println(identity.ID, identity.Name) } ``` @@ -73,7 +73,7 @@ if nm.HasIdentity() { ### Deriving Shared Secrets ```go -sharedSecret, err := nm.DeriveSharedSecret(peerPublicKeyBase64) +sharedSecret, err := nodeManager.DeriveSharedSecret(peerPublicKeyBase64) ``` This performs X25519 ECDH with the peer's public key and hashes the result with SHA-256, producing a 32-byte symmetric key. The same shared secret is derived independently by both sides (no secret is transmitted). @@ -81,7 +81,7 @@ This performs X25519 ECDH with the peer's public key and hashes the result with ### Deleting an Identity ```go -err := nm.Delete() // Removes key and config from disk, clears in-memory state +err := nodeManager.Delete() // Removes key and config from disk, clears in-memory state ``` ## Challenge-Response Authentication diff --git a/docs/index.md b/docs/index.md index 3451618..4508eb8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -7,7 +7,7 @@ description: P2P mesh networking layer for the Lethean network. P2P networking layer for the Lethean network. Encrypted WebSocket mesh with UEPS wire protocol. -**Module:** `forge.lthn.ai/core/go-p2p` +**Module:** `dappco.re/go/core/p2p` **Go:** 1.26 **Licence:** EUPL-1.2 diff --git a/docs/routing.md b/docs/routing.md index b9536d5..3881f01 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -5,7 +5,7 @@ description: UEPS intent-based packet routing with threat circuit breaker. # Intent Routing -The `Dispatcher` routes verified UEPS packets to registered intent handlers. Before routing, it enforces a threat circuit breaker that silently drops packets with elevated threat scores. +The `Dispatcher` routes verified UEPS packets to registered intent handlers. Before routing, it enforces a threat circuit breaker that blocks packets with elevated threat scores and returns sentinel errors to the caller. **File:** `node/dispatcher.go` @@ -74,8 +74,8 @@ Dropped packets are logged at WARN level with the threat score, threshold, inten ### Design Rationale -- **High-threat packets are dropped silently** (from the sender's perspective) rather than returning an error, consistent with the "don't even parse the payload" philosophy. -- **Unknown intents are dropped**, not forwarded, to avoid back-pressure on the transport layer. They are logged at WARN level for debugging. +- **High-threat packets are not dispatched**. The dispatcher logs them and returns `ErrThreatScoreExceeded` to the caller; the sender still receives no protocol-level response. +- **Unknown intents are not forwarded**. The dispatcher logs them and returns `ErrUnknownIntent`, avoiding back-pressure on the transport layer. - **Handler errors propagate** to the caller, allowing upstream code to record failures. ## Intent Constants @@ -84,7 +84,7 @@ Dropped packets are logged at WARN level with the threat score, threshold, inten const ( IntentHandshake byte = 0x01 // Connection establishment / hello IntentCompute byte = 0x20 // Compute job request - IntentRehab byte = 0x30 // Benevolent intervention (pause execution) + IntentPauseExecution byte = 0x30 IntentCustom byte = 0xFF // Extended / application-level sub-protocols ) ``` @@ -100,12 +100,13 @@ const ( ```go var ( - ErrThreatScoreExceeded = fmt.Errorf( - "packet rejected: threat score exceeds safety threshold (%d)", - ThreatScoreThreshold, + ErrThreatScoreExceeded = core.E( + "Dispatcher.Dispatch", + core.Sprintf("packet rejected: threat score exceeds safety threshold (%d)", ThreatScoreThreshold), + nil, ) - ErrUnknownIntent = errors.New("packet dropped: unknown intent") - ErrNilPacket = errors.New("dispatch: nil packet") + ErrUnknownIntent = core.E("Dispatcher.Dispatch", "packet dropped: unknown intent", nil) + ErrNilPacket = core.E("Dispatcher.Dispatch", "nil packet", nil) ) ``` diff --git a/docs/transport.md b/docs/transport.md index 19a7987..56b55f9 100644 --- a/docs/transport.md +++ b/docs/transport.md @@ -11,11 +11,11 @@ The `Transport` manages encrypted WebSocket connections between nodes. After an ```go type TransportConfig struct { - ListenAddr string // ":9091" default - WSPath string // "/ws" -- WebSocket endpoint path + ListenAddress string // ":9091" default + WebSocketPath string // "/ws" -- WebSocket endpoint path TLSCertPath string // Optional TLS for wss:// TLSKeyPath string - MaxConns int // Maximum concurrent connections (default 100) + MaxConnections int // Maximum concurrent connections (default 100) MaxMessageSize int64 // Maximum message size in bytes (default 1MB) PingInterval time.Duration // Keepalive interval (default 30s) PongTimeout time.Duration // Pong wait timeout (default 10s) @@ -25,18 +25,18 @@ type TransportConfig struct { Sensible defaults via `DefaultTransportConfig()`: ```go -cfg := node.DefaultTransportConfig() -// ListenAddr: ":9091", WSPath: "/ws", MaxConns: 100 +transportConfig := node.DefaultTransportConfig() +// ListenAddress: ":9091", WebSocketPath: "/ws", MaxConnections: 100 // MaxMessageSize: 1MB, PingInterval: 30s, PongTimeout: 10s ``` ## Creating and Starting ```go -transport := node.NewTransport(nodeManager, peerRegistry, cfg) +transport := node.NewTransport(nodeManager, peerRegistry, transportConfig) // Set message handler before Start() to avoid races -transport.OnMessage(func(conn *node.PeerConnection, msg *node.Message) { +transport.OnMessage(func(peerConnection *node.PeerConnection, msg *node.Message) { // Handle incoming messages }) @@ -86,8 +86,8 @@ Each active connection is wrapped in a `PeerConnection`: ```go type PeerConnection struct { - Peer *Peer // Remote peer identity - Conn *websocket.Conn // Underlying WebSocket + Peer *Peer // Remote peer identity + WebSocketConnection *websocket.Conn // Underlying WebSocket SharedSecret []byte // From X25519 ECDH LastActivity time.Time } @@ -96,15 +96,15 @@ type PeerConnection struct { ### Sending Messages ```go -err := peerConn.Send(msg) +err := peerConnection.Send(msg) ``` -`Send()` serialises the message to JSON, encrypts it with SMSG, sets a 10-second write deadline, and writes as a binary WebSocket frame. A `writeMu` mutex serialises concurrent writes. +`Send()` serialises the message to JSON, encrypts it with SMSG, sets a 10-second write deadline, and writes as a binary WebSocket frame. A `writeMutex` serialises concurrent writes. ### Graceful Close ```go -err := peerConn.GracefulClose("shutting down", node.DisconnectShutdown) +err := peerConnection.GracefulClose("shutting down", node.DisconnectShutdown) ``` Sends a `disconnect` message (best-effort) before closing the connection. Uses `sync.Once` to ensure the connection is only closed once. @@ -123,9 +123,9 @@ const ( ## Incoming Connections -The transport exposes an HTTP handler at the configured `WSPath` that upgrades to WebSocket. Origin checks restrict browser clients to `localhost`, `127.0.0.1`, and `::1`; non-browser clients (no `Origin` header) are allowed. +The transport exposes an HTTP handler at the configured `WebSocketPath` that upgrades to WebSocket. Origin checks restrict browser clients to `localhost`, `127.0.0.1`, and `::1`; non-browser clients (no `Origin` header) are allowed. -The `MaxConns` limit is enforced before the WebSocket upgrade, counting both established and pending (mid-handshake) connections. Excess connections receive HTTP 503. +The `MaxConnections` limit is enforced before the WebSocket upgrade, counting both established and pending (mid-handshake) connections. Excess connections receive HTTP 503. ## Message Deduplication @@ -166,7 +166,7 @@ err = transport.Send(peerID, msg) err = transport.Broadcast(msg) // Query connections -count := transport.ConnectedPeers() +count := transport.ConnectedPeerCount() conn := transport.GetConnection(peerID) // Iterate over all connections diff --git a/docs/ueps.md b/docs/ueps.md index c86b498..24ab498 100644 --- a/docs/ueps.md +++ b/docs/ueps.md @@ -7,7 +7,7 @@ description: TLV-encoded wire protocol with HMAC-SHA256 integrity verification ( The `ueps` package implements the Universal Encrypted Payload System -- a consent-gated TLV (Type-Length-Value) wire protocol with HMAC-SHA256 integrity verification. This is the low-level binary protocol that sits beneath the JSON-over-WebSocket mesh layer. -**Package:** `forge.lthn.ai/core/go-p2p/ueps` +**Package:** `dappco.re/go/core/p2p/ueps` ## TLV Format @@ -25,8 +25,8 @@ Each field is encoded as a 1-byte tag, 2-byte big-endian length (uint16), and va | Tag | Constant | Value Size | Description | |-----|----------|------------|-------------| | `0x01` | `TagVersion` | 1 byte | Protocol version (default `0x09` for IPv9) | -| `0x02` | `TagCurrentLay` | 1 byte | Current network layer | -| `0x03` | `TagTargetLay` | 1 byte | Target network layer | +| `0x02` | `TagCurrentLayer` | 1 byte | Current network layer | +| `0x03` | `TagTargetLayer` | 1 byte | Target network layer | | `0x04` | `TagIntent` | 1 byte | Semantic intent token (routes the packet) | | `0x05` | `TagThreatScore` | 2 bytes | Threat score (0--65535, big-endian uint16) | | `0x06` | `TagHMAC` | 32 bytes | HMAC-SHA256 signature | @@ -156,7 +156,7 @@ Reserved intent values: |----|----------|---------| | `0x01` | `IntentHandshake` | Connection establishment / hello | | `0x20` | `IntentCompute` | Compute job request | -| `0x30` | `IntentRehab` | Benevolent intervention (pause execution) | +| `0x30` | `IntentPauseExecution` | Benevolent intervention (pause execution) | | `0xFF` | `IntentCustom` | Extended / application-level sub-protocols | ## Threat Score diff --git a/go.mod b/go.mod index 003f271..e395272 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,7 @@ module dappco.re/go/core/p2p go 1.26.0 require ( - dappco.re/go/core/io v0.2.0 - dappco.re/go/core/log v0.1.0 + dappco.re/go/core v0.8.0-alpha.1 forge.lthn.ai/Snider/Borg v0.3.1 forge.lthn.ai/Snider/Poindexter v0.0.3 github.com/adrg/xdg v0.5.3 @@ -15,11 +14,11 @@ require ( require ( forge.lthn.ai/Snider/Enchantrix v0.0.4 // indirect - forge.lthn.ai/core/go-log v0.0.4 // indirect github.com/ProtonMail/go-crypto v1.4.0 // indirect github.com/cloudflare/circl v1.6.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/klauspost/compress v1.18.4 // indirect + github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/sys v0.42.0 // indirect diff --git a/go.sum b/go.sum index d7916fc..6783e51 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,18 @@ -dappco.re/go/core/io v0.2.0 h1:zuudgIiTsQQ5ipVt97saWdGLROovbEB/zdVyy9/l+I4= -dappco.re/go/core/io v0.2.0/go.mod h1:1QnQV6X9LNgFKfm8SkOtR9LLaj3bDcsOIeJOOyjbL5E= -dappco.re/go/core/log v0.1.0 h1:pa71Vq2TD2aoEUQWFKwNcaJ3GBY8HbaNGqtE688Unyc= -dappco.re/go/core/log v0.1.0/go.mod h1:Nkqb8gsXhZAO8VLpx7B8i1iAmohhzqA20b9Zr8VUcJs= +dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk= +dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A= forge.lthn.ai/Snider/Borg v0.3.1 h1:gfC1ZTpLoZai07oOWJiVeQ8+qJYK8A795tgVGJHbVL8= forge.lthn.ai/Snider/Borg v0.3.1/go.mod h1:Z7DJD0yHXsxSyM7Mjl6/g4gH1NBsIz44Bf5AFlV76Wg= forge.lthn.ai/Snider/Enchantrix v0.0.4 h1:biwpix/bdedfyc0iVeK15awhhJKH6TEMYOTXzHXx5TI= forge.lthn.ai/Snider/Enchantrix v0.0.4/go.mod h1:OGCwuVeZPq3OPe2h6TX/ZbgEjHU6B7owpIBeXQGbSe0= forge.lthn.ai/Snider/Poindexter v0.0.3 h1:cx5wRhuLRKBM8riIZyNVAT2a8rwRhn1dodFBktocsVE= forge.lthn.ai/Snider/Poindexter v0.0.3/go.mod h1:ddzGia98k3HKkR0gl58IDzqz+MmgW2cQJOCNLfuWPpo= -forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0= -forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= github.com/ProtonMail/go-crypto v1.4.0 h1:Zq/pbM3F5DFgJiMouxEdSVY44MVoQNEKp5d5QxIQceQ= github.com/ProtonMail/go-crypto v1.4.0/go.mod h1:e1OaTyu5SYVrO9gKOEhTc+5UcXtTUa+P3uLudwcgPqo= github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/logging/logger.go b/logging/logger.go index 4ab0666..0f463b7 100644 --- a/logging/logger.go +++ b/logging/logger.go @@ -1,33 +1,27 @@ -// Package logging provides structured logging with log levels and fields. +// logger := New(DefaultConfig()) package logging import ( - "fmt" "io" "maps" - "os" - "strings" "sync" + "syscall" "time" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" ) -// Level represents the severity of a log message. +// level := LevelInfo type Level int const ( - // LevelDebug is the most verbose log level. LevelDebug Level = iota - // LevelInfo is for general informational messages. LevelInfo - // LevelWarn is for warning messages. LevelWarn - // LevelError is for error messages. LevelError ) -// String returns the string representation of the log level. +// label := LevelWarn.String() func (l Level) String() string { switch l { case LevelDebug: @@ -43,44 +37,44 @@ func (l Level) String() string { } } -// Logger provides structured logging with configurable output and level. +// logger := New(DefaultConfig()) type Logger struct { - mu sync.Mutex + mu sync.RWMutex output io.Writer level Level component string } -// Config holds configuration for creating a new Logger. +// config := Config{Output: io.Discard, Level: LevelDebug, Component: "sync"} type Config struct { Output io.Writer Level Level Component string } -// DefaultConfig returns the default logger configuration. +// config := DefaultConfig() func DefaultConfig() Config { return Config{ - Output: os.Stderr, + Output: defaultOutput, Level: LevelInfo, Component: "", } } -// New creates a new Logger with the given configuration. -func New(cfg Config) *Logger { - if cfg.Output == nil { - cfg.Output = os.Stderr +// logger := New(DefaultConfig()) +func New(config Config) *Logger { + if config.Output == nil { + config.Output = defaultOutput } return &Logger{ - output: cfg.Output, - level: cfg.Level, - component: cfg.Component, + output: config.Output, + level: config.Level, + component: config.Component, } } -// WithComponent returns a new Logger with the specified component name. -func (l *Logger) WithComponent(component string) *Logger { +// transportLogger := logger.ComponentLogger("transport") +func (l *Logger) ComponentLogger(component string) *Logger { return &Logger{ output: l.output, level: l.level, @@ -88,25 +82,36 @@ func (l *Logger) WithComponent(component string) *Logger { } } -// SetLevel sets the minimum log level. +// logger.SetLevel(LevelDebug) func (l *Logger) SetLevel(level Level) { l.mu.Lock() defer l.mu.Unlock() l.level = level } -// GetLevel returns the current log level. +// level := logger.GetLevel() func (l *Logger) GetLevel() Level { - l.mu.Lock() - defer l.mu.Unlock() + l.mu.RLock() + defer l.mu.RUnlock() return l.level } -// Fields represents key-value pairs for structured logging. +// fields := Fields{"peer_id": "node-1", "attempt": 2} type Fields map[string]any -// log writes a log message at the specified level. -func (l *Logger) log(level Level, msg string, fields Fields) { +type stderrWriter struct{} + +func (stderrWriter) Write(p []byte) (int, error) { + written, err := syscall.Write(syscall.Stderr, p) + if err != nil { + return written, core.E("logging.stderrWriter.Write", "failed to write log line", err) + } + return written, nil +} + +var defaultOutput io.Writer = stderrWriter{} + +func (l *Logger) log(level Level, message string, fields Fields) { l.mu.Lock() defer l.mu.Unlock() @@ -114,8 +119,7 @@ func (l *Logger) log(level Level, msg string, fields Fields) { return } - // Build the log line - var sb strings.Builder + sb := core.NewBuilder() timestamp := time.Now().Format("2006/01/02 15:04:05") sb.WriteString(timestamp) sb.WriteString(" [") @@ -129,64 +133,63 @@ func (l *Logger) log(level Level, msg string, fields Fields) { } sb.WriteString(" ") - sb.WriteString(msg) + sb.WriteString(message) - // Add fields if present if len(fields) > 0 { sb.WriteString(" |") for k, v := range fields { sb.WriteString(" ") sb.WriteString(k) sb.WriteString("=") - sb.WriteString(fmt.Sprintf("%v", v)) + sb.WriteString(core.Sprint(v)) } } sb.WriteString("\n") - fmt.Fprint(l.output, sb.String()) + _, _ = l.output.Write([]byte(sb.String())) } -// Debug logs a debug message. -func (l *Logger) Debug(msg string, fields ...Fields) { - l.log(LevelDebug, msg, mergeFields(fields)) +// Debug("connected", Fields{"peer_id": "node-1"}) +func (l *Logger) Debug(message string, fields ...Fields) { + l.log(LevelDebug, message, mergeFields(fields)) } -// Info logs an informational message. -func (l *Logger) Info(msg string, fields ...Fields) { - l.log(LevelInfo, msg, mergeFields(fields)) +// Info("worker started", Fields{"component": "transport"}) +func (l *Logger) Info(message string, fields ...Fields) { + l.log(LevelInfo, message, mergeFields(fields)) } -// Warn logs a warning message. -func (l *Logger) Warn(msg string, fields ...Fields) { - l.log(LevelWarn, msg, mergeFields(fields)) +// Warn("peer rate limited", Fields{"peer_id": "node-1"}) +func (l *Logger) Warn(message string, fields ...Fields) { + l.log(LevelWarn, message, mergeFields(fields)) } -// Error logs an error message. -func (l *Logger) Error(msg string, fields ...Fields) { - l.log(LevelError, msg, mergeFields(fields)) +// Error("send failed", Fields{"peer_id": "node-1"}) +func (l *Logger) Error(message string, fields ...Fields) { + l.log(LevelError, message, mergeFields(fields)) } -// Debugf logs a formatted debug message. +// Debugf("connected peer %s", "node-1") func (l *Logger) Debugf(format string, args ...any) { - l.log(LevelDebug, fmt.Sprintf(format, args...), nil) + l.log(LevelDebug, core.Sprintf(format, args...), nil) } -// Infof logs a formatted informational message. +// Infof("worker %s ready", "node-1") func (l *Logger) Infof(format string, args ...any) { - l.log(LevelInfo, fmt.Sprintf(format, args...), nil) + l.log(LevelInfo, core.Sprintf(format, args...), nil) } -// Warnf logs a formatted warning message. +// Warnf("peer %s is slow", "node-1") func (l *Logger) Warnf(format string, args ...any) { - l.log(LevelWarn, fmt.Sprintf(format, args...), nil) + l.log(LevelWarn, core.Sprintf(format, args...), nil) } -// Errorf logs a formatted error message. +// Errorf("peer %s failed", "node-1") func (l *Logger) Errorf(format string, args ...any) { - l.log(LevelError, fmt.Sprintf(format, args...), nil) + l.log(LevelError, core.Sprintf(format, args...), nil) } -// mergeFields combines multiple Fields maps into one. +// fields := mergeFields([]Fields{{"peer_id": "node-1"}, {"attempt": 2}}) func mergeFields(fields []Fields) Fields { if len(fields) == 0 { return nil @@ -198,79 +201,75 @@ func mergeFields(fields []Fields) Fields { return result } -// --- Global logger for convenience --- - var ( globalLogger = New(DefaultConfig()) globalMu sync.RWMutex ) -// SetGlobal sets the global logger instance. +// SetGlobal(New(DefaultConfig())) func SetGlobal(l *Logger) { globalMu.Lock() defer globalMu.Unlock() globalLogger = l } -// GetGlobal returns the global logger instance. +// logger := GetGlobal() func GetGlobal() *Logger { globalMu.RLock() defer globalMu.RUnlock() return globalLogger } -// SetGlobalLevel sets the log level of the global logger. +// SetGlobalLevel(LevelDebug) func SetGlobalLevel(level Level) { globalMu.RLock() defer globalMu.RUnlock() globalLogger.SetLevel(level) } -// Global convenience functions that use the global logger - -// Debug logs a debug message using the global logger. -func Debug(msg string, fields ...Fields) { - GetGlobal().Debug(msg, fields...) +// Debug("connected", Fields{"peer_id": "node-1"}) +func Debug(message string, fields ...Fields) { + GetGlobal().Debug(message, fields...) } -// Info logs an informational message using the global logger. -func Info(msg string, fields ...Fields) { - GetGlobal().Info(msg, fields...) +// Info("worker started", Fields{"component": "transport"}) +func Info(message string, fields ...Fields) { + GetGlobal().Info(message, fields...) } -// Warn logs a warning message using the global logger. -func Warn(msg string, fields ...Fields) { - GetGlobal().Warn(msg, fields...) +// Warn("peer rate limited", Fields{"peer_id": "node-1"}) +func Warn(message string, fields ...Fields) { + GetGlobal().Warn(message, fields...) } -// Error logs an error message using the global logger. -func Error(msg string, fields ...Fields) { - GetGlobal().Error(msg, fields...) +// Error("send failed", Fields{"peer_id": "node-1"}) +func Error(message string, fields ...Fields) { + GetGlobal().Error(message, fields...) } -// Debugf logs a formatted debug message using the global logger. +// Debugf("connected peer %s", "node-1") func Debugf(format string, args ...any) { GetGlobal().Debugf(format, args...) } -// Infof logs a formatted informational message using the global logger. +// Infof("worker %s ready", "node-1") func Infof(format string, args ...any) { GetGlobal().Infof(format, args...) } -// Warnf logs a formatted warning message using the global logger. +// Warnf("peer %s is slow", "node-1") func Warnf(format string, args ...any) { GetGlobal().Warnf(format, args...) } -// Errorf logs a formatted error message using the global logger. +// Errorf("peer %s failed", "node-1") func Errorf(format string, args ...any) { GetGlobal().Errorf(format, args...) } -// ParseLevel parses a string into a log level. +// level, err := ParseLevel("warn") func ParseLevel(s string) (Level, error) { - switch strings.ToUpper(s) { + switch core.Upper(s) { case "DEBUG": return LevelDebug, nil case "INFO": @@ -280,6 +279,6 @@ func ParseLevel(s string) (Level, error) { case "ERROR": return LevelError, nil default: - return LevelInfo, coreerr.E("logging.ParseLevel", "unknown log level: "+s, nil) + return LevelInfo, core.E("logging.ParseLevel", "unknown log level: "+s, nil) } } diff --git a/logging/logger_test.go b/logging/logger_test.go index 5fa5163..c553130 100644 --- a/logging/logger_test.go +++ b/logging/logger_test.go @@ -2,11 +2,13 @@ package logging import ( "bytes" - "strings" + "sync" "testing" + + core "dappco.re/go/core" ) -func TestLoggerLevels(t *testing.T) { +func TestLogger_Levels_Good(t *testing.T) { var buf bytes.Buffer logger := New(Config{ Output: &buf, @@ -21,29 +23,29 @@ func TestLoggerLevels(t *testing.T) { // Info should appear logger.Info("info message") - if !strings.Contains(buf.String(), "[INFO]") { + if !core.Contains(buf.String(), "[INFO]") { t.Error("Info message should appear") } - if !strings.Contains(buf.String(), "info message") { + if !core.Contains(buf.String(), "info message") { t.Error("Info message content should appear") } buf.Reset() // Warn should appear logger.Warn("warn message") - if !strings.Contains(buf.String(), "[WARN]") { + if !core.Contains(buf.String(), "[WARN]") { t.Error("Warn message should appear") } buf.Reset() // Error should appear logger.Error("error message") - if !strings.Contains(buf.String(), "[ERROR]") { + if !core.Contains(buf.String(), "[ERROR]") { t.Error("Error message should appear") } } -func TestLoggerDebugLevel(t *testing.T) { +func TestLogger_DebugLevel_Good(t *testing.T) { var buf bytes.Buffer logger := New(Config{ Output: &buf, @@ -51,12 +53,12 @@ func TestLoggerDebugLevel(t *testing.T) { }) logger.Debug("debug message") - if !strings.Contains(buf.String(), "[DEBUG]") { + if !core.Contains(buf.String(), "[DEBUG]") { t.Error("Debug message should appear at Debug level") } } -func TestLoggerWithFields(t *testing.T) { +func TestLogger_WithFields_Good(t *testing.T) { var buf bytes.Buffer logger := New(Config{ Output: &buf, @@ -66,15 +68,15 @@ func TestLoggerWithFields(t *testing.T) { logger.Info("test message", Fields{"key": "value", "num": 42}) output := buf.String() - if !strings.Contains(output, "key=value") { + if !core.Contains(output, "key=value") { t.Error("Field key=value should appear") } - if !strings.Contains(output, "num=42") { + if !core.Contains(output, "num=42") { t.Error("Field num=42 should appear") } } -func TestLoggerWithComponent(t *testing.T) { +func TestLogger_ConfigComponent_Good(t *testing.T) { var buf bytes.Buffer logger := New(Config{ Output: &buf, @@ -85,28 +87,33 @@ func TestLoggerWithComponent(t *testing.T) { logger.Info("test message") output := buf.String() - if !strings.Contains(output, "[TestComponent]") { + if !core.Contains(output, "[TestComponent]") { t.Error("Component name should appear in log") } } -func TestLoggerDerivedComponent(t *testing.T) { +func TestLogger_ComponentLogger_Good(t *testing.T) { var buf bytes.Buffer parent := New(Config{ Output: &buf, Level: LevelInfo, }) - child := parent.WithComponent("ChildComponent") + child := parent.ComponentLogger("ChildComponent") child.Info("child message") + secondaryLogger := parent.ComponentLogger("SecondaryComponent") + secondaryLogger.Info("secondary message") output := buf.String() - if !strings.Contains(output, "[ChildComponent]") { + if !core.Contains(output, "[ChildComponent]") { t.Error("Derived component name should appear") } + if !core.Contains(output, "[SecondaryComponent]") { + t.Error("Secondary component should preserve the component name") + } } -func TestLoggerFormatted(t *testing.T) { +func TestLogger_Formatted_Good(t *testing.T) { var buf bytes.Buffer logger := New(Config{ Output: &buf, @@ -116,12 +123,12 @@ func TestLoggerFormatted(t *testing.T) { logger.Infof("formatted %s %d", "string", 123) output := buf.String() - if !strings.Contains(output, "formatted string 123") { + if !core.Contains(output, "formatted string 123") { t.Errorf("Formatted message should appear, got: %s", output) } } -func TestSetLevel(t *testing.T) { +func TestLogger_SetLevel_Good(t *testing.T) { var buf bytes.Buffer logger := New(Config{ Output: &buf, @@ -137,17 +144,17 @@ func TestSetLevel(t *testing.T) { // Change to Info level logger.SetLevel(LevelInfo) logger.Info("should appear now") - if !strings.Contains(buf.String(), "should appear now") { + if !core.Contains(buf.String(), "should appear now") { t.Error("Info should appear after level change") } - // Verify GetLevel + // Verify Level if logger.GetLevel() != LevelInfo { t.Error("GetLevel should return LevelInfo") } } -func TestParseLevel(t *testing.T) { +func TestLogger_ParseLevel_Good(t *testing.T) { tests := []struct { input string expected Level @@ -180,7 +187,7 @@ func TestParseLevel(t *testing.T) { } } -func TestGlobalLogger(t *testing.T) { +func TestLogger_GlobalLogger_Good(t *testing.T) { var buf bytes.Buffer logger := New(Config{ Output: &buf, @@ -190,7 +197,7 @@ func TestGlobalLogger(t *testing.T) { SetGlobal(logger) Info("global test") - if !strings.Contains(buf.String(), "global test") { + if !core.Contains(buf.String(), "global test") { t.Error("Global logger should write message") } @@ -205,7 +212,7 @@ func TestGlobalLogger(t *testing.T) { SetGlobal(New(DefaultConfig())) } -func TestLevelString(t *testing.T) { +func TestLogger_LevelString_Good(t *testing.T) { tests := []struct { level Level expected string @@ -224,7 +231,7 @@ func TestLevelString(t *testing.T) { } } -func TestMergeFields(t *testing.T) { +func TestLogger_MergeFields_Good(t *testing.T) { // Empty fields result := mergeFields(nil) if result != nil { @@ -260,3 +267,35 @@ func TestMergeFields(t *testing.T) { t.Error("Later fields should override earlier ones") } } + +func TestLogger_ParseLevel_Bad(t *testing.T) { + _, err := ParseLevel("bogus") + if err == nil { + t.Error("ParseLevel should return an error for an unrecognised level string") + } +} + +func TestLogger_ConcurrentWrite_Ugly(t *testing.T) { + var buf bytes.Buffer + logger := New(Config{ + Output: &buf, + Level: LevelDebug, + }) + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := range goroutines { + go func(n int) { + defer wg.Done() + logger.Infof("concurrent message %d", n) + }(i) + } + + wg.Wait() + // Only assert no panics / races occurred; output ordering is non-deterministic. + if buf.Len() == 0 { + t.Error("expected concurrent log writes to produce output") + } +} diff --git a/node/ax_test_helpers_test.go b/node/ax_test_helpers_test.go new file mode 100644 index 0000000..067165a --- /dev/null +++ b/node/ax_test_helpers_test.go @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package node + +import ( + "io/fs" + "testing" + + core "dappco.re/go/core" + "github.com/stretchr/testify/require" +) + +func testJoinPath(parts ...string) string { + return core.JoinPath(parts...) +} + +func testNodeManagerPaths(dir string) (string, string) { + return testJoinPath(dir, "private.key"), testJoinPath(dir, "node.json") +} + +func testWriteFile(t *testing.T, path string, content []byte, mode fs.FileMode) { + t.Helper() + require.NoError(t, filesystemResultError(localFileSystem.WriteMode(path, string(content), mode))) +} + +func testReadFile(t *testing.T, path string) []byte { + t.Helper() + content, err := filesystemRead(path) + require.NoError(t, err) + return []byte(content) +} + +func testJSONMarshal(t *testing.T, v any) []byte { + t.Helper() + result := core.JSONMarshal(v) + require.True(t, result.OK, "marshal should succeed: %v", result.Value) + return result.Value.([]byte) +} + +func testJSONUnmarshal(t *testing.T, data []byte, target any) { + t.Helper() + result := core.JSONUnmarshal(data, target) + require.True(t, result.OK, "unmarshal should succeed: %v", result.Value) +} diff --git a/node/bench_test.go b/node/bench_test.go index 7123797..f96cefb 100644 --- a/node/bench_test.go +++ b/node/bench_test.go @@ -2,11 +2,10 @@ package node import ( "encoding/base64" - "encoding/json" - "path/filepath" "testing" "time" + core "dappco.re/go/core" "forge.lthn.ai/Snider/Borg/pkg/smsg" ) @@ -16,10 +15,7 @@ func BenchmarkIdentityGenerate(b *testing.B) { b.ReportAllocs() for b.Loop() { dir := b.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { b.Fatalf("create node manager: %v", err) } @@ -34,10 +30,10 @@ func BenchmarkDeriveSharedSecret(b *testing.B) { dir1 := b.TempDir() dir2 := b.TempDir() - nm1, _ := NewNodeManagerWithPaths(filepath.Join(dir1, "k"), filepath.Join(dir1, "n")) + nm1, _ := NewNodeManagerFromPaths(testJoinPath(dir1, "k"), testJoinPath(dir1, "n")) nm1.GenerateIdentity("node1", RoleDual) - nm2, _ := NewNodeManagerWithPaths(filepath.Join(dir2, "k"), filepath.Join(dir2, "n")) + nm2, _ := NewNodeManagerFromPaths(testJoinPath(dir2, "k"), testJoinPath(dir2, "n")) nm2.GenerateIdentity("node2", RoleDual) peerPubKey := nm2.GetIdentity().PublicKey @@ -77,7 +73,7 @@ func BenchmarkMessageSerialise(b *testing.B) { b.ResetTimer() for b.Loop() { - msg, err := NewMessage(MsgStats, "sender-id", "receiver-id", payload) + msg, err := NewMessage(MessageStats, "sender-id", "receiver-id", payload) if err != nil { b.Fatalf("create message: %v", err) } @@ -88,8 +84,8 @@ func BenchmarkMessageSerialise(b *testing.B) { } var restored Message - if err := json.Unmarshal(data, &restored); err != nil { - b.Fatalf("unmarshal message: %v", err) + if result := core.JSONUnmarshal(data, &restored); !result.OK { + b.Fatalf("unmarshal message: %v", result.Value) } } } @@ -102,7 +98,7 @@ func BenchmarkMessageCreateOnly(b *testing.B) { b.ResetTimer() for b.Loop() { - _, err := NewMessage(MsgPing, "sender", "receiver", payload) + _, err := NewMessage(MessagePing, "sender", "receiver", payload) if err != nil { b.Fatalf("create message: %v", err) } @@ -136,9 +132,8 @@ func BenchmarkMarshalJSON(b *testing.B) { b.Run("Stdlib", func(b *testing.B) { b.ReportAllocs() for b.Loop() { - _, err := json.Marshal(data) - if err != nil { - b.Fatal(err) + if result := core.JSONMarshal(data); !result.OK { + b.Fatal(result.Value) } } }) @@ -150,10 +145,10 @@ func BenchmarkSMSGEncryptDecrypt(b *testing.B) { dir1 := b.TempDir() dir2 := b.TempDir() - nm1, _ := NewNodeManagerWithPaths(filepath.Join(dir1, "k"), filepath.Join(dir1, "n")) + nm1, _ := NewNodeManagerFromPaths(testJoinPath(dir1, "k"), testJoinPath(dir1, "n")) nm1.GenerateIdentity("node1", RoleDual) - nm2, _ := NewNodeManagerWithPaths(filepath.Join(dir2, "k"), filepath.Join(dir2, "n")) + nm2, _ := NewNodeManagerFromPaths(testJoinPath(dir2, "k"), testJoinPath(dir2, "n")) nm2.GenerateIdentity("node2", RoleDual) sharedSecret, _ := nm1.DeriveSharedSecret(nm2.GetIdentity().PublicKey) @@ -202,7 +197,7 @@ func BenchmarkChallengeSignVerify(b *testing.B) { // BenchmarkPeerScoring measures KD-tree rebuild and peer selection. func BenchmarkPeerScoring(b *testing.B) { dir := b.TempDir() - reg, err := NewPeerRegistryWithPath(filepath.Join(dir, "peers.json")) + reg, err := NewPeerRegistryFromPath(testJoinPath(dir, "peers.json")) if err != nil { b.Fatalf("create registry: %v", err) } @@ -211,13 +206,13 @@ func BenchmarkPeerScoring(b *testing.B) { // Add 50 peers with varied metrics for i := range 50 { peer := &Peer{ - ID: filepath.Join("peer", string(rune('A'+i%26)), string(rune('0'+i/26))), - Name: "peer", - PingMS: float64(i*10 + 5), - Hops: i%5 + 1, - GeoKM: float64(i * 100), - Score: float64(50 + i%50), - AddedAt: time.Now(), + ID: testJoinPath("peer", string(rune('A'+i%26)), string(rune('0'+i/26))), + Name: "peer", + PingMilliseconds: float64(i*10 + 5), + Hops: i%5 + 1, + GeographicKilometres: float64(i * 100), + Score: float64(50 + i%50), + AddedAt: time.Now(), } // Bypass AddPeer's duplicate check by adding directly reg.mu.Lock() diff --git a/node/buffer_pool.go b/node/buffer_pool.go new file mode 100644 index 0000000..b883e09 --- /dev/null +++ b/node/buffer_pool.go @@ -0,0 +1,51 @@ +package node + +import ( + "bytes" + "sync" + + core "dappco.re/go/core" +) + +// bufferPool provides reusable byte buffers for JSON encoding in hot paths. +// This reduces allocation overhead in message serialization. +var bufferPool = sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, 1024)) + }, +} + +func getBuffer() *bytes.Buffer { + buffer := bufferPool.Get().(*bytes.Buffer) + buffer.Reset() + return buffer +} + +func putBuffer(buffer *bytes.Buffer) { + // Don't pool buffers that grew too large (>64KB) + if buffer.Cap() <= 65536 { + bufferPool.Put(buffer) + } +} + +// MarshalJSON encodes a value to JSON using Core's JSON primitive and then +// restores the historical no-EscapeHTML behaviour expected by the node package. +// Returns a copy of the encoded bytes (safe to use after the function returns). +// +// data, err := MarshalJSON(value) +func MarshalJSON(value any) ([]byte, error) { + encoded := core.JSONMarshal(value) + if !encoded.OK { + return nil, encoded.Value.(error) + } + data := encoded.Value.([]byte) + + data = bytes.ReplaceAll(data, []byte(`\u003c`), []byte("<")) + data = bytes.ReplaceAll(data, []byte(`\u003e`), []byte(">")) + data = bytes.ReplaceAll(data, []byte(`\u0026`), []byte("&")) + + // Return a copy since callers may retain the slice after subsequent calls. + out := make([]byte, len(data)) + copy(out, data) + return out, nil +} diff --git a/node/bufpool_test.go b/node/buffer_pool_test.go similarity index 82% rename from node/bufpool_test.go rename to node/buffer_pool_test.go index cd0c786..35e1cc5 100644 --- a/node/bufpool_test.go +++ b/node/buffer_pool_test.go @@ -2,17 +2,17 @@ package node import ( "bytes" - "encoding/json" "sync" "testing" + core "dappco.re/go/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// --- bufpool.go tests --- +// --- buffer_pool.go tests --- -func TestGetBuffer_ReturnsResetBuffer(t *testing.T) { +func TestBufferPool_Buffer_ReturnsResetBuffer_Good(t *testing.T) { t.Run("buffer is initially empty", func(t *testing.T) { buf := getBuffer() defer putBuffer(buf) @@ -33,7 +33,7 @@ func TestGetBuffer_ReturnsResetBuffer(t *testing.T) { }) } -func TestPutBuffer_DiscardsOversizedBuffers(t *testing.T) { +func TestBufferPool_PutBuffer_DiscardsOversizedBuffers_Good(t *testing.T) { t.Run("buffer at 64KB limit is pooled", func(t *testing.T) { buf := getBuffer() buf.Grow(65536) @@ -59,7 +59,7 @@ func TestPutBuffer_DiscardsOversizedBuffers(t *testing.T) { }) } -func TestBufPool_BufferIndependence(t *testing.T) { +func TestBufferPool_BufferIndependence_Good(t *testing.T) { buf1 := getBuffer() buf2 := getBuffer() @@ -77,7 +77,7 @@ func TestBufPool_BufferIndependence(t *testing.T) { putBuffer(buf2) } -func TestMarshalJSON_BasicTypes(t *testing.T) { +func TestBufferPool_MarshalJSON_BasicTypes_Good(t *testing.T) { tests := []struct { name string input any @@ -121,8 +121,7 @@ func TestMarshalJSON_BasicTypes(t *testing.T) { got, err := MarshalJSON(tt.input) require.NoError(t, err) - expected, err := json.Marshal(tt.input) - require.NoError(t, err) + expected := testJSONMarshal(t, tt.input) assert.JSONEq(t, string(expected), string(got), "MarshalJSON output should match json.Marshal") @@ -130,7 +129,7 @@ func TestMarshalJSON_BasicTypes(t *testing.T) { } } -func TestMarshalJSON_NoTrailingNewline(t *testing.T) { +func TestBufferPool_MarshalJSON_NoTrailingNewline_Good(t *testing.T) { data, err := MarshalJSON(map[string]string{"key": "value"}) require.NoError(t, err) @@ -138,7 +137,7 @@ func TestMarshalJSON_NoTrailingNewline(t *testing.T) { "MarshalJSON should strip the trailing newline added by json.Encoder") } -func TestMarshalJSON_HTMLEscaping(t *testing.T) { +func TestBufferPool_MarshalJSON_HTMLEscaping_Good(t *testing.T) { input := map[string]string{"html": ""} data, err := MarshalJSON(input) require.NoError(t, err) @@ -147,7 +146,7 @@ func TestMarshalJSON_HTMLEscaping(t *testing.T) { "HTML characters should not be escaped when EscapeHTML is false") } -func TestMarshalJSON_ReturnsCopy(t *testing.T) { +func TestBufferPool_MarshalJSON_ReturnsCopy_Good(t *testing.T) { data1, err := MarshalJSON("first") require.NoError(t, err) @@ -162,7 +161,7 @@ func TestMarshalJSON_ReturnsCopy(t *testing.T) { "returned slice should be a copy and not be mutated by subsequent calls") } -func TestMarshalJSON_ReturnsIndependentCopy(t *testing.T) { +func TestBufferPool_MarshalJSON_ReturnsIndependentCopy_Good(t *testing.T) { data1, err := MarshalJSON(map[string]string{"first": "call"}) require.NoError(t, err) @@ -175,13 +174,13 @@ func TestMarshalJSON_ReturnsIndependentCopy(t *testing.T) { "second result should contain its own data") } -func TestMarshalJSON_InvalidValue(t *testing.T) { +func TestBufferPool_MarshalJSON_InvalidValue_Bad(t *testing.T) { ch := make(chan int) _, err := MarshalJSON(ch) assert.Error(t, err, "marshalling an unserialisable type should return an error") } -func TestBufferPool_ConcurrentAccess(t *testing.T) { +func TestBufferPool_ConcurrentAccess_Ugly(t *testing.T) { const goroutines = 100 const iterations = 50 @@ -206,7 +205,7 @@ func TestBufferPool_ConcurrentAccess(t *testing.T) { wg.Wait() } -func TestMarshalJSON_ConcurrentSafety(t *testing.T) { +func TestBufferPool_MarshalJSON_ConcurrentSafety_Ugly(t *testing.T) { const goroutines = 50 var wg sync.WaitGroup @@ -223,8 +222,8 @@ func TestMarshalJSON_ConcurrentSafety(t *testing.T) { if err == nil { var parsed PingPayload - err = json.Unmarshal(data, &parsed) - if err != nil { + if result := core.JSONUnmarshal(data, &parsed); !result.OK { + err = result.Value.(error) errs[idx] = err return } @@ -242,7 +241,7 @@ func TestMarshalJSON_ConcurrentSafety(t *testing.T) { } } -func TestBufferPool_ReuseAfterReset(t *testing.T) { +func TestBufferPool_ReuseAfterReset_Ugly(t *testing.T) { buf := getBuffer() buf.Write(make([]byte, 4096)) putBuffer(buf) diff --git a/node/bufpool.go b/node/bufpool.go deleted file mode 100644 index 7848214..0000000 --- a/node/bufpool.go +++ /dev/null @@ -1,55 +0,0 @@ -package node - -import ( - "bytes" - "encoding/json" - "sync" -) - -// bufferPool provides reusable byte buffers for JSON encoding. -// This reduces allocation overhead in hot paths like message serialization. -var bufferPool = sync.Pool{ - New: func() any { - return bytes.NewBuffer(make([]byte, 0, 1024)) - }, -} - -// getBuffer retrieves a buffer from the pool. -func getBuffer() *bytes.Buffer { - buf := bufferPool.Get().(*bytes.Buffer) - buf.Reset() - return buf -} - -// putBuffer returns a buffer to the pool. -func putBuffer(buf *bytes.Buffer) { - // Don't pool buffers that grew too large (>64KB) - if buf.Cap() <= 65536 { - bufferPool.Put(buf) - } -} - -// MarshalJSON encodes a value to JSON using a pooled buffer. -// Returns a copy of the encoded bytes (safe to use after the function returns). -func MarshalJSON(v any) ([]byte, error) { - buf := getBuffer() - defer putBuffer(buf) - - enc := json.NewEncoder(buf) - // Don't escape HTML characters (matches json.Marshal behavior for these use cases) - enc.SetEscapeHTML(false) - if err := enc.Encode(v); err != nil { - return nil, err - } - - // json.Encoder.Encode adds a newline; remove it to match json.Marshal - data := buf.Bytes() - if len(data) > 0 && data[len(data)-1] == '\n' { - data = data[:len(data)-1] - } - - // Return a copy since the buffer will be reused - result := make([]byte, len(data)) - copy(result, data) - return result, nil -} diff --git a/node/bundle.go b/node/bundle.go index 8c48f57..034fd95 100644 --- a/node/bundle.go +++ b/node/bundle.go @@ -5,29 +5,28 @@ import ( "bytes" "crypto/sha256" "encoding/hex" - "encoding/json" "io" - "os" - "path/filepath" - "strings" + "io/fs" - coreio "dappco.re/go/core/io" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" "forge.lthn.ai/Snider/Borg/pkg/datanode" "forge.lthn.ai/Snider/Borg/pkg/tim" ) -// BundleType defines the type of deployment bundle. +// bundleType := BundleProfile type BundleType string const ( - BundleProfile BundleType = "profile" // Just config/profile JSON - BundleMiner BundleType = "miner" // Miner binary + config - BundleFull BundleType = "full" // Everything (miner + profiles + config) + // BundleProfile contains a profile JSON payload. + BundleProfile BundleType = "profile" + // BundleMiner contains a miner binary and optional profile data. + BundleMiner BundleType = "miner" + // BundleFull contains the full deployment payload. + BundleFull BundleType = "full" ) -// Bundle represents a deployment bundle for P2P transfer. +// bundle := &Bundle{Type: BundleProfile, Name: "xmrig", Data: []byte("{}")} type Bundle struct { Type BundleType `json:"type"` Name string `json:"name"` @@ -35,7 +34,7 @@ type Bundle struct { Checksum string `json:"checksum"` // SHA-256 of Data } -// BundleManifest describes the contents of a bundle. +// manifest := BundleManifest{Name: "xmrig", Type: BundleMiner} type BundleManifest struct { Type BundleType `json:"type"` Name string `json:"name"` @@ -45,22 +44,19 @@ type BundleManifest struct { CreatedAt string `json:"createdAt"` } -// CreateProfileBundle creates an encrypted bundle containing a mining profile. +// bundle, err := CreateProfileBundle(profileJSON, "xmrig-default", "password") func CreateProfileBundle(profileJSON []byte, name string, password string) (*Bundle, error) { - // Create a TIM with just the profile config - t, err := tim.New() + timBundle, err := tim.New() if err != nil { - return nil, coreerr.E("CreateProfileBundle", "failed to create TIM", err) + return nil, core.E("CreateProfileBundle", "failed to create TIM", err) } - t.Config = profileJSON + timBundle.Config = profileJSON - // Encrypt to STIM format - stimData, err := t.ToSigil(password) + stimData, err := timBundle.ToSigil(password) if err != nil { - return nil, coreerr.E("CreateProfileBundle", "failed to encrypt bundle", err) + return nil, core.E("CreateProfileBundle", "failed to encrypt bundle", err) } - // Calculate checksum checksum := calculateChecksum(stimData) return &Bundle{ @@ -71,7 +67,7 @@ func CreateProfileBundle(profileJSON []byte, name string, password string) (*Bun }, nil } -// CreateProfileBundleUnencrypted creates a plain JSON bundle (for testing or trusted networks). +// bundle, err := CreateProfileBundleUnencrypted(profileJSON, "xmrig-default") func CreateProfileBundleUnencrypted(profileJSON []byte, name string) (*Bundle, error) { checksum := calculateChecksum(profileJSON) @@ -83,44 +79,38 @@ func CreateProfileBundleUnencrypted(profileJSON []byte, name string) (*Bundle, e }, nil } -// CreateMinerBundle creates an encrypted bundle containing a miner binary and optional profile. +// bundle, err := CreateMinerBundle("/srv/miners/xmrig", profileJSON, "xmrig", "password") func CreateMinerBundle(minerPath string, profileJSON []byte, name string, password string) (*Bundle, error) { - // Read miner binary - minerContent, err := coreio.Local.Read(minerPath) + minerContent, err := filesystemRead(minerPath) if err != nil { - return nil, coreerr.E("CreateMinerBundle", "failed to read miner binary", err) + return nil, core.E("CreateMinerBundle", "failed to read miner binary", err) } minerData := []byte(minerContent) - // Create a tarball with the miner binary tarData, err := createTarball(map[string][]byte{ - filepath.Base(minerPath): minerData, + core.PathBase(minerPath): minerData, }) if err != nil { - return nil, coreerr.E("CreateMinerBundle", "failed to create tarball", err) + return nil, core.E("CreateMinerBundle", "failed to create tarball", err) } - // Create DataNode from tarball - dn, err := datanode.FromTar(tarData) + dataNode, err := datanode.FromTar(tarData) if err != nil { - return nil, coreerr.E("CreateMinerBundle", "failed to create datanode", err) + return nil, core.E("CreateMinerBundle", "failed to create datanode", err) } - // Create TIM from DataNode - t, err := tim.FromDataNode(dn) + timBundle, err := tim.FromDataNode(dataNode) if err != nil { - return nil, coreerr.E("CreateMinerBundle", "failed to create TIM", err) + return nil, core.E("CreateMinerBundle", "failed to create TIM", err) } - // Set profile as config if provided if profileJSON != nil { - t.Config = profileJSON + timBundle.Config = profileJSON } - // Encrypt to STIM format - stimData, err := t.ToSigil(password) + stimData, err := timBundle.ToSigil(password) if err != nil { - return nil, coreerr.E("CreateMinerBundle", "failed to encrypt bundle", err) + return nil, core.E("CreateMinerBundle", "failed to encrypt bundle", err) } checksum := calculateChecksum(stimData) @@ -133,67 +123,58 @@ func CreateMinerBundle(minerPath string, profileJSON []byte, name string, passwo }, nil } -// ExtractProfileBundle decrypts and extracts a profile bundle. +// profileJSON, err := ExtractProfileBundle(bundle, "password") func ExtractProfileBundle(bundle *Bundle, password string) ([]byte, error) { - // Verify checksum first if calculateChecksum(bundle.Data) != bundle.Checksum { - return nil, coreerr.E("ExtractProfileBundle", "checksum mismatch - bundle may be corrupted", nil) + return nil, core.E("ExtractProfileBundle", "checksum mismatch - bundle may be corrupted", nil) } - // If it's unencrypted JSON, just return it if isJSON(bundle.Data) { return bundle.Data, nil } - // Decrypt STIM format - t, err := tim.FromSigil(bundle.Data, password) + timBundle, err := tim.FromSigil(bundle.Data, password) if err != nil { - return nil, coreerr.E("ExtractProfileBundle", "failed to decrypt bundle", err) + return nil, core.E("ExtractProfileBundle", "failed to decrypt bundle", err) } - return t.Config, nil + return timBundle.Config, nil } -// ExtractMinerBundle decrypts and extracts a miner bundle, returning the miner path and profile. +// minerPath, profileJSON, err := ExtractMinerBundle(bundle, "password", "/srv/miners") func ExtractMinerBundle(bundle *Bundle, password string, destDir string) (string, []byte, error) { - // Verify checksum if calculateChecksum(bundle.Data) != bundle.Checksum { - return "", nil, coreerr.E("ExtractMinerBundle", "checksum mismatch - bundle may be corrupted", nil) + return "", nil, core.E("ExtractMinerBundle", "checksum mismatch - bundle may be corrupted", nil) } - // Decrypt STIM format - t, err := tim.FromSigil(bundle.Data, password) + timBundle, err := tim.FromSigil(bundle.Data, password) if err != nil { - return "", nil, coreerr.E("ExtractMinerBundle", "failed to decrypt bundle", err) + return "", nil, core.E("ExtractMinerBundle", "failed to decrypt bundle", err) } - // Convert rootfs to tarball and extract - tarData, err := t.RootFS.ToTar() + tarData, err := timBundle.RootFS.ToTar() if err != nil { - return "", nil, coreerr.E("ExtractMinerBundle", "failed to convert rootfs to tar", err) + return "", nil, core.E("ExtractMinerBundle", "failed to convert rootfs to tar", err) } - // Extract tarball to destination minerPath, err := extractTarball(tarData, destDir) if err != nil { - return "", nil, coreerr.E("ExtractMinerBundle", "failed to extract tarball", err) + return "", nil, core.E("ExtractMinerBundle", "failed to extract tarball", err) } - return minerPath, t.Config, nil + return minerPath, timBundle.Config, nil } -// VerifyBundle checks if a bundle's checksum is valid. +// ok := VerifyBundle(bundle) func VerifyBundle(bundle *Bundle) bool { return calculateChecksum(bundle.Data) == bundle.Checksum } -// calculateChecksum computes SHA-256 checksum of data. func calculateChecksum(data []byte) string { hash := sha256.Sum256(data) return hex.EncodeToString(hash[:]) } -// isJSON checks if data starts with JSON characters. func isJSON(data []byte) bool { if len(data) == 0 { return false @@ -202,73 +183,78 @@ func isJSON(data []byte) bool { return data[0] == '{' || data[0] == '[' } -// createTarball creates a tar archive from a map of filename -> content. func createTarball(files map[string][]byte) ([]byte, error) { var buf bytes.Buffer - tw := tar.NewWriter(&buf) + tarWriter := tar.NewWriter(&buf) - // Track directories we've created - dirs := make(map[string]bool) + createdDirectories := make(map[string]bool) for name, content := range files { - // Create parent directories if needed - dir := filepath.Dir(name) - if dir != "." && !dirs[dir] { - hdr := &tar.Header{ + dir := core.PathDir(name) + if dir != "." && !createdDirectories[dir] { + header := &tar.Header{ Name: dir + "/", Mode: 0755, Typeflag: tar.TypeDir, } - if err := tw.WriteHeader(hdr); err != nil { + if err := tarWriter.WriteHeader(header); err != nil { return nil, err } - dirs[dir] = true + createdDirectories[dir] = true } - // Determine file mode (executable for binaries in miners/) + // Binaries in miners/ and non-JSON content get executable permissions. mode := int64(0644) - if filepath.Dir(name) == "miners" || !isJSON(content) { + if core.PathDir(name) == "miners" || !isJSON(content) { mode = 0755 } - hdr := &tar.Header{ + header := &tar.Header{ Name: name, Mode: mode, Size: int64(len(content)), } - if err := tw.WriteHeader(hdr); err != nil { + if err := tarWriter.WriteHeader(header); err != nil { return nil, err } - if _, err := tw.Write(content); err != nil { + if _, err := tarWriter.Write(content); err != nil { return nil, err } } - if err := tw.Close(); err != nil { + if err := tarWriter.Close(); err != nil { return nil, err } return buf.Bytes(), nil } -// extractTarball extracts a tar archive to a directory, returns first executable found. func extractTarball(tarData []byte, destDir string) (string, error) { // Ensure destDir is an absolute, clean path for security checks - absDestDir, err := filepath.Abs(destDir) - if err != nil { - return "", coreerr.E("extractTarball", "failed to resolve destination directory", err) + absDestDir := destDir + pathSeparator := core.Env("DS") + if pathSeparator == "" { + pathSeparator = "/" + } + if !core.PathIsAbs(absDestDir) { + cwd := core.Env("DIR_CWD") + if cwd == "" { + return "", core.E("extractTarball", "failed to resolve destination directory", nil) + } + absDestDir = core.CleanPath(core.Concat(cwd, pathSeparator, absDestDir), pathSeparator) + } else { + absDestDir = core.CleanPath(absDestDir, pathSeparator) } - absDestDir = filepath.Clean(absDestDir) - if err := coreio.Local.EnsureDir(absDestDir); err != nil { + if err := filesystemEnsureDir(absDestDir); err != nil { return "", err } - tr := tar.NewReader(bytes.NewReader(tarData)) + tarReader := tar.NewReader(bytes.NewReader(tarData)) var firstExecutable string for { - hdr, err := tr.Next() + header, err := tarReader.Next() if err == io.EOF { break } @@ -277,61 +263,58 @@ func extractTarball(tarData []byte, destDir string) (string, error) { } // Security: Sanitize the tar entry name to prevent path traversal (Zip Slip) - cleanName := filepath.Clean(hdr.Name) + cleanName := core.CleanPath(header.Name, "/") // Reject absolute paths - if filepath.IsAbs(cleanName) { - return "", coreerr.E("extractTarball", "invalid tar entry: absolute path not allowed: "+hdr.Name, nil) + if core.PathIsAbs(cleanName) { + return "", core.E("extractTarball", "invalid tar entry: absolute path not allowed: "+header.Name, nil) } // Reject paths that escape the destination directory - if strings.HasPrefix(cleanName, ".."+string(os.PathSeparator)) || cleanName == ".." { - return "", coreerr.E("extractTarball", "invalid tar entry: path traversal attempt: "+hdr.Name, nil) + if core.HasPrefix(cleanName, "../") || cleanName == ".." { + return "", core.E("extractTarball", "invalid tar entry: path traversal attempt: "+header.Name, nil) } // Build the full path and verify it's within destDir - fullPath := filepath.Join(absDestDir, cleanName) - fullPath = filepath.Clean(fullPath) + fullPath := core.CleanPath(core.Concat(absDestDir, pathSeparator, cleanName), pathSeparator) // Final security check: ensure the path is still within destDir - if !strings.HasPrefix(fullPath, absDestDir+string(os.PathSeparator)) && fullPath != absDestDir { - return "", coreerr.E("extractTarball", "invalid tar entry: path escape attempt: "+hdr.Name, nil) + allowedPrefix := core.Concat(absDestDir, pathSeparator) + if absDestDir == pathSeparator { + allowedPrefix = absDestDir + } + if !core.HasPrefix(fullPath, allowedPrefix) && fullPath != absDestDir { + return "", core.E("extractTarball", "invalid tar entry: path escape attempt: "+header.Name, nil) } - switch hdr.Typeflag { + switch header.Typeflag { case tar.TypeDir: - if err := coreio.Local.EnsureDir(fullPath); err != nil { + if err := filesystemEnsureDir(fullPath); err != nil { return "", err } case tar.TypeReg: // Ensure parent directory exists - if err := coreio.Local.EnsureDir(filepath.Dir(fullPath)); err != nil { + if err := filesystemEnsureDir(core.PathDir(fullPath)); err != nil { return "", err } - // os.OpenFile is used deliberately here instead of coreio.Local.Create/Write - // because coreio hardcodes file permissions (0644) and we need to preserve - // the tar header's mode bits — executable binaries require 0755. - f, err := os.OpenFile(fullPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) - if err != nil { - return "", coreerr.E("extractTarball", "failed to create file "+hdr.Name, err) - } - // Limit file size to prevent decompression bombs (100MB max per file) const maxFileSize int64 = 100 * 1024 * 1024 - limitedReader := io.LimitReader(tr, maxFileSize+1) - written, err := io.Copy(f, limitedReader) - f.Close() + limitedReader := io.LimitReader(tarReader, maxFileSize+1) + content, err := io.ReadAll(limitedReader) if err != nil { - return "", coreerr.E("extractTarball", "failed to write file "+hdr.Name, err) + return "", core.E("extractTarball", "failed to write file "+header.Name, err) } - if written > maxFileSize { - coreio.Local.Delete(fullPath) - return "", coreerr.E("extractTarball", "file "+hdr.Name+" exceeds maximum size", nil) + if int64(len(content)) > maxFileSize { + filesystemDelete(fullPath) + return "", core.E("extractTarball", "file "+header.Name+" exceeds maximum size", nil) + } + if err := filesystemResultError(localFileSystem.WriteMode(fullPath, string(content), fs.FileMode(header.Mode))); err != nil { + return "", core.E("extractTarball", "failed to create file "+header.Name, err) } // Track first executable - if hdr.Mode&0111 != 0 && firstExecutable == "" { + if header.Mode&0111 != 0 && firstExecutable == "" { firstExecutable = fullPath } // Explicitly ignore symlinks and hard links to prevent symlink attacks @@ -344,18 +327,27 @@ func extractTarball(tarData []byte, destDir string) (string, error) { return firstExecutable, nil } -// StreamBundle writes a bundle to a writer (for large transfers). +// err := StreamBundle(bundle, writer) func StreamBundle(bundle *Bundle, w io.Writer) error { - encoder := json.NewEncoder(w) - return encoder.Encode(bundle) + result := core.JSONMarshal(bundle) + if !result.OK { + return result.Value.(error) + } + _, err := w.Write(result.Value.([]byte)) + return err } -// ReadBundle reads a bundle from a reader. +// bundle, err := ReadBundle(reader) func ReadBundle(r io.Reader) (*Bundle, error) { - var bundle Bundle - decoder := json.NewDecoder(r) - if err := decoder.Decode(&bundle); err != nil { + var buf bytes.Buffer + if _, err := io.Copy(&buf, r); err != nil { return nil, err } + + var bundle Bundle + result := core.JSONUnmarshal(buf.Bytes(), &bundle) + if !result.OK { + return nil, result.Value.(error) + } return &bundle, nil } diff --git a/node/bundle_test.go b/node/bundle_test.go index 80c2ed3..59d22a5 100644 --- a/node/bundle_test.go +++ b/node/bundle_test.go @@ -3,12 +3,10 @@ package node import ( "archive/tar" "bytes" - "os" - "path/filepath" "testing" ) -func TestCreateProfileBundleUnencrypted(t *testing.T) { +func TestBundle_CreateProfileBundleUnencrypted_Good(t *testing.T) { profileJSON := []byte(`{"name":"test-profile","minerType":"xmrig","config":{}}`) bundle, err := CreateProfileBundleUnencrypted(profileJSON, "test-profile") @@ -33,7 +31,7 @@ func TestCreateProfileBundleUnencrypted(t *testing.T) { } } -func TestVerifyBundle(t *testing.T) { +func TestBundle_VerifyBundle_Good(t *testing.T) { t.Run("ValidChecksum", func(t *testing.T) { bundle, _ := CreateProfileBundleUnencrypted([]byte(`{"test":"data"}`), "test") @@ -61,7 +59,7 @@ func TestVerifyBundle(t *testing.T) { }) } -func TestCreateProfileBundle(t *testing.T) { +func TestBundle_CreateProfileBundle_Good(t *testing.T) { profileJSON := []byte(`{"name":"encrypted-profile","minerType":"xmrig"}`) password := "test-password-123" @@ -90,7 +88,7 @@ func TestCreateProfileBundle(t *testing.T) { } } -func TestExtractProfileBundle(t *testing.T) { +func TestBundle_ExtractProfileBundle_Good(t *testing.T) { t.Run("UnencryptedBundle", func(t *testing.T) { originalJSON := []byte(`{"name":"plain","config":{}}`) bundle, _ := CreateProfileBundleUnencrypted(originalJSON, "plain") @@ -142,7 +140,7 @@ func TestExtractProfileBundle(t *testing.T) { }) } -func TestTarballFunctions(t *testing.T) { +func TestBundle_TarballFunctions_Good(t *testing.T) { t.Run("CreateAndExtractTarball", func(t *testing.T) { files := map[string][]byte{ "file1.txt": []byte("content of file 1"), @@ -160,8 +158,7 @@ func TestTarballFunctions(t *testing.T) { } // Extract to temp directory - tmpDir, _ := os.MkdirTemp("", "tarball-test") - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() firstExec, err := extractTarball(tarData, tmpDir) if err != nil { @@ -170,12 +167,7 @@ func TestTarballFunctions(t *testing.T) { // Check files exist for name, content := range files { - path := filepath.Join(tmpDir, name) - data, err := os.ReadFile(path) - if err != nil { - t.Errorf("failed to read extracted file %s: %v", name, err) - continue - } + data := testReadFile(t, testJoinPath(tmpDir, name)) if !bytes.Equal(data, content) { t.Errorf("content mismatch for %s", name) @@ -189,7 +181,7 @@ func TestTarballFunctions(t *testing.T) { }) } -func TestStreamAndReadBundle(t *testing.T) { +func TestBundle_StreamAndReadBundle_Good(t *testing.T) { original, _ := CreateProfileBundleUnencrypted([]byte(`{"streaming":"test"}`), "stream-test") // Stream to buffer @@ -218,7 +210,7 @@ func TestStreamAndReadBundle(t *testing.T) { } } -func TestCalculateChecksum(t *testing.T) { +func TestBundle_CalculateChecksum_Good(t *testing.T) { t.Run("Deterministic", func(t *testing.T) { data := []byte("test data for checksum") @@ -256,7 +248,7 @@ func TestCalculateChecksum(t *testing.T) { }) } -func TestIsJSON(t *testing.T) { +func TestBundle_IsJSON_Good(t *testing.T) { tests := []struct { data []byte expected bool @@ -279,7 +271,7 @@ func TestIsJSON(t *testing.T) { } } -func TestBundleTypes(t *testing.T) { +func TestBundle_Types_Good(t *testing.T) { types := []BundleType{ BundleProfile, BundleMiner, @@ -295,16 +287,11 @@ func TestBundleTypes(t *testing.T) { } } -func TestCreateMinerBundle(t *testing.T) { +func TestBundle_CreateMinerBundle_Good(t *testing.T) { // Create a temp "miner binary" - tmpDir, _ := os.MkdirTemp("", "miner-bundle-test") - defer os.RemoveAll(tmpDir) - - minerPath := filepath.Join(tmpDir, "test-miner") - err := os.WriteFile(minerPath, []byte("fake miner binary content"), 0755) - if err != nil { - t.Fatalf("failed to create test miner: %v", err) - } + tmpDir := t.TempDir() + minerPath := testJoinPath(tmpDir, "test-miner") + testWriteFile(t, minerPath, []byte("fake miner binary content"), 0o755) profileJSON := []byte(`{"profile":"data"}`) password := "miner-password" @@ -323,8 +310,7 @@ func TestCreateMinerBundle(t *testing.T) { } // Extract and verify - extractDir, _ := os.MkdirTemp("", "miner-extract-test") - defer os.RemoveAll(extractDir) + extractDir := t.TempDir() extractedPath, extractedProfile, err := ExtractMinerBundle(bundle, password, extractDir) if err != nil { @@ -341,10 +327,7 @@ func TestCreateMinerBundle(t *testing.T) { // If we got an extracted path, verify its content if extractedPath != "" { - minerData, err := os.ReadFile(extractedPath) - if err != nil { - t.Fatalf("failed to read extracted miner: %v", err) - } + minerData := testReadFile(t, extractedPath) if string(minerData) != "fake miner binary content" { t.Error("miner content mismatch") @@ -354,7 +337,7 @@ func TestCreateMinerBundle(t *testing.T) { // --- Additional coverage tests for bundle.go --- -func TestExtractTarball_PathTraversal(t *testing.T) { +func TestBundle_ExtractTarball_PathTraversal_Bad(t *testing.T) { t.Run("AbsolutePath", func(t *testing.T) { // Create a tarball with an absolute path entry tarData, err := createTarballWithCustomName("/etc/passwd", []byte("malicious")) @@ -446,8 +429,8 @@ func TestExtractTarball_PathTraversal(t *testing.T) { } // Verify symlink was not created - linkPath := filepath.Join(tmpDir, "link") - if _, statErr := os.Lstat(linkPath); !os.IsNotExist(statErr) { + linkPath := testJoinPath(tmpDir, "link") + if filesystemExists(linkPath) { t.Error("symlink should not be created") } }) @@ -481,10 +464,7 @@ func TestExtractTarball_PathTraversal(t *testing.T) { } // Verify directory and file exist - data, err := os.ReadFile(filepath.Join(tmpDir, "mydir", "file.txt")) - if err != nil { - t.Fatalf("failed to read extracted file: %v", err) - } + data := testReadFile(t, testJoinPath(tmpDir, "mydir", "file.txt")) if !bytes.Equal(data, content) { t.Error("content mismatch") } @@ -531,7 +511,7 @@ func createTarballWithSymlink(name, target string) ([]byte, error) { return buf.Bytes(), nil } -func TestExtractMinerBundle_ChecksumMismatch(t *testing.T) { +func TestBundle_ExtractMinerBundle_ChecksumMismatch_Bad(t *testing.T) { bundle := &Bundle{ Type: BundleMiner, Name: "bad-bundle", @@ -545,17 +525,17 @@ func TestExtractMinerBundle_ChecksumMismatch(t *testing.T) { } } -func TestCreateMinerBundle_NonExistentFile(t *testing.T) { +func TestBundle_CreateMinerBundle_NonExistentFile_Bad(t *testing.T) { _, err := CreateMinerBundle("/non/existent/miner", nil, "test", "password") if err == nil { t.Error("expected error for non-existent miner file") } } -func TestCreateMinerBundle_NilProfile(t *testing.T) { +func TestBundle_CreateMinerBundle_NilProfile_Ugly(t *testing.T) { tmpDir := t.TempDir() - minerPath := filepath.Join(tmpDir, "miner") - os.WriteFile(minerPath, []byte("binary"), 0755) + minerPath := testJoinPath(tmpDir, "miner") + testWriteFile(t, minerPath, []byte("binary"), 0o755) bundle, err := CreateMinerBundle(minerPath, nil, "nil-profile", "pass") if err != nil { @@ -566,7 +546,7 @@ func TestCreateMinerBundle_NilProfile(t *testing.T) { } } -func TestReadBundle_InvalidJSON(t *testing.T) { +func TestBundle_ReadBundle_InvalidJSON_Bad(t *testing.T) { reader := bytes.NewReader([]byte("not json")) _, err := ReadBundle(reader) if err == nil { @@ -574,7 +554,7 @@ func TestReadBundle_InvalidJSON(t *testing.T) { } } -func TestStreamBundle_EmptyBundle(t *testing.T) { +func TestBundle_StreamBundle_EmptyBundle_Ugly(t *testing.T) { bundle := &Bundle{ Type: BundleProfile, Name: "empty", @@ -598,7 +578,7 @@ func TestStreamBundle_EmptyBundle(t *testing.T) { } } -func TestCreateTarball_MultipleDirs(t *testing.T) { +func TestBundle_CreateTarball_MultipleDirs_Good(t *testing.T) { files := map[string][]byte{ "dir1/file1.txt": []byte("content1"), "dir2/file2.txt": []byte("content2"), @@ -616,11 +596,7 @@ func TestCreateTarball_MultipleDirs(t *testing.T) { } for name, content := range files { - data, err := os.ReadFile(filepath.Join(tmpDir, name)) - if err != nil { - t.Errorf("failed to read %s: %v", name, err) - continue - } + data := testReadFile(t, testJoinPath(tmpDir, name)) if !bytes.Equal(data, content) { t.Errorf("content mismatch for %s", name) } diff --git a/node/controller.go b/node/controller.go index 224c4d6..8441352 100644 --- a/node/controller.go +++ b/node/controller.go @@ -2,33 +2,32 @@ package node import ( "context" - "encoding/json" "sync" "time" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" "dappco.re/go/core/p2p/logging" ) -// Controller manages remote peer operations from a controller node. +// controller := NewController(nodeManager, peerRegistry, transport) type Controller struct { - node *NodeManager - peers *PeerRegistry - transport *Transport - mu sync.RWMutex + nodeManager *NodeManager + peerRegistry *PeerRegistry + transport *Transport + mutex sync.RWMutex - // Pending requests awaiting responses - pending map[string]chan *Message // message ID -> response channel + // Pending requests awaiting responses. + pendingRequests map[string]chan *Message // message ID -> response channel } -// NewController creates a new Controller instance. -func NewController(node *NodeManager, peers *PeerRegistry, transport *Transport) *Controller { +// controller := NewController(nodeManager, peerRegistry, transport) +func NewController(nodeManager *NodeManager, peerRegistry *PeerRegistry, transport *Transport) *Controller { c := &Controller{ - node: node, - peers: peers, - transport: transport, - pending: make(map[string]chan *Message), + nodeManager: nodeManager, + peerRegistry: peerRegistry, + transport: transport, + pendingRequests: make(map[string]chan *Message), } // Register message handler for responses @@ -37,114 +36,107 @@ func NewController(node *NodeManager, peers *PeerRegistry, transport *Transport) return c } -// handleResponse processes incoming messages that are responses to our requests. -func (c *Controller) handleResponse(conn *PeerConnection, msg *Message) { - if msg.ReplyTo == "" { +func (c *Controller) handleResponse(_ *PeerConnection, message *Message) { + if message.ReplyTo == "" { return // Not a response, let worker handle it } - c.mu.Lock() - ch, exists := c.pending[msg.ReplyTo] - if exists { - delete(c.pending, msg.ReplyTo) + c.mutex.Lock() + responseChannel, hasPendingRequest := c.pendingRequests[message.ReplyTo] + if hasPendingRequest { + delete(c.pendingRequests, message.ReplyTo) } - c.mu.Unlock() + c.mutex.Unlock() - if exists && ch != nil { + if hasPendingRequest && responseChannel != nil { select { - case ch <- msg: + case responseChannel <- message: default: - // Channel full or closed + // Late duplicate response; drop it. } } } -// sendRequest sends a message and waits for a response. -func (c *Controller) sendRequest(peerID string, msg *Message, timeout time.Duration) (*Message, error) { - actualPeerID := peerID +func (c *Controller) sendRequest(peerID string, message *Message, timeout time.Duration) (*Message, error) { + resolvedPeerID := peerID // Auto-connect if not already connected if c.transport.GetConnection(peerID) == nil { - peer := c.peers.GetPeer(peerID) + peer := c.peerRegistry.GetPeer(peerID) if peer == nil { - return nil, coreerr.E("Controller.sendRequest", "peer not found: "+peerID, nil) + return nil, core.E("Controller.sendRequest", "peer not found: "+peerID, nil) } conn, err := c.transport.Connect(peer) if err != nil { - return nil, coreerr.E("Controller.sendRequest", "failed to connect to peer", err) + return nil, core.E("Controller.sendRequest", "failed to connect to peer", err) } - // Use the real peer ID after handshake (it may have changed) - actualPeerID = conn.Peer.ID - // Update the message destination - msg.To = actualPeerID + resolvedPeerID = conn.Peer.ID + message.To = resolvedPeerID } - // Create response channel - respCh := make(chan *Message, 1) + responseChannel := make(chan *Message, 1) - c.mu.Lock() - c.pending[msg.ID] = respCh - c.mu.Unlock() + c.mutex.Lock() + c.pendingRequests[message.ID] = responseChannel + c.mutex.Unlock() - // Clean up on exit - ensure channel is closed and removed from map + // Clean up on exit. Deleting the pending entry is enough because + // handleResponse only routes through the map. defer func() { - c.mu.Lock() - delete(c.pending, msg.ID) - c.mu.Unlock() - close(respCh) // Close channel to allow garbage collection + c.mutex.Lock() + delete(c.pendingRequests, message.ID) + c.mutex.Unlock() }() - // Send the message - if err := c.transport.Send(actualPeerID, msg); err != nil { - return nil, coreerr.E("Controller.sendRequest", "failed to send message", err) + if err := c.transport.Send(resolvedPeerID, message); err != nil { + return nil, core.E("Controller.sendRequest", "failed to send message", err) } - // Wait for response ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() select { - case resp := <-respCh: - return resp, nil + case response := <-responseChannel: + return response, nil case <-ctx.Done(): - return nil, coreerr.E("Controller.sendRequest", "request timeout", nil) + return nil, core.E("Controller.sendRequest", "request timeout", nil) } } -// GetRemoteStats requests miner statistics from a remote peer. +// stats, err := controller.GetRemoteStats("worker-1") func (c *Controller) GetRemoteStats(peerID string) (*StatsPayload, error) { - identity := c.node.GetIdentity() + identity := c.nodeManager.GetIdentity() if identity == nil { - return nil, ErrIdentityNotInitialized + return nil, ErrorIdentityNotInitialized } - msg, err := NewMessage(MsgGetStats, identity.ID, peerID, nil) + requestMessage, err := NewMessage(MessageGetStats, identity.ID, peerID, nil) if err != nil { - return nil, coreerr.E("Controller.GetRemoteStats", "failed to create message", err) + return nil, core.E("Controller.GetRemoteStats", "failed to create message", err) } - resp, err := c.sendRequest(peerID, msg, 10*time.Second) + response, err := c.sendRequest(peerID, requestMessage, 10*time.Second) if err != nil { return nil, err } var stats StatsPayload - if err := ParseResponse(resp, MsgStats, &stats); err != nil { + if err := ParseResponse(response, MessageStats, &stats); err != nil { return nil, err } return &stats, nil } -// StartRemoteMiner requests a remote peer to start a miner with a given profile. -func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, configOverride json.RawMessage) error { - identity := c.node.GetIdentity() +// err := controller.StartRemoteMiner("worker-1", "xmrig", "profile-1", nil) +func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, configOverride RawMessage) error { + identity := c.nodeManager.GetIdentity() if identity == nil { - return ErrIdentityNotInitialized + return ErrorIdentityNotInitialized } if minerType == "" { - return coreerr.E("Controller.StartRemoteMiner", "miner type is required", nil) + return core.E("Controller.StartRemoteMiner", "miner type is required", nil) } payload := StartMinerPayload{ @@ -153,98 +145,98 @@ func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, confi Config: configOverride, } - msg, err := NewMessage(MsgStartMiner, identity.ID, peerID, payload) + requestMessage, err := NewMessage(MessageStartMiner, identity.ID, peerID, payload) if err != nil { - return coreerr.E("Controller.StartRemoteMiner", "failed to create message", err) + return core.E("Controller.StartRemoteMiner", "failed to create message", err) } - resp, err := c.sendRequest(peerID, msg, 30*time.Second) + response, err := c.sendRequest(peerID, requestMessage, 30*time.Second) if err != nil { return err } var ack MinerAckPayload - if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil { + if err := ParseResponse(response, MessageMinerAck, &ack); err != nil { return err } if !ack.Success { - return coreerr.E("Controller.StartRemoteMiner", "miner start failed: "+ack.Error, nil) + return core.E("Controller.StartRemoteMiner", "miner start failed: "+ack.Error, nil) } return nil } -// StopRemoteMiner requests a remote peer to stop a miner. +// err := controller.StopRemoteMiner("worker-1", "xmrig-0") func (c *Controller) StopRemoteMiner(peerID, minerName string) error { - identity := c.node.GetIdentity() + identity := c.nodeManager.GetIdentity() if identity == nil { - return ErrIdentityNotInitialized + return ErrorIdentityNotInitialized } payload := StopMinerPayload{ MinerName: minerName, } - msg, err := NewMessage(MsgStopMiner, identity.ID, peerID, payload) + requestMessage, err := NewMessage(MessageStopMiner, identity.ID, peerID, payload) if err != nil { - return coreerr.E("Controller.StopRemoteMiner", "failed to create message", err) + return core.E("Controller.StopRemoteMiner", "failed to create message", err) } - resp, err := c.sendRequest(peerID, msg, 30*time.Second) + response, err := c.sendRequest(peerID, requestMessage, 30*time.Second) if err != nil { return err } var ack MinerAckPayload - if err := ParseResponse(resp, MsgMinerAck, &ack); err != nil { + if err := ParseResponse(response, MessageMinerAck, &ack); err != nil { return err } if !ack.Success { - return coreerr.E("Controller.StopRemoteMiner", "miner stop failed: "+ack.Error, nil) + return core.E("Controller.StopRemoteMiner", "miner stop failed: "+ack.Error, nil) } return nil } -// GetRemoteLogs requests console logs from a remote miner. +// logs, err := controller.GetRemoteLogs("worker-1", "xmrig-0", 100) func (c *Controller) GetRemoteLogs(peerID, minerName string, lines int) ([]string, error) { - identity := c.node.GetIdentity() + identity := c.nodeManager.GetIdentity() if identity == nil { - return nil, ErrIdentityNotInitialized + return nil, ErrorIdentityNotInitialized } - payload := GetLogsPayload{ + payload := LogsRequestPayload{ MinerName: minerName, Lines: lines, } - msg, err := NewMessage(MsgGetLogs, identity.ID, peerID, payload) + requestMessage, err := NewMessage(MessageGetLogs, identity.ID, peerID, payload) if err != nil { - return nil, coreerr.E("Controller.GetRemoteLogs", "failed to create message", err) + return nil, core.E("Controller.GetRemoteLogs", "failed to create message", err) } - resp, err := c.sendRequest(peerID, msg, 10*time.Second) + response, err := c.sendRequest(peerID, requestMessage, 10*time.Second) if err != nil { return nil, err } var logs LogsPayload - if err := ParseResponse(resp, MsgLogs, &logs); err != nil { + if err := ParseResponse(response, MessageLogs, &logs); err != nil { return nil, err } return logs.Lines, nil } -// GetAllStats fetches stats from all connected peers. +// statsByPeerID := controller.GetAllStats() func (c *Controller) GetAllStats() map[string]*StatsPayload { results := make(map[string]*StatsPayload) var mu sync.Mutex var wg sync.WaitGroup - for peer := range c.peers.ConnectedPeers() { + for peer := range c.peerRegistry.ConnectedPeers() { wg.Add(1) go func(p *Peer) { defer wg.Done() @@ -267,11 +259,11 @@ func (c *Controller) GetAllStats() map[string]*StatsPayload { return results } -// PingPeer sends a ping to a peer and updates metrics. +// rttMilliseconds, err := controller.PingPeer("worker-1") func (c *Controller) PingPeer(peerID string) (float64, error) { - identity := c.node.GetIdentity() + identity := c.nodeManager.GetIdentity() if identity == nil { - return 0, ErrIdentityNotInitialized + return 0, ErrorIdentityNotInitialized } sentAt := time.Now() @@ -279,48 +271,48 @@ func (c *Controller) PingPeer(peerID string) (float64, error) { SentAt: sentAt.UnixMilli(), } - msg, err := NewMessage(MsgPing, identity.ID, peerID, payload) + requestMessage, err := NewMessage(MessagePing, identity.ID, peerID, payload) if err != nil { - return 0, coreerr.E("Controller.PingPeer", "failed to create message", err) + return 0, core.E("Controller.PingPeer", "failed to create message", err) } - resp, err := c.sendRequest(peerID, msg, 5*time.Second) + response, err := c.sendRequest(peerID, requestMessage, 5*time.Second) if err != nil { return 0, err } - if err := ValidateResponse(resp, MsgPong); err != nil { + if err := ValidateResponse(response, MessagePong); err != nil { return 0, err } - // Calculate round-trip time - rtt := time.Since(sentAt).Seconds() * 1000 // Convert to ms + // Calculate round-trip time in milliseconds. + rtt := time.Since(sentAt).Seconds() * 1000 // Update peer metrics - peer := c.peers.GetPeer(peerID) + peer := c.peerRegistry.GetPeer(peerID) if peer != nil { - c.peers.UpdateMetrics(peerID, rtt, peer.GeoKM, peer.Hops) + c.peerRegistry.UpdateMetrics(peerID, rtt, peer.GeographicKilometres, peer.Hops) } return rtt, nil } -// ConnectToPeer establishes a connection to a peer. +// err := controller.ConnectToPeer("worker-1") func (c *Controller) ConnectToPeer(peerID string) error { - peer := c.peers.GetPeer(peerID) + peer := c.peerRegistry.GetPeer(peerID) if peer == nil { - return coreerr.E("Controller.ConnectToPeer", "peer not found: "+peerID, nil) + return core.E("Controller.ConnectToPeer", "peer not found: "+peerID, nil) } _, err := c.transport.Connect(peer) return err } -// DisconnectFromPeer closes connection to a peer. +// err := controller.DisconnectFromPeer("worker-1") func (c *Controller) DisconnectFromPeer(peerID string) error { conn := c.transport.GetConnection(peerID) if conn == nil { - return coreerr.E("Controller.DisconnectFromPeer", "peer not connected: "+peerID, nil) + return core.E("Controller.DisconnectFromPeer", "peer not connected: "+peerID, nil) } return conn.Close() diff --git a/node/controller_test.go b/node/controller_test.go index ee9a383..dd05231 100644 --- a/node/controller_test.go +++ b/node/controller_test.go @@ -1,17 +1,15 @@ package node import ( - "encoding/json" - "fmt" "net/http" "net/http/httptest" "net/url" - "path/filepath" "sync" "sync/atomic" "testing" "time" + core "dappco.re/go/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,7 +24,7 @@ func setupControllerPair(t *testing.T) (*Controller, *Worker, *testTransportPair // Server side: register a Worker to handle incoming requests. worker := NewWorker(tp.ServerNode, tp.Server) - worker.RegisterWithTransport() + worker.RegisterOnTransport() // Client side: create a Controller (registers handleResponse via OnMessage). controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -42,23 +40,23 @@ func setupControllerPair(t *testing.T) (*Controller, *Worker, *testTransportPair // makeWorkerServer spins up an independent server transport with a Worker // registered, returning the server's NodeManager, address, and a cleanup func. -// Useful for multi-peer tests (GetAllStats, ConcurrentRequests). +// Useful for multi-peer tests (AllStats, ConcurrentRequests). func makeWorkerServer(t *testing.T) (*NodeManager, string, *Transport) { t.Helper() - nm := testNode(t, "worker", RoleWorker) - reg := testRegistry(t) + nm := newTestNodeManager(t, "worker", RoleWorker) + reg := newTestPeerRegistry(t) cfg := DefaultTransportConfig() srv := NewTransport(nm, reg, cfg) mux := http.NewServeMux() - mux.HandleFunc(cfg.WSPath, srv.handleWSUpgrade) + mux.HandleFunc(cfg.WebSocketPath, srv.handleWebSocketUpgrade) ts := httptest.NewServer(mux) u, _ := url.Parse(ts.URL) worker := NewWorker(nm, srv) - worker.RegisterWithTransport() + worker.RegisterOnTransport() t.Cleanup(func() { // Brief pause to let in-flight readLoop/Send operations finish before @@ -75,18 +73,18 @@ func makeWorkerServer(t *testing.T) (*NodeManager, string, *Transport) { // --- Controller Tests --- -func TestController_RequestResponseCorrelation(t *testing.T) { +func TestController_RequestResponseCorrelation_Good(t *testing.T) { controller, _, tp := setupControllerPair(t) serverID := tp.ServerNode.GetIdentity().ID // Send a ping request via the controller; the server-side worker - // replies with MsgPong, setting ReplyTo to the original message ID. + // replies with MessagePong, setting ReplyTo to the original message ID. rtt, err := controller.PingPeer(serverID) require.NoError(t, err, "PingPeer should succeed") assert.Greater(t, rtt, 0.0, "RTT should be positive") } -func TestController_RequestTimeout(t *testing.T) { +func TestController_RequestTimeout_Bad(t *testing.T) { tp := setupTestTransportPair(t) // Register a handler on the server that deliberately ignores all messages, @@ -103,7 +101,7 @@ func TestController_RequestTimeout(t *testing.T) { clientID := tp.ClientNode.GetIdentity().ID // Use sendRequest directly with a short deadline (PingPeer uses 5s internally). - msg, err := NewMessage(MsgPing, clientID, serverID, PingPayload{ + msg, err := NewMessage(MessagePing, clientID, serverID, PingPayload{ SentAt: time.Now().UnixMilli(), }) require.NoError(t, err) @@ -117,12 +115,12 @@ func TestController_RequestTimeout(t *testing.T) { assert.Less(t, elapsed, 1*time.Second, "should return quickly after the deadline") } -func TestController_AutoConnect(t *testing.T) { +func TestController_AutoConnect_Good(t *testing.T) { tp := setupTestTransportPair(t) // Register worker on the server side. worker := NewWorker(tp.ServerNode, tp.Server) - worker.RegisterWithTransport() + worker.RegisterOnTransport() // Create controller WITHOUT establishing a connection first. controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -138,7 +136,7 @@ func TestController_AutoConnect(t *testing.T) { tp.ClientReg.AddPeer(peer) // Confirm no connection exists yet. - assert.Equal(t, 0, tp.Client.ConnectedPeers(), "should have no connections initially") + assert.Equal(t, 0, tp.Client.ConnectedPeerCount(), "should have no connections initially") // Send a request — controller should auto-connect via transport before sending. rtt, err := controller.PingPeer(serverIdentity.ID) @@ -146,13 +144,13 @@ func TestController_AutoConnect(t *testing.T) { assert.Greater(t, rtt, 0.0, "RTT should be positive after auto-connect") // Verify connection was established. - assert.Equal(t, 1, tp.Client.ConnectedPeers(), "should have 1 connection after auto-connect") + assert.Equal(t, 1, tp.Client.ConnectedPeerCount(), "should have 1 connection after auto-connect") } -func TestController_GetAllStats(t *testing.T) { +func TestController_AllStats_Good(t *testing.T) { // Controller node with connections to two independent worker servers. - controllerNM := testNode(t, "controller", RoleController) - controllerReg := testRegistry(t) + controllerNM := newTestNodeManager(t, "controller", RoleController) + controllerReg := newTestPeerRegistry(t) controllerTransport := NewTransport(controllerNM, controllerReg, DefaultTransportConfig()) t.Cleanup(func() { controllerTransport.Stop() }) @@ -180,7 +178,7 @@ func TestController_GetAllStats(t *testing.T) { controller := NewController(controllerNM, controllerReg, controllerTransport) - // GetAllStats fetches stats from all connected peers in parallel. + // AllStats fetches stats from all connected peers in parallel. stats := controller.GetAllStats() assert.Len(t, stats, numWorkers, "should get stats from all connected workers") @@ -194,14 +192,14 @@ func TestController_GetAllStats(t *testing.T) { } } -func TestController_PingPeerRTT(t *testing.T) { +func TestController_PingPeerRTT_Good(t *testing.T) { controller, _, tp := setupControllerPair(t) serverID := tp.ServerNode.GetIdentity().ID // Record initial peer metrics. peerBefore := tp.ClientReg.GetPeer(serverID) require.NotNil(t, peerBefore, "server peer should exist in the client registry") - initialPingMS := peerBefore.PingMS + initialPingMilliseconds := peerBefore.PingMilliseconds // Send a ping. rtt, err := controller.PingPeer(serverID) @@ -212,16 +210,16 @@ func TestController_PingPeerRTT(t *testing.T) { // Verify the peer registry was updated with the measured latency. peerAfter := tp.ClientReg.GetPeer(serverID) require.NotNil(t, peerAfter, "server peer should still exist after ping") - assert.NotEqual(t, initialPingMS, peerAfter.PingMS, - "PingMS should be updated after a successful ping") - assert.Greater(t, peerAfter.PingMS, 0.0, "PingMS should be positive") + assert.NotEqual(t, initialPingMilliseconds, peerAfter.PingMilliseconds, + "PingMilliseconds should be updated after a successful ping") + assert.Greater(t, peerAfter.PingMilliseconds, 0.0, "PingMilliseconds should be positive") } -func TestController_ConcurrentRequests(t *testing.T) { +func TestController_ConcurrentRequests_Ugly(t *testing.T) { // Multiple goroutines send pings to different peers simultaneously. // Verify correct correlation — no cross-talk between responses. - controllerNM := testNode(t, "controller", RoleController) - controllerReg := testRegistry(t) + controllerNM := newTestNodeManager(t, "controller", RoleController) + controllerReg := newTestPeerRegistry(t) controllerTransport := NewTransport(controllerNM, controllerReg, DefaultTransportConfig()) t.Cleanup(func() { controllerTransport.Stop() }) @@ -271,7 +269,7 @@ func TestController_ConcurrentRequests(t *testing.T) { } } -func TestController_DeadPeerCleanup(t *testing.T) { +func TestController_DeadPeerCleanup_Good(t *testing.T) { tp := setupTestTransportPair(t) // Server deliberately ignores all messages. @@ -285,7 +283,7 @@ func TestController_DeadPeerCleanup(t *testing.T) { clientID := tp.ClientNode.GetIdentity().ID // Fire off a request that will time out. - msg, err := NewMessage(MsgPing, clientID, serverID, PingPayload{ + msg, err := NewMessage(MessagePing, clientID, serverID, PingPayload{ SentAt: time.Now().UnixMilli(), }) require.NoError(t, err) @@ -297,9 +295,9 @@ func TestController_DeadPeerCleanup(t *testing.T) { // The defer block inside sendRequest should have cleaned up the pending entry. time.Sleep(50 * time.Millisecond) - controller.mu.RLock() - pendingCount := len(controller.pending) - controller.mu.RUnlock() + controller.mutex.RLock() + pendingCount := len(controller.pendingRequests) + controller.mutex.RUnlock() assert.Equal(t, 0, pendingCount, "pending map should be empty after timeout — no goroutine/memory leak") @@ -307,7 +305,7 @@ func TestController_DeadPeerCleanup(t *testing.T) { // --- Additional edge-case tests --- -func TestController_MultipleSequentialPings(t *testing.T) { +func TestController_MultipleSequentialPings_Good(t *testing.T) { // Ensures sequential requests to the same peer are correctly correlated. controller, _, tp := setupControllerPair(t) serverID := tp.ServerNode.GetIdentity().ID @@ -319,7 +317,7 @@ func TestController_MultipleSequentialPings(t *testing.T) { } } -func TestController_ConcurrentRequestsSamePeer(t *testing.T) { +func TestController_ConcurrentRequestsSamePeer_Ugly(t *testing.T) { // Multiple goroutines sending requests to the SAME peer simultaneously. // Tests concurrent pending-map insertions/deletions under contention. controller, _, tp := setupControllerPair(t) @@ -343,12 +341,12 @@ func TestController_ConcurrentRequestsSamePeer(t *testing.T) { "all concurrent requests to the same peer should succeed") } -func TestController_GetRemoteStats(t *testing.T) { +func TestController_RemoteStats_Good(t *testing.T) { controller, _, tp := setupControllerPair(t) serverID := tp.ServerNode.GetIdentity().ID stats, err := controller.GetRemoteStats(serverID) - require.NoError(t, err, "GetRemoteStats should succeed") + require.NoError(t, err, "RemoteStats should succeed") require.NotNil(t, stats) assert.NotEmpty(t, stats.NodeID, "stats should contain the node ID") @@ -357,7 +355,7 @@ func TestController_GetRemoteStats(t *testing.T) { assert.GreaterOrEqual(t, stats.Uptime, int64(0), "uptime should be non-negative") } -func TestController_ConnectToPeerUnknown(t *testing.T) { +func TestController_ConnectToPeerUnknown_Bad(t *testing.T) { tp := setupTestTransportPair(t) controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -366,17 +364,17 @@ func TestController_ConnectToPeerUnknown(t *testing.T) { assert.Contains(t, err.Error(), "not found") } -func TestController_DisconnectFromPeer(t *testing.T) { +func TestController_DisconnectFromPeer_Good(t *testing.T) { controller, _, tp := setupControllerPair(t) serverID := tp.ServerNode.GetIdentity().ID - assert.Equal(t, 1, tp.Client.ConnectedPeers(), "should have 1 connection") + assert.Equal(t, 1, tp.Client.ConnectedPeerCount(), "should have 1 connection") err := controller.DisconnectFromPeer(serverID) require.NoError(t, err, "DisconnectFromPeer should succeed") } -func TestController_DisconnectFromPeerNotConnected(t *testing.T) { +func TestController_DisconnectFromPeerNotConnected_Bad(t *testing.T) { tp := setupTestTransportPair(t) controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -385,12 +383,12 @@ func TestController_DisconnectFromPeerNotConnected(t *testing.T) { assert.Contains(t, err.Error(), "not connected") } -func TestController_SendRequestPeerNotFound(t *testing.T) { +func TestController_SendRequestPeerNotFound_Bad(t *testing.T) { tp := setupTestTransportPair(t) controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) clientID := tp.ClientNode.GetIdentity().ID - msg, err := NewMessage(MsgPing, clientID, "ghost-peer", PingPayload{ + msg, err := NewMessage(MessagePing, clientID, "ghost-peer", PingPayload{ SentAt: time.Now().UnixMilli(), }) require.NoError(t, err) @@ -401,7 +399,7 @@ func TestController_SendRequestPeerNotFound(t *testing.T) { assert.Contains(t, err.Error(), "peer not found") } -// --- Tests for StartRemoteMiner, StopRemoteMiner, GetRemoteLogs --- +// --- Tests for StartRemoteMiner, StopRemoteMiner, RemoteLogs --- // setupControllerPairWithMiner creates a controller/worker pair where the worker // has a fully configured MinerManager so that start/stop/logs handlers work. @@ -434,7 +432,7 @@ func setupControllerPairWithMiner(t *testing.T) (*Controller, *Worker, *testTran }, } worker.SetMinerManager(mm) - worker.RegisterWithTransport() + worker.RegisterOnTransport() // Client side: create a Controller. controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -446,7 +444,7 @@ func setupControllerPairWithMiner(t *testing.T) (*Controller, *Worker, *testTran return controller, worker, tp } -// mockMinerManagerFull implements MinerManager with functional start/stop/list/get. +// mockMinerManagerFull implements MinerManager with functional start/stop/list/lookup. type mockMinerManagerFull struct { mu sync.Mutex miners map[string]*mockMinerFull @@ -475,7 +473,7 @@ func (m *mockMinerManagerFull) StopMiner(name string) error { defer m.mu.Unlock() if _, exists := m.miners[name]; !exists { - return fmt.Errorf("miner %s not found", name) + return core.E("mockMinerManagerFull.StopMiner", "miner "+name+" not found", nil) } delete(m.miners, name) return nil @@ -498,11 +496,15 @@ func (m *mockMinerManagerFull) GetMiner(name string) (MinerInstance, error) { miner, exists := m.miners[name] if !exists { - return nil, fmt.Errorf("miner %s not found", name) + return nil, core.E("mockMinerManagerFull.GetMiner", "miner "+name+" not found", nil) } return miner, nil } +func (m *mockMinerManagerFull) Miner(name string) (MinerInstance, error) { + return m.GetMiner(name) +} + // mockMinerFull implements MinerInstance with real data. type mockMinerFull struct { name string @@ -521,25 +523,30 @@ func (m *mockMinerFull) GetConsoleHistory(lines int) []string { return m.consoleHistory[:lines] } -func TestController_StartRemoteMiner(t *testing.T) { +func (m *mockMinerFull) Name() string { return m.GetName() } +func (m *mockMinerFull) Type() string { return m.GetType() } +func (m *mockMinerFull) Stats() (any, error) { return m.GetStats() } +func (m *mockMinerFull) ConsoleHistory(lines int) []string { return m.GetConsoleHistory(lines) } + +func TestController_StartRemoteMiner_Good(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID - configOverride := json.RawMessage(`{"pool":"pool.example.com:3333"}`) + configOverride := RawMessage(`{"pool":"pool.example.com:3333"}`) err := controller.StartRemoteMiner(serverID, "xmrig", "profile-1", configOverride) require.NoError(t, err, "StartRemoteMiner should succeed") } -func TestController_StartRemoteMiner_WithConfig(t *testing.T) { +func TestController_StartRemoteMiner_WithConfig_Good(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID - configOverride := json.RawMessage(`{"pool":"custom-pool:3333","threads":4}`) + configOverride := RawMessage(`{"pool":"custom-pool:3333","threads":4}`) err := controller.StartRemoteMiner(serverID, "xmrig", "", configOverride) require.NoError(t, err, "StartRemoteMiner with config override should succeed") } -func TestController_StartRemoteMiner_EmptyType(t *testing.T) { +func TestController_StartRemoteMiner_EmptyType_Bad(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID @@ -548,14 +555,12 @@ func TestController_StartRemoteMiner_EmptyType(t *testing.T) { assert.Contains(t, err.Error(), "miner type is required") } -func TestController_StartRemoteMiner_NoIdentity(t *testing.T) { +func TestController_StartRemoteMiner_NoIdentity_Bad(t *testing.T) { tp := setupTestTransportPair(t) // Create a node without identity - nmNoID, err := NewNodeManagerWithPaths( - filepath.Join(t.TempDir(), "priv.key"), - filepath.Join(t.TempDir(), "node.json"), - ) + keyPath, configPath := testNodeManagerPaths(t.TempDir()) + nmNoID, err := NewNodeManagerFromPaths(keyPath, configPath) require.NoError(t, err) controller := NewController(nmNoID, tp.ClientReg, tp.Client) @@ -565,7 +570,7 @@ func TestController_StartRemoteMiner_NoIdentity(t *testing.T) { assert.Contains(t, err.Error(), "identity not initialized") } -func TestController_StopRemoteMiner(t *testing.T) { +func TestController_StopRemoteMiner_Good(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID @@ -573,7 +578,7 @@ func TestController_StopRemoteMiner(t *testing.T) { require.NoError(t, err, "StopRemoteMiner should succeed for existing miner") } -func TestController_StopRemoteMiner_NotFound(t *testing.T) { +func TestController_StopRemoteMiner_NotFound_Bad(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID @@ -581,12 +586,10 @@ func TestController_StopRemoteMiner_NotFound(t *testing.T) { require.Error(t, err, "StopRemoteMiner should fail for non-existent miner") } -func TestController_StopRemoteMiner_NoIdentity(t *testing.T) { +func TestController_StopRemoteMiner_NoIdentity_Bad(t *testing.T) { tp := setupTestTransportPair(t) - nmNoID, err := NewNodeManagerWithPaths( - filepath.Join(t.TempDir(), "priv.key"), - filepath.Join(t.TempDir(), "node.json"), - ) + keyPath, configPath := testNodeManagerPaths(t.TempDir()) + nmNoID, err := NewNodeManagerFromPaths(keyPath, configPath) require.NoError(t, err) controller := NewController(nmNoID, tp.ClientReg, tp.Client) @@ -596,32 +599,30 @@ func TestController_StopRemoteMiner_NoIdentity(t *testing.T) { assert.Contains(t, err.Error(), "identity not initialized") } -func TestController_GetRemoteLogs(t *testing.T) { +func TestController_RemoteLogs_Good(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID lines, err := controller.GetRemoteLogs(serverID, "running-miner", 10) - require.NoError(t, err, "GetRemoteLogs should succeed") + require.NoError(t, err, "RemoteLogs should succeed") require.NotNil(t, lines) assert.Len(t, lines, 3, "should return all 3 console history lines") assert.Contains(t, lines[0], "started") } -func TestController_GetRemoteLogs_LimitedLines(t *testing.T) { +func TestController_RemoteLogs_LimitedLines_Good(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID lines, err := controller.GetRemoteLogs(serverID, "running-miner", 1) - require.NoError(t, err, "GetRemoteLogs with limited lines should succeed") + require.NoError(t, err, "RemoteLogs with limited lines should succeed") assert.Len(t, lines, 1, "should return only 1 line") } -func TestController_GetRemoteLogs_NoIdentity(t *testing.T) { +func TestController_RemoteLogs_NoIdentity_Bad(t *testing.T) { tp := setupTestTransportPair(t) - nmNoID, err := NewNodeManagerWithPaths( - filepath.Join(t.TempDir(), "priv.key"), - filepath.Join(t.TempDir(), "node.json"), - ) + keyPath, configPath := testNodeManagerPaths(t.TempDir()) + nmNoID, err := NewNodeManagerFromPaths(keyPath, configPath) require.NoError(t, err) controller := NewController(nmNoID, tp.ClientReg, tp.Client) @@ -631,12 +632,12 @@ func TestController_GetRemoteLogs_NoIdentity(t *testing.T) { assert.Contains(t, err.Error(), "identity not initialized") } -func TestController_GetRemoteStats_WithMiners(t *testing.T) { +func TestController_RemoteStats_WithMiners_Good(t *testing.T) { controller, _, tp := setupControllerPairWithMiner(t) serverID := tp.ServerNode.GetIdentity().ID stats, err := controller.GetRemoteStats(serverID) - require.NoError(t, err, "GetRemoteStats should succeed") + require.NoError(t, err, "RemoteStats should succeed") require.NotNil(t, stats) assert.NotEmpty(t, stats.NodeID) // The worker has a miner manager with 1 running miner @@ -645,12 +646,10 @@ func TestController_GetRemoteStats_WithMiners(t *testing.T) { assert.Equal(t, 1234.5, stats.Miners[0].Hashrate) } -func TestController_GetRemoteStats_NoIdentity(t *testing.T) { +func TestController_RemoteStats_NoIdentity_Bad(t *testing.T) { tp := setupTestTransportPair(t) - nmNoID, err := NewNodeManagerWithPaths( - filepath.Join(t.TempDir(), "priv.key"), - filepath.Join(t.TempDir(), "node.json"), - ) + keyPath, configPath := testNodeManagerPaths(t.TempDir()) + nmNoID, err := NewNodeManagerFromPaths(keyPath, configPath) require.NoError(t, err) controller := NewController(nmNoID, tp.ClientReg, tp.Client) @@ -660,11 +659,11 @@ func TestController_GetRemoteStats_NoIdentity(t *testing.T) { assert.Contains(t, err.Error(), "identity not initialized") } -func TestController_ConnectToPeer_Success(t *testing.T) { +func TestController_ConnectToPeer_Success_Good(t *testing.T) { tp := setupTestTransportPair(t) worker := NewWorker(tp.ServerNode, tp.Server) - worker.RegisterWithTransport() + worker.RegisterOnTransport() controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -681,25 +680,25 @@ func TestController_ConnectToPeer_Success(t *testing.T) { err := controller.ConnectToPeer(serverIdentity.ID) require.NoError(t, err, "ConnectToPeer should succeed") - assert.Equal(t, 1, tp.Client.ConnectedPeers(), "should have 1 connection after ConnectToPeer") + assert.Equal(t, 1, tp.Client.ConnectedPeerCount(), "should have 1 connection after ConnectToPeer") } -func TestController_HandleResponse_NonReply(t *testing.T) { +func TestController_HandleResponse_NonReply_Good(t *testing.T) { tp := setupTestTransportPair(t) controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) // handleResponse should ignore messages without ReplyTo - msg, _ := NewMessage(MsgPing, "sender", "target", PingPayload{SentAt: 123}) + msg, _ := NewMessage(MessagePing, "sender", "target", PingPayload{SentAt: 123}) controller.handleResponse(nil, msg) // No pending entries should be affected - controller.mu.RLock() - count := len(controller.pending) - controller.mu.RUnlock() + controller.mutex.RLock() + count := len(controller.pendingRequests) + controller.mutex.RUnlock() assert.Equal(t, 0, count) } -func TestController_HandleResponse_FullChannel(t *testing.T) { +func TestController_HandleResponse_FullChannel_Ugly(t *testing.T) { tp := setupTestTransportPair(t) controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -707,28 +706,26 @@ func TestController_HandleResponse_FullChannel(t *testing.T) { ch := make(chan *Message, 1) ch <- &Message{} // Fill the channel - controller.mu.Lock() - controller.pending["test-id"] = ch - controller.mu.Unlock() + controller.mutex.Lock() + controller.pendingRequests["test-id"] = ch + controller.mutex.Unlock() // handleResponse with matching reply should not panic on full channel - msg, _ := NewMessage(MsgPong, "sender", "target", PongPayload{SentAt: 123}) + msg, _ := NewMessage(MessagePong, "sender", "target", PongPayload{SentAt: 123}) msg.ReplyTo = "test-id" controller.handleResponse(nil, msg) // The pending entry should be removed despite channel being full - controller.mu.RLock() - _, exists := controller.pending["test-id"] - controller.mu.RUnlock() + controller.mutex.RLock() + _, exists := controller.pendingRequests["test-id"] + controller.mutex.RUnlock() assert.False(t, exists, "pending entry should be removed after handling") } -func TestController_PingPeer_NoIdentity(t *testing.T) { +func TestController_PingPeer_NoIdentity_Bad(t *testing.T) { tp := setupTestTransportPair(t) - nmNoID, _ := NewNodeManagerWithPaths( - filepath.Join(t.TempDir(), "priv.key"), - filepath.Join(t.TempDir(), "node.json"), - ) + keyPath, configPath := testNodeManagerPaths(t.TempDir()) + nmNoID, _ := NewNodeManagerFromPaths(keyPath, configPath) controller := NewController(nmNoID, tp.ClientReg, tp.Client) _, err := controller.PingPeer("some-peer") diff --git a/node/dispatcher.go b/node/dispatcher.go index c240e18..53e5cff 100644 --- a/node/dispatcher.go +++ b/node/dispatcher.go @@ -1,56 +1,39 @@ package node import ( - "fmt" "iter" "sync" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" "dappco.re/go/core/p2p/logging" "dappco.re/go/core/p2p/ueps" ) -// ThreatScoreThreshold is the maximum allowable threat score. Packets exceeding -// this value are silently dropped by the circuit breaker and logged as threat -// events. The threshold sits at ~76% of the uint16 range (50,000 / 65,535), -// providing headroom for legitimate elevated-risk traffic whilst rejecting -// clearly hostile payloads. +// threshold := ThreatScoreThreshold const ThreatScoreThreshold uint16 = 50000 -// Well-known intent identifiers. These correspond to the semantic tokens -// carried in the UEPS IntentID header field (RFC-021). +// intentID := IntentPauseExecution const ( - IntentHandshake byte = 0x01 // Connection establishment / hello - IntentCompute byte = 0x20 // Compute job request - IntentRehab byte = 0x30 // Benevolent intervention (pause execution) - IntentCustom byte = 0xFF // Extended / application-level sub-protocols + IntentHandshake byte = 0x01 + IntentCompute byte = 0x20 + IntentPauseExecution byte = 0x30 + IntentCustom byte = 0xFF ) -// IntentHandler processes a UEPS packet that has been routed by intent. -// Implementations receive the fully parsed and HMAC-verified packet. -type IntentHandler func(pkt *ueps.ParsedPacket) error +// var handler IntentHandler = func(packet *ueps.ParsedPacket) error { return nil } +type IntentHandler func(packet *ueps.ParsedPacket) error -// Dispatcher routes verified UEPS packets to registered intent handlers. -// It enforces a threat circuit breaker before routing: any packet whose -// ThreatScore exceeds ThreatScoreThreshold is dropped and logged. -// -// Design decisions: -// - Handlers are registered per IntentID (1:1 mapping). -// - Unknown intents are logged at WARN level and silently dropped (no error -// returned to the caller) to avoid back-pressure on the transport layer. -// - High-threat packets are dropped silently (logged at WARN) rather than -// returning an error, consistent with the "don't even parse the payload" -// philosophy from the original stub. -// - The dispatcher is safe for concurrent use; a RWMutex protects the -// handler map. +// dispatcher := NewDispatcher() +// dispatcher.RegisterHandler(IntentCompute, func(packet *ueps.ParsedPacket) error { return nil }) +// err := dispatcher.Dispatch(packet) type Dispatcher struct { handlers map[byte]IntentHandler mu sync.RWMutex log *logging.Logger } -// NewDispatcher creates a Dispatcher with no registered handlers. +// dispatcher := NewDispatcher() func NewDispatcher() *Dispatcher { return &Dispatcher{ handlers: make(map[byte]IntentHandler), @@ -61,19 +44,20 @@ func NewDispatcher() *Dispatcher { } } -// RegisterHandler associates an IntentHandler with a specific IntentID. -// Calling RegisterHandler with an IntentID that already has a handler will -// replace the previous handler. +// dispatcher.RegisterHandler(IntentCompute, func(packet *ueps.ParsedPacket) error { return nil }) func (d *Dispatcher) RegisterHandler(intentID byte, handler IntentHandler) { d.mu.Lock() defer d.mu.Unlock() d.handlers[intentID] = handler d.log.Debug("handler registered", logging.Fields{ - "intent_id": fmt.Sprintf("0x%02X", intentID), + "intent_id": core.Sprintf("0x%02X", intentID), }) } -// Handlers returns an iterator over all registered intent handlers. +// for intentID, handler := range dispatcher.Handlers() { +// _ = intentID +// _ = handler +// } func (d *Dispatcher) Handlers() iter.Seq2[byte, IntentHandler] { return func(yield func(byte, IntentHandler) bool) { d.mu.RLock() @@ -87,59 +71,46 @@ func (d *Dispatcher) Handlers() iter.Seq2[byte, IntentHandler] { } } -// Dispatch routes a parsed UEPS packet through the threat circuit breaker -// and then to the appropriate intent handler. -// -// Behaviour: -// - Returns ErrThreatScoreExceeded if the packet's ThreatScore exceeds the -// threshold (packet is dropped and logged). -// - Returns ErrUnknownIntent if no handler is registered for the IntentID -// (packet is dropped and logged). -// - Returns nil on successful delivery to a handler, or any error the -// handler itself returns. -// - A nil packet returns ErrNilPacket immediately. -func (d *Dispatcher) Dispatch(pkt *ueps.ParsedPacket) error { - if pkt == nil { - return ErrNilPacket +// err := dispatcher.Dispatch(packet) +func (d *Dispatcher) Dispatch(packet *ueps.ParsedPacket) error { + if packet == nil { + return ErrorNilPacket } // 1. Threat circuit breaker (L5 guard) - if pkt.Header.ThreatScore > ThreatScoreThreshold { + if packet.Header.ThreatScore > ThreatScoreThreshold { d.log.Warn("packet dropped: threat score exceeds safety threshold", logging.Fields{ - "threat_score": pkt.Header.ThreatScore, + "threat_score": packet.Header.ThreatScore, "threshold": ThreatScoreThreshold, - "intent_id": fmt.Sprintf("0x%02X", pkt.Header.IntentID), - "version": pkt.Header.Version, + "intent_id": core.Sprintf("0x%02X", packet.Header.IntentID), + "version": packet.Header.Version, }) - return ErrThreatScoreExceeded + return ErrorThreatScoreExceeded } // 2. Intent routing (L9 semantic) d.mu.RLock() - handler, exists := d.handlers[pkt.Header.IntentID] + handler, exists := d.handlers[packet.Header.IntentID] d.mu.RUnlock() if !exists { d.log.Warn("packet dropped: unknown intent", logging.Fields{ - "intent_id": fmt.Sprintf("0x%02X", pkt.Header.IntentID), - "version": pkt.Header.Version, + "intent_id": core.Sprintf("0x%02X", packet.Header.IntentID), + "version": packet.Header.Version, }) - return ErrUnknownIntent + return ErrorUnknownIntent } - return handler(pkt) + return handler(packet) } -// Sentinel errors returned by Dispatch. var ( - // ErrThreatScoreExceeded is returned when a packet's ThreatScore exceeds - // the safety threshold. - ErrThreatScoreExceeded = coreerr.E("Dispatcher.Dispatch", fmt.Sprintf("packet rejected: threat score exceeds safety threshold (%d)", ThreatScoreThreshold), nil) + // err := ErrorThreatScoreExceeded + ErrorThreatScoreExceeded = core.E("Dispatcher.Dispatch", core.Sprintf("packet rejected: threat score exceeds safety threshold (%d)", ThreatScoreThreshold), nil) - // ErrUnknownIntent is returned when no handler is registered for the - // packet's IntentID. - ErrUnknownIntent = coreerr.E("Dispatcher.Dispatch", "packet dropped: unknown intent", nil) + // err := ErrorUnknownIntent + ErrorUnknownIntent = core.E("Dispatcher.Dispatch", "packet dropped: unknown intent", nil) - // ErrNilPacket is returned when a nil packet is passed to Dispatch. - ErrNilPacket = coreerr.E("Dispatcher.Dispatch", "nil packet", nil) + // err := ErrorNilPacket + ErrorNilPacket = core.E("Dispatcher.Dispatch", "nil packet", nil) ) diff --git a/node/dispatcher_test.go b/node/dispatcher_test.go index f817c03..d458875 100644 --- a/node/dispatcher_test.go +++ b/node/dispatcher_test.go @@ -1,11 +1,11 @@ package node import ( - "fmt" "sync" "sync/atomic" "testing" + core "dappco.re/go/core" "dappco.re/go/core/p2p/ueps" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,7 +28,7 @@ func makePacket(intentID byte, threatScore uint16, payload []byte) *ueps.ParsedP // --- Dispatcher Tests --- -func TestDispatcher_RegisterAndDispatch(t *testing.T) { +func TestDispatcher_RegisterAndDispatch_Good(t *testing.T) { t.Run("handler receives the correct packet", func(t *testing.T) { d := NewDispatcher() var received *ueps.ParsedPacket @@ -49,7 +49,7 @@ func TestDispatcher_RegisterAndDispatch(t *testing.T) { t.Run("handler error propagates to caller", func(t *testing.T) { d := NewDispatcher() - handlerErr := fmt.Errorf("compute failed") + handlerErr := core.NewError("compute failed") d.RegisterHandler(IntentCompute, func(pkt *ueps.ParsedPacket) error { return handlerErr @@ -62,7 +62,7 @@ func TestDispatcher_RegisterAndDispatch(t *testing.T) { }) } -func TestDispatcher_ThreatCircuitBreaker(t *testing.T) { +func TestDispatcher_ThreatCircuitBreaker_Good(t *testing.T) { tests := []struct { name string threatScore uint16 @@ -78,13 +78,13 @@ func TestDispatcher_ThreatCircuitBreaker(t *testing.T) { { name: "score just above threshold is rejected", threatScore: ThreatScoreThreshold + 1, - wantErr: ErrThreatScoreExceeded, + wantErr: ErrorThreatScoreExceeded, dispatched: false, }, { name: "maximum uint16 score is rejected", threatScore: 65535, - wantErr: ErrThreatScoreExceeded, + wantErr: ErrorThreatScoreExceeded, dispatched: false, }, { @@ -118,7 +118,7 @@ func TestDispatcher_ThreatCircuitBreaker(t *testing.T) { } } -func TestDispatcher_UnknownIntentDropped(t *testing.T) { +func TestDispatcher_UnknownIntentDropped_Bad(t *testing.T) { d := NewDispatcher() // Register handlers for known intents only @@ -130,13 +130,13 @@ func TestDispatcher_UnknownIntentDropped(t *testing.T) { pkt := makePacket(0x42, 0, []byte("unknown")) err := d.Dispatch(pkt) - assert.ErrorIs(t, err, ErrUnknownIntent) + assert.ErrorIs(t, err, ErrorUnknownIntent) } -func TestDispatcher_MultipleHandlersCorrectRouting(t *testing.T) { +func TestDispatcher_MultipleHandlersCorrectRouting_Good(t *testing.T) { d := NewDispatcher() - var handshakeCalled, computeCalled, rehabCalled, customCalled bool + var handshakeCalled, computeCalled, pauseExecutionCalled, customCalled bool d.RegisterHandler(IntentHandshake, func(pkt *ueps.ParsedPacket) error { handshakeCalled = true @@ -146,8 +146,8 @@ func TestDispatcher_MultipleHandlersCorrectRouting(t *testing.T) { computeCalled = true return nil }) - d.RegisterHandler(IntentRehab, func(pkt *ueps.ParsedPacket) error { - rehabCalled = true + d.RegisterHandler(IntentPauseExecution, func(pkt *ueps.ParsedPacket) error { + pauseExecutionCalled = true return nil }) d.RegisterHandler(IntentCustom, func(pkt *ueps.ParsedPacket) error { @@ -162,7 +162,7 @@ func TestDispatcher_MultipleHandlersCorrectRouting(t *testing.T) { }{ {"handshake routes correctly", IntentHandshake, &handshakeCalled}, {"compute routes correctly", IntentCompute, &computeCalled}, - {"rehab routes correctly", IntentRehab, &rehabCalled}, + {"pause execution routes correctly", IntentPauseExecution, &pauseExecutionCalled}, {"custom routes correctly", IntentCustom, &customCalled}, } @@ -171,7 +171,7 @@ func TestDispatcher_MultipleHandlersCorrectRouting(t *testing.T) { // Reset all flags handshakeCalled = false computeCalled = false - rehabCalled = false + pauseExecutionCalled = false customCalled = false pkt := makePacket(tt.intentID, 0, []byte("payload")) @@ -192,11 +192,11 @@ func TestDispatcher_MultipleHandlersCorrectRouting(t *testing.T) { } } -func TestDispatcher_NilAndEmptyPayload(t *testing.T) { - t.Run("nil packet returns ErrNilPacket", func(t *testing.T) { +func TestDispatcher_NilAndEmptyPayload_Ugly(t *testing.T) { + t.Run("nil packet returns ErrorNilPacket", func(t *testing.T) { d := NewDispatcher() err := d.Dispatch(nil) - assert.ErrorIs(t, err, ErrNilPacket) + assert.ErrorIs(t, err, ErrorNilPacket) }) t.Run("nil payload is delivered to handler", func(t *testing.T) { @@ -234,7 +234,7 @@ func TestDispatcher_NilAndEmptyPayload(t *testing.T) { }) } -func TestDispatcher_ConcurrentDispatchSafety(t *testing.T) { +func TestDispatcher_ConcurrentDispatchSafety_Ugly(t *testing.T) { d := NewDispatcher() var count atomic.Int64 @@ -261,7 +261,7 @@ func TestDispatcher_ConcurrentDispatchSafety(t *testing.T) { assert.Equal(t, int64(goroutines), count.Load()) } -func TestDispatcher_ConcurrentRegisterAndDispatch(t *testing.T) { +func TestDispatcher_ConcurrentRegisterAndDispatch_Ugly(t *testing.T) { d := NewDispatcher() var count atomic.Int64 @@ -301,7 +301,7 @@ func TestDispatcher_ConcurrentRegisterAndDispatch(t *testing.T) { assert.True(t, count.Load() >= 0) } -func TestDispatcher_ReplaceHandler(t *testing.T) { +func TestDispatcher_ReplaceHandler_Good(t *testing.T) { d := NewDispatcher() var firstCalled, secondCalled bool @@ -325,22 +325,22 @@ func TestDispatcher_ReplaceHandler(t *testing.T) { assert.True(t, secondCalled, "replacement handler should be called") } -func TestDispatcher_ThreatBlocksBeforeRouting(t *testing.T) { +func TestDispatcher_ThreatBlocksBeforeRouting_Good(t *testing.T) { // Verify that the circuit breaker fires before intent routing, - // so even an unknown intent returns ErrThreatScoreExceeded (not ErrUnknownIntent). + // so even an unknown intent returns ErrorThreatScoreExceeded (not ErrorUnknownIntent). d := NewDispatcher() pkt := makePacket(0x42, ThreatScoreThreshold+1, []byte("hostile")) err := d.Dispatch(pkt) - assert.ErrorIs(t, err, ErrThreatScoreExceeded, + assert.ErrorIs(t, err, ErrorThreatScoreExceeded, "threat circuit breaker should fire before intent routing") } -func TestDispatcher_IntentConstants(t *testing.T) { +func TestDispatcher_IntentConstants_Good(t *testing.T) { // Verify the well-known intent IDs match the spec (RFC-021). assert.Equal(t, byte(0x01), IntentHandshake) assert.Equal(t, byte(0x20), IntentCompute) - assert.Equal(t, byte(0x30), IntentRehab) + assert.Equal(t, byte(0x30), IntentPauseExecution) assert.Equal(t, byte(0xFF), IntentCustom) } diff --git a/node/errors.go b/node/errors.go index 218610b..26ac3d6 100644 --- a/node/errors.go +++ b/node/errors.go @@ -1,14 +1,9 @@ package node -import coreerr "dappco.re/go/core/log" +import core "dappco.re/go/core" -// Sentinel errors shared across the node package. var ( - // ErrIdentityNotInitialized is returned when a node operation requires - // a node identity but none has been generated or loaded. - ErrIdentityNotInitialized = coreerr.E("node", "node identity not initialized", nil) + ErrorIdentityNotInitialized = core.E("node", "node identity not initialized", nil) - // ErrMinerManagerNotConfigured is returned when a miner operation is - // attempted but no MinerManager has been set on the Worker. - ErrMinerManagerNotConfigured = coreerr.E("node", "miner manager not configured", nil) + ErrorMinerManagerNotConfigured = core.E("node", "miner manager not configured", nil) ) diff --git a/node/filesystem.go b/node/filesystem.go new file mode 100644 index 0000000..4139b7b --- /dev/null +++ b/node/filesystem.go @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package node + +import core "dappco.re/go/core" + +// localFileSystem is the package-scoped filesystem rooted at `/` so node code +// can use Core file operations without os helpers. +var localFileSystem = (&core.Fs{}).New("/") + +func filesystemEnsureDir(path string) error { + return filesystemResultError(localFileSystem.EnsureDir(path)) +} + +func filesystemWrite(path, content string) error { + return filesystemResultError(localFileSystem.Write(path, content)) +} + +func filesystemRead(path string) (string, error) { + result := localFileSystem.Read(path) + if !result.OK { + return "", filesystemResultError(result) + } + + content, ok := result.Value.(string) + if !ok { + return "", core.E("node.filesystemRead", "filesystem read returned non-string content", nil) + } + + return content, nil +} + +func filesystemDelete(path string) error { + return filesystemResultError(localFileSystem.Delete(path)) +} + +func filesystemRename(oldPath, newPath string) error { + return filesystemResultError(localFileSystem.Rename(oldPath, newPath)) +} + +func filesystemExists(path string) bool { + return localFileSystem.Exists(path) +} + +func filesystemResultError(result core.Result) error { + if result.OK { + return nil + } + + if err, ok := result.Value.(error); ok && err != nil { + return err + } + + return core.E("node.filesystem", "filesystem operation failed", nil) +} diff --git a/node/identity.go b/node/identity.go index 5b650a5..9679305 100644 --- a/node/identity.go +++ b/node/identity.go @@ -7,45 +7,41 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" - "encoding/json" - "path/filepath" "sync" "time" - coreio "dappco.re/go/core/io" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" "forge.lthn.ai/Snider/Borg/pkg/stmf" "github.com/adrg/xdg" ) -// ChallengeSize is the size of the challenge in bytes +// challenge := make([]byte, ChallengeSize) const ChallengeSize = 32 -// GenerateChallenge creates a random challenge for authentication. +// challenge, err := GenerateChallenge() func GenerateChallenge() ([]byte, error) { challenge := make([]byte, ChallengeSize) if _, err := rand.Read(challenge); err != nil { - return nil, coreerr.E("GenerateChallenge", "failed to generate challenge", err) + return nil, core.E("GenerateChallenge", "failed to generate challenge", err) } return challenge, nil } -// SignChallenge creates an HMAC signature of a challenge using a shared secret. -// The signature proves possession of the shared secret without revealing it. +// signature := SignChallenge(challenge, sharedSecret) func SignChallenge(challenge []byte, sharedSecret []byte) []byte { mac := hmac.New(sha256.New, sharedSecret) mac.Write(challenge) return mac.Sum(nil) } -// VerifyChallenge verifies that a challenge response was signed with the correct shared secret. +// ok := VerifyChallenge(challenge, signature, sharedSecret) func VerifyChallenge(challenge, response, sharedSecret []byte) bool { expected := SignChallenge(challenge, sharedSecret) return hmac.Equal(response, expected) } -// NodeRole defines the operational mode of a node. +// role := RoleWorker type NodeRole string const ( @@ -57,7 +53,7 @@ const ( RoleDual NodeRole = "dual" ) -// NodeIdentity represents the public identity of a node. +// identity := NodeIdentity{Name: "worker-1", Role: RoleWorker} type NodeIdentity struct { ID string `json:"id"` // Derived from public key (first 16 bytes hex) Name string `json:"name"` // Human-friendly name @@ -66,7 +62,7 @@ type NodeIdentity struct { Role NodeRole `json:"role"` } -// NodeManager handles node identity operations including key generation and storage. +// nodeManager, err := NewNodeManager() type NodeManager struct { identity *NodeIdentity privateKey []byte // Never serialized to JSON @@ -76,46 +72,50 @@ type NodeManager struct { mu sync.RWMutex } -// NewNodeManager creates a new NodeManager, loading existing identity if available. +// nodeManager, err := NewNodeManager() func NewNodeManager() (*NodeManager, error) { keyPath, err := xdg.DataFile("lethean-desktop/node/private.key") if err != nil { - return nil, coreerr.E("NodeManager.New", "failed to get key path", err) + return nil, core.E("NodeManager.New", "failed to get key path", err) } configPath, err := xdg.ConfigFile("lethean-desktop/node.json") if err != nil { - return nil, coreerr.E("NodeManager.New", "failed to get config path", err) + return nil, core.E("NodeManager.New", "failed to get config path", err) } - return NewNodeManagerWithPaths(keyPath, configPath) + return NewNodeManagerFromPaths(keyPath, configPath) } -// NewNodeManagerWithPaths creates a NodeManager with custom paths. -// This is primarily useful for testing to avoid xdg path caching issues. -func NewNodeManagerWithPaths(keyPath, configPath string) (*NodeManager, error) { +// nodeManager, err := NewNodeManagerFromPaths("/srv/p2p/private.key", "/srv/p2p/node.json") +// Missing files are treated as a fresh install; malformed or partial identity +// state returns an error so callers can handle it explicitly. +func NewNodeManagerFromPaths(keyPath, configPath string) (*NodeManager, error) { nm := &NodeManager{ keyPath: keyPath, configPath: configPath, } - // Try to load existing identity - if err := nm.loadIdentity(); err != nil { - // Identity doesn't exist yet, that's ok + // Missing files indicate a first run; anything else is a load failure. + if !filesystemExists(keyPath) && !filesystemExists(configPath) { return nm, nil } + if err := nm.loadIdentity(); err != nil { + return nil, err + } + return nm, nil } -// HasIdentity returns true if a node identity has been initialized. +// hasIdentity := nodeManager.HasIdentity() func (n *NodeManager) HasIdentity() bool { n.mu.RLock() defer n.mu.RUnlock() return n.identity != nil } -// GetIdentity returns the node's public identity. +// identity := nodeManager.GetIdentity() func (n *NodeManager) GetIdentity() *NodeIdentity { n.mu.RLock() defer n.mu.RUnlock() @@ -127,7 +127,7 @@ func (n *NodeManager) GetIdentity() *NodeIdentity { return &identity } -// GenerateIdentity creates a new node identity with the given name and role. +// err := nodeManager.GenerateIdentity("worker-1", RoleWorker) func (n *NodeManager) GenerateIdentity(name string, role NodeRole) error { n.mu.Lock() defer n.mu.Unlock() @@ -135,7 +135,7 @@ func (n *NodeManager) GenerateIdentity(name string, role NodeRole) error { // Generate X25519 keypair using STMF keyPair, err := stmf.GenerateKeyPair() if err != nil { - return coreerr.E("NodeManager.GenerateIdentity", "failed to generate keypair", err) + return core.E("NodeManager.GenerateIdentity", "failed to generate keypair", err) } // Derive node ID from public key (first 16 bytes as hex = 32 char ID) @@ -156,43 +156,42 @@ func (n *NodeManager) GenerateIdentity(name string, role NodeRole) error { // Save private key if err := n.savePrivateKey(); err != nil { - return coreerr.E("NodeManager.GenerateIdentity", "failed to save private key", err) + return core.E("NodeManager.GenerateIdentity", "failed to save private key", err) } // Save identity config if err := n.saveIdentity(); err != nil { - return coreerr.E("NodeManager.GenerateIdentity", "failed to save identity", err) + return core.E("NodeManager.GenerateIdentity", "failed to save identity", err) } return nil } -// DeriveSharedSecret derives a shared secret with a peer using X25519 ECDH. -// The result is hashed with SHA-256 for use as a symmetric key. +// sharedSecret, err := nodeManager.DeriveSharedSecret(peer.PublicKey) func (n *NodeManager) DeriveSharedSecret(peerPubKeyBase64 string) ([]byte, error) { n.mu.RLock() defer n.mu.RUnlock() if n.privateKey == nil { - return nil, ErrIdentityNotInitialized + return nil, ErrorIdentityNotInitialized } // Load peer's public key peerPubKey, err := stmf.LoadPublicKeyBase64(peerPubKeyBase64) if err != nil { - return nil, coreerr.E("NodeManager.DeriveSharedSecret", "failed to load peer public key", err) + return nil, core.E("NodeManager.DeriveSharedSecret", "failed to load peer public key", err) } // Load our private key privateKey, err := ecdh.X25519().NewPrivateKey(n.privateKey) if err != nil { - return nil, coreerr.E("NodeManager.DeriveSharedSecret", "failed to load private key", err) + return nil, core.E("NodeManager.DeriveSharedSecret", "failed to load private key", err) } // Derive shared secret using ECDH sharedSecret, err := privateKey.ECDH(peerPubKey) if err != nil { - return nil, coreerr.E("NodeManager.DeriveSharedSecret", "failed to derive shared secret", err) + return nil, core.E("NodeManager.DeriveSharedSecret", "failed to derive shared secret", err) } // Hash the shared secret using SHA-256 (same pattern as Borg/trix) @@ -200,66 +199,59 @@ func (n *NodeManager) DeriveSharedSecret(peerPubKeyBase64 string) ([]byte, error return hash[:], nil } -// savePrivateKey saves the private key to disk with restricted permissions. func (n *NodeManager) savePrivateKey() error { - // Ensure directory exists - dir := filepath.Dir(n.keyPath) - if err := coreio.Local.EnsureDir(dir); err != nil { - return coreerr.E("NodeManager.savePrivateKey", "failed to create key directory", err) + dir := core.PathDir(n.keyPath) + if err := filesystemEnsureDir(dir); err != nil { + return core.E("NodeManager.savePrivateKey", "failed to create key directory", err) } - // Write private key - if err := coreio.Local.Write(n.keyPath, string(n.privateKey)); err != nil { - return coreerr.E("NodeManager.savePrivateKey", "failed to write private key", err) + if err := filesystemWrite(n.keyPath, string(n.privateKey)); err != nil { + return core.E("NodeManager.savePrivateKey", "failed to write private key", err) } return nil } -// saveIdentity saves the public identity to the config file. func (n *NodeManager) saveIdentity() error { - // Ensure directory exists - dir := filepath.Dir(n.configPath) - if err := coreio.Local.EnsureDir(dir); err != nil { - return coreerr.E("NodeManager.saveIdentity", "failed to create config directory", err) + dir := core.PathDir(n.configPath) + if err := filesystemEnsureDir(dir); err != nil { + return core.E("NodeManager.saveIdentity", "failed to create config directory", err) } - data, err := json.MarshalIndent(n.identity, "", " ") - if err != nil { - return coreerr.E("NodeManager.saveIdentity", "failed to marshal identity", err) + result := core.JSONMarshal(n.identity) + if !result.OK { + return core.E("NodeManager.saveIdentity", "failed to marshal identity", result.Value.(error)) } + data := result.Value.([]byte) - if err := coreio.Local.Write(n.configPath, string(data)); err != nil { - return coreerr.E("NodeManager.saveIdentity", "failed to write identity", err) + if err := filesystemWrite(n.configPath, string(data)); err != nil { + return core.E("NodeManager.saveIdentity", "failed to write identity", err) } return nil } -// loadIdentity loads the node identity from disk. func (n *NodeManager) loadIdentity() error { - // Load identity config - content, err := coreio.Local.Read(n.configPath) + content, err := filesystemRead(n.configPath) if err != nil { - return coreerr.E("NodeManager.loadIdentity", "failed to read identity", err) + return core.E("NodeManager.loadIdentity", "failed to read identity", err) } var identity NodeIdentity - if err := json.Unmarshal([]byte(content), &identity); err != nil { - return coreerr.E("NodeManager.loadIdentity", "failed to unmarshal identity", err) + result := core.JSONUnmarshalString(content, &identity) + if !result.OK { + return core.E("NodeManager.loadIdentity", "failed to unmarshal identity", result.Value.(error)) } - // Load private key - keyContent, err := coreio.Local.Read(n.keyPath) + keyContent, err := filesystemRead(n.keyPath) if err != nil { - return coreerr.E("NodeManager.loadIdentity", "failed to read private key", err) + return core.E("NodeManager.loadIdentity", "failed to read private key", err) } privateKey := []byte(keyContent) - // Reconstruct keypair from private key keyPair, err := stmf.LoadKeyPair(privateKey) if err != nil { - return coreerr.E("NodeManager.loadIdentity", "failed to load keypair", err) + return core.E("NodeManager.loadIdentity", "failed to load keypair", err) } n.identity = &identity @@ -269,22 +261,22 @@ func (n *NodeManager) loadIdentity() error { return nil } -// Delete removes the node identity and keys from disk. +// err := nodeManager.Delete() func (n *NodeManager) Delete() error { n.mu.Lock() defer n.mu.Unlock() // Remove private key (ignore if already absent) - if coreio.Local.Exists(n.keyPath) { - if err := coreio.Local.Delete(n.keyPath); err != nil { - return coreerr.E("NodeManager.Delete", "failed to remove private key", err) + if filesystemExists(n.keyPath) { + if err := filesystemDelete(n.keyPath); err != nil { + return core.E("NodeManager.Delete", "failed to remove private key", err) } } // Remove identity config (ignore if already absent) - if coreio.Local.Exists(n.configPath) { - if err := coreio.Local.Delete(n.configPath); err != nil { - return coreerr.E("NodeManager.Delete", "failed to remove identity", err) + if filesystemExists(n.configPath) { + if err := filesystemDelete(n.configPath); err != nil { + return core.E("NodeManager.Delete", "failed to remove identity", err) } } diff --git a/node/identity_test.go b/node/identity_test.go index e2af1fb..f72d8ef 100644 --- a/node/identity_test.go +++ b/node/identity_test.go @@ -1,38 +1,22 @@ package node import ( - "os" - "path/filepath" "testing" ) -// setupTestNodeManager creates a NodeManager with paths in a temp directory. -func setupTestNodeManager(t *testing.T) (*NodeManager, func()) { - tmpDir, err := os.MkdirTemp("", "node-identity-test") +func newTestNodeManagerWithoutIdentity(t *testing.T) *NodeManager { + tmpDir := t.TempDir() + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(tmpDir)) if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - - keyPath := filepath.Join(tmpDir, "private.key") - configPath := filepath.Join(tmpDir, "node.json") - - nm, err := NewNodeManagerWithPaths(keyPath, configPath) - if err != nil { - os.RemoveAll(tmpDir) t.Fatalf("failed to create node manager: %v", err) } - cleanup := func() { - os.RemoveAll(tmpDir) - } - - return nm, cleanup + return nm } -func TestNodeIdentity(t *testing.T) { +func TestIdentity_NodeIdentity_Good(t *testing.T) { t.Run("NewNodeManager", func(t *testing.T) { - nm, cleanup := setupTestNodeManager(t) - defer cleanup() + nm := newTestNodeManagerWithoutIdentity(t) if nm.HasIdentity() { t.Error("new node manager should not have identity") @@ -40,8 +24,7 @@ func TestNodeIdentity(t *testing.T) { }) t.Run("GenerateIdentity", func(t *testing.T) { - nm, cleanup := setupTestNodeManager(t) - defer cleanup() + nm := newTestNodeManagerWithoutIdentity(t) err := nm.GenerateIdentity("test-node", RoleDual) if err != nil { @@ -75,17 +58,11 @@ func TestNodeIdentity(t *testing.T) { }) t.Run("LoadExistingIdentity", func(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "node-load-test") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - keyPath := filepath.Join(tmpDir, "private.key") - configPath := filepath.Join(tmpDir, "node.json") + tmpDir := t.TempDir() + keyPath, configPath := testNodeManagerPaths(tmpDir) // First, create an identity - nm1, err := NewNodeManagerWithPaths(keyPath, configPath) + nm1, err := NewNodeManagerFromPaths(keyPath, configPath) if err != nil { t.Fatalf("failed to create first node manager: %v", err) } @@ -99,7 +76,7 @@ func TestNodeIdentity(t *testing.T) { originalPubKey := nm1.GetIdentity().PublicKey // Create a new manager - should load existing identity - nm2, err := NewNodeManagerWithPaths(keyPath, configPath) + nm2, err := NewNodeManagerFromPaths(keyPath, configPath) if err != nil { t.Fatalf("failed to create second node manager: %v", err) } @@ -120,16 +97,11 @@ func TestNodeIdentity(t *testing.T) { t.Run("DeriveSharedSecret", func(t *testing.T) { // Create two node managers with separate temp directories - tmpDir1, _ := os.MkdirTemp("", "node1") - tmpDir2, _ := os.MkdirTemp("", "node2") - defer os.RemoveAll(tmpDir1) - defer os.RemoveAll(tmpDir2) + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() // Node 1 - nm1, err := NewNodeManagerWithPaths( - filepath.Join(tmpDir1, "private.key"), - filepath.Join(tmpDir1, "node.json"), - ) + nm1, err := NewNodeManagerFromPaths(testNodeManagerPaths(tmpDir1)) if err != nil { t.Fatalf("failed to create node manager 1: %v", err) } @@ -139,10 +111,7 @@ func TestNodeIdentity(t *testing.T) { } // Node 2 - nm2, err := NewNodeManagerWithPaths( - filepath.Join(tmpDir2, "private.key"), - filepath.Join(tmpDir2, "node.json"), - ) + nm2, err := NewNodeManagerFromPaths(testNodeManagerPaths(tmpDir2)) if err != nil { t.Fatalf("failed to create node manager 2: %v", err) } @@ -175,8 +144,7 @@ func TestNodeIdentity(t *testing.T) { }) t.Run("DeleteIdentity", func(t *testing.T) { - nm, cleanup := setupTestNodeManager(t) - defer cleanup() + nm := newTestNodeManagerWithoutIdentity(t) err := nm.GenerateIdentity("delete-me", RoleDual) if err != nil { @@ -198,7 +166,23 @@ func TestNodeIdentity(t *testing.T) { }) } -func TestNodeRoles(t *testing.T) { +func TestIdentity_NodeManagerFromPaths_CorruptIdentity_Bad(t *testing.T) { + tmpDir := t.TempDir() + keyPath, configPath := testNodeManagerPaths(tmpDir) + + testWriteFile(t, configPath, []byte(`{"id":"node-1","name":"broken","publicKey":"not-json"`), 0o600) + + nm, err := NewNodeManagerFromPaths(keyPath, configPath) + if err == nil { + t.Fatal("expected error when loading a corrupted node identity") + } + + if nm != nil { + t.Fatal("expected nil node manager when identity data is corrupted") + } +} + +func TestIdentity_NodeRoles_Good(t *testing.T) { tests := []struct { role NodeRole expected string @@ -217,7 +201,7 @@ func TestNodeRoles(t *testing.T) { } } -func TestChallengeResponse(t *testing.T) { +func TestIdentity_ChallengeResponse_Good(t *testing.T) { t.Run("GenerateChallenge", func(t *testing.T) { challenge, err := GenerateChallenge() if err != nil { @@ -315,21 +299,13 @@ func TestChallengeResponse(t *testing.T) { t.Run("IntegrationWithSharedSecret", func(t *testing.T) { // Create two nodes and test end-to-end challenge-response - tmpDir1, _ := os.MkdirTemp("", "node-challenge-1") - tmpDir2, _ := os.MkdirTemp("", "node-challenge-2") - defer os.RemoveAll(tmpDir1) - defer os.RemoveAll(tmpDir2) + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() - nm1, _ := NewNodeManagerWithPaths( - filepath.Join(tmpDir1, "private.key"), - filepath.Join(tmpDir1, "node.json"), - ) + nm1, _ := NewNodeManagerFromPaths(testNodeManagerPaths(tmpDir1)) nm1.GenerateIdentity("challenger", RoleDual) - nm2, _ := NewNodeManagerWithPaths( - filepath.Join(tmpDir2, "private.key"), - filepath.Join(tmpDir2, "node.json"), - ) + nm2, _ := NewNodeManagerFromPaths(testNodeManagerPaths(tmpDir2)) nm2.GenerateIdentity("responder", RoleDual) // Challenger generates challenge @@ -352,9 +328,8 @@ func TestChallengeResponse(t *testing.T) { }) } -func TestNodeManager_DeriveSharedSecret_NoIdentity(t *testing.T) { - nm, cleanup := setupTestNodeManager(t) - defer cleanup() +func TestIdentity_NodeManager_DeriveSharedSecret_NoIdentity_Bad(t *testing.T) { + nm := newTestNodeManagerWithoutIdentity(t) // No identity generated _, err := nm.DeriveSharedSecret("some-key") @@ -363,9 +338,8 @@ func TestNodeManager_DeriveSharedSecret_NoIdentity(t *testing.T) { } } -func TestNodeManager_GetIdentity_NilWhenNoIdentity(t *testing.T) { - nm, cleanup := setupTestNodeManager(t) - defer cleanup() +func TestIdentity_NodeManager_Identity_NilWhenNoIdentity_Bad(t *testing.T) { + nm := newTestNodeManagerWithoutIdentity(t) identity := nm.GetIdentity() if identity != nil { @@ -373,11 +347,11 @@ func TestNodeManager_GetIdentity_NilWhenNoIdentity(t *testing.T) { } } -func TestNodeManager_Delete_NoFiles(t *testing.T) { +func TestIdentity_NodeManager_Delete_NoFiles_Bad(t *testing.T) { tmpDir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(tmpDir, "nonexistent.key"), - filepath.Join(tmpDir, "nonexistent.json"), + nm, err := NewNodeManagerFromPaths( + testJoinPath(tmpDir, "nonexistent.key"), + testJoinPath(tmpDir, "nonexistent.json"), ) if err != nil { t.Fatalf("failed to create node manager: %v", err) diff --git a/node/integration_test.go b/node/integration_test.go index 990419f..13d40bc 100644 --- a/node/integration_test.go +++ b/node/integration_test.go @@ -3,11 +3,9 @@ package node import ( "bufio" "bytes" - "encoding/json" "net/http" "net/http/httptest" "net/url" - "path/filepath" "sync" "sync/atomic" "testing" @@ -29,12 +27,12 @@ import ( // 5. Graceful shutdown with disconnect messages // ============================================================================ -func TestIntegration_FullNodeLifecycle(t *testing.T) { +func TestIntegration_FullNodeLifecycle_Good(t *testing.T) { // ---------------------------------------------------------------- // Step 1: Identity creation // ---------------------------------------------------------------- - controllerNM := testNode(t, "integration-controller", RoleController) - workerNM := testNode(t, "integration-worker", RoleWorker) + controllerNM := newTestNodeManager(t, "integration-controller", RoleController) + workerNM := newTestNodeManager(t, "integration-worker", RoleWorker) controllerIdentity := controllerNM.GetIdentity() workerIdentity := workerNM.GetIdentity() @@ -50,8 +48,8 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { // ---------------------------------------------------------------- // Step 2: Set up transports, registries, worker, and controller // ---------------------------------------------------------------- - workerReg := testRegistry(t) - controllerReg := testRegistry(t) + workerReg := newTestPeerRegistry(t) + controllerReg := newTestPeerRegistry(t) workerCfg := DefaultTransportConfig() workerCfg.PingInterval = 2 * time.Second @@ -80,11 +78,11 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { }, }, }) - worker.RegisterWithTransport() + worker.RegisterOnTransport() // Start the worker transport behind httptest. mux := http.NewServeMux() - mux.HandleFunc(workerCfg.WSPath, workerTransport.handleWSUpgrade) + mux.HandleFunc(workerCfg.WebSocketPath, workerTransport.handleWebSocketUpgrade) ts := httptest.NewServer(mux) t.Cleanup(func() { controllerTransport.Stop() @@ -118,9 +116,9 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { // Allow server-side goroutines to register the connection. time.Sleep(100 * time.Millisecond) - assert.Equal(t, 1, controllerTransport.ConnectedPeers(), + assert.Equal(t, 1, controllerTransport.ConnectedPeerCount(), "controller should have 1 connected peer") - assert.Equal(t, 1, workerTransport.ConnectedPeers(), + assert.Equal(t, 1, workerTransport.ConnectedPeerCount(), "worker should have 1 connected peer") // Verify the peer's real identity is stored. @@ -140,13 +138,13 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { // Verify registry metrics were updated. peerAfterPing := controllerReg.GetPeer(serverPeerID) require.NotNil(t, peerAfterPing) - assert.Greater(t, peerAfterPing.PingMS, 0.0, "PingMS should be updated") + assert.Greater(t, peerAfterPing.PingMilliseconds, 0.0, "PingMilliseconds should be updated") // ---------------------------------------------------------------- - // Step 5: Encrypted message exchange — GetRemoteStats + // Step 5: Encrypted message exchange — RemoteStats // ---------------------------------------------------------------- stats, err := controller.GetRemoteStats(serverPeerID) - require.NoError(t, err, "GetRemoteStats should succeed") + require.NoError(t, err, "RemoteStats should succeed") require.NotNil(t, stats) assert.Equal(t, workerIdentity.ID, stats.NodeID) assert.Equal(t, "integration-worker", stats.NodeName) @@ -201,7 +199,7 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { parsed3, err := ueps.ReadAndVerify(bufio.NewReader(bytes.NewReader(wireData3)), sharedSecret) require.NoError(t, err) err = dispatcher.Dispatch(parsed3) - assert.ErrorIs(t, err, ErrThreatScoreExceeded, + assert.ErrorIs(t, err, ErrorThreatScoreExceeded, "high-threat packet should be dropped by circuit breaker") // Compute handler should NOT have been called again. assert.Equal(t, int32(1), computeReceived.Load()) @@ -211,7 +209,7 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { // ---------------------------------------------------------------- disconnectReceived := make(chan *Message, 1) workerTransport.OnMessage(func(conn *PeerConnection, msg *Message) { - if msg.Type == MsgDisconnect { + if msg.Type == MessageDisconnect { disconnectReceived <- msg } }) @@ -221,7 +219,7 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { select { case msg := <-disconnectReceived: - assert.Equal(t, MsgDisconnect, msg.Type) + assert.Equal(t, MessageDisconnect, msg.Type) var payload DisconnectPayload require.NoError(t, msg.ParsePayload(&payload)) assert.Equal(t, "integration test complete", payload.Reason) @@ -234,15 +232,15 @@ func TestIntegration_FullNodeLifecycle(t *testing.T) { time.Sleep(200 * time.Millisecond) // After graceful close, the controller should have 0 peers. - assert.Equal(t, 0, controllerTransport.ConnectedPeers(), + assert.Equal(t, 0, controllerTransport.ConnectedPeerCount(), "controller should have 0 peers after graceful close") } // TestIntegration_SharedSecretAgreement verifies that two independently created // nodes derive the same shared secret via ECDH. -func TestIntegration_SharedSecretAgreement(t *testing.T) { - nodeA := testNode(t, "secret-node-a", RoleDual) - nodeB := testNode(t, "secret-node-b", RoleDual) +func TestIntegration_SharedSecretAgreement_Good(t *testing.T) { + nodeA := newTestNodeManager(t, "secret-node-a", RoleDual) + nodeB := newTestNodeManager(t, "secret-node-b", RoleDual) pubKeyA := nodeA.GetIdentity().PublicKey pubKeyB := nodeB.GetIdentity().PublicKey @@ -260,7 +258,7 @@ func TestIntegration_SharedSecretAgreement(t *testing.T) { // TestIntegration_TwoNodeBidirectionalMessages verifies that both nodes // can send and receive encrypted messages after the handshake. -func TestIntegration_TwoNodeBidirectionalMessages(t *testing.T) { +func TestIntegration_TwoNodeBidirectionalMessages_Good(t *testing.T) { controller, _, tp := setupControllerPair(t) serverID := tp.ServerNode.GetIdentity().ID @@ -285,9 +283,9 @@ func TestIntegration_TwoNodeBidirectionalMessages(t *testing.T) { // TestIntegration_MultiPeerTopology verifies that a controller can // simultaneously communicate with multiple workers. -func TestIntegration_MultiPeerTopology(t *testing.T) { - controllerNM := testNode(t, "multi-controller", RoleController) - controllerReg := testRegistry(t) +func TestIntegration_MultiPeerTopology_Good(t *testing.T) { + controllerNM := newTestNodeManager(t, "multi-controller", RoleController) + controllerReg := newTestPeerRegistry(t) controllerTransport := NewTransport(controllerNM, controllerReg, DefaultTransportConfig()) t.Cleanup(func() { controllerTransport.Stop() }) @@ -312,7 +310,7 @@ func TestIntegration_MultiPeerTopology(t *testing.T) { } time.Sleep(100 * time.Millisecond) - assert.Equal(t, numWorkers, controllerTransport.ConnectedPeers(), + assert.Equal(t, numWorkers, controllerTransport.ConnectedPeerCount(), "controller should be connected to all workers") controller := NewController(controllerNM, controllerReg, controllerTransport) @@ -343,13 +341,12 @@ func TestIntegration_MultiPeerTopology(t *testing.T) { // TestIntegration_IdentityPersistenceAndReload verifies that a node identity // can be generated, persisted, and reloaded from disk. -func TestIntegration_IdentityPersistenceAndReload(t *testing.T) { +func TestIntegration_IdentityPersistenceAndReload_Good(t *testing.T) { dir := t.TempDir() - keyPath := filepath.Join(dir, "private.key") - configPath := filepath.Join(dir, "node.json") + keyPath, configPath := testNodeManagerPaths(dir) // Create and persist identity. - nm1, err := NewNodeManagerWithPaths(keyPath, configPath) + nm1, err := NewNodeManagerFromPaths(keyPath, configPath) require.NoError(t, err) require.NoError(t, nm1.GenerateIdentity("persistent-node", RoleDual)) @@ -357,7 +354,7 @@ func TestIntegration_IdentityPersistenceAndReload(t *testing.T) { require.NotNil(t, original) // Reload from disk. - nm2, err := NewNodeManagerWithPaths(keyPath, configPath) + nm2, err := NewNodeManagerFromPaths(keyPath, configPath) require.NoError(t, err) require.True(t, nm2.HasIdentity(), "identity should be loaded from disk") @@ -386,10 +383,7 @@ func TestIntegration_IdentityPersistenceAndReload(t *testing.T) { // stmfGenerateKeyPair is a helper that generates a keypair and returns // the public key as base64 (for use in DeriveSharedSecret tests). func stmfGenerateKeyPair(dir string) (string, error) { - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { return "", err } @@ -399,12 +393,11 @@ func stmfGenerateKeyPair(dir string) (string, error) { return nm.GetIdentity().PublicKey, nil } - // TestIntegration_UEPSFullRoundTrip exercises a complete UEPS packet // lifecycle: build, sign, transmit (simulated), read, verify, dispatch. -func TestIntegration_UEPSFullRoundTrip(t *testing.T) { - nodeA := testNode(t, "ueps-node-a", RoleController) - nodeB := testNode(t, "ueps-node-b", RoleWorker) +func TestIntegration_UEPSFullRoundTrip_Ugly(t *testing.T) { + nodeA := newTestNodeManager(t, "ueps-node-a", RoleController) + nodeB := newTestNodeManager(t, "ueps-node-b", RoleWorker) bPubKey := nodeB.GetIdentity().PublicKey sharedSecret, err := nodeA.DeriveSharedSecret(bPubKey) @@ -453,9 +446,9 @@ func TestIntegration_UEPSFullRoundTrip(t *testing.T) { // TestIntegration_UEPSIntegrityFailure verifies that a tampered UEPS packet // is rejected by HMAC verification. -func TestIntegration_UEPSIntegrityFailure(t *testing.T) { - nodeA := testNode(t, "integrity-a", RoleController) - nodeB := testNode(t, "integrity-b", RoleWorker) +func TestIntegration_UEPSIntegrityFailure_Bad(t *testing.T) { + nodeA := newTestNodeManager(t, "integrity-a", RoleController) + nodeB := newTestNodeManager(t, "integrity-b", RoleWorker) bPubKey := nodeB.GetIdentity().PublicKey sharedSecret, err := nodeA.DeriveSharedSecret(bPubKey) @@ -484,15 +477,15 @@ func TestIntegration_UEPSIntegrityFailure(t *testing.T) { // TestIntegration_AllowlistHandshakeRejection verifies that a peer not in the // allowlist is rejected during the WebSocket handshake. -func TestIntegration_AllowlistHandshakeRejection(t *testing.T) { - workerNM := testNode(t, "allowlist-worker", RoleWorker) - workerReg := testRegistry(t) +func TestIntegration_AllowlistHandshakeRejection_Bad(t *testing.T) { + workerNM := newTestNodeManager(t, "allowlist-worker", RoleWorker) + workerReg := newTestPeerRegistry(t) workerReg.SetAuthMode(PeerAuthAllowlist) workerTransport := NewTransport(workerNM, workerReg, DefaultTransportConfig()) mux := http.NewServeMux() - mux.HandleFunc("/ws", workerTransport.handleWSUpgrade) + mux.HandleFunc("/ws", workerTransport.handleWebSocketUpgrade) ts := httptest.NewServer(mux) t.Cleanup(func() { workerTransport.Stop() @@ -501,8 +494,8 @@ func TestIntegration_AllowlistHandshakeRejection(t *testing.T) { u, _ := url.Parse(ts.URL) - controllerNM := testNode(t, "rejected-controller", RoleController) - controllerReg := testRegistry(t) + controllerNM := newTestNodeManager(t, "rejected-controller", RoleController) + controllerReg := newTestPeerRegistry(t) controllerTransport := NewTransport(controllerNM, controllerReg, DefaultTransportConfig()) t.Cleanup(func() { controllerTransport.Stop() }) @@ -521,22 +514,22 @@ func TestIntegration_AllowlistHandshakeRejection(t *testing.T) { // TestIntegration_AllowlistHandshakeAccepted verifies that an allowlisted // peer can connect successfully. -func TestIntegration_AllowlistHandshakeAccepted(t *testing.T) { - workerNM := testNode(t, "allowlist-worker-ok", RoleWorker) - workerReg := testRegistry(t) +func TestIntegration_AllowlistHandshakeAccepted_Good(t *testing.T) { + workerNM := newTestNodeManager(t, "allowlist-worker-ok", RoleWorker) + workerReg := newTestPeerRegistry(t) workerReg.SetAuthMode(PeerAuthAllowlist) - controllerNM := testNode(t, "allowed-controller", RoleController) - controllerReg := testRegistry(t) + controllerNM := newTestNodeManager(t, "allowed-controller", RoleController) + controllerReg := newTestPeerRegistry(t) workerReg.AllowPublicKey(controllerNM.GetIdentity().PublicKey) workerTransport := NewTransport(workerNM, workerReg, DefaultTransportConfig()) worker := NewWorker(workerNM, workerTransport) - worker.RegisterWithTransport() + worker.RegisterOnTransport() mux := http.NewServeMux() - mux.HandleFunc("/ws", workerTransport.handleWSUpgrade) + mux.HandleFunc("/ws", workerTransport.handleWebSocketUpgrade) ts := httptest.NewServer(mux) t.Cleanup(func() { workerTransport.Stop() @@ -563,7 +556,7 @@ func TestIntegration_AllowlistHandshakeAccepted(t *testing.T) { // TestIntegration_DispatcherWithRealUEPSPackets builds real UEPS packets // from wire bytes and routes them through the dispatcher. -func TestIntegration_DispatcherWithRealUEPSPackets(t *testing.T) { +func TestIntegration_DispatcherWithRealUEPSPackets_Good(t *testing.T) { sharedSecret := make([]byte, 32) for i := range sharedSecret { sharedSecret[i] = byte(i ^ 0x42) @@ -579,7 +572,7 @@ func TestIntegration_DispatcherWithRealUEPSPackets(t *testing.T) { }{ {IntentHandshake, "handshake", "hello"}, {IntentCompute, "compute", `{"job":"123"}`}, - {IntentRehab, "rehab", "pause"}, + {IntentPauseExecution, "pause-execution", "pause"}, {IntentCustom, "custom", "app-specific-data"}, } @@ -614,11 +607,11 @@ func TestIntegration_DispatcherWithRealUEPSPackets(t *testing.T) { // TestIntegration_MessageSerialiseDeserialise verifies that messages survive // the full serialisation/encryption/decryption/deserialisation pipeline // with all fields intact. -func TestIntegration_MessageSerialiseDeserialise(t *testing.T) { +func TestIntegration_MessageSerialiseDeserialise_Good(t *testing.T) { tp := setupTestTransportPair(t) pc := tp.connectClient(t) - original, err := NewMessage(MsgStats, tp.ClientNode.GetIdentity().ID, tp.ServerNode.GetIdentity().ID, StatsPayload{ + original, err := NewMessage(MessageStats, tp.ClientNode.GetIdentity().ID, tp.ServerNode.GetIdentity().ID, StatsPayload{ NodeID: "test-node", NodeName: "test-name", Miners: []MinerStatsItem{ @@ -653,18 +646,18 @@ func TestIntegration_MessageSerialiseDeserialise(t *testing.T) { assert.Equal(t, original.ReplyTo, decrypted.ReplyTo) var originalStats, decryptedStats StatsPayload - require.NoError(t, json.Unmarshal(original.Payload, &originalStats)) - require.NoError(t, json.Unmarshal(decrypted.Payload, &decryptedStats)) + testJSONUnmarshal(t, original.Payload, &originalStats) + testJSONUnmarshal(t, decrypted.Payload, &decryptedStats) assert.Equal(t, originalStats, decryptedStats) } -// TestIntegration_GetRemoteStats_EndToEnd tests the full stats retrieval flow +// TestIntegration_RemoteStats_EndToEnd tests the full stats retrieval flow // across a real WebSocket connection. -func TestIntegration_GetRemoteStats_EndToEnd(t *testing.T) { +func TestIntegration_RemoteStats_EndToEnd_Good(t *testing.T) { tp := setupTestTransportPair(t) worker := NewWorker(tp.ServerNode, tp.Server) - worker.RegisterWithTransport() + worker.RegisterOnTransport() controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) @@ -674,7 +667,7 @@ func TestIntegration_GetRemoteStats_EndToEnd(t *testing.T) { serverID := tp.ServerNode.GetIdentity().ID stats, err := controller.GetRemoteStats(serverID) - require.NoError(t, err, "GetRemoteStats should succeed end-to-end") + require.NoError(t, err, "RemoteStats should succeed end-to-end") require.NotNil(t, stats) assert.Equal(t, serverID, stats.NodeID) assert.Equal(t, "server", stats.NodeName) diff --git a/node/levin/connection.go b/node/levin/connection.go index a3e1a11..a479b04 100644 --- a/node/levin/connection.go +++ b/node/levin/connection.go @@ -10,24 +10,22 @@ import ( "time" ) -// Levin protocol flags. +// flags := FlagRequest | FlagResponse const ( FlagRequest uint32 = 0x00000001 FlagResponse uint32 = 0x00000002 ) -// LevinProtocolVersion is the protocol version field written into every header. +// header.ProtocolVersion = LevinProtocolVersion const LevinProtocolVersion uint32 = 1 -// Default timeout values for Connection read and write operations. +// connection.ReadTimeout = DefaultReadTimeout const ( DefaultReadTimeout = 120 * time.Second DefaultWriteTimeout = 30 * time.Second ) -// Connection wraps a net.Conn and provides framed Levin packet I/O. -// All writes are serialised by an internal mutex, making it safe to call -// WritePacket and WriteResponse concurrently from multiple goroutines. +// connection := NewConnection(networkConnection) type Connection struct { // MaxPayloadSize is the upper bound accepted for incoming payloads. // Defaults to the package-level MaxPayloadSize (100 MB). @@ -39,66 +37,64 @@ type Connection struct { // WriteTimeout is the deadline applied before each write call. WriteTimeout time.Duration - conn net.Conn - writeMu sync.Mutex + networkConnection net.Conn + writeMutex sync.Mutex } -// NewConnection creates a Connection that wraps conn with sensible defaults. -func NewConnection(conn net.Conn) *Connection { +// connection := NewConnection(networkConnection) +func NewConnection(connection net.Conn) *Connection { return &Connection{ - MaxPayloadSize: MaxPayloadSize, - ReadTimeout: DefaultReadTimeout, - WriteTimeout: DefaultWriteTimeout, - conn: conn, + MaxPayloadSize: MaxPayloadSize, + ReadTimeout: DefaultReadTimeout, + WriteTimeout: DefaultWriteTimeout, + networkConnection: connection, } } -// WritePacket sends a Levin request or notification. It builds a 33-byte -// header, then writes header + payload atomically under the write mutex. -func (c *Connection) WritePacket(cmd uint32, payload []byte, expectResponse bool) error { - h := Header{ +// err := connection.WritePacket(CommandPing, payload, true) +func (connection *Connection) WritePacket(commandID uint32, payload []byte, expectResponse bool) error { + header := Header{ Signature: Signature, PayloadSize: uint64(len(payload)), ExpectResponse: expectResponse, - Command: cmd, + Command: commandID, ReturnCode: ReturnOK, Flags: FlagRequest, ProtocolVersion: LevinProtocolVersion, } - return c.writeFrame(&h, payload) + return connection.writeFrame(&header, payload) } -// WriteResponse sends a Levin response packet with the given return code. -func (c *Connection) WriteResponse(cmd uint32, payload []byte, returnCode int32) error { - h := Header{ +// err := connection.WriteResponse(CommandPing, payload, ReturnOK) +func (connection *Connection) WriteResponse(commandID uint32, payload []byte, returnCode int32) error { + header := Header{ Signature: Signature, PayloadSize: uint64(len(payload)), ExpectResponse: false, - Command: cmd, + Command: commandID, ReturnCode: returnCode, Flags: FlagResponse, ProtocolVersion: LevinProtocolVersion, } - return c.writeFrame(&h, payload) + return connection.writeFrame(&header, payload) } -// writeFrame serialises header + payload and writes them atomically. -func (c *Connection) writeFrame(h *Header, payload []byte) error { - buf := EncodeHeader(h) +func (connection *Connection) writeFrame(header *Header, payload []byte) error { + headerBytes := EncodeHeader(header) - c.writeMu.Lock() - defer c.writeMu.Unlock() + connection.writeMutex.Lock() + defer connection.writeMutex.Unlock() - if err := c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil { + if err := connection.networkConnection.SetWriteDeadline(time.Now().Add(connection.WriteTimeout)); err != nil { return err } - if _, err := c.conn.Write(buf[:]); err != nil { + if _, err := connection.networkConnection.Write(headerBytes[:]); err != nil { return err } if len(payload) > 0 { - if _, err := c.conn.Write(payload); err != nil { + if _, err := connection.networkConnection.Write(payload); err != nil { return err } } @@ -106,49 +102,46 @@ func (c *Connection) writeFrame(h *Header, payload []byte) error { return nil } -// ReadPacket reads exactly 33 header bytes, validates the signature, -// checks the payload size against MaxPayloadSize, then reads exactly -// PayloadSize bytes of payload data. -func (c *Connection) ReadPacket() (Header, []byte, error) { - if err := c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil { +// header, payload, err := connection.ReadPacket() +func (connection *Connection) ReadPacket() (Header, []byte, error) { + if err := connection.networkConnection.SetReadDeadline(time.Now().Add(connection.ReadTimeout)); err != nil { return Header{}, nil, err } - // Read header. - var hdrBuf [HeaderSize]byte - if _, err := io.ReadFull(c.conn, hdrBuf[:]); err != nil { + var headerBytes [HeaderSize]byte + if _, err := io.ReadFull(connection.networkConnection, headerBytes[:]); err != nil { return Header{}, nil, err } - h, err := DecodeHeader(hdrBuf) + header, err := DecodeHeader(headerBytes) if err != nil { return Header{}, nil, err } // Check against the connection-specific payload limit. - if h.PayloadSize > c.MaxPayloadSize { - return Header{}, nil, ErrPayloadTooBig + if header.PayloadSize > connection.MaxPayloadSize { + return Header{}, nil, ErrorPayloadTooBig } // Empty payload is valid — return nil data without allocation. - if h.PayloadSize == 0 { - return h, nil, nil + if header.PayloadSize == 0 { + return header, nil, nil } - payload := make([]byte, h.PayloadSize) - if _, err := io.ReadFull(c.conn, payload); err != nil { + payload := make([]byte, header.PayloadSize) + if _, err := io.ReadFull(connection.networkConnection, payload); err != nil { return Header{}, nil, err } - return h, payload, nil + return header, payload, nil } -// Close closes the underlying network connection. -func (c *Connection) Close() error { - return c.conn.Close() +// err := connection.Close() +func (connection *Connection) Close() error { + return connection.networkConnection.Close() } -// RemoteAddr returns the remote address of the underlying connection as a string. -func (c *Connection) RemoteAddr() string { - return c.conn.RemoteAddr().String() +// addr := connection.RemoteAddr() +func (connection *Connection) RemoteAddr() string { + return connection.networkConnection.RemoteAddr().String() } diff --git a/node/levin/connection_test.go b/node/levin/connection_test.go index 84e494c..2b355f4 100644 --- a/node/levin/connection_test.go +++ b/node/levin/connection_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestConnection_RoundTrip(t *testing.T) { +func TestConnection_RoundTrip_Ugly(t *testing.T) { a, b := net.Pipe() defer a.Close() defer b.Close() @@ -41,7 +41,7 @@ func TestConnection_RoundTrip(t *testing.T) { assert.Equal(t, payload, data) } -func TestConnection_EmptyPayload(t *testing.T) { +func TestConnection_EmptyPayload_Ugly(t *testing.T) { a, b := net.Pipe() defer a.Close() defer b.Close() @@ -64,7 +64,7 @@ func TestConnection_EmptyPayload(t *testing.T) { assert.Nil(t, data) } -func TestConnection_Response(t *testing.T) { +func TestConnection_Response_Good(t *testing.T) { a, b := net.Pipe() defer a.Close() defer b.Close() @@ -73,7 +73,7 @@ func TestConnection_Response(t *testing.T) { receiver := NewConnection(b) payload := []byte("response data") - retCode := ReturnErrFormat + retCode := ReturnErrorFormat errCh := make(chan error, 1) go func() { @@ -91,7 +91,7 @@ func TestConnection_Response(t *testing.T) { assert.Equal(t, payload, data) } -func TestConnection_PayloadTooBig(t *testing.T) { +func TestConnection_PayloadTooBig_Bad(t *testing.T) { a, b := net.Pipe() defer a.Close() defer b.Close() @@ -120,12 +120,12 @@ func TestConnection_PayloadTooBig(t *testing.T) { _, _, err := receiver.ReadPacket() require.Error(t, err) - assert.ErrorIs(t, err, ErrPayloadTooBig) + assert.ErrorIs(t, err, ErrorPayloadTooBig) require.NoError(t, <-errCh) } -func TestConnection_ReadTimeout(t *testing.T) { +func TestConnection_ReadTimeout_Bad(t *testing.T) { a, b := net.Pipe() defer a.Close() defer b.Close() @@ -143,7 +143,7 @@ func TestConnection_ReadTimeout(t *testing.T) { assert.True(t, netErr.Timeout(), "expected timeout error") } -func TestConnection_RemoteAddr(t *testing.T) { +func TestConnection_RemoteAddr_Good(t *testing.T) { a, b := net.Pipe() defer a.Close() defer b.Close() @@ -153,7 +153,7 @@ func TestConnection_RemoteAddr(t *testing.T) { assert.NotEmpty(t, addr) } -func TestConnection_Close(t *testing.T) { +func TestConnection_Close_Ugly(t *testing.T) { a, b := net.Pipe() defer b.Close() diff --git a/node/levin/header.go b/node/levin/header.go index e93531f..9da0760 100644 --- a/node/levin/header.go +++ b/node/levin/header.go @@ -8,28 +8,28 @@ package levin import ( "encoding/binary" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" ) -// HeaderSize is the exact byte length of a serialised Levin header. +// headerBytes := make([]byte, HeaderSize) const HeaderSize = 33 -// Signature is the magic value that opens every Levin packet. +// header.Signature = Signature const Signature uint64 = 0x0101010101012101 -// MaxPayloadSize is the upper bound we accept for a single payload (100 MB). +// header.PayloadSize <= MaxPayloadSize const MaxPayloadSize uint64 = 100 * 1024 * 1024 -// Return-code constants carried in every Levin response. const ( - ReturnOK int32 = 0 - ReturnErrConnection int32 = -1 - ReturnErrFormat int32 = -7 - ReturnErrSignature int32 = -13 + // returnCode := ReturnOK + ReturnOK int32 = 0 + ReturnErrorConnection int32 = -1 + ReturnErrorFormat int32 = -7 + ReturnErrorSignature int32 = -13 ) -// Command IDs for the CryptoNote P2P layer. const ( + // commandID := CommandHandshake CommandHandshake uint32 = 1001 CommandTimedSync uint32 = 1002 CommandPing uint32 = 1003 @@ -41,13 +41,13 @@ const ( CommandResponseChain uint32 = 2007 ) -// Sentinel errors returned by DecodeHeader. var ( - ErrBadSignature = coreerr.E("levin", "bad signature", nil) - ErrPayloadTooBig = coreerr.E("levin", "payload exceeds maximum size", nil) + // err := ErrorBadSignature + ErrorBadSignature = core.E("levin", "bad signature", nil) + ErrorPayloadTooBig = core.E("levin", "payload exceeds maximum size", nil) ) -// Header is the 33-byte packed header that prefixes every Levin message. +// header := Header{Command: CommandHandshake, ExpectResponse: true} type Header struct { Signature uint64 PayloadSize uint64 @@ -58,39 +58,38 @@ type Header struct { ProtocolVersion uint32 } -// EncodeHeader serialises h into a fixed-size 33-byte array (little-endian). -func EncodeHeader(h *Header) [HeaderSize]byte { - var buf [HeaderSize]byte - binary.LittleEndian.PutUint64(buf[0:8], h.Signature) - binary.LittleEndian.PutUint64(buf[8:16], h.PayloadSize) - if h.ExpectResponse { - buf[16] = 0x01 +// encoded := EncodeHeader(header) +func EncodeHeader(header *Header) [HeaderSize]byte { + var headerBytes [HeaderSize]byte + binary.LittleEndian.PutUint64(headerBytes[0:8], header.Signature) + binary.LittleEndian.PutUint64(headerBytes[8:16], header.PayloadSize) + if header.ExpectResponse { + headerBytes[16] = 0x01 } else { - buf[16] = 0x00 + headerBytes[16] = 0x00 } - binary.LittleEndian.PutUint32(buf[17:21], h.Command) - binary.LittleEndian.PutUint32(buf[21:25], uint32(h.ReturnCode)) - binary.LittleEndian.PutUint32(buf[25:29], h.Flags) - binary.LittleEndian.PutUint32(buf[29:33], h.ProtocolVersion) - return buf + binary.LittleEndian.PutUint32(headerBytes[17:21], header.Command) + binary.LittleEndian.PutUint32(headerBytes[21:25], uint32(header.ReturnCode)) + binary.LittleEndian.PutUint32(headerBytes[25:29], header.Flags) + binary.LittleEndian.PutUint32(headerBytes[29:33], header.ProtocolVersion) + return headerBytes } -// DecodeHeader deserialises a 33-byte array into a Header, validating -// the magic signature. -func DecodeHeader(buf [HeaderSize]byte) (Header, error) { - var h Header - h.Signature = binary.LittleEndian.Uint64(buf[0:8]) - if h.Signature != Signature { - return Header{}, ErrBadSignature +// header, err := DecodeHeader(headerBytes) +func DecodeHeader(headerBytes [HeaderSize]byte) (Header, error) { + var header Header + header.Signature = binary.LittleEndian.Uint64(headerBytes[0:8]) + if header.Signature != Signature { + return Header{}, ErrorBadSignature } - h.PayloadSize = binary.LittleEndian.Uint64(buf[8:16]) - if h.PayloadSize > MaxPayloadSize { - return Header{}, ErrPayloadTooBig + header.PayloadSize = binary.LittleEndian.Uint64(headerBytes[8:16]) + if header.PayloadSize > MaxPayloadSize { + return Header{}, ErrorPayloadTooBig } - h.ExpectResponse = buf[16] == 0x01 - h.Command = binary.LittleEndian.Uint32(buf[17:21]) - h.ReturnCode = int32(binary.LittleEndian.Uint32(buf[21:25])) - h.Flags = binary.LittleEndian.Uint32(buf[25:29]) - h.ProtocolVersion = binary.LittleEndian.Uint32(buf[29:33]) - return h, nil + header.ExpectResponse = headerBytes[16] == 0x01 + header.Command = binary.LittleEndian.Uint32(headerBytes[17:21]) + header.ReturnCode = int32(binary.LittleEndian.Uint32(headerBytes[21:25])) + header.Flags = binary.LittleEndian.Uint32(headerBytes[25:29]) + header.ProtocolVersion = binary.LittleEndian.Uint32(headerBytes[29:33]) + return header, nil } diff --git a/node/levin/header_test.go b/node/levin/header_test.go index 4edfdaf..39acdd0 100644 --- a/node/levin/header_test.go +++ b/node/levin/header_test.go @@ -11,11 +11,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestHeaderSizeIs33(t *testing.T) { +func TestHeader_SizeIs33_Good(t *testing.T) { assert.Equal(t, 33, HeaderSize) } -func TestEncodeHeader_KnownValues(t *testing.T) { +func TestHeader_EncodeHeader_KnownValues_Good(t *testing.T) { h := &Header{ Signature: Signature, PayloadSize: 256, @@ -56,7 +56,7 @@ func TestEncodeHeader_KnownValues(t *testing.T) { assert.Equal(t, uint32(0), pv) } -func TestEncodeHeader_ExpectResponseFalse(t *testing.T) { +func TestHeader_EncodeHeader_ExpectResponseFalse_Good(t *testing.T) { h := &Header{ Signature: Signature, PayloadSize: 42, @@ -68,26 +68,26 @@ func TestEncodeHeader_ExpectResponseFalse(t *testing.T) { assert.Equal(t, byte(0x00), buf[16]) } -func TestEncodeHeader_NegativeReturnCode(t *testing.T) { +func TestHeader_EncodeHeader_NegativeReturnCode_Good(t *testing.T) { h := &Header{ Signature: Signature, PayloadSize: 0, ExpectResponse: false, Command: CommandHandshake, - ReturnCode: ReturnErrFormat, + ReturnCode: ReturnErrorFormat, } buf := EncodeHeader(h) rc := int32(binary.LittleEndian.Uint32(buf[21:25])) - assert.Equal(t, ReturnErrFormat, rc) + assert.Equal(t, ReturnErrorFormat, rc) } -func TestDecodeHeader_RoundTrip(t *testing.T) { +func TestHeader_DecodeHeader_RoundTrip_Ugly(t *testing.T) { original := &Header{ Signature: Signature, PayloadSize: 1024, ExpectResponse: true, Command: CommandTimedSync, - ReturnCode: ReturnErrConnection, + ReturnCode: ReturnErrorConnection, Flags: 0, ProtocolVersion: 0, } @@ -105,7 +105,7 @@ func TestDecodeHeader_RoundTrip(t *testing.T) { assert.Equal(t, original.ProtocolVersion, decoded.ProtocolVersion) } -func TestDecodeHeader_AllCommands(t *testing.T) { +func TestHeader_DecodeHeader_AllCommands_Good(t *testing.T) { commands := []uint32{ CommandHandshake, CommandTimedSync, @@ -131,7 +131,7 @@ func TestDecodeHeader_AllCommands(t *testing.T) { } } -func TestDecodeHeader_BadSignature(t *testing.T) { +func TestHeader_DecodeHeader_BadSignature_Bad(t *testing.T) { h := &Header{ Signature: 0xDEADBEEF, PayloadSize: 0, @@ -140,10 +140,10 @@ func TestDecodeHeader_BadSignature(t *testing.T) { buf := EncodeHeader(h) _, err := DecodeHeader(buf) require.Error(t, err) - assert.ErrorIs(t, err, ErrBadSignature) + assert.ErrorIs(t, err, ErrorBadSignature) } -func TestDecodeHeader_PayloadTooBig(t *testing.T) { +func TestHeader_DecodeHeader_PayloadTooBig_Bad(t *testing.T) { h := &Header{ Signature: Signature, PayloadSize: MaxPayloadSize + 1, @@ -152,10 +152,10 @@ func TestDecodeHeader_PayloadTooBig(t *testing.T) { buf := EncodeHeader(h) _, err := DecodeHeader(buf) require.Error(t, err) - assert.ErrorIs(t, err, ErrPayloadTooBig) + assert.ErrorIs(t, err, ErrorPayloadTooBig) } -func TestDecodeHeader_MaxPayloadExact(t *testing.T) { +func TestHeader_DecodeHeader_MaxPayloadExact_Ugly(t *testing.T) { h := &Header{ Signature: Signature, PayloadSize: MaxPayloadSize, diff --git a/node/levin/storage.go b/node/levin/storage.go index 3d39718..08aa7ca 100644 --- a/node/levin/storage.go +++ b/node/levin/storage.go @@ -5,12 +5,11 @@ package levin import ( "encoding/binary" - "fmt" "maps" "math" "slices" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" ) // Portable storage signatures and version (9-byte header). @@ -41,20 +40,18 @@ const ( // Sentinel errors for storage encoding and decoding. var ( - ErrStorageBadSignature = coreerr.E("levin.storage", "bad storage signature", nil) - ErrStorageTruncated = coreerr.E("levin.storage", "truncated storage data", nil) - ErrStorageBadVersion = coreerr.E("levin.storage", "unsupported storage version", nil) - ErrStorageNameTooLong = coreerr.E("levin.storage", "entry name exceeds 255 bytes", nil) - ErrStorageTypeMismatch = coreerr.E("levin.storage", "value type mismatch", nil) - ErrStorageUnknownType = coreerr.E("levin.storage", "unknown type tag", nil) + ErrorStorageBadSignature = core.E("levin.storage", "bad storage signature", nil) + ErrorStorageTruncated = core.E("levin.storage", "truncated storage data", nil) + ErrorStorageBadVersion = core.E("levin.storage", "unsupported storage version", nil) + ErrorStorageNameTooLong = core.E("levin.storage", "entry name exceeds 255 bytes", nil) + ErrorStorageTypeMismatch = core.E("levin.storage", "value type mismatch", nil) + ErrorStorageUnknownType = core.E("levin.storage", "unknown type tag", nil) ) -// Section is an ordered map of named values forming a portable storage section. -// Field iteration order is always alphabetical by key for deterministic encoding. +// section := Section{"id": StringValue([]byte("peer-1"))} type Section map[string]Value -// Value holds a typed portable storage value. Use the constructor functions -// (Uint64Val, StringVal, ObjectVal, etc.) to create instances. +// value := StringValue([]byte("peer-1")) type Value struct { Type uint8 @@ -77,162 +74,162 @@ type Value struct { // Scalar constructors // --------------------------------------------------------------------------- -// Uint64Val creates a Value of TypeUint64. -func Uint64Val(v uint64) Value { return Value{Type: TypeUint64, uintVal: v} } +// value := Uint64Value(42) +func Uint64Value(value uint64) Value { return Value{Type: TypeUint64, uintVal: value} } -// Uint32Val creates a Value of TypeUint32. -func Uint32Val(v uint32) Value { return Value{Type: TypeUint32, uintVal: uint64(v)} } +// value := Uint32Value(42) +func Uint32Value(value uint32) Value { return Value{Type: TypeUint32, uintVal: uint64(value)} } -// Uint16Val creates a Value of TypeUint16. -func Uint16Val(v uint16) Value { return Value{Type: TypeUint16, uintVal: uint64(v)} } +// value := Uint16Value(42) +func Uint16Value(value uint16) Value { return Value{Type: TypeUint16, uintVal: uint64(value)} } -// Uint8Val creates a Value of TypeUint8. -func Uint8Val(v uint8) Value { return Value{Type: TypeUint8, uintVal: uint64(v)} } +// value := Uint8Value(42) +func Uint8Value(value uint8) Value { return Value{Type: TypeUint8, uintVal: uint64(value)} } -// Int64Val creates a Value of TypeInt64. -func Int64Val(v int64) Value { return Value{Type: TypeInt64, intVal: v} } +// value := Int64Value(42) +func Int64Value(value int64) Value { return Value{Type: TypeInt64, intVal: value} } -// Int32Val creates a Value of TypeInt32. -func Int32Val(v int32) Value { return Value{Type: TypeInt32, intVal: int64(v)} } +// value := Int32Value(42) +func Int32Value(value int32) Value { return Value{Type: TypeInt32, intVal: int64(value)} } -// Int16Val creates a Value of TypeInt16. -func Int16Val(v int16) Value { return Value{Type: TypeInt16, intVal: int64(v)} } +// value := Int16Value(42) +func Int16Value(value int16) Value { return Value{Type: TypeInt16, intVal: int64(value)} } -// Int8Val creates a Value of TypeInt8. -func Int8Val(v int8) Value { return Value{Type: TypeInt8, intVal: int64(v)} } +// value := Int8Value(42) +func Int8Value(value int8) Value { return Value{Type: TypeInt8, intVal: int64(value)} } -// BoolVal creates a Value of TypeBool. -func BoolVal(v bool) Value { return Value{Type: TypeBool, boolVal: v} } +// value := BoolValue(true) +func BoolValue(value bool) Value { return Value{Type: TypeBool, boolVal: value} } -// DoubleVal creates a Value of TypeDouble. -func DoubleVal(v float64) Value { return Value{Type: TypeDouble, floatVal: v} } +// value := DoubleValue(3.14) +func DoubleValue(value float64) Value { return Value{Type: TypeDouble, floatVal: value} } -// StringVal creates a Value of TypeString. The slice is not copied. -func StringVal(v []byte) Value { return Value{Type: TypeString, bytesVal: v} } +// value := StringValue([]byte("hello")) +func StringValue(value []byte) Value { return Value{Type: TypeString, bytesVal: value} } -// ObjectVal creates a Value of TypeObject wrapping a nested Section. -func ObjectVal(s Section) Value { return Value{Type: TypeObject, objectVal: s} } +// value := ObjectValue(Section{"id": StringValue([]byte("peer-1"))}) +func ObjectValue(section Section) Value { return Value{Type: TypeObject, objectVal: section} } // --------------------------------------------------------------------------- // Array constructors // --------------------------------------------------------------------------- -// Uint64ArrayVal creates a typed array of uint64 values. -func Uint64ArrayVal(vs []uint64) Value { - return Value{Type: ArrayFlag | TypeUint64, uint64Array: vs} +// value := Uint64ArrayValue([]uint64{1, 2, 3}) +func Uint64ArrayValue(values []uint64) Value { + return Value{Type: ArrayFlag | TypeUint64, uint64Array: values} } -// Uint32ArrayVal creates a typed array of uint32 values. -func Uint32ArrayVal(vs []uint32) Value { - return Value{Type: ArrayFlag | TypeUint32, uint32Array: vs} +// value := Uint32ArrayValue([]uint32{1, 2, 3}) +func Uint32ArrayValue(values []uint32) Value { + return Value{Type: ArrayFlag | TypeUint32, uint32Array: values} } -// StringArrayVal creates a typed array of byte-string values. -func StringArrayVal(vs [][]byte) Value { - return Value{Type: ArrayFlag | TypeString, stringArray: vs} +// value := StringArrayValue([][]byte{[]byte("a"), []byte("b")}) +func StringArrayValue(values [][]byte) Value { + return Value{Type: ArrayFlag | TypeString, stringArray: values} } -// ObjectArrayVal creates a typed array of Section values. -func ObjectArrayVal(vs []Section) Value { - return Value{Type: ArrayFlag | TypeObject, objectArray: vs} +// value := ObjectArrayValue([]Section{{"id": StringValue([]byte("peer-1"))}}) +func ObjectArrayValue(values []Section) Value { + return Value{Type: ArrayFlag | TypeObject, objectArray: values} } // --------------------------------------------------------------------------- // Scalar accessors // --------------------------------------------------------------------------- -// AsUint64 returns the uint64 value or an error on type mismatch. +// value, err := Uint64Value(42).AsUint64() func (v Value) AsUint64() (uint64, error) { if v.Type != TypeUint64 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return v.uintVal, nil } -// AsUint32 returns the uint32 value or an error on type mismatch. +// value, err := Uint32Value(42).AsUint32() func (v Value) AsUint32() (uint32, error) { if v.Type != TypeUint32 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return uint32(v.uintVal), nil } -// AsUint16 returns the uint16 value or an error on type mismatch. +// value, err := Uint16Value(42).AsUint16() func (v Value) AsUint16() (uint16, error) { if v.Type != TypeUint16 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return uint16(v.uintVal), nil } -// AsUint8 returns the uint8 value or an error on type mismatch. +// value, err := Uint8Value(42).AsUint8() func (v Value) AsUint8() (uint8, error) { if v.Type != TypeUint8 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return uint8(v.uintVal), nil } -// AsInt64 returns the int64 value or an error on type mismatch. +// value, err := Int64Value(42).AsInt64() func (v Value) AsInt64() (int64, error) { if v.Type != TypeInt64 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return v.intVal, nil } -// AsInt32 returns the int32 value or an error on type mismatch. +// value, err := Int32Value(42).AsInt32() func (v Value) AsInt32() (int32, error) { if v.Type != TypeInt32 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return int32(v.intVal), nil } -// AsInt16 returns the int16 value or an error on type mismatch. +// value, err := Int16Value(42).AsInt16() func (v Value) AsInt16() (int16, error) { if v.Type != TypeInt16 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return int16(v.intVal), nil } -// AsInt8 returns the int8 value or an error on type mismatch. +// value, err := Int8Value(42).AsInt8() func (v Value) AsInt8() (int8, error) { if v.Type != TypeInt8 { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return int8(v.intVal), nil } -// AsBool returns the bool value or an error on type mismatch. +// value, err := BoolValue(true).AsBool() func (v Value) AsBool() (bool, error) { if v.Type != TypeBool { - return false, ErrStorageTypeMismatch + return false, ErrorStorageTypeMismatch } return v.boolVal, nil } -// AsDouble returns the float64 value or an error on type mismatch. +// value, err := DoubleValue(3.14).AsDouble() func (v Value) AsDouble() (float64, error) { if v.Type != TypeDouble { - return 0, ErrStorageTypeMismatch + return 0, ErrorStorageTypeMismatch } return v.floatVal, nil } -// AsString returns the byte-string value or an error on type mismatch. +// value, err := StringValue([]byte("hello")).AsString() func (v Value) AsString() ([]byte, error) { if v.Type != TypeString { - return nil, ErrStorageTypeMismatch + return nil, ErrorStorageTypeMismatch } return v.bytesVal, nil } -// AsSection returns the nested Section or an error on type mismatch. +// section, err := ObjectValue(Section{"id": StringValue([]byte("peer-1"))}).AsSection() func (v Value) AsSection() (Section, error) { if v.Type != TypeObject { - return nil, ErrStorageTypeMismatch + return nil, ErrorStorageTypeMismatch } return v.objectVal, nil } @@ -241,34 +238,34 @@ func (v Value) AsSection() (Section, error) { // Array accessors // --------------------------------------------------------------------------- -// AsUint64Array returns the []uint64 array or an error on type mismatch. +// values, err := Uint64ArrayValue([]uint64{1, 2, 3}).AsUint64Array() func (v Value) AsUint64Array() ([]uint64, error) { if v.Type != (ArrayFlag | TypeUint64) { - return nil, ErrStorageTypeMismatch + return nil, ErrorStorageTypeMismatch } return v.uint64Array, nil } -// AsUint32Array returns the []uint32 array or an error on type mismatch. +// values, err := Uint32ArrayValue([]uint32{1, 2, 3}).AsUint32Array() func (v Value) AsUint32Array() ([]uint32, error) { if v.Type != (ArrayFlag | TypeUint32) { - return nil, ErrStorageTypeMismatch + return nil, ErrorStorageTypeMismatch } return v.uint32Array, nil } -// AsStringArray returns the [][]byte array or an error on type mismatch. +// values, err := StringArrayValue([][]byte{[]byte("a"), []byte("b")}).AsStringArray() func (v Value) AsStringArray() ([][]byte, error) { if v.Type != (ArrayFlag | TypeString) { - return nil, ErrStorageTypeMismatch + return nil, ErrorStorageTypeMismatch } return v.stringArray, nil } -// AsSectionArray returns the []Section array or an error on type mismatch. +// values, err := ObjectArrayValue([]Section{{"id": StringValue([]byte("peer-1"))}}).AsSectionArray() func (v Value) AsSectionArray() ([]Section, error) { if v.Type != (ArrayFlag | TypeObject) { - return nil, ErrStorageTypeMismatch + return nil, ErrorStorageTypeMismatch } return v.objectArray, nil } @@ -277,28 +274,26 @@ func (v Value) AsSectionArray() ([]Section, error) { // Encoder // --------------------------------------------------------------------------- -// EncodeStorage serialises a Section to the portable storage binary format, -// including the 9-byte header. Keys are sorted alphabetically to ensure -// deterministic output. -func EncodeStorage(s Section) ([]byte, error) { - buf := make([]byte, 0, 256) +// data, err := EncodeStorage(section) +func EncodeStorage(section Section) ([]byte, error) { + buffer := make([]byte, 0, 256) // 9-byte storage header. var hdr [StorageHeaderSize]byte binary.LittleEndian.PutUint32(hdr[0:4], StorageSignatureA) binary.LittleEndian.PutUint32(hdr[4:8], StorageSignatureB) hdr[8] = StorageVersion - buf = append(buf, hdr[:]...) + buffer = append(buffer, hdr[:]...) // Encode root section. - out, err := encodeSection(buf, s) + out, err := encodeSection(buffer, section) if err != nil { return nil, err } return out, nil } -// encodeSection appends a section (entry count + entries) to buf. +// encodeSection appends a section (entry count + entries) to buffer. func encodeSection(buf []byte, s Section) ([]byte, error) { // Sort keys for deterministic output. keys := slices.Sorted(maps.Keys(s)) @@ -311,7 +306,7 @@ func encodeSection(buf []byte, s Section) ([]byte, error) { // Name: uint8 length + raw bytes. if len(name) > 255 { - return nil, ErrStorageNameTooLong + return nil, ErrorStorageNameTooLong } buf = append(buf, byte(len(name))) buf = append(buf, name...) @@ -394,7 +389,7 @@ func encodeValue(buf []byte, v Value) ([]byte, error) { return encodeSection(buf, v.objectVal) default: - return nil, coreerr.E("levin.encodeValue", fmt.Sprintf("unknown type tag: 0x%02x", v.Type), ErrStorageUnknownType) + return nil, core.E("levin.encodeValue", core.Sprintf("unknown type tag: 0x%02x", v.Type), ErrorStorageUnknownType) } } @@ -441,7 +436,7 @@ func encodeArray(buf []byte, v Value) ([]byte, error) { return buf, nil default: - return nil, coreerr.E("levin.encodeArray", fmt.Sprintf("unknown type tag: array of 0x%02x", elemType), ErrStorageUnknownType) + return nil, core.E("levin.encodeArray", core.Sprintf("unknown type tag: array of 0x%02x", elemType), ErrorStorageUnknownType) } } @@ -449,11 +444,10 @@ func encodeArray(buf []byte, v Value) ([]byte, error) { // Decoder // --------------------------------------------------------------------------- -// DecodeStorage deserialises portable storage binary data (including the -// 9-byte header) into a Section. +// section, err := DecodeStorage(data) func DecodeStorage(data []byte) (Section, error) { if len(data) < StorageHeaderSize { - return nil, ErrStorageTruncated + return nil, ErrorStorageTruncated } sigA := binary.LittleEndian.Uint32(data[0:4]) @@ -461,22 +455,22 @@ func DecodeStorage(data []byte) (Section, error) { ver := data[8] if sigA != StorageSignatureA || sigB != StorageSignatureB { - return nil, ErrStorageBadSignature + return nil, ErrorStorageBadSignature } if ver != StorageVersion { - return nil, ErrStorageBadVersion + return nil, ErrorStorageBadVersion } s, _, err := decodeSection(data[StorageHeaderSize:]) return s, err } -// decodeSection reads a section from buf and returns the section plus +// decodeSection reads a section from buffer and returns the section plus // the number of bytes consumed. func decodeSection(buf []byte) (Section, int, error) { count, n, err := UnpackVarint(buf) if err != nil { - return nil, 0, coreerr.E("levin.decodeSection", "section entry count", err) + return nil, 0, core.E("levin.decodeSection", "section entry count", err) } off := n @@ -485,21 +479,21 @@ func decodeSection(buf []byte) (Section, int, error) { for range count { // Name length (1 byte). if off >= len(buf) { - return nil, 0, ErrStorageTruncated + return nil, 0, ErrorStorageTruncated } nameLen := int(buf[off]) off++ // Name bytes. if off+nameLen > len(buf) { - return nil, 0, ErrStorageTruncated + return nil, 0, ErrorStorageTruncated } name := string(buf[off : off+nameLen]) off += nameLen // Type tag (1 byte). if off >= len(buf) { - return nil, 0, ErrStorageTruncated + return nil, 0, ErrorStorageTruncated } tag := buf[off] off++ @@ -507,7 +501,7 @@ func decodeSection(buf []byte) (Section, int, error) { // Value. val, consumed, err := decodeValue(buf[off:], tag) if err != nil { - return nil, 0, coreerr.E("levin.decodeSection", "field "+name, err) + return nil, 0, core.E("levin.decodeSection", "field "+name, err) } off += consumed @@ -517,7 +511,7 @@ func decodeSection(buf []byte) (Section, int, error) { return s, off, nil } -// decodeValue reads a value of the given type tag from buf and returns +// decodeValue reads a value of the given type tag from buffer and returns // the value plus bytes consumed. func decodeValue(buf []byte, tag uint8) (Value, int, error) { // Array types. @@ -528,68 +522,68 @@ func decodeValue(buf []byte, tag uint8) (Value, int, error) { switch tag { case TypeUint64: if len(buf) < 8 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } v := binary.LittleEndian.Uint64(buf[:8]) return Value{Type: TypeUint64, uintVal: v}, 8, nil case TypeInt64: if len(buf) < 8 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } v := int64(binary.LittleEndian.Uint64(buf[:8])) return Value{Type: TypeInt64, intVal: v}, 8, nil case TypeDouble: if len(buf) < 8 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } bits := binary.LittleEndian.Uint64(buf[:8]) return Value{Type: TypeDouble, floatVal: math.Float64frombits(bits)}, 8, nil case TypeUint32: if len(buf) < 4 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } v := binary.LittleEndian.Uint32(buf[:4]) return Value{Type: TypeUint32, uintVal: uint64(v)}, 4, nil case TypeInt32: if len(buf) < 4 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } v := int32(binary.LittleEndian.Uint32(buf[:4])) return Value{Type: TypeInt32, intVal: int64(v)}, 4, nil case TypeUint16: if len(buf) < 2 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } v := binary.LittleEndian.Uint16(buf[:2]) return Value{Type: TypeUint16, uintVal: uint64(v)}, 2, nil case TypeInt16: if len(buf) < 2 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } v := int16(binary.LittleEndian.Uint16(buf[:2])) return Value{Type: TypeInt16, intVal: int64(v)}, 2, nil case TypeUint8: if len(buf) < 1 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } return Value{Type: TypeUint8, uintVal: uint64(buf[0])}, 1, nil case TypeInt8: if len(buf) < 1 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } return Value{Type: TypeInt8, intVal: int64(int8(buf[0]))}, 1, nil case TypeBool: if len(buf) < 1 { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } return Value{Type: TypeBool, boolVal: buf[0] != 0}, 1, nil @@ -599,7 +593,7 @@ func decodeValue(buf []byte, tag uint8) (Value, int, error) { return Value{}, 0, err } if uint64(len(buf)-n) < strLen { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } data := make([]byte, strLen) copy(data, buf[n:n+int(strLen)]) @@ -613,11 +607,11 @@ func decodeValue(buf []byte, tag uint8) (Value, int, error) { return Value{Type: TypeObject, objectVal: sec}, consumed, nil default: - return Value{}, 0, coreerr.E("levin.decodeValue", fmt.Sprintf("unknown type tag: 0x%02x", tag), ErrStorageUnknownType) + return Value{}, 0, core.E("levin.decodeValue", core.Sprintf("unknown type tag: 0x%02x", tag), ErrorStorageUnknownType) } } -// decodeArray reads a typed array from buf (tag has ArrayFlag set). +// decodeArray reads a typed array from buffer (tag has ArrayFlag set). func decodeArray(buf []byte, tag uint8) (Value, int, error) { elemType := tag & ^ArrayFlag @@ -632,7 +626,7 @@ func decodeArray(buf []byte, tag uint8) (Value, int, error) { arr := make([]uint64, count) for i := range count { if off+8 > len(buf) { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } arr[i] = binary.LittleEndian.Uint64(buf[off : off+8]) off += 8 @@ -643,7 +637,7 @@ func decodeArray(buf []byte, tag uint8) (Value, int, error) { arr := make([]uint32, count) for i := range count { if off+4 > len(buf) { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } arr[i] = binary.LittleEndian.Uint32(buf[off : off+4]) off += 4 @@ -659,7 +653,7 @@ func decodeArray(buf []byte, tag uint8) (Value, int, error) { } off += sn if uint64(len(buf)-off) < strLen { - return Value{}, 0, ErrStorageTruncated + return Value{}, 0, ErrorStorageTruncated } data := make([]byte, strLen) copy(data, buf[off:off+int(strLen)]) @@ -681,6 +675,6 @@ func decodeArray(buf []byte, tag uint8) (Value, int, error) { return Value{Type: tag, objectArray: arr}, off, nil default: - return Value{}, 0, coreerr.E("levin.decodeArray", fmt.Sprintf("unknown type tag: array of 0x%02x", elemType), ErrStorageUnknownType) + return Value{}, 0, core.E("levin.decodeArray", core.Sprintf("unknown type tag: array of 0x%02x", elemType), ErrorStorageUnknownType) } } diff --git a/node/levin/storage_test.go b/node/levin/storage_test.go index ae16c52..7118d48 100644 --- a/node/levin/storage_test.go +++ b/node/levin/storage_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestEncodeStorage_EmptySection(t *testing.T) { +func TestStorage_EncodeStorage_EmptySection_Ugly(t *testing.T) { s := Section{} data, err := EncodeStorage(s) require.NoError(t, err) @@ -35,19 +35,19 @@ func TestEncodeStorage_EmptySection(t *testing.T) { assert.Equal(t, byte(0x00), data[9]) } -func TestStorage_PrimitivesRoundTrip(t *testing.T) { +func TestStorage_PrimitivesRoundTrip_Ugly(t *testing.T) { s := Section{ - "u64": Uint64Val(0xDEADBEEFCAFEBABE), - "u32": Uint32Val(0xCAFEBABE), - "u16": Uint16Val(0xBEEF), - "u8": Uint8Val(42), - "i64": Int64Val(-9223372036854775808), - "i32": Int32Val(-2147483648), - "i16": Int16Val(-32768), - "i8": Int8Val(-128), - "flag": BoolVal(true), - "height": StringVal([]byte("hello world")), - "pi": DoubleVal(3.141592653589793), + "u64": Uint64Value(0xDEADBEEFCAFEBABE), + "u32": Uint32Value(0xCAFEBABE), + "u16": Uint16Value(0xBEEF), + "u8": Uint8Value(42), + "i64": Int64Value(-9223372036854775808), + "i32": Int32Value(-2147483648), + "i16": Int16Value(-32768), + "i8": Int8Value(-128), + "flag": BoolValue(true), + "height": StringValue([]byte("hello world")), + "pi": DoubleValue(3.141592653589793), } data, err := EncodeStorage(s) @@ -106,14 +106,14 @@ func TestStorage_PrimitivesRoundTrip(t *testing.T) { assert.Equal(t, 3.141592653589793, pi) } -func TestStorage_NestedObject(t *testing.T) { +func TestStorage_NestedObject_Good(t *testing.T) { inner := Section{ - "port": Uint16Val(18080), - "host": StringVal([]byte("127.0.0.1")), + "port": Uint16Value(18080), + "host": StringValue([]byte("127.0.0.1")), } outer := Section{ - "node_data": ObjectVal(inner), - "version": Uint32Val(1), + "node_data": ObjectValue(inner), + "version": Uint32Value(1), } data, err := EncodeStorage(outer) @@ -138,9 +138,9 @@ func TestStorage_NestedObject(t *testing.T) { assert.Equal(t, []byte("127.0.0.1"), host) } -func TestStorage_Uint64Array(t *testing.T) { +func TestStorage_Uint64Array_Good(t *testing.T) { s := Section{ - "heights": Uint64ArrayVal([]uint64{10, 20, 30}), + "heights": Uint64ArrayValue([]uint64{10, 20, 30}), } data, err := EncodeStorage(s) @@ -154,9 +154,9 @@ func TestStorage_Uint64Array(t *testing.T) { assert.Equal(t, []uint64{10, 20, 30}, arr) } -func TestStorage_StringArray(t *testing.T) { +func TestStorage_StringArray_Good(t *testing.T) { s := Section{ - "peers": StringArrayVal([][]byte{[]byte("foo"), []byte("bar")}), + "peers": StringArrayValue([][]byte{[]byte("foo"), []byte("bar")}), } data, err := EncodeStorage(s) @@ -172,13 +172,13 @@ func TestStorage_StringArray(t *testing.T) { assert.Equal(t, []byte("bar"), arr[1]) } -func TestStorage_ObjectArray(t *testing.T) { +func TestStorage_ObjectArray_Good(t *testing.T) { sections := []Section{ - {"id": Uint32Val(1), "name": StringVal([]byte("alice"))}, - {"id": Uint32Val(2), "name": StringVal([]byte("bob"))}, + {"id": Uint32Value(1), "name": StringValue([]byte("alice"))}, + {"id": Uint32Value(2), "name": StringValue([]byte("bob"))}, } s := Section{ - "nodes": ObjectArrayVal(sections), + "nodes": ObjectArrayValue(sections), } data, err := EncodeStorage(s) @@ -208,30 +208,30 @@ func TestStorage_ObjectArray(t *testing.T) { assert.Equal(t, []byte("bob"), name2) } -func TestDecodeStorage_BadSignature(t *testing.T) { +func TestStorage_DecodeStorage_BadSignature_Bad(t *testing.T) { // Corrupt the first 4 bytes. data := []byte{0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x01, 0x02, 0x01, 0x01, 0x00} _, err := DecodeStorage(data) require.Error(t, err) - assert.ErrorIs(t, err, ErrStorageBadSignature) + assert.ErrorIs(t, err, ErrorStorageBadSignature) } -func TestDecodeStorage_TooShort(t *testing.T) { +func TestStorage_DecodeStorage_TooShort_Bad(t *testing.T) { _, err := DecodeStorage([]byte{0x01, 0x11}) require.Error(t, err) - assert.ErrorIs(t, err, ErrStorageTruncated) + assert.ErrorIs(t, err, ErrorStorageTruncated) } -func TestStorage_ByteIdenticalReencode(t *testing.T) { +func TestStorage_ByteIdenticalReencode_Ugly(t *testing.T) { s := Section{ - "alpha": Uint64Val(999), - "bravo": StringVal([]byte("deterministic")), - "charlie": BoolVal(false), - "delta": ObjectVal(Section{ - "x": Int32Val(-42), - "y": Int32Val(100), + "alpha": Uint64Value(999), + "bravo": StringValue([]byte("deterministic")), + "charlie": BoolValue(false), + "delta": ObjectValue(Section{ + "x": Int32Value(-42), + "y": Int32Value(100), }), - "echo": Uint64ArrayVal([]uint64{1, 2, 3}), + "echo": Uint64ArrayValue([]uint64{1, 2, 3}), } data1, err := EncodeStorage(s) @@ -246,28 +246,28 @@ func TestStorage_ByteIdenticalReencode(t *testing.T) { assert.Equal(t, data1, data2, "re-encoded bytes must be identical") } -func TestStorage_TypeMismatchErrors(t *testing.T) { - v := Uint64Val(42) +func TestStorage_TypeMismatchErrors_Bad(t *testing.T) { + v := Uint64Value(42) _, err := v.AsUint32() - assert.ErrorIs(t, err, ErrStorageTypeMismatch) + assert.ErrorIs(t, err, ErrorStorageTypeMismatch) _, err = v.AsString() - assert.ErrorIs(t, err, ErrStorageTypeMismatch) + assert.ErrorIs(t, err, ErrorStorageTypeMismatch) _, err = v.AsBool() - assert.ErrorIs(t, err, ErrStorageTypeMismatch) + assert.ErrorIs(t, err, ErrorStorageTypeMismatch) _, err = v.AsSection() - assert.ErrorIs(t, err, ErrStorageTypeMismatch) + assert.ErrorIs(t, err, ErrorStorageTypeMismatch) _, err = v.AsUint64Array() - assert.ErrorIs(t, err, ErrStorageTypeMismatch) + assert.ErrorIs(t, err, ErrorStorageTypeMismatch) } -func TestStorage_Uint32Array(t *testing.T) { +func TestStorage_Uint32Array_Good(t *testing.T) { s := Section{ - "ports": Uint32ArrayVal([]uint32{8080, 8443, 9090}), + "ports": Uint32ArrayValue([]uint32{8080, 8443, 9090}), } data, err := EncodeStorage(s) @@ -281,19 +281,19 @@ func TestStorage_Uint32Array(t *testing.T) { assert.Equal(t, []uint32{8080, 8443, 9090}, arr) } -func TestDecodeStorage_BadVersion(t *testing.T) { +func TestStorage_DecodeStorage_BadVersion_Bad(t *testing.T) { // Valid signatures but version 2 instead of 1. data := []byte{0x01, 0x11, 0x01, 0x01, 0x01, 0x01, 0x02, 0x01, 0x02, 0x00} _, err := DecodeStorage(data) require.Error(t, err) - assert.ErrorIs(t, err, ErrStorageBadVersion) + assert.ErrorIs(t, err, ErrorStorageBadVersion) } -func TestStorage_EmptyArrays(t *testing.T) { +func TestStorage_EmptyArrays_Ugly(t *testing.T) { s := Section{ - "empty_u64": Uint64ArrayVal([]uint64{}), - "empty_str": StringArrayVal([][]byte{}), - "empty_obj": ObjectArrayVal([]Section{}), + "empty_u64": Uint64ArrayValue([]uint64{}), + "empty_str": StringArrayValue([][]byte{}), + "empty_obj": ObjectArrayValue([]Section{}), } data, err := EncodeStorage(s) @@ -315,10 +315,10 @@ func TestStorage_EmptyArrays(t *testing.T) { assert.Empty(t, objarr) } -func TestStorage_BoolFalseRoundTrip(t *testing.T) { +func TestStorage_BoolFalseRoundTrip_Ugly(t *testing.T) { s := Section{ - "off": BoolVal(false), - "on": BoolVal(true), + "off": BoolValue(false), + "on": BoolValue(true), } data, err := EncodeStorage(s) diff --git a/node/levin/varint.go b/node/levin/varint.go index 2830e71..b245833 100644 --- a/node/levin/varint.go +++ b/node/levin/varint.go @@ -6,89 +6,86 @@ package levin import ( "encoding/binary" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" ) // Size-mark bits occupying the two lowest bits of the first byte. const ( - varintMask = 0x03 - varintMark1 = 0x00 // 1 byte, max 63 - varintMark2 = 0x01 // 2 bytes, max 16,383 - varintMark4 = 0x02 // 4 bytes, max 1,073,741,823 - varintMark8 = 0x03 // 8 bytes, max 4,611,686,018,427,387,903 - varintMax1 = 63 - varintMax2 = 16_383 - varintMax4 = 1_073_741_823 - varintMax8 = 4_611_686_018_427_387_903 + varintMask = 0x03 + varintMark1 = 0x00 // 1 byte, max 63 + varintMark2 = 0x01 // 2 bytes, max 16,383 + varintMark4 = 0x02 // 4 bytes, max 1,073,741,823 + varintMark8 = 0x03 // 8 bytes, max 4,611,686,018,427,387,903 + varintMax1 = 63 + varintMax2 = 16_383 + varintMax4 = 1_073_741_823 + varintMax8 = 4_611_686_018_427_387_903 ) -// ErrVarintTruncated is returned when the buffer is too short. -var ErrVarintTruncated = coreerr.E("levin", "truncated varint", nil) +// ErrorVarintTruncated is returned when the buffer is too short. +var ErrorVarintTruncated = core.E("levin", "truncated varint", nil) -// ErrVarintOverflow is returned when the value is too large to encode. -var ErrVarintOverflow = coreerr.E("levin", "varint overflow", nil) +// ErrorVarintOverflow is returned when the value is too large to encode. +var ErrorVarintOverflow = core.E("levin", "varint overflow", nil) -// PackVarint encodes v using the epee portable-storage varint scheme. -// The low two bits of the first byte indicate the total encoded width; -// the remaining bits carry the value in little-endian order. -func PackVarint(v uint64) []byte { +// encoded := PackVarint(42) +func PackVarint(value uint64) []byte { switch { - case v <= varintMax1: - return []byte{byte((v << 2) | varintMark1)} - case v <= varintMax2: - raw := uint16((v << 2) | varintMark2) - buf := make([]byte, 2) - binary.LittleEndian.PutUint16(buf, raw) - return buf - case v <= varintMax4: - raw := uint32((v << 2) | varintMark4) - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, raw) - return buf + case value <= varintMax1: + return []byte{byte((value << 2) | varintMark1)} + case value <= varintMax2: + raw := uint16((value << 2) | varintMark2) + buffer := make([]byte, 2) + binary.LittleEndian.PutUint16(buffer, raw) + return buffer + case value <= varintMax4: + raw := uint32((value << 2) | varintMark4) + buffer := make([]byte, 4) + binary.LittleEndian.PutUint32(buffer, raw) + return buffer default: - raw := (v << 2) | varintMark8 - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, raw) - return buf + raw := (value << 2) | varintMark8 + buffer := make([]byte, 8) + binary.LittleEndian.PutUint64(buffer, raw) + return buffer } } -// UnpackVarint decodes one epee portable-storage varint from buf. -// It returns the decoded value, the number of bytes consumed, and any error. -func UnpackVarint(buf []byte) (value uint64, bytesConsumed int, err error) { - if len(buf) == 0 { - return 0, 0, ErrVarintTruncated +// value, err := UnpackVarint(buffer) +func UnpackVarint(buffer []byte) (value uint64, bytesConsumed int, err error) { + if len(buffer) == 0 { + return 0, 0, ErrorVarintTruncated } - mark := buf[0] & varintMask + mark := buffer[0] & varintMask switch mark { case varintMark1: - value = uint64(buf[0]) >> 2 + value = uint64(buffer[0]) >> 2 return value, 1, nil case varintMark2: - if len(buf) < 2 { - return 0, 0, ErrVarintTruncated + if len(buffer) < 2 { + return 0, 0, ErrorVarintTruncated } - raw := binary.LittleEndian.Uint16(buf[:2]) + raw := binary.LittleEndian.Uint16(buffer[:2]) value = uint64(raw) >> 2 return value, 2, nil case varintMark4: - if len(buf) < 4 { - return 0, 0, ErrVarintTruncated + if len(buffer) < 4 { + return 0, 0, ErrorVarintTruncated } - raw := binary.LittleEndian.Uint32(buf[:4]) + raw := binary.LittleEndian.Uint32(buffer[:4]) value = uint64(raw) >> 2 return value, 4, nil case varintMark8: - if len(buf) < 8 { - return 0, 0, ErrVarintTruncated + if len(buffer) < 8 { + return 0, 0, ErrorVarintTruncated } - raw := binary.LittleEndian.Uint64(buf[:8]) + raw := binary.LittleEndian.Uint64(buffer[:8]) value = raw >> 2 return value, 8, nil default: // Unreachable — mark is masked to 2 bits. - return 0, 0, ErrVarintTruncated + return 0, 0, ErrorVarintTruncated } } diff --git a/node/levin/varint_test.go b/node/levin/varint_test.go index 2082864..d2088eb 100644 --- a/node/levin/varint_test.go +++ b/node/levin/varint_test.go @@ -10,41 +10,41 @@ import ( "github.com/stretchr/testify/require" ) -func TestPackVarint_Value5(t *testing.T) { +func TestVarint_PackVarint_Value5_Good(t *testing.T) { // 5 << 2 | 0x00 = 20 = 0x14 got := PackVarint(5) assert.Equal(t, []byte{0x14}, got) } -func TestPackVarint_Value100(t *testing.T) { +func TestVarint_PackVarint_Value100_Good(t *testing.T) { // 100 << 2 | 0x01 = 401 = 0x0191 → LE [0x91, 0x01] got := PackVarint(100) assert.Equal(t, []byte{0x91, 0x01}, got) } -func TestPackVarint_Value65536(t *testing.T) { +func TestVarint_PackVarint_Value65536_Good(t *testing.T) { // 65536 << 2 | 0x02 = 262146 = 0x00040002 → LE [0x02, 0x00, 0x04, 0x00] got := PackVarint(65536) assert.Equal(t, []byte{0x02, 0x00, 0x04, 0x00}, got) } -func TestPackVarint_Value2Billion(t *testing.T) { +func TestVarint_PackVarint_Value2Billion_Good(t *testing.T) { got := PackVarint(2_000_000_000) require.Len(t, got, 8) // Low 2 bits must be 0x03 (8-byte mark). assert.Equal(t, byte(0x03), got[0]&0x03) } -func TestPackVarint_Zero(t *testing.T) { +func TestVarint_PackVarint_Zero_Ugly(t *testing.T) { got := PackVarint(0) assert.Equal(t, []byte{0x00}, got) } -func TestPackVarint_Boundaries(t *testing.T) { +func TestVarint_PackVarint_Boundaries_Good(t *testing.T) { tests := []struct { - name string - value uint64 - wantLen int + name string + value uint64 + wantLen int }{ {"1-byte max (63)", 63, 1}, {"2-byte min (64)", 64, 2}, @@ -63,7 +63,7 @@ func TestPackVarint_Boundaries(t *testing.T) { } } -func TestVarint_RoundTrip(t *testing.T) { +func TestVarint_RoundTrip_Ugly(t *testing.T) { values := []uint64{ 0, 1, 63, 64, 100, 16_383, 16_384, 1_073_741_823, 1_073_741_824, @@ -79,38 +79,38 @@ func TestVarint_RoundTrip(t *testing.T) { } } -func TestUnpackVarint_EmptyInput(t *testing.T) { +func TestVarint_UnpackVarint_EmptyInput_Ugly(t *testing.T) { _, _, err := UnpackVarint([]byte{}) require.Error(t, err) - assert.ErrorIs(t, err, ErrVarintTruncated) + assert.ErrorIs(t, err, ErrorVarintTruncated) } -func TestUnpackVarint_Truncated2Byte(t *testing.T) { +func TestVarint_UnpackVarint_Truncated2Byte_Bad(t *testing.T) { // Encode 64 (needs 2 bytes), then only pass 1 byte. buf := PackVarint(64) require.Len(t, buf, 2) _, _, err := UnpackVarint(buf[:1]) require.Error(t, err) - assert.ErrorIs(t, err, ErrVarintTruncated) + assert.ErrorIs(t, err, ErrorVarintTruncated) } -func TestUnpackVarint_Truncated4Byte(t *testing.T) { +func TestVarint_UnpackVarint_Truncated4Byte_Bad(t *testing.T) { buf := PackVarint(16_384) require.Len(t, buf, 4) _, _, err := UnpackVarint(buf[:2]) require.Error(t, err) - assert.ErrorIs(t, err, ErrVarintTruncated) + assert.ErrorIs(t, err, ErrorVarintTruncated) } -func TestUnpackVarint_Truncated8Byte(t *testing.T) { +func TestVarint_UnpackVarint_Truncated8Byte_Bad(t *testing.T) { buf := PackVarint(1_073_741_824) require.Len(t, buf, 8) _, _, err := UnpackVarint(buf[:4]) require.Error(t, err) - assert.ErrorIs(t, err, ErrVarintTruncated) + assert.ErrorIs(t, err, ErrorVarintTruncated) } -func TestUnpackVarint_ExtraBytes(t *testing.T) { +func TestVarint_UnpackVarint_ExtraBytes_Good(t *testing.T) { // Ensure that extra trailing bytes are not consumed. buf := append(PackVarint(42), 0xFF, 0xFF) decoded, consumed, err := UnpackVarint(buf) @@ -119,7 +119,7 @@ func TestUnpackVarint_ExtraBytes(t *testing.T) { assert.Equal(t, 1, consumed) } -func TestPackVarint_SizeMarkBits(t *testing.T) { +func TestVarint_PackVarint_SizeMarkBits_Good(t *testing.T) { tests := []struct { name string value uint64 diff --git a/node/message.go b/node/message.go index d4b2b6e..7fc85e4 100644 --- a/node/message.go +++ b/node/message.go @@ -1,85 +1,106 @@ package node import ( - "encoding/json" "slices" "time" + core "dappco.re/go/core" "github.com/google/uuid" ) -// Protocol version constants const ( - // ProtocolVersion is the current protocol version + // version := ProtocolVersion ProtocolVersion = "1.0" - // MinProtocolVersion is the minimum supported version + // minimumVersion := MinProtocolVersion MinProtocolVersion = "1.0" ) -// SupportedProtocolVersions lists all protocol versions this node supports. -// Used for version negotiation during handshake. +// versions := SupportedProtocolVersions var SupportedProtocolVersions = []string{"1.0"} -// IsProtocolVersionSupported checks if a given version is supported. +// payload := RawMessage(`{"pool":"pool.example.com:3333"}`) +type RawMessage []byte + +// data, err := RawMessage(`{"ok":true}`).MarshalJSON() +func (m RawMessage) MarshalJSON() ([]byte, error) { + if m == nil { + return []byte("null"), nil + } + + return m, nil +} + +// var payload RawMessage +// _ = payload.UnmarshalJSON([]byte(`{"ok":true}`)) +func (m *RawMessage) UnmarshalJSON(data []byte) error { + if m == nil { + return core.E("node.RawMessage.UnmarshalJSON", "raw message target is nil", nil) + } + + *m = append((*m)[:0], data...) + return nil +} + +// ok := IsProtocolVersionSupported("1.0") func IsProtocolVersionSupported(version string) bool { return slices.Contains(SupportedProtocolVersions, version) } -// MessageType defines the type of P2P message. +// messageType := MessagePing type MessageType string const ( // Connection lifecycle - MsgHandshake MessageType = "handshake" - MsgHandshakeAck MessageType = "handshake_ack" - MsgPing MessageType = "ping" - MsgPong MessageType = "pong" - MsgDisconnect MessageType = "disconnect" + MessageHandshake MessageType = "handshake" + MessageHandshakeAck MessageType = "handshake_ack" + MessagePing MessageType = "ping" + MessagePong MessageType = "pong" + MessageDisconnect MessageType = "disconnect" // Miner operations - MsgGetStats MessageType = "get_stats" - MsgStats MessageType = "stats" - MsgStartMiner MessageType = "start_miner" - MsgStopMiner MessageType = "stop_miner" - MsgMinerAck MessageType = "miner_ack" + MessageGetStats MessageType = "get_stats" + MessageStats MessageType = "stats" + MessageStartMiner MessageType = "start_miner" + MessageStopMiner MessageType = "stop_miner" + MessageMinerAck MessageType = "miner_ack" // Deployment - MsgDeploy MessageType = "deploy" - MsgDeployAck MessageType = "deploy_ack" + MessageDeploy MessageType = "deploy" + MessageDeployAck MessageType = "deploy_ack" // Logs - MsgGetLogs MessageType = "get_logs" - MsgLogs MessageType = "logs" + MessageGetLogs MessageType = "get_logs" + MessageLogs MessageType = "logs" // Error response - MsgError MessageType = "error" + MessageError MessageType = "error" ) -// Message represents a P2P message between nodes. +// message, err := NewMessage(MessagePing, "controller", "worker", PingPayload{SentAt: time.Now().UnixMilli()}) type Message struct { - ID string `json:"id"` // UUID - Type MessageType `json:"type"` - From string `json:"from"` // Sender node ID - To string `json:"to"` // Recipient node ID (empty for broadcast) - Timestamp time.Time `json:"ts"` - Payload json.RawMessage `json:"payload"` - ReplyTo string `json:"replyTo,omitempty"` // ID of message being replied to + ID string `json:"id"` // UUID + Type MessageType `json:"type"` + From string `json:"from"` // Sender node ID + To string `json:"to"` // Recipient node ID (empty for broadcast) + Timestamp time.Time `json:"ts"` + Payload RawMessage `json:"payload"` + ReplyTo string `json:"replyTo,omitempty"` // ID of message being replied to } -// NewMessage creates a new message with a generated ID and timestamp. -func NewMessage(msgType MessageType, from, to string, payload any) (*Message, error) { - var payloadBytes json.RawMessage +// message, err := NewMessage(MessagePing, "controller", "worker-1", PingPayload{SentAt: 42}) +func NewMessage(messageType MessageType, from, to string, payload any) (*Message, error) { + var payloadBytes RawMessage if payload != nil { data, err := MarshalJSON(payload) if err != nil { return nil, err } - payloadBytes = data + payloadBytes = RawMessage(data) } return &Message{ ID: uuid.New().String(), - Type: msgType, + Type: messageType, From: from, To: to, Timestamp: time.Now(), @@ -87,9 +108,9 @@ func NewMessage(msgType MessageType, from, to string, payload any) (*Message, er }, nil } -// Reply creates a reply message to this message. -func (m *Message) Reply(msgType MessageType, payload any) (*Message, error) { - reply, err := NewMessage(msgType, m.To, m.From, payload) +// reply, err := message.Reply(MessagePong, PongPayload{SentAt: 42, ReceivedAt: 43}) +func (m *Message) Reply(messageType MessageType, payload any) (*Message, error) { + reply, err := NewMessage(messageType, m.To, m.From, payload) if err != nil { return nil, err } @@ -97,24 +118,29 @@ func (m *Message) Reply(msgType MessageType, payload any) (*Message, error) { return reply, nil } -// ParsePayload unmarshals the payload into the given struct. -func (m *Message) ParsePayload(v any) error { +// var ping PingPayload +// err := message.ParsePayload(&ping) +func (m *Message) ParsePayload(target any) error { if m.Payload == nil { return nil } - return json.Unmarshal(m.Payload, v) + result := core.JSONUnmarshal(m.Payload, target) + if !result.OK { + return result.Value.(error) + } + return nil } // --- Payload Types --- -// HandshakePayload is sent during connection establishment. +// payload := HandshakePayload{Identity: NodeIdentity{Name: "worker-1"}, Version: ProtocolVersion} type HandshakePayload struct { Identity NodeIdentity `json:"identity"` Challenge []byte `json:"challenge,omitempty"` // Random bytes for auth Version string `json:"version"` // Protocol version } -// HandshakeAckPayload is the response to a handshake. +// ack := HandshakeAckPayload{Accepted: true} type HandshakeAckPayload struct { Identity NodeIdentity `json:"identity"` ChallengeResponse []byte `json:"challengeResponse,omitempty"` @@ -122,37 +148,37 @@ type HandshakeAckPayload struct { Reason string `json:"reason,omitempty"` // If not accepted } -// PingPayload for keepalive/latency measurement. +// payload := PingPayload{SentAt: 42} type PingPayload struct { SentAt int64 `json:"sentAt"` // Unix timestamp in milliseconds } -// PongPayload response to ping. +// payload := PongPayload{SentAt: 42, ReceivedAt: 43} type PongPayload struct { SentAt int64 `json:"sentAt"` // Echo of ping's sentAt ReceivedAt int64 `json:"receivedAt"` // When ping was received } -// StartMinerPayload requests starting a miner. +// payload := StartMinerPayload{MinerType: "xmrig"} type StartMinerPayload struct { - MinerType string `json:"minerType"` // Required: miner type (e.g., "xmrig", "tt-miner") - ProfileID string `json:"profileId,omitempty"` - Config json.RawMessage `json:"config,omitempty"` // Override profile config + MinerType string `json:"minerType"` // Required: miner type (e.g., "xmrig", "tt-miner") + ProfileID string `json:"profileId,omitempty"` + Config RawMessage `json:"config,omitempty"` // Override profile config } -// StopMinerPayload requests stopping a miner. +// payload := StopMinerPayload{MinerName: "xmrig-0"} type StopMinerPayload struct { MinerName string `json:"minerName"` } -// MinerAckPayload acknowledges a miner start/stop operation. +// ack := MinerAckPayload{Success: true, MinerName: "xmrig-0"} type MinerAckPayload struct { Success bool `json:"success"` MinerName string `json:"minerName,omitempty"` Error string `json:"error,omitempty"` } -// MinerStatsItem represents stats for a single miner. +// miner := MinerStatsItem{Name: "xmrig-0", Hashrate: 1200} type MinerStatsItem struct { Name string `json:"name"` Type string `json:"type"` @@ -165,7 +191,7 @@ type MinerStatsItem struct { CPUThreads int `json:"cpuThreads,omitempty"` } -// StatsPayload contains miner statistics. +// stats := StatsPayload{NodeID: "worker-1"} type StatsPayload struct { NodeID string `json:"nodeId"` NodeName string `json:"nodeName"` @@ -173,21 +199,21 @@ type StatsPayload struct { Uptime int64 `json:"uptime"` // Node uptime in seconds } -// GetLogsPayload requests console logs from a miner. -type GetLogsPayload struct { +// payload := LogsRequestPayload{MinerName: "xmrig-0", Lines: 100} +type LogsRequestPayload struct { MinerName string `json:"minerName"` Lines int `json:"lines"` // Number of lines to fetch Since int64 `json:"since,omitempty"` // Unix timestamp, logs after this time } -// LogsPayload contains console log lines. +// payload := LogsPayload{MinerName: "xmrig-0", Lines: []string{"started"}} type LogsPayload struct { MinerName string `json:"minerName"` Lines []string `json:"lines"` HasMore bool `json:"hasMore"` // More logs available } -// DeployPayload contains a deployment bundle. +// payload := DeployPayload{Name: "xmrig", BundleType: string(BundleMiner)} type DeployPayload struct { BundleType string `json:"type"` // "profile" | "miner" | "full" Data []byte `json:"data"` // STIM-encrypted bundle @@ -195,39 +221,39 @@ type DeployPayload struct { Name string `json:"name"` // Profile or miner name } -// DeployAckPayload acknowledges a deployment. +// ack := DeployAckPayload{Success: true, Name: "xmrig"} type DeployAckPayload struct { Success bool `json:"success"` Name string `json:"name,omitempty"` Error string `json:"error,omitempty"` } -// ErrorPayload contains error information. +// payload := ErrorPayload{Code: ErrorCodeOperationFailed, Message: "start failed"} type ErrorPayload struct { Code int `json:"code"` Message string `json:"message"` Details string `json:"details,omitempty"` } -// Common error codes const ( - ErrCodeUnknown = 1000 - ErrCodeInvalidMessage = 1001 - ErrCodeUnauthorized = 1002 - ErrCodeNotFound = 1003 - ErrCodeOperationFailed = 1004 - ErrCodeTimeout = 1005 + ErrorCodeUnknown = 1000 + ErrorCodeInvalidMessage = 1001 + ErrorCodeUnauthorized = 1002 + ErrorCodeNotFound = 1003 + // code := ErrorCodeOperationFailed + ErrorCodeOperationFailed = 1004 + ErrorCodeTimeout = 1005 ) -// NewErrorMessage creates an error response message. +// errorMessage, err := NewErrorMessage("worker-1", "controller-1", ErrorCodeOperationFailed, "miner start failed", "req-1") func NewErrorMessage(from, to string, code int, message string, replyTo string) (*Message, error) { - msg, err := NewMessage(MsgError, from, to, ErrorPayload{ + errorMessage, err := NewMessage(MessageError, from, to, ErrorPayload{ Code: code, Message: message, }) if err != nil { return nil, err } - msg.ReplyTo = replyTo - return msg, nil + errorMessage.ReplyTo = replyTo + return errorMessage, nil } diff --git a/node/message_test.go b/node/message_test.go index 4443470..85ec47d 100644 --- a/node/message_test.go +++ b/node/message_test.go @@ -1,20 +1,19 @@ package node import ( - "encoding/json" "testing" "time" ) -func TestNewMessage(t *testing.T) { +func TestMessage_NewMessage_Good(t *testing.T) { t.Run("BasicMessage", func(t *testing.T) { - msg, err := NewMessage(MsgPing, "sender-id", "receiver-id", nil) + msg, err := NewMessage(MessagePing, "sender-id", "receiver-id", nil) if err != nil { t.Fatalf("failed to create message: %v", err) } - if msg.Type != MsgPing { - t.Errorf("expected type MsgPing, got %s", msg.Type) + if msg.Type != MessagePing { + t.Errorf("expected type MessagePing, got %s", msg.Type) } if msg.From != "sender-id" { @@ -39,7 +38,7 @@ func TestNewMessage(t *testing.T) { SentAt: time.Now().UnixMilli(), } - msg, err := NewMessage(MsgPing, "sender", "receiver", payload) + msg, err := NewMessage(MessagePing, "sender", "receiver", payload) if err != nil { t.Fatalf("failed to create message: %v", err) } @@ -60,10 +59,10 @@ func TestNewMessage(t *testing.T) { }) } -func TestMessageReply(t *testing.T) { - original, _ := NewMessage(MsgPing, "sender", "receiver", PingPayload{SentAt: 12345}) +func TestMessage_Reply_Good(t *testing.T) { + original, _ := NewMessage(MessagePing, "sender", "receiver", PingPayload{SentAt: 12345}) - reply, err := original.Reply(MsgPong, PongPayload{ + reply, err := original.Reply(MessagePong, PongPayload{ SentAt: 12345, ReceivedAt: 12350, }) @@ -84,19 +83,19 @@ func TestMessageReply(t *testing.T) { t.Error("reply To should be original From") } - if reply.Type != MsgPong { - t.Errorf("expected type MsgPong, got %s", reply.Type) + if reply.Type != MessagePong { + t.Errorf("expected type MessagePong, got %s", reply.Type) } } -func TestParsePayload(t *testing.T) { +func TestMessage_ParsePayload_Good(t *testing.T) { t.Run("ValidPayload", func(t *testing.T) { payload := StartMinerPayload{ MinerType: "xmrig", ProfileID: "test-profile", } - msg, _ := NewMessage(MsgStartMiner, "ctrl", "worker", payload) + msg, _ := NewMessage(MessageStartMiner, "ctrl", "worker", payload) var parsed StartMinerPayload err := msg.ParsePayload(&parsed) @@ -110,7 +109,7 @@ func TestParsePayload(t *testing.T) { }) t.Run("NilPayload", func(t *testing.T) { - msg, _ := NewMessage(MsgGetStats, "ctrl", "worker", nil) + msg, _ := NewMessage(MessageGetStats, "ctrl", "worker", nil) var parsed StatsPayload err := msg.ParsePayload(&parsed) @@ -138,7 +137,7 @@ func TestParsePayload(t *testing.T) { Uptime: 86400, } - msg, _ := NewMessage(MsgStats, "worker", "ctrl", stats) + msg, _ := NewMessage(MessageStats, "worker", "ctrl", stats) var parsed StatsPayload err := msg.ParsePayload(&parsed) @@ -160,14 +159,14 @@ func TestParsePayload(t *testing.T) { }) } -func TestNewErrorMessage(t *testing.T) { - errMsg, err := NewErrorMessage("sender", "receiver", ErrCodeOperationFailed, "something went wrong", "original-msg-id") +func TestMessage_NewErrorMessage_Bad(t *testing.T) { + errMsg, err := NewErrorMessage("sender", "receiver", ErrorCodeOperationFailed, "something went wrong", "original-msg-id") if err != nil { t.Fatalf("failed to create error message: %v", err) } - if errMsg.Type != MsgError { - t.Errorf("expected type MsgError, got %s", errMsg.Type) + if errMsg.Type != MessageError { + t.Errorf("expected type MessageError, got %s", errMsg.Type) } if errMsg.ReplyTo != "original-msg-id" { @@ -180,8 +179,8 @@ func TestNewErrorMessage(t *testing.T) { t.Fatalf("failed to parse error payload: %v", err) } - if errPayload.Code != ErrCodeOperationFailed { - t.Errorf("expected code %d, got %d", ErrCodeOperationFailed, errPayload.Code) + if errPayload.Code != ErrorCodeOperationFailed { + t.Errorf("expected code %d, got %d", ErrorCodeOperationFailed, errPayload.Code) } if errPayload.Message != "something went wrong" { @@ -189,24 +188,18 @@ func TestNewErrorMessage(t *testing.T) { } } -func TestMessageSerialization(t *testing.T) { - original, _ := NewMessage(MsgStartMiner, "ctrl", "worker", StartMinerPayload{ +func TestMessage_Serialization_Good(t *testing.T) { + original, _ := NewMessage(MessageStartMiner, "ctrl", "worker", StartMinerPayload{ MinerType: "xmrig", ProfileID: "my-profile", }) // Serialize - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("failed to serialize message: %v", err) - } + data := testJSONMarshal(t, original) // Deserialize var restored Message - err = json.Unmarshal(data, &restored) - if err != nil { - t.Fatalf("failed to deserialize message: %v", err) - } + testJSONUnmarshal(t, data, &restored) if restored.ID != original.ID { t.Error("ID mismatch after serialization") @@ -221,8 +214,7 @@ func TestMessageSerialization(t *testing.T) { } var payload StartMinerPayload - err = restored.ParsePayload(&payload) - if err != nil { + if err := restored.ParsePayload(&payload); err != nil { t.Fatalf("failed to parse restored payload: %v", err) } @@ -231,23 +223,23 @@ func TestMessageSerialization(t *testing.T) { } } -func TestMessageTypes(t *testing.T) { +func TestMessage_Types_Good(t *testing.T) { types := []MessageType{ - MsgHandshake, - MsgHandshakeAck, - MsgPing, - MsgPong, - MsgDisconnect, - MsgGetStats, - MsgStats, - MsgStartMiner, - MsgStopMiner, - MsgMinerAck, - MsgDeploy, - MsgDeployAck, - MsgGetLogs, - MsgLogs, - MsgError, + MessageHandshake, + MessageHandshakeAck, + MessagePing, + MessagePong, + MessageDisconnect, + MessageGetStats, + MessageStats, + MessageStartMiner, + MessageStopMiner, + MessageMinerAck, + MessageDeploy, + MessageDeployAck, + MessageGetLogs, + MessageLogs, + MessageError, } for _, msgType := range types { @@ -264,14 +256,14 @@ func TestMessageTypes(t *testing.T) { } } -func TestErrorCodes(t *testing.T) { +func TestMessage_ErrorCodes_Bad(t *testing.T) { codes := map[int]string{ - ErrCodeUnknown: "Unknown", - ErrCodeInvalidMessage: "InvalidMessage", - ErrCodeUnauthorized: "Unauthorized", - ErrCodeNotFound: "NotFound", - ErrCodeOperationFailed: "OperationFailed", - ErrCodeTimeout: "Timeout", + ErrorCodeUnknown: "Unknown", + ErrorCodeInvalidMessage: "InvalidMessage", + ErrorCodeUnauthorized: "Unauthorized", + ErrorCodeNotFound: "NotFound", + ErrorCodeOperationFailed: "OperationFailed", + ErrorCodeTimeout: "Timeout", } for code, name := range codes { @@ -283,8 +275,8 @@ func TestErrorCodes(t *testing.T) { } } -func TestNewMessage_NilPayload(t *testing.T) { - msg, err := NewMessage(MsgPing, "from", "to", nil) +func TestMessage_NewMessage_NilPayload_Ugly(t *testing.T) { + msg, err := NewMessage(MessagePing, "from", "to", nil) if err != nil { t.Fatalf("NewMessage with nil payload should succeed: %v", err) } @@ -293,7 +285,7 @@ func TestNewMessage_NilPayload(t *testing.T) { } } -func TestMessage_ParsePayload_Nil(t *testing.T) { +func TestMessage_ParsePayload_Nil_Ugly(t *testing.T) { msg := &Message{Payload: nil} var target PingPayload err := msg.ParsePayload(&target) @@ -302,22 +294,25 @@ func TestMessage_ParsePayload_Nil(t *testing.T) { } } -func TestNewErrorMessage_Success(t *testing.T) { - msg, err := NewErrorMessage("from", "to", ErrCodeOperationFailed, "something went wrong", "reply-123") +func TestMessage_NewErrorMessage_Success_Bad(t *testing.T) { + msg, err := NewErrorMessage("from", "to", ErrorCodeOperationFailed, "something went wrong", "reply-123") if err != nil { t.Fatalf("NewErrorMessage failed: %v", err) } - if msg.Type != MsgError { - t.Errorf("expected type %s, got %s", MsgError, msg.Type) + if msg.Type != MessageError { + t.Errorf("expected type %s, got %s", MessageError, msg.Type) } if msg.ReplyTo != "reply-123" { t.Errorf("expected ReplyTo 'reply-123', got '%s'", msg.ReplyTo) } var payload ErrorPayload - msg.ParsePayload(&payload) - if payload.Code != ErrCodeOperationFailed { - t.Errorf("expected code %d, got %d", ErrCodeOperationFailed, payload.Code) + err = msg.ParsePayload(&payload) + if err != nil { + t.Fatalf("ParsePayload failed: %v", err) + } + if payload.Code != ErrorCodeOperationFailed { + t.Errorf("expected code %d, got %d", ErrorCodeOperationFailed, payload.Code) } if payload.Message != "something went wrong" { t.Errorf("expected message 'something went wrong', got '%s'", payload.Message) diff --git a/node/peer.go b/node/peer.go index d4ff02c..2b7606b 100644 --- a/node/peer.go +++ b/node/peer.go @@ -1,24 +1,28 @@ package node import ( - "encoding/json" "iter" "maps" - "path/filepath" "regexp" "slices" "sync" "time" - coreio "dappco.re/go/core/io" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" "dappco.re/go/core/p2p/logging" poindexter "forge.lthn.ai/Snider/Poindexter" "github.com/adrg/xdg" ) -// Peer represents a known remote node. +// peer := &Peer{ +// ID: "worker-1", +// Name: "Worker 1", +// Address: "127.0.0.1:9101", +// PingMilliseconds: 42.5, +// GeographicKilometres: 100, +// Score: 80, +// } type Peer struct { ID string `json:"id"` Name string `json:"name"` @@ -29,23 +33,22 @@ type Peer struct { LastSeen time.Time `json:"lastSeen"` // Poindexter metrics (updated dynamically) - PingMS float64 `json:"pingMs"` // Latency in milliseconds - Hops int `json:"hops"` // Network hop count - GeoKM float64 `json:"geoKm"` // Geographic distance in kilometers - Score float64 `json:"score"` // Reliability score 0-100 + PingMilliseconds float64 `json:"pingMs"` // Latency in milliseconds + Hops int `json:"hops"` // Network hop count + GeographicKilometres float64 `json:"geoKm"` // Geographic distance in kilometres + Score float64 `json:"score"` // Reliability score 0-100 // Connection state (not persisted) Connected bool `json:"-"` } -// saveDebounceInterval is the minimum time between disk writes. -const saveDebounceInterval = 5 * time.Second +const peerRegistrySaveDebounceInterval = 5 * time.Second -// PeerAuthMode controls how unknown peers are handled +// mode := PeerAuthAllowlist type PeerAuthMode int const ( - // PeerAuthOpen allows any peer to connect (original behavior) + // PeerAuthOpen allows any peer to connect. PeerAuthOpen PeerAuthMode = iota // PeerAuthAllowlist only allows pre-registered peers or those with allowed public keys PeerAuthAllowlist @@ -57,10 +60,9 @@ const ( PeerNameMaxLength = 64 ) -// peerNameRegex validates peer names: alphanumeric, hyphens, underscores, and spaces -var peerNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\-_ ]{0,62}[a-zA-Z0-9]$|^[a-zA-Z0-9]$`) +// peerNamePattern validates peer names: alphanumeric, hyphens, underscores, and spaces. +var peerNamePattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\-_ ]{0,62}[a-zA-Z0-9]$|^[a-zA-Z0-9]$`) -// safeKeyPrefix returns a truncated key for logging, handling short keys safely func safeKeyPrefix(key string) string { if len(key) >= 16 { return key[:16] + "..." @@ -71,26 +73,23 @@ func safeKeyPrefix(key string) string { return key } -// validatePeerName checks if a peer name is valid. -// Peer names must be 1-64 characters, start and end with alphanumeric, -// and contain only alphanumeric, hyphens, underscores, and spaces. func validatePeerName(name string) error { if name == "" { return nil // Empty names are allowed (optional field) } if len(name) < PeerNameMinLength { - return coreerr.E("validatePeerName", "peer name too short", nil) + return core.E("validatePeerName", "peer name too short", nil) } if len(name) > PeerNameMaxLength { - return coreerr.E("validatePeerName", "peer name too long", nil) + return core.E("validatePeerName", "peer name too long", nil) } - if !peerNameRegex.MatchString(name) { - return coreerr.E("validatePeerName", "peer name contains invalid characters (use alphanumeric, hyphens, underscores, spaces)", nil) + if !peerNamePattern.MatchString(name) { + return core.E("validatePeerName", "peer name contains invalid characters (use alphanumeric, hyphens, underscores, spaces)", nil) } return nil } -// PeerRegistry manages known peers with KD-tree based selection. +// peerRegistry, err := NewPeerRegistry() type PeerRegistry struct { peers map[string]*Peer kdTree *poindexter.KDTree[string] // KD-tree with peer ID as payload @@ -103,55 +102,56 @@ type PeerRegistry struct { allowedPublicKeyMu sync.RWMutex // Protects allowedPublicKeys // Debounce disk writes - dirty bool // Whether there are unsaved changes - saveTimer *time.Timer // Timer for debounced save - saveMu sync.Mutex // Protects dirty and saveTimer - stopChan chan struct{} // Signal to stop background save - saveStopOnce sync.Once // Ensure stopChan is closed only once + hasPendingChanges bool // Whether there are unsaved changes + pendingSaveTimer *time.Timer // Timer for debounced save + saveMutex sync.Mutex // Protects pending save state } -// Dimension weights for peer selection -// Lower ping, hops, geo are better; higher score is better +// Dimension weights for peer selection. +// Lower ping, hops, and geographic distance are better; higher score is better. var ( - pingWeight = 1.0 - hopsWeight = 0.7 - geoWeight = 0.2 - scoreWeight = 1.2 + pingWeight = 1.0 + hopsWeight = 0.7 + geographicWeight = 0.2 + scoreWeight = 1.2 ) -// NewPeerRegistry creates a new PeerRegistry, loading existing peers if available. +// peerRegistry, err := NewPeerRegistry() func NewPeerRegistry() (*PeerRegistry, error) { peersPath, err := xdg.ConfigFile("lethean-desktop/peers.json") if err != nil { - return nil, coreerr.E("PeerRegistry.New", "failed to get peers path", err) + return nil, core.E("PeerRegistry.New", "failed to get peers path", err) } - return NewPeerRegistryWithPath(peersPath) + return NewPeerRegistryFromPath(peersPath) } -// NewPeerRegistryWithPath creates a new PeerRegistry with a custom path. -// This is primarily useful for testing to avoid xdg path caching issues. -func NewPeerRegistryWithPath(peersPath string) (*PeerRegistry, error) { +// peerRegistry, err := NewPeerRegistryFromPath("/srv/p2p/peers.json") +// Missing files are treated as an empty registry; malformed registry files +// return an error so callers can repair the persisted state. +func NewPeerRegistryFromPath(peersPath string) (*PeerRegistry, error) { pr := &PeerRegistry{ peers: make(map[string]*Peer), path: peersPath, - stopChan: make(chan struct{}), - authMode: PeerAuthOpen, // Default to open for backward compatibility + authMode: PeerAuthOpen, // Default to open. allowedPublicKeys: make(map[string]bool), } - // Try to load existing peers - if err := pr.load(); err != nil { - // No existing peers, that's ok + // Missing files indicate a first run; any existing file must parse cleanly. + if !filesystemExists(peersPath) { pr.rebuildKDTree() return pr, nil } + if err := pr.load(); err != nil { + return nil, err + } + pr.rebuildKDTree() return pr, nil } -// SetAuthMode sets the authentication mode for peer connections. +// registry.SetAuthMode(PeerAuthAllowlist) func (r *PeerRegistry) SetAuthMode(mode PeerAuthMode) { r.allowedPublicKeyMu.Lock() defer r.allowedPublicKeyMu.Unlock() @@ -159,14 +159,14 @@ func (r *PeerRegistry) SetAuthMode(mode PeerAuthMode) { logging.Info("peer auth mode changed", logging.Fields{"mode": mode}) } -// GetAuthMode returns the current authentication mode. +// mode := registry.GetAuthMode() func (r *PeerRegistry) GetAuthMode() PeerAuthMode { r.allowedPublicKeyMu.RLock() defer r.allowedPublicKeyMu.RUnlock() return r.authMode } -// AllowPublicKey adds a public key to the allowlist. +// registry.AllowPublicKey(peer.PublicKey) func (r *PeerRegistry) AllowPublicKey(publicKey string) { r.allowedPublicKeyMu.Lock() defer r.allowedPublicKeyMu.Unlock() @@ -174,7 +174,7 @@ func (r *PeerRegistry) AllowPublicKey(publicKey string) { logging.Debug("public key added to allowlist", logging.Fields{"key": safeKeyPrefix(publicKey)}) } -// RevokePublicKey removes a public key from the allowlist. +// registry.RevokePublicKey(peer.PublicKey) func (r *PeerRegistry) RevokePublicKey(publicKey string) { r.allowedPublicKeyMu.Lock() defer r.allowedPublicKeyMu.Unlock() @@ -182,17 +182,17 @@ func (r *PeerRegistry) RevokePublicKey(publicKey string) { logging.Debug("public key removed from allowlist", logging.Fields{"key": safeKeyPrefix(publicKey)}) } -// IsPublicKeyAllowed checks if a public key is in the allowlist. +// allowed := registry.IsPublicKeyAllowed(peer.PublicKey) func (r *PeerRegistry) IsPublicKeyAllowed(publicKey string) bool { r.allowedPublicKeyMu.RLock() defer r.allowedPublicKeyMu.RUnlock() return r.allowedPublicKeys[publicKey] } -// IsPeerAllowed checks if a peer is allowed to connect based on auth mode. -// Returns true if: -// - AuthMode is Open (allow all) -// - AuthMode is Allowlist AND (peer is pre-registered OR public key is allowlisted) +// Returns true when AuthMode is Open (all allowed), or when Allowlist mode is active +// and the peer is pre-registered or its public key is in the allowlist. +// +// allowed := registry.IsPeerAllowed(peer.ID, peer.PublicKey) func (r *PeerRegistry) IsPeerAllowed(peerID string, publicKey string) bool { r.allowedPublicKeyMu.RLock() authMode := r.authMode @@ -217,12 +217,14 @@ func (r *PeerRegistry) IsPeerAllowed(peerID string, publicKey string) bool { return keyAllowed } -// ListAllowedPublicKeys returns all allowlisted public keys. +// keys := registry.ListAllowedPublicKeys() func (r *PeerRegistry) ListAllowedPublicKeys() []string { return slices.Collect(r.AllowedPublicKeys()) } -// AllowedPublicKeys returns an iterator over all allowlisted public keys. +// for key := range registry.AllowedPublicKeys() { +// log.Printf("allowed: %s", key[:16]) +// } func (r *PeerRegistry) AllowedPublicKeys() iter.Seq[string] { return func(yield func(string) bool) { r.allowedPublicKeyMu.RLock() @@ -236,15 +238,20 @@ func (r *PeerRegistry) AllowedPublicKeys() iter.Seq[string] { } } -// AddPeer adds a new peer to the registry. -// Note: Persistence is debounced (writes batched every 5s). Call Close() to ensure -// all changes are flushed to disk before shutdown. +// err := registry.AddPeer(&Peer{ID: "worker-1", Address: "10.0.0.1:9091", Role: RoleWorker}) func (r *PeerRegistry) AddPeer(peer *Peer) error { + if peer == nil { + return core.E("PeerRegistry.AddPeer", "peer is nil", nil) + } + + peerCopy := *peer + peer = &peerCopy + r.mu.Lock() if peer.ID == "" { r.mu.Unlock() - return coreerr.E("PeerRegistry.AddPeer", "peer ID is required", nil) + return core.E("PeerRegistry.AddPeer", "peer ID is required", nil) } // Validate peer name (P2P-LOW-3) @@ -255,7 +262,7 @@ func (r *PeerRegistry) AddPeer(peer *Peer) error { if _, exists := r.peers[peer.ID]; exists { r.mu.Unlock() - return coreerr.E("PeerRegistry.AddPeer", "peer "+peer.ID+" already exists", nil) + return core.E("PeerRegistry.AddPeer", "peer "+peer.ID+" already exists", nil) } // Set defaults @@ -263,51 +270,67 @@ func (r *PeerRegistry) AddPeer(peer *Peer) error { peer.AddedAt = time.Now() } if peer.Score == 0 { - peer.Score = 50 // Default neutral score + peer.Score = ScoreDefault } r.peers[peer.ID] = peer r.rebuildKDTree() r.mu.Unlock() - return r.save() + r.scheduleSave() + return nil } -// UpdatePeer updates an existing peer's information. -// Note: Persistence is debounced. Call Close() to flush before shutdown. +// Persistence is debounced. Call Close() to flush before shutdown. +// +// err := registry.UpdatePeer(&Peer{ID: "worker-1", Score: 90}) func (r *PeerRegistry) UpdatePeer(peer *Peer) error { + if peer == nil { + return core.E("PeerRegistry.UpdatePeer", "peer is nil", nil) + } + + if peer.ID == "" { + return core.E("PeerRegistry.UpdatePeer", "peer ID is required", nil) + } + + peerCopy := *peer + peer = &peerCopy + r.mu.Lock() if _, exists := r.peers[peer.ID]; !exists { r.mu.Unlock() - return coreerr.E("PeerRegistry.UpdatePeer", "peer "+peer.ID+" not found", nil) + return core.E("PeerRegistry.UpdatePeer", "peer "+peer.ID+" not found", nil) } r.peers[peer.ID] = peer r.rebuildKDTree() r.mu.Unlock() - return r.save() + r.scheduleSave() + return nil } -// RemovePeer removes a peer from the registry. -// Note: Persistence is debounced. Call Close() to flush before shutdown. +// Persistence is debounced. Call Close() to flush before shutdown. +// +// err := registry.RemovePeer("worker-1") func (r *PeerRegistry) RemovePeer(id string) error { r.mu.Lock() if _, exists := r.peers[id]; !exists { r.mu.Unlock() - return coreerr.E("PeerRegistry.RemovePeer", "peer "+id+" not found", nil) + return core.E("PeerRegistry.RemovePeer", "peer "+id+" not found", nil) } delete(r.peers, id) r.rebuildKDTree() r.mu.Unlock() - return r.save() + r.scheduleSave() + return nil } -// GetPeer returns a peer by ID. +// peer := registry.GetPeer("worker-1") func (r *PeerRegistry) GetPeer(id string) *Peer { r.mu.RLock() defer r.mu.RUnlock() @@ -317,18 +340,20 @@ func (r *PeerRegistry) GetPeer(id string) *Peer { return nil } - // Return a copy peerCopy := *peer return &peerCopy } -// ListPeers returns all registered peers. +// peers := registry.ListPeers() func (r *PeerRegistry) ListPeers() []*Peer { return slices.Collect(r.Peers()) } -// Peers returns an iterator over all registered peers. // Each peer is a copy to prevent mutation. +// +// for peer := range registry.Peers() { +// _ = peer +// } func (r *PeerRegistry) Peers() iter.Seq[*Peer] { return func(yield func(*Peer) bool) { r.mu.RLock() @@ -343,29 +368,30 @@ func (r *PeerRegistry) Peers() iter.Seq[*Peer] { } } -// UpdateMetrics updates a peer's performance metrics. +// registry.UpdateMetrics("worker-1", 42.5, 100, 3) // Note: Persistence is debounced. Call Close() to flush before shutdown. -func (r *PeerRegistry) UpdateMetrics(id string, pingMS, geoKM float64, hops int) error { +func (r *PeerRegistry) UpdateMetrics(id string, pingMilliseconds, geographicKilometres float64, hopCount int) error { r.mu.Lock() peer, exists := r.peers[id] if !exists { r.mu.Unlock() - return coreerr.E("PeerRegistry.UpdateMetrics", "peer "+id+" not found", nil) + return core.E("PeerRegistry.UpdateMetrics", "peer "+id+" not found", nil) } - peer.PingMS = pingMS - peer.GeoKM = geoKM - peer.Hops = hops + peer.PingMilliseconds = pingMilliseconds + peer.GeographicKilometres = geographicKilometres + peer.Hops = hopCount peer.LastSeen = time.Now() r.rebuildKDTree() r.mu.Unlock() - return r.save() + r.scheduleSave() + return nil } -// UpdateScore updates a peer's reliability score. +// registry.UpdateScore("worker-1", 90) // Note: Persistence is debounced. Call Close() to flush before shutdown. func (r *PeerRegistry) UpdateScore(id string, score float64) error { r.mu.Lock() @@ -373,7 +399,7 @@ func (r *PeerRegistry) UpdateScore(id string, score float64) error { peer, exists := r.peers[id] if !exists { r.mu.Unlock() - return coreerr.E("PeerRegistry.UpdateScore", "peer "+id+" not found", nil) + return core.E("PeerRegistry.UpdateScore", "peer "+id+" not found", nil) } // Clamp score to 0-100 @@ -383,10 +409,11 @@ func (r *PeerRegistry) UpdateScore(id string, score float64) error { r.rebuildKDTree() r.mu.Unlock() - return r.save() + r.scheduleSave() + return nil } -// SetConnected updates a peer's connection state. +// registry.SetConnected("worker-1", true) func (r *PeerRegistry) SetConnected(id string, connected bool) { r.mu.Lock() defer r.mu.Unlock() @@ -409,7 +436,7 @@ const ( ScoreDefault = 50.0 // Default score for new peers ) -// RecordSuccess records a successful interaction with a peer, improving their score. +// registry.RecordSuccess("worker-1") func (r *PeerRegistry) RecordSuccess(id string) { r.mu.Lock() peer, exists := r.peers[id] @@ -421,10 +448,10 @@ func (r *PeerRegistry) RecordSuccess(id string) { peer.Score = min(peer.Score+ScoreSuccessIncrement, ScoreMaximum) peer.LastSeen = time.Now() r.mu.Unlock() - r.save() + r.scheduleSave() } -// RecordFailure records a failed interaction with a peer, reducing their score. +// registry.RecordFailure("worker-1") func (r *PeerRegistry) RecordFailure(id string) { r.mu.Lock() peer, exists := r.peers[id] @@ -436,7 +463,7 @@ func (r *PeerRegistry) RecordFailure(id string) { peer.Score = max(peer.Score-ScoreFailureDecrement, ScoreMinimum) newScore := peer.Score r.mu.Unlock() - r.save() + r.scheduleSave() logging.Debug("peer score decreased", logging.Fields{ "peer_id": id, @@ -445,7 +472,7 @@ func (r *PeerRegistry) RecordFailure(id string) { }) } -// RecordTimeout records a timeout when communicating with a peer. +// registry.RecordTimeout("worker-1") func (r *PeerRegistry) RecordTimeout(id string) { r.mu.Lock() peer, exists := r.peers[id] @@ -457,7 +484,7 @@ func (r *PeerRegistry) RecordTimeout(id string) { peer.Score = max(peer.Score-ScoreTimeoutDecrement, ScoreMinimum) newScore := peer.Score r.mu.Unlock() - r.save() + r.scheduleSave() logging.Debug("peer score decreased", logging.Fields{ "peer_id": id, @@ -466,14 +493,13 @@ func (r *PeerRegistry) RecordTimeout(id string) { }) } -// GetPeersByScore returns peers sorted by score (highest first). +// peers := registry.GetPeersByScore() func (r *PeerRegistry) GetPeersByScore() []*Peer { r.mu.RLock() defer r.mu.RUnlock() peers := slices.Collect(maps.Values(r.peers)) - // Sort by score descending slices.SortFunc(peers, func(a, b *Peer) int { if b.Score > a.Score { return 1 @@ -484,10 +510,18 @@ func (r *PeerRegistry) GetPeersByScore() []*Peer { return 0 }) - return peers + peerCopies := make([]*Peer, 0, len(peers)) + for _, peer := range peers { + peerCopy := *peer + peerCopies = append(peerCopies, &peerCopy) + } + + return peerCopies } -// PeersByScore returns an iterator over peers sorted by score (highest first). +// for peer := range registry.PeersByScore() { +// log.Printf("peer %s score=%.0f", peer.ID, peer.Score) +// } func (r *PeerRegistry) PeersByScore() iter.Seq[*Peer] { return func(yield func(*Peer) bool) { peers := r.GetPeersByScore() @@ -499,8 +533,9 @@ func (r *PeerRegistry) PeersByScore() iter.Seq[*Peer] { } } -// SelectOptimalPeer returns the best peer based on multi-factor optimization. -// Uses Poindexter KD-tree to find the peer closest to ideal metrics. +// Uses Poindexter KD-tree to find the peer closest to ideal metrics (low ping, low hops, high score). +// +// peer := registry.SelectOptimalPeer() func (r *PeerRegistry) SelectOptimalPeer() *Peer { r.mu.RLock() defer r.mu.RUnlock() @@ -509,7 +544,7 @@ func (r *PeerRegistry) SelectOptimalPeer() *Peer { return nil } - // Target: ideal peer (0 ping, 0 hops, 0 geo, 100 score) + // Target: ideal peer (0 ping, 0 hops, 0 geographic distance, 100 score) // Score is inverted (100 - score) so lower is better in the tree target := []float64{0, 0, 0, 0} @@ -527,7 +562,7 @@ func (r *PeerRegistry) SelectOptimalPeer() *Peer { return &peerCopy } -// SelectNearestPeers returns the n best peers based on multi-factor optimization. +// peers := registry.SelectNearestPeers(3) func (r *PeerRegistry) SelectNearestPeers(n int) []*Peer { r.mu.RLock() defer r.mu.RUnlock() @@ -552,13 +587,16 @@ func (r *PeerRegistry) SelectNearestPeers(n int) []*Peer { return peers } -// GetConnectedPeers returns all currently connected peers. +// connectedPeers := registry.GetConnectedPeers() func (r *PeerRegistry) GetConnectedPeers() []*Peer { return slices.Collect(r.ConnectedPeers()) } -// ConnectedPeers returns an iterator over all currently connected peers. // Each peer is a copy to prevent mutation. +// +// for peer := range registry.ConnectedPeers() { +// _ = peer +// } func (r *PeerRegistry) ConnectedPeers() iter.Seq[*Peer] { return func(yield func(*Peer) bool) { r.mu.RLock() @@ -575,15 +613,13 @@ func (r *PeerRegistry) ConnectedPeers() iter.Seq[*Peer] { } } -// Count returns the number of registered peers. +// n := registry.Count() func (r *PeerRegistry) Count() int { r.mu.RLock() defer r.mu.RUnlock() return len(r.peers) } -// rebuildKDTree rebuilds the KD-tree from current peers. -// Must be called with lock held. func (r *PeerRegistry) rebuildKDTree() { if len(r.peers) == 0 { r.kdTree = nil @@ -592,14 +628,14 @@ func (r *PeerRegistry) rebuildKDTree() { points := make([]poindexter.KDPoint[string], 0, len(r.peers)) for _, peer := range r.peers { - // Build 4D point with weighted, normalized values + // Build a 4D point with weighted, normalised values. // Invert score so that higher score = lower value (better) point := poindexter.KDPoint[string]{ ID: peer.ID, Coords: []float64{ - peer.PingMS * pingWeight, + peer.PingMilliseconds * pingWeight, float64(peer.Hops) * hopsWeight, - peer.GeoKM * geoWeight, + peer.GeographicKilometres * geographicWeight, (100 - peer.Score) * scoreWeight, // Invert score }, Value: peer.ID, @@ -618,26 +654,26 @@ func (r *PeerRegistry) rebuildKDTree() { } // scheduleSave schedules a debounced save operation. -// Multiple calls within saveDebounceInterval will be coalesced into a single save. -// Must NOT be called with r.mu held. +// Multiple calls within peerRegistrySaveDebounceInterval will be coalesced into a single save. +// Call it after releasing r.mu so peer state and save state do not interleave. func (r *PeerRegistry) scheduleSave() { - r.saveMu.Lock() - defer r.saveMu.Unlock() + r.saveMutex.Lock() + defer r.saveMutex.Unlock() - r.dirty = true + r.hasPendingChanges = true // If timer already running, let it handle the save - if r.saveTimer != nil { + if r.pendingSaveTimer != nil { return } // Start a new timer - r.saveTimer = time.AfterFunc(saveDebounceInterval, func() { - r.saveMu.Lock() - r.saveTimer = nil - shouldSave := r.dirty - r.dirty = false - r.saveMu.Unlock() + r.pendingSaveTimer = time.AfterFunc(peerRegistrySaveDebounceInterval, func() { + r.saveMutex.Lock() + r.pendingSaveTimer = nil + shouldSave := r.hasPendingChanges + r.hasPendingChanges = false + r.saveMutex.Unlock() if shouldSave { r.mu.RLock() @@ -655,48 +691,45 @@ func (r *PeerRegistry) scheduleSave() { // Must be called with r.mu held (at least RLock). func (r *PeerRegistry) saveNow() error { // Ensure directory exists - dir := filepath.Dir(r.path) - if err := coreio.Local.EnsureDir(dir); err != nil { - return coreerr.E("PeerRegistry.saveNow", "failed to create peers directory", err) + dir := core.PathDir(r.path) + if err := filesystemEnsureDir(dir); err != nil { + return core.E("PeerRegistry.saveNow", "failed to create peers directory", err) } // Convert to slice for JSON peers := slices.Collect(maps.Values(r.peers)) - data, err := json.MarshalIndent(peers, "", " ") - if err != nil { - return coreerr.E("PeerRegistry.saveNow", "failed to marshal peers", err) + result := core.JSONMarshal(peers) + if !result.OK { + return core.E("PeerRegistry.saveNow", "failed to marshal peers", result.Value.(error)) } + data := result.Value.([]byte) // Use atomic write pattern: write to temp file, then rename tmpPath := r.path + ".tmp" - if err := coreio.Local.Write(tmpPath, string(data)); err != nil { - return coreerr.E("PeerRegistry.saveNow", "failed to write peers temp file", err) + if err := filesystemWrite(tmpPath, string(data)); err != nil { + return core.E("PeerRegistry.saveNow", "failed to write peers temp file", err) } - if err := coreio.Local.Rename(tmpPath, r.path); err != nil { - coreio.Local.Delete(tmpPath) // Clean up temp file - return coreerr.E("PeerRegistry.saveNow", "failed to rename peers file", err) + if err := filesystemRename(tmpPath, r.path); err != nil { + filesystemDelete(tmpPath) // Clean up temp file + return core.E("PeerRegistry.saveNow", "failed to rename peers file", err) } return nil } -// Close flushes any pending changes and releases resources. +// registry.Close() func (r *PeerRegistry) Close() error { - r.saveStopOnce.Do(func() { - close(r.stopChan) - }) - - // Cancel pending timer and save immediately if dirty - r.saveMu.Lock() - if r.saveTimer != nil { - r.saveTimer.Stop() - r.saveTimer = nil + // Cancel any pending timer and save immediately if changes are queued. + r.saveMutex.Lock() + if r.pendingSaveTimer != nil { + r.pendingSaveTimer.Stop() + r.pendingSaveTimer = nil } - shouldSave := r.dirty - r.dirty = false - r.saveMu.Unlock() + shouldSave := r.hasPendingChanges + r.hasPendingChanges = false + r.saveMutex.Unlock() if shouldSave { r.mu.RLock() @@ -708,24 +741,16 @@ func (r *PeerRegistry) Close() error { return nil } -// save is a helper that schedules a debounced save. -// Kept for backward compatibility but now debounces writes. -// Must NOT be called with r.mu held. -func (r *PeerRegistry) save() error { - r.scheduleSave() - return nil // Errors will be logged asynchronously -} - -// load reads peers from disk. func (r *PeerRegistry) load() error { - content, err := coreio.Local.Read(r.path) + content, err := filesystemRead(r.path) if err != nil { - return coreerr.E("PeerRegistry.load", "failed to read peers", err) + return core.E("PeerRegistry.load", "failed to read peers", err) } var peers []*Peer - if err := json.Unmarshal([]byte(content), &peers); err != nil { - return coreerr.E("PeerRegistry.load", "failed to unmarshal peers", err) + result := core.JSONUnmarshalString(content, &peers) + if !result.OK { + return core.E("PeerRegistry.load", "failed to unmarshal peers", result.Value.(error)) } r.peers = make(map[string]*Peer) @@ -735,5 +760,3 @@ func (r *PeerRegistry) load() error { return nil } - -// Example usage inside a connection handler diff --git a/node/peer_test.go b/node/peer_test.go index 9653cbe..45d6f0f 100644 --- a/node/peer_test.go +++ b/node/peer_test.go @@ -1,35 +1,24 @@ package node import ( - "os" - "path/filepath" "slices" "testing" "time" ) func setupTestPeerRegistry(t *testing.T) (*PeerRegistry, func()) { - tmpDir, err := os.MkdirTemp("", "peer-registry-test") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) - } + tmpDir := t.TempDir() + peersPath := testJoinPath(tmpDir, "peers.json") - peersPath := filepath.Join(tmpDir, "peers.json") - - pr, err := NewPeerRegistryWithPath(peersPath) + pr, err := NewPeerRegistryFromPath(peersPath) if err != nil { - os.RemoveAll(tmpDir) t.Fatalf("failed to create peer registry: %v", err) } - cleanup := func() { - os.RemoveAll(tmpDir) - } - - return pr, cleanup + return pr, func() {} } -func TestPeerRegistry_NewPeerRegistry(t *testing.T) { +func TestPeer_Registry_NewPeerRegistry_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -38,7 +27,23 @@ func TestPeerRegistry_NewPeerRegistry(t *testing.T) { } } -func TestPeerRegistry_AddPeer(t *testing.T) { +func TestPeer_Registry_NewPeerRegistryFromPath_CorruptFile_Bad(t *testing.T) { + tmpDir := t.TempDir() + peersPath := testJoinPath(tmpDir, "peers.json") + + testWriteFile(t, peersPath, []byte(`{"id":"peer-1"`), 0o600) + + pr, err := NewPeerRegistryFromPath(peersPath) + if err == nil { + t.Fatal("expected error when loading a corrupted peer registry") + } + + if pr != nil { + t.Fatal("expected nil peer registry when persisted data is corrupted") + } +} + +func TestPeer_Registry_AddPeer_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -60,6 +65,15 @@ func TestPeerRegistry_AddPeer(t *testing.T) { t.Errorf("expected 1 peer, got %d", pr.Count()) } + peer.Name = "Mutated after add" + stored := pr.GetPeer("test-peer-1") + if stored == nil { + t.Fatal("expected peer to exist after add") + } + if stored.Name != "Test Peer" { + t.Errorf("expected stored peer to remain unchanged, got %q", stored.Name) + } + // Try to add duplicate err = pr.AddPeer(peer) if err == nil { @@ -67,7 +81,7 @@ func TestPeerRegistry_AddPeer(t *testing.T) { } } -func TestPeerRegistry_GetPeer(t *testing.T) { +func TestPeer_Registry_Peer_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -97,7 +111,7 @@ func TestPeerRegistry_GetPeer(t *testing.T) { } } -func TestPeerRegistry_ListPeers(t *testing.T) { +func TestPeer_Registry_ListPeers_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -117,7 +131,7 @@ func TestPeerRegistry_ListPeers(t *testing.T) { } } -func TestPeerRegistry_RemovePeer(t *testing.T) { +func TestPeer_Registry_RemovePeer_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -150,7 +164,7 @@ func TestPeerRegistry_RemovePeer(t *testing.T) { } } -func TestPeerRegistry_UpdateMetrics(t *testing.T) { +func TestPeer_Registry_UpdateMetrics_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -172,18 +186,18 @@ func TestPeerRegistry_UpdateMetrics(t *testing.T) { if updated == nil { t.Fatal("expected peer to exist") } - if updated.PingMS != 50.5 { - t.Errorf("expected ping 50.5, got %f", updated.PingMS) + if updated.PingMilliseconds != 50.5 { + t.Errorf("expected ping 50.5, got %f", updated.PingMilliseconds) } - if updated.GeoKM != 100.2 { - t.Errorf("expected geo 100.2, got %f", updated.GeoKM) + if updated.GeographicKilometres != 100.2 { + t.Errorf("expected geographic distance 100.2, got %f", updated.GeographicKilometres) } if updated.Hops != 3 { t.Errorf("expected hops 3, got %d", updated.Hops) } } -func TestPeerRegistry_UpdateScore(t *testing.T) { +func TestPeer_Registry_UpdateScore_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -237,7 +251,7 @@ func TestPeerRegistry_UpdateScore(t *testing.T) { } } -func TestPeerRegistry_SetConnected(t *testing.T) { +func TestPeer_Registry_MarkConnected_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -272,7 +286,7 @@ func TestPeerRegistry_SetConnected(t *testing.T) { } } -func TestPeerRegistry_GetConnectedPeers(t *testing.T) { +func TestPeer_Registry_ConnectedPeerList_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -295,15 +309,15 @@ func TestPeerRegistry_GetConnectedPeers(t *testing.T) { } } -func TestPeerRegistry_SelectOptimalPeer(t *testing.T) { +func TestPeer_Registry_SelectOptimalPeer_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() // Add peers with different metrics peers := []*Peer{ - {ID: "opt-1", Name: "Slow Peer", PingMS: 200, Hops: 5, GeoKM: 1000, Score: 50}, - {ID: "opt-2", Name: "Fast Peer", PingMS: 10, Hops: 1, GeoKM: 50, Score: 90}, - {ID: "opt-3", Name: "Medium Peer", PingMS: 50, Hops: 2, GeoKM: 200, Score: 70}, + {ID: "opt-1", Name: "Slow Peer", PingMilliseconds: 200, Hops: 5, GeographicKilometres: 1000, Score: 50}, + {ID: "opt-2", Name: "Fast Peer", PingMilliseconds: 10, Hops: 1, GeographicKilometres: 50, Score: 90}, + {ID: "opt-3", Name: "Medium Peer", PingMilliseconds: 50, Hops: 2, GeographicKilometres: 200, Score: 70}, } for _, p := range peers { @@ -321,15 +335,15 @@ func TestPeerRegistry_SelectOptimalPeer(t *testing.T) { } } -func TestPeerRegistry_SelectNearestPeers(t *testing.T) { +func TestPeer_Registry_SelectNearestPeers_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() peers := []*Peer{ - {ID: "near-1", Name: "Peer 1", PingMS: 100, Score: 50}, - {ID: "near-2", Name: "Peer 2", PingMS: 10, Score: 90}, - {ID: "near-3", Name: "Peer 3", PingMS: 50, Score: 70}, - {ID: "near-4", Name: "Peer 4", PingMS: 200, Score: 30}, + {ID: "near-1", Name: "Peer 1", PingMilliseconds: 100, Score: 50}, + {ID: "near-2", Name: "Peer 2", PingMilliseconds: 10, Score: 90}, + {ID: "near-3", Name: "Peer 3", PingMilliseconds: 50, Score: 70}, + {ID: "near-4", Name: "Peer 4", PingMilliseconds: 200, Score: 30}, } for _, p := range peers { @@ -342,14 +356,12 @@ func TestPeerRegistry_SelectNearestPeers(t *testing.T) { } } -func TestPeerRegistry_Persistence(t *testing.T) { - tmpDir, _ := os.MkdirTemp("", "persist-test") - defer os.RemoveAll(tmpDir) - - peersPath := filepath.Join(tmpDir, "peers.json") +func TestPeer_Registry_Persistence_Good(t *testing.T) { + tmpDir := t.TempDir() + peersPath := testJoinPath(tmpDir, "peers.json") // Create and save - pr1, err := NewPeerRegistryWithPath(peersPath) + pr1, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to create first registry: %v", err) } @@ -370,7 +382,7 @@ func TestPeerRegistry_Persistence(t *testing.T) { } // Load in new registry from same path - pr2, err := NewPeerRegistryWithPath(peersPath) + pr2, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to create second registry: %v", err) } @@ -391,7 +403,7 @@ func TestPeerRegistry_Persistence(t *testing.T) { // --- Security Feature Tests --- -func TestPeerRegistry_AuthMode(t *testing.T) { +func TestPeer_Registry_AuthMode_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -413,7 +425,7 @@ func TestPeerRegistry_AuthMode(t *testing.T) { } } -func TestPeerRegistry_PublicKeyAllowlist(t *testing.T) { +func TestPeer_Registry_PublicKeyAllowlist_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -450,7 +462,7 @@ func TestPeerRegistry_PublicKeyAllowlist(t *testing.T) { } } -func TestPeerRegistry_IsPeerAllowed_OpenMode(t *testing.T) { +func TestPeer_Registry_IsPeerAllowed_OpenMode_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -466,7 +478,7 @@ func TestPeerRegistry_IsPeerAllowed_OpenMode(t *testing.T) { } } -func TestPeerRegistry_IsPeerAllowed_AllowlistMode(t *testing.T) { +func TestPeer_Registry_IsPeerAllowed_AllowlistMode_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -501,7 +513,7 @@ func TestPeerRegistry_IsPeerAllowed_AllowlistMode(t *testing.T) { } } -func TestPeerRegistry_PeerNameValidation(t *testing.T) { +func TestPeer_Registry_PeerNameValidation_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -545,7 +557,7 @@ func TestPeerRegistry_PeerNameValidation(t *testing.T) { } } -func TestPeerRegistry_ScoreRecording(t *testing.T) { +func TestPeer_Registry_ScoreRecording_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -601,7 +613,7 @@ func TestPeerRegistry_ScoreRecording(t *testing.T) { } } -func TestPeerRegistry_GetPeersByScore(t *testing.T) { +func TestPeer_Registry_PeersSortedByScore_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -631,11 +643,39 @@ func TestPeerRegistry_GetPeersByScore(t *testing.T) { if sorted[2].ID != "low-score" { t.Errorf("third peer should be low-score, got %s", sorted[2].ID) } + + sorted[0].Name = "Mutated" + restored := pr.GetPeer("high-score") + if restored == nil { + t.Fatal("expected high-score peer to still exist") + } + if restored.Name != "High" { + t.Errorf("expected registry peer to remain unchanged, got %q", restored.Name) + } +} + +func TestPeer_Registry_NilPeerInputs_Bad(t *testing.T) { + pr, cleanup := setupTestPeerRegistry(t) + defer cleanup() + + t.Run("AddPeer", func(t *testing.T) { + err := pr.AddPeer(nil) + if err == nil { + t.Fatal("expected error when adding nil peer") + } + }) + + t.Run("UpdatePeer", func(t *testing.T) { + err := pr.UpdatePeer(nil) + if err == nil { + t.Fatal("expected error when updating nil peer") + } + }) } // --- Additional coverage tests for peer.go --- -func TestSafeKeyPrefix(t *testing.T) { +func TestPeer_SafeKeyPrefix_Good(t *testing.T) { tests := []struct { name string key string @@ -658,7 +698,7 @@ func TestSafeKeyPrefix(t *testing.T) { } } -func TestValidatePeerName(t *testing.T) { +func TestPeer_ValidatePeerName_Good(t *testing.T) { tests := []struct { name string peerName string @@ -691,7 +731,7 @@ func TestValidatePeerName(t *testing.T) { } } -func TestPeerRegistry_AddPeer_EmptyID(t *testing.T) { +func TestPeer_Registry_AddPeer_EmptyID_Bad(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -702,7 +742,7 @@ func TestPeerRegistry_AddPeer_EmptyID(t *testing.T) { } } -func TestPeerRegistry_UpdatePeer(t *testing.T) { +func TestPeer_Registry_UpdatePeer_Good(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -733,9 +773,22 @@ func TestPeerRegistry_UpdatePeer(t *testing.T) { if updated.Score != 80 { t.Errorf("expected score 80, got %f", updated.Score) } + + peer.Name = "Mutated after update" + peer.Score = 12 + stored := pr.GetPeer("update-test") + if stored == nil { + t.Fatal("expected peer to exist after update mutation") + } + if stored.Name != "Updated" { + t.Errorf("expected stored peer name to remain Updated, got %q", stored.Name) + } + if stored.Score != 80 { + t.Errorf("expected stored peer score to remain 80, got %f", stored.Score) + } } -func TestPeerRegistry_UpdateMetrics_NotFound(t *testing.T) { +func TestPeer_Registry_UpdateMetrics_NotFound_Bad(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -745,7 +798,7 @@ func TestPeerRegistry_UpdateMetrics_NotFound(t *testing.T) { } } -func TestPeerRegistry_UpdateScore_NotFound(t *testing.T) { +func TestPeer_Registry_UpdateScore_NotFound_Bad(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -755,7 +808,7 @@ func TestPeerRegistry_UpdateScore_NotFound(t *testing.T) { } } -func TestPeerRegistry_RecordSuccess_NotFound(t *testing.T) { +func TestPeer_Registry_RecordSuccess_NotFound_Bad(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -763,21 +816,21 @@ func TestPeerRegistry_RecordSuccess_NotFound(t *testing.T) { pr.RecordSuccess("ghost-peer") } -func TestPeerRegistry_RecordFailure_NotFound(t *testing.T) { +func TestPeer_Registry_RecordFailure_NotFound_Bad(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() pr.RecordFailure("ghost-peer") } -func TestPeerRegistry_RecordTimeout_NotFound(t *testing.T) { +func TestPeer_Registry_RecordTimeout_NotFound_Bad(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() pr.RecordTimeout("ghost-peer") } -func TestPeerRegistry_SelectOptimalPeer_EmptyRegistry(t *testing.T) { +func TestPeer_Registry_SelectOptimalPeer_EmptyRegistry_Ugly(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -787,7 +840,7 @@ func TestPeerRegistry_SelectOptimalPeer_EmptyRegistry(t *testing.T) { } } -func TestPeerRegistry_SelectNearestPeers_EmptyRegistry(t *testing.T) { +func TestPeer_Registry_SelectNearestPeers_EmptyRegistry_Ugly(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -797,7 +850,7 @@ func TestPeerRegistry_SelectNearestPeers_EmptyRegistry(t *testing.T) { } } -func TestPeerRegistry_SetConnected_NonExistent(t *testing.T) { +func TestPeer_Registry_MarkConnected_NonExistent_Bad(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -805,7 +858,7 @@ func TestPeerRegistry_SetConnected_NonExistent(t *testing.T) { pr.SetConnected("ghost-peer", true) } -func TestPeerRegistry_Close_NoDirtyData(t *testing.T) { +func TestPeer_Registry_Close_NoDirtyData_Ugly(t *testing.T) { pr, cleanup := setupTestPeerRegistry(t) defer cleanup() @@ -816,12 +869,10 @@ func TestPeerRegistry_Close_NoDirtyData(t *testing.T) { } } -func TestPeerRegistry_Close_WithDirtyData(t *testing.T) { - tmpDir, _ := os.MkdirTemp("", "close-dirty-test") - defer os.RemoveAll(tmpDir) - - peersPath := filepath.Join(tmpDir, "peers.json") - pr, err := NewPeerRegistryWithPath(peersPath) +func TestPeer_Registry_Close_WithDirtyData_Ugly(t *testing.T) { + tmpDir := t.TempDir() + peersPath := testJoinPath(tmpDir, "peers.json") + pr, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to create registry: %v", err) } @@ -836,7 +887,7 @@ func TestPeerRegistry_Close_WithDirtyData(t *testing.T) { } // Verify data was saved - pr2, err := NewPeerRegistryWithPath(peersPath) + pr2, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to reload: %v", err) } @@ -845,12 +896,10 @@ func TestPeerRegistry_Close_WithDirtyData(t *testing.T) { } } -func TestPeerRegistry_ScheduleSave_Debounce(t *testing.T) { - tmpDir, _ := os.MkdirTemp("", "debounce-test") - defer os.RemoveAll(tmpDir) - - peersPath := filepath.Join(tmpDir, "peers.json") - pr, err := NewPeerRegistryWithPath(peersPath) +func TestPeer_Registry_ScheduleSave_Debounce_Ugly(t *testing.T) { + tmpDir := t.TempDir() + peersPath := testJoinPath(tmpDir, "peers.json") + pr, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to create registry: %v", err) } @@ -867,12 +916,10 @@ func TestPeerRegistry_ScheduleSave_Debounce(t *testing.T) { } } -func TestPeerRegistry_SaveNow(t *testing.T) { - tmpDir, _ := os.MkdirTemp("", "savenow-test") - defer os.RemoveAll(tmpDir) - - peersPath := filepath.Join(tmpDir, "subdir", "peers.json") - pr, err := NewPeerRegistryWithPath(peersPath) +func TestPeer_Registry_SaveNow_Good(t *testing.T) { + tmpDir := t.TempDir() + peersPath := testJoinPath(tmpDir, "subdir", "peers.json") + pr, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to create registry: %v", err) } @@ -888,21 +935,19 @@ func TestPeerRegistry_SaveNow(t *testing.T) { } // Verify the file was written - if _, err := os.Stat(peersPath); os.IsNotExist(err) { + if !filesystemExists(peersPath) { t.Error("peers.json should exist after saveNow") } } -func TestPeerRegistry_ScheduleSave_TimerFires(t *testing.T) { +func TestPeer_Registry_ScheduleSave_TimerFires_Ugly(t *testing.T) { if testing.Short() { t.Skip("skipping debounce timer test in short mode") } - tmpDir, _ := os.MkdirTemp("", "timer-fire-test") - defer os.RemoveAll(tmpDir) - - peersPath := filepath.Join(tmpDir, "peers.json") - pr, err := NewPeerRegistryWithPath(peersPath) + tmpDir := t.TempDir() + peersPath := testJoinPath(tmpDir, "peers.json") + pr, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to create registry: %v", err) } @@ -913,12 +958,12 @@ func TestPeerRegistry_ScheduleSave_TimerFires(t *testing.T) { time.Sleep(6 * time.Second) // The file should have been saved by the timer - if _, err := os.Stat(peersPath); os.IsNotExist(err) { + if !filesystemExists(peersPath) { t.Error("peers.json should exist after debounce timer fires") } // Reload and verify - pr2, err := NewPeerRegistryWithPath(peersPath) + pr2, err := NewPeerRegistryFromPath(peersPath) if err != nil { t.Fatalf("failed to reload: %v", err) } diff --git a/node/protocol.go b/node/protocol.go index 80ca346..0565e76 100644 --- a/node/protocol.go +++ b/node/protocol.go @@ -1,53 +1,46 @@ package node import ( - "fmt" - - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" ) -// ProtocolError represents an error from the remote peer. +// err := &ProtocolError{Code: ErrorCodeOperationFailed, Message: "start failed"} type ProtocolError struct { Code int Message string } func (e *ProtocolError) Error() string { - return fmt.Sprintf("remote error (%d): %s", e.Code, e.Message) + return core.Sprintf("remote error (%d): %s", e.Code, e.Message) } -// ResponseHandler provides helpers for handling protocol responses. +// handler := &ResponseHandler{} type ResponseHandler struct{} -// ValidateResponse checks if the response is valid and returns a parsed error if it's an error response. -// It checks: -// 1. If response is nil (returns error) -// 2. If response is an error message (returns ProtocolError) -// 3. If response type matches expected (returns error if not) +// err := handler.ValidateResponse(resp, MessageStats) func (h *ResponseHandler) ValidateResponse(resp *Message, expectedType MessageType) error { if resp == nil { - return coreerr.E("ResponseHandler.ValidateResponse", "nil response", nil) + return core.E("ResponseHandler.ValidateResponse", "nil response", nil) } // Check for error response - if resp.Type == MsgError { + if resp.Type == MessageError { var errPayload ErrorPayload if err := resp.ParsePayload(&errPayload); err != nil { - return &ProtocolError{Code: ErrCodeUnknown, Message: "unable to parse error response"} + return &ProtocolError{Code: ErrorCodeUnknown, Message: "unable to parse error response"} } return &ProtocolError{Code: errPayload.Code, Message: errPayload.Message} } // Check expected type if resp.Type != expectedType { - return coreerr.E("ResponseHandler.ValidateResponse", "unexpected response type: expected "+string(expectedType)+", got "+string(resp.Type), nil) + return core.E("ResponseHandler.ValidateResponse", "unexpected response type: expected "+string(expectedType)+", got "+string(resp.Type), nil) } return nil } -// ParseResponse validates the response and parses the payload into the target. -// This combines ValidateResponse and ParsePayload into a single call. +// err := handler.ParseResponse(resp, MessageStats, &stats) func (h *ResponseHandler) ParseResponse(resp *Message, expectedType MessageType, target any) error { if err := h.ValidateResponse(resp, expectedType); err != nil { return err @@ -55,33 +48,33 @@ func (h *ResponseHandler) ParseResponse(resp *Message, expectedType MessageType, if target != nil { if err := resp.ParsePayload(target); err != nil { - return coreerr.E("ResponseHandler.ParseResponse", "failed to parse "+string(expectedType)+" payload", err) + return core.E("ResponseHandler.ParseResponse", "failed to parse "+string(expectedType)+" payload", err) } } return nil } -// DefaultResponseHandler is the default response handler instance. +// handler := DefaultResponseHandler var DefaultResponseHandler = &ResponseHandler{} -// ValidateResponse is a convenience function using the default handler. +// err := ValidateResponse(message, MessageStats) func ValidateResponse(resp *Message, expectedType MessageType) error { return DefaultResponseHandler.ValidateResponse(resp, expectedType) } -// ParseResponse is a convenience function using the default handler. +// err := ParseResponse(message, MessageStats, &stats) func ParseResponse(resp *Message, expectedType MessageType, target any) error { return DefaultResponseHandler.ParseResponse(resp, expectedType, target) } -// IsProtocolError returns true if the error is a ProtocolError. +// ok := IsProtocolError(err) func IsProtocolError(err error) bool { _, ok := err.(*ProtocolError) return ok } -// GetProtocolErrorCode returns the error code if err is a ProtocolError, otherwise returns 0. +// code := GetProtocolErrorCode(err) func GetProtocolErrorCode(err error) int { if pe, ok := err.(*ProtocolError); ok { return pe.Code diff --git a/node/protocol_test.go b/node/protocol_test.go index 1d728a4..5598c55 100644 --- a/node/protocol_test.go +++ b/node/protocol_test.go @@ -1,23 +1,24 @@ package node import ( - "fmt" "testing" + + core "dappco.re/go/core" ) -func TestResponseHandler_ValidateResponse(t *testing.T) { +func TestProtocol_ResponseHandler_ValidateResponse_Good(t *testing.T) { handler := &ResponseHandler{} t.Run("NilResponse", func(t *testing.T) { - err := handler.ValidateResponse(nil, MsgStats) + err := handler.ValidateResponse(nil, MessageStats) if err == nil { t.Error("Expected error for nil response") } }) t.Run("ErrorResponse", func(t *testing.T) { - errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeOperationFailed, "operation failed", "") - err := handler.ValidateResponse(errMsg, MsgStats) + errMsg, _ := NewErrorMessage("sender", "receiver", ErrorCodeOperationFailed, "operation failed", "") + err := handler.ValidateResponse(errMsg, MessageStats) if err == nil { t.Fatal("Expected error for error response") } @@ -26,14 +27,14 @@ func TestResponseHandler_ValidateResponse(t *testing.T) { t.Errorf("Expected ProtocolError, got %T", err) } - if GetProtocolErrorCode(err) != ErrCodeOperationFailed { - t.Errorf("Expected code %d, got %d", ErrCodeOperationFailed, GetProtocolErrorCode(err)) + if GetProtocolErrorCode(err) != ErrorCodeOperationFailed { + t.Errorf("Expected code %d, got %d", ErrorCodeOperationFailed, GetProtocolErrorCode(err)) } }) t.Run("WrongType", func(t *testing.T) { - msg, _ := NewMessage(MsgPong, "sender", "receiver", nil) - err := handler.ValidateResponse(msg, MsgStats) + msg, _ := NewMessage(MessagePong, "sender", "receiver", nil) + err := handler.ValidateResponse(msg, MessageStats) if err == nil { t.Error("Expected error for wrong type") } @@ -43,15 +44,15 @@ func TestResponseHandler_ValidateResponse(t *testing.T) { }) t.Run("ValidResponse", func(t *testing.T) { - msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"}) - err := handler.ValidateResponse(msg, MsgStats) + msg, _ := NewMessage(MessageStats, "sender", "receiver", StatsPayload{NodeID: "test"}) + err := handler.ValidateResponse(msg, MessageStats) if err != nil { t.Errorf("Unexpected error: %v", err) } }) } -func TestResponseHandler_ParseResponse(t *testing.T) { +func TestProtocol_ResponseHandler_ParseResponse_Good(t *testing.T) { handler := &ResponseHandler{} t.Run("ParseStats", func(t *testing.T) { @@ -60,10 +61,10 @@ func TestResponseHandler_ParseResponse(t *testing.T) { NodeName: "Test Node", Uptime: 3600, } - msg, _ := NewMessage(MsgStats, "sender", "receiver", payload) + msg, _ := NewMessage(MessageStats, "sender", "receiver", payload) var parsed StatsPayload - err := handler.ParseResponse(msg, MsgStats, &parsed) + err := handler.ParseResponse(msg, MessageStats, &parsed) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -81,10 +82,10 @@ func TestResponseHandler_ParseResponse(t *testing.T) { Success: true, MinerName: "xmrig-1", } - msg, _ := NewMessage(MsgMinerAck, "sender", "receiver", payload) + msg, _ := NewMessage(MessageMinerAck, "sender", "receiver", payload) var parsed MinerAckPayload - err := handler.ParseResponse(msg, MsgMinerAck, &parsed) + err := handler.ParseResponse(msg, MessageMinerAck, &parsed) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -98,10 +99,10 @@ func TestResponseHandler_ParseResponse(t *testing.T) { }) t.Run("ErrorResponse", func(t *testing.T) { - errMsg, _ := NewErrorMessage("sender", "receiver", ErrCodeNotFound, "not found", "") + errMsg, _ := NewErrorMessage("sender", "receiver", ErrorCodeNotFound, "not found", "") var parsed StatsPayload - err := handler.ParseResponse(errMsg, MsgStats, &parsed) + err := handler.ParseResponse(errMsg, MessageStats, &parsed) if err == nil { t.Error("Expected error for error response") } @@ -111,15 +112,15 @@ func TestResponseHandler_ParseResponse(t *testing.T) { }) t.Run("NilTarget", func(t *testing.T) { - msg, _ := NewMessage(MsgPong, "sender", "receiver", nil) - err := handler.ParseResponse(msg, MsgPong, nil) + msg, _ := NewMessage(MessagePong, "sender", "receiver", nil) + err := handler.ParseResponse(msg, MessagePong, nil) if err != nil { t.Errorf("Unexpected error with nil target: %v", err) } }) } -func TestProtocolError(t *testing.T) { +func TestProtocol_Error_Bad(t *testing.T) { err := &ProtocolError{Code: 1001, Message: "test error"} if err.Error() != "remote error (1001): test error" { @@ -135,17 +136,17 @@ func TestProtocolError(t *testing.T) { } } -func TestConvenienceFunctions(t *testing.T) { - msg, _ := NewMessage(MsgStats, "sender", "receiver", StatsPayload{NodeID: "test"}) +func TestProtocol_ConvenienceFunctions_Good(t *testing.T) { + msg, _ := NewMessage(MessageStats, "sender", "receiver", StatsPayload{NodeID: "test"}) // Test ValidateResponse - if err := ValidateResponse(msg, MsgStats); err != nil { + if err := ValidateResponse(msg, MessageStats); err != nil { t.Errorf("ValidateResponse failed: %v", err) } // Test ParseResponse var parsed StatsPayload - if err := ParseResponse(msg, MsgStats, &parsed); err != nil { + if err := ParseResponse(msg, MessageStats, &parsed); err != nil { t.Errorf("ParseResponse failed: %v", err) } if parsed.NodeID != "test" { @@ -153,8 +154,8 @@ func TestConvenienceFunctions(t *testing.T) { } } -func TestGetProtocolErrorCode_NonProtocolError(t *testing.T) { - err := fmt.Errorf("regular error") +func TestProtocol_ProtocolErrorCode_NonProtocolError_Bad(t *testing.T) { + err := core.NewError("regular error") if GetProtocolErrorCode(err) != 0 { t.Error("Expected 0 for non-ProtocolError") } diff --git a/node/transport.go b/node/transport.go index e30c0e5..08eee04 100644 --- a/node/transport.go +++ b/node/transport.go @@ -4,183 +4,247 @@ import ( "context" "crypto/tls" "encoding/base64" - "encoding/json" - "fmt" "iter" "maps" "net/http" "net/url" "slices" + "strings" "sync" "sync/atomic" "time" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" "dappco.re/go/core/p2p/logging" "forge.lthn.ai/Snider/Borg/pkg/smsg" "github.com/gorilla/websocket" ) -// debugLogCounter tracks message counts for rate limiting debug logs -var debugLogCounter atomic.Int64 +var messageLogSampleCounter atomic.Int64 -// debugLogInterval controls how often we log debug messages in hot paths (1 in N) -const debugLogInterval = 100 +// messageLogSampleInterval controls how often we log debug messages in hot paths (1 in N). +const messageLogSampleInterval = 100 -// DefaultMaxMessageSize is the default maximum message size (1MB) +// limit := DefaultMaxMessageSize const DefaultMaxMessageSize int64 = 1 << 20 // 1MB -// TransportConfig configures the WebSocket transport. +// prefix := agentUserAgentPrefix +const agentUserAgentPrefix = "agent-go-p2p" + +const ( + defaultTransportListenAddress = ":9091" + defaultTransportWebSocketPath = "/ws" + defaultTransportMaximumConnections = 100 +) + +// transportConfig := DefaultTransportConfig() type TransportConfig struct { - ListenAddr string // ":9091" default - WSPath string // "/ws" - WebSocket endpoint path - TLSCertPath string // Optional TLS for wss:// - TLSKeyPath string - MaxConns int // Maximum concurrent connections - MaxMessageSize int64 // Maximum message size in bytes (0 = 1MB default) - PingInterval time.Duration // WebSocket keepalive interval - PongTimeout time.Duration // Timeout waiting for pong + ListenAddress string // config.ListenAddress = ":9091" + ListenAddr string + WebSocketPath string // config.WebSocketPath = "/ws" + TLSCertPath string // config.TLSCertPath = "/srv/p2p/tls.crt" + TLSKeyPath string // config.TLSKeyPath = "/srv/p2p/tls.key" + MaxConnections int // config.MaxConnections = 100 + MaxMessageSize int64 // config.MaxMessageSize = 1 << 20 + PingInterval time.Duration // config.PingInterval = 30 * time.Second + PongTimeout time.Duration // config.PongTimeout = 10 * time.Second } -// DefaultTransportConfig returns sensible defaults. +// transportConfig := DefaultTransportConfig() func DefaultTransportConfig() TransportConfig { return TransportConfig{ - ListenAddr: ":9091", - WSPath: "/ws", - MaxConns: 100, + ListenAddress: defaultTransportListenAddress, + ListenAddr: defaultTransportListenAddress, + WebSocketPath: defaultTransportWebSocketPath, + MaxConnections: defaultTransportMaximumConnections, MaxMessageSize: DefaultMaxMessageSize, PingInterval: 30 * time.Second, PongTimeout: 10 * time.Second, } } -// MessageHandler processes incoming messages. -type MessageHandler func(conn *PeerConnection, msg *Message) - -// MessageDeduplicator tracks seen message IDs to prevent duplicate processing -type MessageDeduplicator struct { - seen map[string]time.Time - mu sync.RWMutex - ttl time.Duration +func (c TransportConfig) listenAddress() string { + if c.ListenAddress != "" && c.ListenAddress != defaultTransportListenAddress { + return c.ListenAddress + } + if c.ListenAddr != "" && c.ListenAddr != defaultTransportListenAddress { + return c.ListenAddr + } + if c.ListenAddress != "" { + return c.ListenAddress + } + if c.ListenAddr != "" { + return c.ListenAddr + } + return defaultTransportListenAddress } -// NewMessageDeduplicator creates a deduplicator with specified TTL -func NewMessageDeduplicator(ttl time.Duration) *MessageDeduplicator { +func (c TransportConfig) webSocketPath() string { + if c.WebSocketPath != "" { + return c.WebSocketPath + } + return defaultTransportWebSocketPath +} + +func (c TransportConfig) maximumConnections() int { + if c.MaxConnections > 0 { + return c.MaxConnections + } + return defaultTransportMaximumConnections +} + +// var handler MessageHandler = func(peerConnection *PeerConnection, message *Message) {} +type MessageHandler func(peerConnection *PeerConnection, message *Message) + +// deduplicator := NewMessageDeduplicator(5 * time.Minute) +type MessageDeduplicator struct { + recentMessageTimes map[string]time.Time + mutex sync.RWMutex + timeToLive time.Duration +} + +// deduplicator := NewMessageDeduplicator(5 * time.Minute) +func NewMessageDeduplicator(retentionWindow time.Duration) *MessageDeduplicator { d := &MessageDeduplicator{ - seen: make(map[string]time.Time), - ttl: ttl, + recentMessageTimes: make(map[string]time.Time), + timeToLive: retentionWindow, } return d } -// IsDuplicate checks if a message ID has been seen recently +// duplicate := deduplicator.IsDuplicate(message.ID) func (d *MessageDeduplicator) IsDuplicate(msgID string) bool { - d.mu.RLock() - _, exists := d.seen[msgID] - d.mu.RUnlock() - return exists + d.mutex.RLock() + seenAt, exists := d.recentMessageTimes[msgID] + retentionWindow := d.timeToLive + d.mutex.RUnlock() + + if !exists { + return false + } + + if retentionWindow > 0 && time.Since(seenAt) <= retentionWindow { + return true + } + + d.mutex.Lock() + defer d.mutex.Unlock() + + seenAt, exists = d.recentMessageTimes[msgID] + if !exists { + return false + } + + if retentionWindow <= 0 || time.Since(seenAt) > retentionWindow { + delete(d.recentMessageTimes, msgID) + return false + } + + return true } -// Mark records a message ID as seen +// deduplicator.Mark(message.ID) func (d *MessageDeduplicator) Mark(msgID string) { - d.mu.Lock() - d.seen[msgID] = time.Now() - d.mu.Unlock() + d.mutex.Lock() + d.recentMessageTimes[msgID] = time.Now() + d.mutex.Unlock() } -// Cleanup removes expired entries +// deduplicator.Cleanup() func (d *MessageDeduplicator) Cleanup() { - d.mu.Lock() - defer d.mu.Unlock() + d.mutex.Lock() + defer d.mutex.Unlock() now := time.Now() - for id, seen := range d.seen { - if now.Sub(seen) > d.ttl { - delete(d.seen, id) + for id, seenAt := range d.recentMessageTimes { + if now.Sub(seenAt) > d.timeToLive { + delete(d.recentMessageTimes, id) } } } -// Transport manages WebSocket connections with SMSG encryption. +// transport := NewTransport(nodeManager, peerRegistry, DefaultTransportConfig()) type Transport struct { - config TransportConfig - server *http.Server - upgrader websocket.Upgrader - conns map[string]*PeerConnection // peer ID -> connection - pendingConns atomic.Int32 // tracks connections during handshake - node *NodeManager - registry *PeerRegistry - handler MessageHandler - dedup *MessageDeduplicator // Message deduplication - mu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + config TransportConfig + httpServer *http.Server + upgrader websocket.Upgrader + connections map[string]*PeerConnection // peer ID -> connection + pendingHandshakeCount atomic.Int32 // tracks connections during handshake + nodeManager *NodeManager + peerRegistry *PeerRegistry + messageHandler MessageHandler + messageDeduplicator *MessageDeduplicator // Message deduplication + mutex sync.RWMutex + lifecycleContext context.Context + cancelLifecycle context.CancelFunc + waitGroup sync.WaitGroup } -// PeerRateLimiter implements a simple token bucket rate limiter per peer +// rateLimiter := NewPeerRateLimiter(100, 50) type PeerRateLimiter struct { - tokens int - maxTokens int - refillRate int // tokens per second - lastRefill time.Time - mu sync.Mutex + availableTokens int + capacity int + refillPerSecond int // tokens per second + lastRefillTime time.Time + mutex sync.Mutex } -// NewPeerRateLimiter creates a rate limiter with specified messages/second -func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter { +// rateLimiter := NewPeerRateLimiter(100, 50) +func NewPeerRateLimiter(maxTokens, refillPerSecond int) *PeerRateLimiter { return &PeerRateLimiter{ - tokens: maxTokens, - maxTokens: maxTokens, - refillRate: refillRate, - lastRefill: time.Now(), + availableTokens: maxTokens, + capacity: maxTokens, + refillPerSecond: refillPerSecond, + lastRefillTime: time.Now(), } } -// Allow checks if a message is allowed and consumes a token if so +// allowed := rateLimiter.Allow() func (r *PeerRateLimiter) Allow() bool { - r.mu.Lock() - defer r.mu.Unlock() + r.mutex.Lock() + defer r.mutex.Unlock() // Refill tokens based on elapsed time now := time.Now() - elapsed := now.Sub(r.lastRefill) - tokensToAdd := int(elapsed.Seconds()) * r.refillRate + elapsed := now.Sub(r.lastRefillTime) + tokensToAdd := int(elapsed.Seconds()) * r.refillPerSecond if tokensToAdd > 0 { - r.tokens = min(r.tokens+tokensToAdd, r.maxTokens) - r.lastRefill = now + r.availableTokens = min(r.availableTokens+tokensToAdd, r.capacity) + r.lastRefillTime = now } // Check if we have tokens available - if r.tokens > 0 { - r.tokens-- + if r.availableTokens > 0 { + r.availableTokens-- return true } return false } -// PeerConnection represents an active connection to a peer. +// peerConnection := &PeerConnection{Peer: &Peer{ID: "worker-1"}} type PeerConnection struct { - Peer *Peer - Conn *websocket.Conn - SharedSecret []byte // Derived via X25519 ECDH, used for SMSG - LastActivity time.Time - writeMu sync.Mutex // Serialize WebSocket writes - transport *Transport - closeOnce sync.Once // Ensure Close() is only called once - rateLimiter *PeerRateLimiter // Per-peer message rate limiting + Peer *Peer + WebSocketConnection *websocket.Conn + Conn *websocket.Conn + SharedSecret []byte // Derived via X25519 ECDH, used for SMSG + LastActivity time.Time + UserAgent string // Request identity advertised by the peer + writeMutex sync.Mutex // Serialize WebSocket writes + transport *Transport + closeOnce sync.Once // Ensure Close() is only called once + rateLimiter *PeerRateLimiter // Per-peer message rate limiting } -// NewTransport creates a new WebSocket transport. -func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport { - ctx, cancel := context.WithCancel(context.Background()) +// transport := NewTransport(nodeManager, peerRegistry, DefaultTransportConfig()) +func NewTransport(nodeManager *NodeManager, peerRegistry *PeerRegistry, config TransportConfig) *Transport { + lifecycleContext, cancelLifecycle := context.WithCancel(context.Background()) return &Transport{ - config: config, - node: node, - registry: registry, - conns: make(map[string]*PeerConnection), - dedup: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup + config: config, + nodeManager: nodeManager, + peerRegistry: peerRegistry, + connections: make(map[string]*PeerConnection), + messageDeduplicator: NewMessageDeduplicator(5 * time.Minute), // 5 minute TTL for dedup upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -191,26 +255,84 @@ func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportCon return true // No origin header (non-browser client) } // Allow localhost and 127.0.0.1 origins - u, err := url.Parse(origin) + originURL, err := url.Parse(origin) if err != nil { return false } - host := u.Hostname() + host := originURL.Hostname() return host == "localhost" || host == "127.0.0.1" || host == "::1" }, }, - ctx: ctx, - cancel: cancel, + lifecycleContext: lifecycleContext, + cancelLifecycle: cancelLifecycle, } } -// Start begins listening for incoming connections. +func (pc *PeerConnection) webSocketConnection() *websocket.Conn { + if pc.WebSocketConnection != nil { + return pc.WebSocketConnection + } + return pc.Conn +} + +func agentHeaderToken(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "unknown" + } + + var sb strings.Builder + sb.Grow(len(value)) + for _, r := range value { + switch { + case r >= 'a' && r <= 'z': + sb.WriteRune(r) + case r >= 'A' && r <= 'Z': + sb.WriteRune(r) + case r >= '0' && r <= '9': + sb.WriteRune(r) + case r == '-' || r == '_' || r == '.': + sb.WriteRune(r) + case r == ' ': + sb.WriteByte('_') + default: + sb.WriteByte('_') + } + } + + token := sb.String() + if token == "" { + return "unknown" + } + + return token +} + +func (t *Transport) agentUserAgent() string { + identity := t.nodeManager.GetIdentity() + if identity == nil { + return core.Sprintf("%s proto=%s", agentUserAgentPrefix, ProtocolVersion) + } + + return core.Sprintf( + "%s id=%s name=%s role=%s proto=%s", + agentUserAgentPrefix, + identity.ID, + agentHeaderToken(identity.Name), + identity.Role, + ProtocolVersion, + ) +} + +// err := transport.Start() func (t *Transport) Start() error { mux := http.NewServeMux() - mux.HandleFunc(t.config.WSPath, t.handleWSUpgrade) + mux.HandleFunc(t.config.webSocketPath(), t.handleWebSocketUpgrade) - t.server = &http.Server{ - Addr: t.config.ListenAddr, + listenAddress := t.config.listenAddress() + + t.httpServer = &http.Server{ + Addr: listenAddress, Handler: mux, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, @@ -220,7 +342,7 @@ func (t *Transport) Start() error { // Apply TLS hardening if TLS is enabled if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" { - t.server.TLSConfig = &tls.Config{ + t.httpServer.TLSConfig = &tls.Config{ MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ // TLS 1.3 ciphers (automatically used when available) @@ -242,28 +364,28 @@ func (t *Transport) Start() error { } } - t.wg.Go(func() { + t.waitGroup.Go(func() { var err error if t.config.TLSCertPath != "" && t.config.TLSKeyPath != "" { - err = t.server.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath) + err = t.httpServer.ListenAndServeTLS(t.config.TLSCertPath, t.config.TLSKeyPath) } else { - err = t.server.ListenAndServe() + err = t.httpServer.ListenAndServe() } if err != nil && err != http.ErrServerClosed { - logging.Error("HTTP server error", logging.Fields{"error": err, "addr": t.config.ListenAddr}) + logging.Error("HTTP server error", logging.Fields{"error": err, "addr": listenAddress}) } }) // Start message deduplication cleanup goroutine - t.wg.Go(func() { + t.waitGroup.Go(func() { ticker := time.NewTicker(time.Minute) defer ticker.Stop() for { select { - case <-t.ctx.Done(): + case <-t.lifecycleContext.Done(): return case <-ticker.C: - t.dedup.Cleanup() + t.messageDeduplicator.Cleanup() } } }) @@ -271,166 +393,175 @@ func (t *Transport) Start() error { return nil } -// Stop gracefully shuts down the transport. +// err := transport.Stop() func (t *Transport) Stop() error { - t.cancel() + t.cancelLifecycle() // Gracefully close all connections with shutdown message - t.mu.RLock() - conns := slices.Collect(maps.Values(t.conns)) - t.mu.RUnlock() + t.mutex.RLock() + connections := slices.Collect(maps.Values(t.connections)) + t.mutex.RUnlock() - for _, pc := range conns { + for _, pc := range connections { pc.GracefulClose("server shutdown", DisconnectShutdown) } // Shutdown HTTP server if it was started - if t.server != nil { + if t.httpServer != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := t.server.Shutdown(ctx); err != nil { - return coreerr.E("Transport.Stop", "server shutdown error", err) + if err := t.httpServer.Shutdown(ctx); err != nil { + return core.E("Transport.Stop", "server shutdown error", err) } } - t.wg.Wait() + t.waitGroup.Wait() return nil } -// OnMessage sets the handler for incoming messages. -// Must be called before Start() to avoid races. +// transport.OnMessage(worker.HandleMessage) func (t *Transport) OnMessage(handler MessageHandler) { - t.mu.Lock() - defer t.mu.Unlock() - t.handler = handler + t.mutex.Lock() + defer t.mutex.Unlock() + t.messageHandler = handler } -// Connect establishes a connection to a peer. +// peerConnection, err := transport.Connect(&Peer{ID: "worker-1", Address: "127.0.0.1:9091"}) func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) { // Build WebSocket URL scheme := "ws" if t.config.TLSCertPath != "" { scheme = "wss" } - u := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.WSPath} + peerURL := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.webSocketPath()} + userAgent := t.agentUserAgent() // Dial the peer with timeout to prevent hanging on unresponsive peers dialer := websocket.Dialer{ HandshakeTimeout: 10 * time.Second, } - conn, _, err := dialer.Dial(u.String(), nil) + conn, _, err := dialer.Dial(peerURL.String(), http.Header{ + "User-Agent": []string{userAgent}, + }) if err != nil { - return nil, coreerr.E("Transport.Connect", "failed to connect to peer", err) + return nil, core.E("Transport.Connect", "failed to connect to peer", err) } - pc := &PeerConnection{ - Peer: peer, - Conn: conn, - LastActivity: time.Now(), - transport: t, - rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill + peerConnection := &PeerConnection{ + Peer: peer, + WebSocketConnection: conn, + Conn: conn, + LastActivity: time.Now(), + UserAgent: userAgent, + transport: t, + rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill } - // Perform handshake with challenge-response authentication - // This also derives and stores the shared secret in pc.SharedSecret - if err := t.performHandshake(pc); err != nil { + // Perform handshake with challenge-response authentication. + // This also derives and stores the shared secret in peerConnection.SharedSecret. + if err := t.performHandshake(peerConnection); err != nil { conn.Close() - return nil, coreerr.E("Transport.Connect", "handshake failed", err) + return nil, core.E("Transport.Connect", "handshake failed", err) } // Store connection using the real peer ID from handshake - t.mu.Lock() - t.conns[pc.Peer.ID] = pc - t.mu.Unlock() + t.mutex.Lock() + t.connections[peerConnection.Peer.ID] = peerConnection + t.mutex.Unlock() - logging.Debug("connected to peer", logging.Fields{"peer_id": pc.Peer.ID, "secret_len": len(pc.SharedSecret)}) + logging.Debug("connected to peer", logging.Fields{"peer_id": peerConnection.Peer.ID, "secret_len": len(peerConnection.SharedSecret)}) + logging.Debug("connected peer metadata", logging.Fields{ + "peer_id": peerConnection.Peer.ID, + "user_agent": peerConnection.UserAgent, + }) // Update registry - t.registry.SetConnected(pc.Peer.ID, true) + t.peerRegistry.SetConnected(peerConnection.Peer.ID, true) // Start read loop - t.wg.Add(1) - go t.readLoop(pc) + t.waitGroup.Add(1) + go t.readLoop(peerConnection) - logging.Debug("started readLoop for peer", logging.Fields{"peer_id": pc.Peer.ID}) + logging.Debug("started readLoop for peer", logging.Fields{"peer_id": peerConnection.Peer.ID}) // Start keepalive - t.wg.Add(1) - go t.keepalive(pc) + t.waitGroup.Add(1) + go t.keepalive(peerConnection) - return pc, nil + return peerConnection, nil } -// Send sends a message to a specific peer. -func (t *Transport) Send(peerID string, msg *Message) error { - t.mu.RLock() - pc, exists := t.conns[peerID] - t.mu.RUnlock() +// err := transport.Send("worker-1", message) +func (t *Transport) Send(peerID string, message *Message) error { + t.mutex.RLock() + peerConnection, exists := t.connections[peerID] + t.mutex.RUnlock() if !exists { - return coreerr.E("Transport.Send", "peer "+peerID+" not connected", nil) + return core.E("Transport.Send", "peer "+peerID+" not connected", nil) } - return pc.Send(msg) + return peerConnection.Send(message) } -// Connections returns an iterator over all active peer connections. +// for peerConnection := range transport.Connections() { +// _ = peerConnection +// } func (t *Transport) Connections() iter.Seq[*PeerConnection] { return func(yield func(*PeerConnection) bool) { - t.mu.RLock() - defer t.mu.RUnlock() + t.mutex.RLock() + defer t.mutex.RUnlock() - for _, pc := range t.conns { - if !yield(pc) { + for _, peerConnection := range t.connections { + if !yield(peerConnection) { return } } } } -// Broadcast sends a message to all connected peers except the sender. -// The sender is identified by msg.From and excluded to prevent echo. -func (t *Transport) Broadcast(msg *Message) error { +// err := transport.Broadcast(announcement) +func (t *Transport) Broadcast(message *Message) error { conns := slices.Collect(t.Connections()) var lastErr error - for _, pc := range conns { - // Exclude sender from broadcast to prevent echo (P2P-MED-6) - if pc.Peer != nil && pc.Peer.ID == msg.From { + for _, peerConnection := range conns { + if peerConnection.Peer != nil && peerConnection.Peer.ID == message.From { continue } - if err := pc.Send(msg); err != nil { + if err := peerConnection.Send(message); err != nil { lastErr = err } } return lastErr } -// GetConnection returns an active connection to a peer. +// connection := transport.GetConnection("worker-1") func (t *Transport) GetConnection(peerID string) *PeerConnection { - t.mu.RLock() - defer t.mu.RUnlock() - return t.conns[peerID] + t.mutex.RLock() + defer t.mutex.RUnlock() + return t.connections[peerID] } -// handleWSUpgrade handles incoming WebSocket connections. -func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { - // Enforce MaxConns limit (including pending connections during handshake) - t.mu.RLock() - currentConns := len(t.conns) - t.mu.RUnlock() - pendingConns := int(t.pendingConns.Load()) +func (t *Transport) handleWebSocketUpgrade(w http.ResponseWriter, r *http.Request) { + userAgent := r.Header.Get("User-Agent") - totalConns := currentConns + pendingConns - if totalConns >= t.config.MaxConns { + // Enforce the maximum connection limit, including pending handshakes. + t.mutex.RLock() + currentConnections := len(t.connections) + t.mutex.RUnlock() + pendingHandshakeCount := int(t.pendingHandshakeCount.Load()) + + totalConnections := currentConnections + pendingHandshakeCount + if totalConnections >= t.config.maximumConnections() { http.Error(w, "Too many connections", http.StatusServiceUnavailable) return } // Track this connection as pending during handshake - t.pendingConns.Add(1) - defer t.pendingConns.Add(-1) + t.pendingHandshakeCount.Add(1) + defer t.pendingHandshakeCount.Add(-1) conn, err := t.upgrader.Upgrade(w, r, nil) if err != nil { @@ -457,12 +588,12 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { // Decode handshake message (not encrypted yet, contains public key) var msg Message - if err := json.Unmarshal(data, &msg); err != nil { + if result := core.JSONUnmarshal(data, &msg); !result.OK { conn.Close() return } - if msg.Type != MsgHandshake { + if msg.Type != MessageHandshake { conn.Close() return } @@ -479,15 +610,16 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { "peer_version": payload.Version, "supported_versions": SupportedProtocolVersions, "peer_id": payload.Identity.ID, + "user_agent": userAgent, }) - identity := t.node.GetIdentity() + identity := t.nodeManager.GetIdentity() if identity != nil { rejectPayload := HandshakeAckPayload{ Identity: *identity, Accepted: false, - Reason: fmt.Sprintf("incompatible protocol version %s, supported: %v", payload.Version, SupportedProtocolVersions), + Reason: core.Sprintf("incompatible protocol version %s, supported: %v", payload.Version, SupportedProtocolVersions), } - rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload) + rejectMsg, _ := NewMessage(MessageHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload) if rejectData, err := MarshalJSON(rejectMsg); err == nil { conn.WriteMessage(websocket.TextMessage, rejectData) } @@ -497,28 +629,29 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { } // Derive shared secret from peer's public key - sharedSecret, err := t.node.DeriveSharedSecret(payload.Identity.PublicKey) + sharedSecret, err := t.nodeManager.DeriveSharedSecret(payload.Identity.PublicKey) if err != nil { conn.Close() return } // Check if peer is allowed to connect (allowlist check) - if !t.registry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) { + if !t.peerRegistry.IsPeerAllowed(payload.Identity.ID, payload.Identity.PublicKey) { logging.Warn("peer connection rejected: not in allowlist", logging.Fields{ "peer_id": payload.Identity.ID, "peer_name": payload.Identity.Name, "public_key": safeKeyPrefix(payload.Identity.PublicKey), + "user_agent": userAgent, }) // Send rejection before closing - identity := t.node.GetIdentity() + identity := t.nodeManager.GetIdentity() if identity != nil { rejectPayload := HandshakeAckPayload{ Identity: *identity, Accepted: false, Reason: "peer not authorized", } - rejectMsg, _ := NewMessage(MsgHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload) + rejectMsg, _ := NewMessage(MessageHandshakeAck, identity.ID, payload.Identity.ID, rejectPayload) if rejectData, err := MarshalJSON(rejectMsg); err == nil { conn.WriteMessage(websocket.TextMessage, rejectData) } @@ -528,7 +661,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { } // Create peer if not exists (only if auth passed) - peer := t.registry.GetPeer(payload.Identity.ID) + peer := t.peerRegistry.GetPeer(payload.Identity.ID) if peer == nil { // Auto-register the peer since they passed allowlist check peer = &Peer{ @@ -539,7 +672,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { AddedAt: time.Now(), Score: 50, } - t.registry.AddPeer(peer) + t.peerRegistry.AddPeer(peer) logging.Info("auto-registered new peer", logging.Fields{ "peer_id": peer.ID, "peer_name": peer.Name, @@ -547,16 +680,17 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { } pc := &PeerConnection{ - Peer: peer, - Conn: conn, - SharedSecret: sharedSecret, - LastActivity: time.Now(), - transport: t, - rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill + Peer: peer, + WebSocketConnection: conn, + Conn: conn, + SharedSecret: sharedSecret, + LastActivity: time.Now(), + UserAgent: userAgent, + transport: t, + rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill } - // Send handshake acknowledgment - identity := t.node.GetIdentity() + identity := t.nodeManager.GetIdentity() if identity == nil { conn.Close() return @@ -574,7 +708,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { Accepted: true, } - ackMsg, err := NewMessage(MsgHandshakeAck, identity.ID, peer.ID, ackPayload) + ackMsg, err := NewMessage(MessageHandshakeAck, identity.ID, peer.ID, ackPayload) if err != nil { conn.Close() return @@ -587,49 +721,61 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { return } + // Make the accepted connection visible before the client reads the ack. + // Connect() returns only after that read completes, so this keeps the + // server registry aligned with the caller's view of the handshake. + t.mutex.Lock() + t.connections[peer.ID] = pc + t.mutex.Unlock() + if err := conn.WriteMessage(websocket.TextMessage, ackData); err != nil { - conn.Close() + t.removeConnection(pc) return } - // Store connection - t.mu.Lock() - t.conns[peer.ID] = pc - t.mu.Unlock() - // Update registry - t.registry.SetConnected(peer.ID, true) + t.peerRegistry.SetConnected(peer.ID, true) + + logging.Debug("accepted peer connection", logging.Fields{ + "peer_id": peer.ID, + "peer_name": peer.Name, + "user_agent": userAgent, + }) // Start read loop - t.wg.Add(1) + t.waitGroup.Add(1) go t.readLoop(pc) // Start keepalive - t.wg.Add(1) + t.waitGroup.Add(1) go t.keepalive(pc) } -// performHandshake initiates handshake with a peer. func (t *Transport) performHandshake(pc *PeerConnection) error { // Set handshake timeout handshakeTimeout := 10 * time.Second - pc.Conn.SetWriteDeadline(time.Now().Add(handshakeTimeout)) - pc.Conn.SetReadDeadline(time.Now().Add(handshakeTimeout)) + connection := pc.webSocketConnection() + if connection == nil { + return core.E("Transport.performHandshake", "websocket connection is nil", nil) + } + + connection.SetWriteDeadline(time.Now().Add(handshakeTimeout)) + connection.SetReadDeadline(time.Now().Add(handshakeTimeout)) defer func() { // Reset deadlines after handshake - pc.Conn.SetWriteDeadline(time.Time{}) - pc.Conn.SetReadDeadline(time.Time{}) + connection.SetWriteDeadline(time.Time{}) + connection.SetReadDeadline(time.Time{}) }() - identity := t.node.GetIdentity() + identity := t.nodeManager.GetIdentity() if identity == nil { - return ErrIdentityNotInitialized + return ErrorIdentityNotInitialized } // Generate challenge for the server to prove it has the matching private key challenge, err := GenerateChallenge() if err != nil { - return coreerr.E("Transport.performHandshake", "generate challenge", err) + return core.E("Transport.performHandshake", "generate challenge", err) } payload := HandshakePayload{ @@ -638,43 +784,43 @@ func (t *Transport) performHandshake(pc *PeerConnection) error { Version: ProtocolVersion, } - msg, err := NewMessage(MsgHandshake, identity.ID, pc.Peer.ID, payload) + msg, err := NewMessage(MessageHandshake, identity.ID, pc.Peer.ID, payload) if err != nil { - return coreerr.E("Transport.performHandshake", "create handshake message", err) + return core.E("Transport.performHandshake", "create handshake message", err) } // First message is unencrypted (peer needs our public key) data, err := MarshalJSON(msg) if err != nil { - return coreerr.E("Transport.performHandshake", "marshal handshake message", err) + return core.E("Transport.performHandshake", "marshal handshake message", err) } - if err := pc.Conn.WriteMessage(websocket.TextMessage, data); err != nil { - return coreerr.E("Transport.performHandshake", "send handshake", err) + if err := connection.WriteMessage(websocket.TextMessage, data); err != nil { + return core.E("Transport.performHandshake", "send handshake", err) } // Wait for ack - _, ackData, err := pc.Conn.ReadMessage() + _, ackData, err := connection.ReadMessage() if err != nil { - return coreerr.E("Transport.performHandshake", "read handshake ack", err) + return core.E("Transport.performHandshake", "read handshake ack", err) } var ackMsg Message - if err := json.Unmarshal(ackData, &ackMsg); err != nil { - return coreerr.E("Transport.performHandshake", "unmarshal handshake ack", err) + if result := core.JSONUnmarshal(ackData, &ackMsg); !result.OK { + return core.E("Transport.performHandshake", "unmarshal handshake ack", result.Value.(error)) } - if ackMsg.Type != MsgHandshakeAck { - return coreerr.E("Transport.performHandshake", "expected handshake_ack, got "+string(ackMsg.Type), nil) + if ackMsg.Type != MessageHandshakeAck { + return core.E("Transport.performHandshake", "expected handshake_ack, got "+string(ackMsg.Type), nil) } var ackPayload HandshakeAckPayload if err := ackMsg.ParsePayload(&ackPayload); err != nil { - return coreerr.E("Transport.performHandshake", "parse handshake ack payload", err) + return core.E("Transport.performHandshake", "parse handshake ack payload", err) } if !ackPayload.Accepted { - return coreerr.E("Transport.performHandshake", "handshake rejected: "+ackPayload.Reason, nil) + return core.E("Transport.performHandshake", "handshake rejected: "+ackPayload.Reason, nil) } // Update peer with the received identity info @@ -684,39 +830,39 @@ func (t *Transport) performHandshake(pc *PeerConnection) error { pc.Peer.Role = ackPayload.Identity.Role // Verify challenge response - derive shared secret first using the peer's public key - sharedSecret, err := t.node.DeriveSharedSecret(pc.Peer.PublicKey) + sharedSecret, err := t.nodeManager.DeriveSharedSecret(pc.Peer.PublicKey) if err != nil { - return coreerr.E("Transport.performHandshake", "derive shared secret for challenge verification", err) + return core.E("Transport.performHandshake", "derive shared secret for challenge verification", err) } // Verify the server's response to our challenge if len(ackPayload.ChallengeResponse) == 0 { - return coreerr.E("Transport.performHandshake", "server did not provide challenge response", nil) + return core.E("Transport.performHandshake", "server did not provide challenge response", nil) } if !VerifyChallenge(challenge, ackPayload.ChallengeResponse, sharedSecret) { - return coreerr.E("Transport.performHandshake", "challenge response verification failed: server may not have matching private key", nil) + return core.E("Transport.performHandshake", "challenge response verification failed: server may not have matching private key", nil) } // Store the shared secret for later use pc.SharedSecret = sharedSecret // Update the peer in registry with the real identity - if err := t.registry.UpdatePeer(pc.Peer); err != nil { + if err := t.peerRegistry.UpdatePeer(pc.Peer); err != nil { // If update fails (peer not found with old ID), add as new - t.registry.AddPeer(pc.Peer) + t.peerRegistry.AddPeer(pc.Peer) } logging.Debug("handshake completed with challenge-response verification", logging.Fields{ - "peer_id": pc.Peer.ID, - "peer_name": pc.Peer.Name, + "peer_id": pc.Peer.ID, + "peer_name": pc.Peer.Name, + "user_agent": pc.UserAgent, }) return nil } -// readLoop reads messages from a peer connection. func (t *Transport) readLoop(pc *PeerConnection) { - defer t.wg.Done() + defer t.waitGroup.Done() defer t.removeConnection(pc) // Apply message size limit to prevent memory exhaustion attacks @@ -724,23 +870,28 @@ func (t *Transport) readLoop(pc *PeerConnection) { if maxSize <= 0 { maxSize = DefaultMaxMessageSize } - pc.Conn.SetReadLimit(maxSize) + connection := pc.webSocketConnection() + if connection == nil { + return + } + + connection.SetReadLimit(maxSize) for { select { - case <-t.ctx.Done(): + case <-t.lifecycleContext.Done(): return default: } // Set read deadline to prevent blocking forever on unresponsive connections readDeadline := t.config.PingInterval + t.config.PongTimeout - if err := pc.Conn.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { + if err := connection.SetReadDeadline(time.Now().Add(readDeadline)); err != nil { logging.Error("SetReadDeadline error", logging.Fields{"peer_id": pc.Peer.ID, "error": err}) return } - _, data, err := pc.Conn.ReadMessage() + _, data, err := connection.ReadMessage() if err != nil { logging.Debug("read error from peer", logging.Fields{"peer_id": pc.Peer.ID, "error": err}) return @@ -762,37 +913,36 @@ func (t *Transport) readLoop(pc *PeerConnection) { } // Check for duplicate messages (prevents amplification attacks) - if t.dedup.IsDuplicate(msg.ID) { + if t.messageDeduplicator.IsDuplicate(msg.ID) { logging.Debug("dropping duplicate message", logging.Fields{"msg_id": msg.ID, "peer_id": pc.Peer.ID}) continue } - t.dedup.Mark(msg.ID) + t.messageDeduplicator.Mark(msg.ID) // Rate limit debug logs in hot path to reduce noise (log 1 in N messages) - if debugLogCounter.Add(1)%debugLogInterval == 0 { + if messageLogSampleCounter.Add(1)%messageLogSampleInterval == 0 { logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo, "sample": "1/100"}) } // Dispatch to handler (read handler under lock to avoid race) - t.mu.RLock() - handler := t.handler - t.mu.RUnlock() + t.mutex.RLock() + handler := t.messageHandler + t.mutex.RUnlock() if handler != nil { handler(pc, msg) } } } -// keepalive sends periodic pings. func (t *Transport) keepalive(pc *PeerConnection) { - defer t.wg.Done() + defer t.waitGroup.Done() ticker := time.NewTicker(t.config.PingInterval) defer ticker.Stop() for { select { - case <-t.ctx.Done(): + case <-t.lifecycleContext.Done(): return case <-ticker.C: // Check if connection is still alive @@ -802,8 +952,8 @@ func (t *Transport) keepalive(pc *PeerConnection) { } // Send ping - identity := t.node.GetIdentity() - pingMsg, err := NewMessage(MsgPing, identity.ID, pc.Peer.ID, PingPayload{ + identity := t.nodeManager.GetIdentity() + pingMsg, err := NewMessage(MessagePing, identity.ID, pc.Peer.ID, PingPayload{ SentAt: time.Now().UnixMilli(), }) if err != nil { @@ -818,46 +968,54 @@ func (t *Transport) keepalive(pc *PeerConnection) { } } -// removeConnection removes and cleans up a connection. func (t *Transport) removeConnection(pc *PeerConnection) { - t.mu.Lock() - delete(t.conns, pc.Peer.ID) - t.mu.Unlock() - - t.registry.SetConnected(pc.Peer.ID, false) - pc.Close() + _ = pc.Close() } -// Send sends an encrypted message over the connection. +// err := peerConnection.Send(message) func (pc *PeerConnection) Send(msg *Message) error { - pc.writeMu.Lock() - defer pc.writeMu.Unlock() + pc.writeMutex.Lock() + defer pc.writeMutex.Unlock() - // Encrypt message using SMSG + return pc.sendLocked(msg) +} + +func (pc *PeerConnection) sendLocked(msg *Message) error { data, err := pc.transport.encryptMessage(msg, pc.SharedSecret) if err != nil { return err } - // Set write deadline to prevent blocking forever - if err := pc.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { - return coreerr.E("PeerConnection.Send", "failed to set write deadline", err) + connection := pc.webSocketConnection() + if connection == nil { + return core.E("PeerConnection.Send", "websocket connection is nil", nil) } - defer pc.Conn.SetWriteDeadline(time.Time{}) // Reset deadline after send - return pc.Conn.WriteMessage(websocket.BinaryMessage, data) + if err := connection.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { + return core.E("PeerConnection.Send", "failed to set write deadline", err) + } + defer connection.SetWriteDeadline(time.Time{}) + + return connection.WriteMessage(websocket.BinaryMessage, data) } -// Close closes the connection. +// err := peerConnection.Close() func (pc *PeerConnection) Close() error { var err error pc.closeOnce.Do(func() { - err = pc.Conn.Close() + if pc.transport != nil { + pc.transport.detachConnection(pc) + } + connection := pc.webSocketConnection() + if connection == nil { + return + } + err = connection.Close() }) return err } -// DisconnectPayload contains reason for disconnect. +// payload := DisconnectPayload{Reason: "shutdown", Code: DisconnectNormal} type DisconnectPayload struct { Reason string `json:"reason"` Code int `json:"code"` // Optional disconnect code @@ -872,35 +1030,44 @@ const ( DisconnectShutdown = 1004 // Server shutdown ) -// GracefulClose sends a disconnect message before closing the connection. +// err := peerConnection.GracefulClose("server shutdown", DisconnectShutdown) func (pc *PeerConnection) GracefulClose(reason string, code int) error { - var err error - pc.closeOnce.Do(func() { - // Try to send disconnect message (best effort). - // Note: we must NOT call SetWriteDeadline outside writeMu — Send() - // already manages write deadlines under the lock. Setting it here - // without the lock races with concurrent Send() calls (P2P-RACE-1). - if pc.transport != nil && pc.SharedSecret != nil { - identity := pc.transport.node.GetIdentity() - if identity != nil { - payload := DisconnectPayload{ - Reason: reason, - Code: code, - } - msg, msgErr := NewMessage(MsgDisconnect, identity.ID, pc.Peer.ID, payload) - if msgErr == nil { - pc.Send(msg) - } + pc.writeMutex.Lock() + connection := pc.webSocketConnection() + if connection != nil && pc.transport != nil && pc.SharedSecret != nil { + identity := pc.transport.nodeManager.GetIdentity() + if identity != nil { + payload := DisconnectPayload{ + Reason: reason, + Code: code, + } + msg, msgErr := NewMessage(MessageDisconnect, identity.ID, pc.Peer.ID, payload) + if msgErr == nil { + _ = pc.sendLocked(msg) } } + } + pc.writeMutex.Unlock() - // Close the underlying connection - err = pc.Conn.Close() - }) - return err + return pc.Close() +} + +func (t *Transport) detachConnection(pc *PeerConnection) { + if pc == nil || pc.Peer == nil { + return + } + + t.mutex.Lock() + current, exists := t.connections[pc.Peer.ID] + if exists && current == pc { + delete(t.connections, pc.Peer.ID) + t.mutex.Unlock() + t.peerRegistry.SetConnected(pc.Peer.ID, false) + return + } + t.mutex.Unlock() } -// encryptMessage encrypts a message using SMSG with the shared secret. func (t *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, error) { // Serialize message to JSON (using pooled buffer for efficiency) msgData, err := MarshalJSON(msg) @@ -921,7 +1088,6 @@ func (t *Transport) encryptMessage(msg *Message, sharedSecret []byte) ([]byte, e return encrypted, nil } -// decryptMessage decrypts a message using SMSG with the shared secret. func (t *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message, error) { // Decrypt using shared secret as password password := base64.StdEncoding.EncodeToString(sharedSecret) @@ -932,16 +1098,16 @@ func (t *Transport) decryptMessage(data []byte, sharedSecret []byte) (*Message, // Parse message from JSON var msg Message - if err := json.Unmarshal([]byte(smsgMsg.Body), &msg); err != nil { - return nil, err + if result := core.JSONUnmarshalString(smsgMsg.Body, &msg); !result.OK { + return nil, result.Value.(error) } return &msg, nil } -// ConnectedPeers returns the number of connected peers. -func (t *Transport) ConnectedPeers() int { - t.mu.RLock() - defer t.mu.RUnlock() - return len(t.conns) +// count := transport.ConnectedPeerCount() +func (t *Transport) ConnectedPeerCount() int { + t.mutex.RLock() + defer t.mutex.RUnlock() + return len(t.connections) } diff --git a/node/transport_test.go b/node/transport_test.go index ffa6e5a..378f8b8 100644 --- a/node/transport_test.go +++ b/node/transport_test.go @@ -1,30 +1,25 @@ package node import ( - "encoding/json" "net/http" "net/http/httptest" "net/url" - "path/filepath" "strings" "sync" "sync/atomic" "testing" "time" + core "dappco.re/go/core" "github.com/gorilla/websocket" ) // --- Test Helpers --- -// testNode creates a NodeManager with a generated identity in a temp directory. -func testNode(t *testing.T, name string, role NodeRole) *NodeManager { +func newTestNodeManager(t *testing.T, name string, role NodeRole) *NodeManager { t.Helper() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("create node manager %q: %v", name, err) } @@ -34,11 +29,10 @@ func testNode(t *testing.T, name string, role NodeRole) *NodeManager { return nm } -// testRegistry creates a PeerRegistry with open auth in a temp directory. -func testRegistry(t *testing.T) *PeerRegistry { +func newTestPeerRegistry(t *testing.T) *PeerRegistry { t.Helper() dir := t.TempDir() - reg, err := NewPeerRegistryWithPath(filepath.Join(dir, "peers.json")) + reg, err := NewPeerRegistryFromPath(testJoinPath(dir, "peers.json")) if err != nil { t.Fatalf("create registry: %v", err) } @@ -68,17 +62,17 @@ func setupTestTransportPair(t *testing.T) *testTransportPair { func setupTestTransportPairWithConfig(t *testing.T, serverCfg, clientCfg TransportConfig) *testTransportPair { t.Helper() - serverNM := testNode(t, "server", RoleWorker) - clientNM := testNode(t, "client", RoleController) - serverReg := testRegistry(t) - clientReg := testRegistry(t) + serverNM := newTestNodeManager(t, "server", RoleWorker) + clientNM := newTestNodeManager(t, "client", RoleController) + serverReg := newTestPeerRegistry(t) + clientReg := newTestPeerRegistry(t) serverTransport := NewTransport(serverNM, serverReg, serverCfg) clientTransport := NewTransport(clientNM, clientReg, clientCfg) // Use httptest.Server with the transport's WebSocket handler mux := http.NewServeMux() - mux.HandleFunc(serverCfg.WSPath, serverTransport.handleWSUpgrade) + mux.HandleFunc(serverCfg.WebSocketPath, serverTransport.handleWebSocketUpgrade) ts := httptest.NewServer(mux) u, _ := url.Parse(ts.URL) @@ -124,7 +118,7 @@ func (tp *testTransportPair) connectClient(t *testing.T) *PeerConnection { // --- Unit Tests for Sub-Components --- -func TestMessageDeduplicator(t *testing.T) { +func TestTransport_MessageDeduplicator_Good(t *testing.T) { t.Run("MarkAndCheck", func(t *testing.T) { d := NewMessageDeduplicator(5 * time.Minute) @@ -159,6 +153,17 @@ func TestMessageDeduplicator(t *testing.T) { } }) + t.Run("ExpiredEntriesDoNotLinger", func(t *testing.T) { + d := NewMessageDeduplicator(50 * time.Millisecond) + d.Mark("msg-1") + + time.Sleep(75 * time.Millisecond) + + if d.IsDuplicate("msg-1") { + t.Error("should not be duplicate after TTL even before cleanup runs") + } + }) + t.Run("ConcurrentAccess", func(t *testing.T) { d := NewMessageDeduplicator(5 * time.Minute) var wg sync.WaitGroup @@ -175,7 +180,7 @@ func TestMessageDeduplicator(t *testing.T) { }) } -func TestPeerRateLimiter(t *testing.T) { +func TestTransport_PeerRateLimiter_Good(t *testing.T) { t.Run("AllowUpToBurst", func(t *testing.T) { rl := NewPeerRateLimiter(10, 5) @@ -213,7 +218,7 @@ func TestPeerRateLimiter(t *testing.T) { // --- Transport Integration Tests --- -func TestTransport_FullHandshake(t *testing.T) { +func TestTransport_FullHandshake_Good(t *testing.T) { tp := setupTestTransportPair(t) pc := tp.connectClient(t) @@ -225,11 +230,11 @@ func TestTransport_FullHandshake(t *testing.T) { // Allow server goroutines to register the connection time.Sleep(50 * time.Millisecond) - if tp.Server.ConnectedPeers() != 1 { - t.Errorf("server connected peers: got %d, want 1", tp.Server.ConnectedPeers()) + if tp.Server.ConnectedPeerCount() != 1 { + t.Errorf("server connected peers: got %d, want 1", tp.Server.ConnectedPeerCount()) } - if tp.Client.ConnectedPeers() != 1 { - t.Errorf("client connected peers: got %d, want 1", tp.Client.ConnectedPeers()) + if tp.Client.ConnectedPeerCount() != 1 { + t.Errorf("client connected peers: got %d, want 1", tp.Client.ConnectedPeerCount()) } // Verify peer identity was exchanged correctly @@ -243,7 +248,72 @@ func TestTransport_FullHandshake(t *testing.T) { } } -func TestTransport_HandshakeRejectWrongVersion(t *testing.T) { +func TestTransport_ConnectSendsAgentUserAgent_Good(t *testing.T) { + serverNM := newTestNodeManager(t, "ua-server", RoleWorker) + clientNM := newTestNodeManager(t, "ua-client", RoleController) + serverReg := newTestPeerRegistry(t) + clientReg := newTestPeerRegistry(t) + + serverCfg := DefaultTransportConfig() + clientCfg := DefaultTransportConfig() + serverTransport := NewTransport(serverNM, serverReg, serverCfg) + clientTransport := NewTransport(clientNM, clientReg, clientCfg) + + var capturedUserAgent atomic.Value + + mux := http.NewServeMux() + mux.HandleFunc(serverCfg.WebSocketPath, func(w http.ResponseWriter, r *http.Request) { + capturedUserAgent.Store(r.Header.Get("User-Agent")) + serverTransport.handleWebSocketUpgrade(w, r) + }) + + ts := httptest.NewServer(mux) + t.Cleanup(func() { + clientTransport.Stop() + serverTransport.Stop() + ts.Close() + }) + + u, _ := url.Parse(ts.URL) + serverAddr := u.Host + + peer := &Peer{ + ID: serverNM.GetIdentity().ID, + Name: "server", + Address: serverAddr, + Role: RoleWorker, + } + clientReg.AddPeer(peer) + + pc, err := clientTransport.Connect(peer) + if err != nil { + t.Fatalf("client connect failed: %v", err) + } + + ua, ok := capturedUserAgent.Load().(string) + if !ok || ua == "" { + t.Fatal("expected user-agent to be captured during websocket upgrade") + } + if !strings.HasPrefix(ua, agentUserAgentPrefix) { + t.Fatalf("user-agent prefix: got %q, want prefix %q", ua, agentUserAgentPrefix) + } + if !strings.Contains(ua, "id="+clientNM.GetIdentity().ID) { + t.Fatalf("user-agent should include client identity, got %q", ua) + } + if pc.UserAgent != ua { + t.Fatalf("client connection user-agent: got %q, want %q", pc.UserAgent, ua) + } + + serverConn := serverTransport.GetConnection(clientNM.GetIdentity().ID) + if serverConn == nil { + t.Fatal("server should retain the accepted connection") + } + if serverConn.UserAgent != ua { + t.Fatalf("server connection user-agent: got %q, want %q", serverConn.UserAgent, ua) + } +} + +func TestTransport_HandshakeRejectWrongVersion_Bad(t *testing.T) { tp := setupTestTransportPair(t) // Dial raw WebSocket and send handshake with unsupported version @@ -259,7 +329,7 @@ func TestTransport_HandshakeRejectWrongVersion(t *testing.T) { Identity: *clientIdentity, Version: "99.99", // Unsupported } - msg, _ := NewMessage(MsgHandshake, clientIdentity.ID, "", payload) + msg, _ := NewMessage(MessageHandshake, clientIdentity.ID, "", payload) data, _ := MarshalJSON(msg) if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { @@ -272,9 +342,7 @@ func TestTransport_HandshakeRejectWrongVersion(t *testing.T) { } var resp Message - if err := json.Unmarshal(respData, &resp); err != nil { - t.Fatalf("unmarshal response: %v", err) - } + testJSONUnmarshal(t, respData, &resp) var ack HandshakeAckPayload resp.ParsePayload(&ack) @@ -282,12 +350,12 @@ func TestTransport_HandshakeRejectWrongVersion(t *testing.T) { if ack.Accepted { t.Error("should reject incompatible protocol version") } - if !strings.Contains(ack.Reason, "incompatible protocol version") { + if !core.Contains(ack.Reason, "incompatible protocol version") { t.Errorf("expected version rejection reason, got: %s", ack.Reason) } } -func TestTransport_HandshakeRejectAllowlist(t *testing.T) { +func TestTransport_HandshakeRejectAllowlist_Bad(t *testing.T) { tp := setupTestTransportPair(t) // Switch server to allowlist mode WITHOUT adding client's key @@ -305,12 +373,12 @@ func TestTransport_HandshakeRejectAllowlist(t *testing.T) { if err == nil { t.Fatal("should reject peer not in allowlist") } - if !strings.Contains(err.Error(), "rejected") { + if !core.Contains(err.Error(), "rejected") { t.Errorf("expected rejection error, got: %v", err) } } -func TestTransport_EncryptedMessageRoundTrip(t *testing.T) { +func TestTransport_EncryptedMessageRoundTrip_Ugly(t *testing.T) { tp := setupTestTransportPair(t) received := make(chan *Message, 1) @@ -323,7 +391,7 @@ func TestTransport_EncryptedMessageRoundTrip(t *testing.T) { // Send an encrypted message from client to server clientID := tp.ClientNode.GetIdentity().ID serverID := tp.ServerNode.GetIdentity().ID - sentMsg, _ := NewMessage(MsgPing, clientID, serverID, PingPayload{ + sentMsg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{ SentAt: time.Now().UnixMilli(), }) @@ -333,8 +401,8 @@ func TestTransport_EncryptedMessageRoundTrip(t *testing.T) { select { case msg := <-received: - if msg.Type != MsgPing { - t.Errorf("type: got %s, want %s", msg.Type, MsgPing) + if msg.Type != MessagePing { + t.Errorf("type: got %s, want %s", msg.Type, MessagePing) } if msg.ID != sentMsg.ID { t.Error("message ID mismatch after encrypt/decrypt round-trip") @@ -353,7 +421,7 @@ func TestTransport_EncryptedMessageRoundTrip(t *testing.T) { } } -func TestTransport_MessageDedup(t *testing.T) { +func TestTransport_MessageDedup_Good(t *testing.T) { tp := setupTestTransportPair(t) var count atomic.Int32 @@ -365,7 +433,7 @@ func TestTransport_MessageDedup(t *testing.T) { clientID := tp.ClientNode.GetIdentity().ID serverID := tp.ServerNode.GetIdentity().ID - msg, _ := NewMessage(MsgPing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()}) + msg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()}) // Send the same message twice if err := pc.Send(msg); err != nil { @@ -383,7 +451,7 @@ func TestTransport_MessageDedup(t *testing.T) { } } -func TestTransport_RateLimiting(t *testing.T) { +func TestTransport_RateLimiting_Good(t *testing.T) { tp := setupTestTransportPair(t) var count atomic.Int32 @@ -398,7 +466,7 @@ func TestTransport_RateLimiting(t *testing.T) { // Send 150 messages rapidly (rate limiter burst = 100) for range 150 { - msg, _ := NewMessage(MsgPing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()}) + msg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()}) pc.Send(msg) } @@ -415,17 +483,17 @@ func TestTransport_RateLimiting(t *testing.T) { } } -func TestTransport_MaxConnsEnforcement(t *testing.T) { - // Server with MaxConns=1 - serverNM := testNode(t, "maxconns-server", RoleWorker) - serverReg := testRegistry(t) +func TestTransport_MaxConnectionsEnforcement_Good(t *testing.T) { + // Server with MaxConnections=1 + serverNM := newTestNodeManager(t, "maxconns-server", RoleWorker) + serverReg := newTestPeerRegistry(t) serverCfg := DefaultTransportConfig() - serverCfg.MaxConns = 1 + serverCfg.MaxConnections = 1 serverTransport := NewTransport(serverNM, serverReg, serverCfg) mux := http.NewServeMux() - mux.HandleFunc(serverCfg.WSPath, serverTransport.handleWSUpgrade) + mux.HandleFunc(serverCfg.WebSocketPath, serverTransport.handleWebSocketUpgrade) ts := httptest.NewServer(mux) t.Cleanup(func() { serverTransport.Stop() @@ -436,8 +504,8 @@ func TestTransport_MaxConnsEnforcement(t *testing.T) { serverAddr := u.Host // First client connects successfully - client1NM := testNode(t, "client1", RoleController) - client1Reg := testRegistry(t) + client1NM := newTestNodeManager(t, "client1", RoleController) + client1Reg := newTestPeerRegistry(t) client1Transport := NewTransport(client1NM, client1Reg, DefaultTransportConfig()) t.Cleanup(func() { client1Transport.Stop() }) @@ -452,9 +520,9 @@ func TestTransport_MaxConnsEnforcement(t *testing.T) { // Allow server to register the connection time.Sleep(50 * time.Millisecond) - // Second client should be rejected (MaxConns=1 reached) - client2NM := testNode(t, "client2", RoleController) - client2Reg := testRegistry(t) + // Second client should be rejected (MaxConnections=1 reached) + client2NM := newTestNodeManager(t, "client2", RoleController) + client2Reg := newTestPeerRegistry(t) client2Transport := NewTransport(client2NM, client2Reg, DefaultTransportConfig()) t.Cleanup(func() { client2Transport.Stop() }) @@ -463,11 +531,11 @@ func TestTransport_MaxConnsEnforcement(t *testing.T) { _, err = client2Transport.Connect(peer2) if err == nil { - t.Fatal("second connection should be rejected when MaxConns=1") + t.Fatal("second connection should be rejected when MaxConnections=1") } } -func TestTransport_KeepaliveTimeout(t *testing.T) { +func TestTransport_KeepaliveTimeout_Bad(t *testing.T) { // Use short keepalive settings so the test is fast serverCfg := DefaultTransportConfig() serverCfg.PingInterval = 100 * time.Millisecond @@ -482,8 +550,8 @@ func TestTransport_KeepaliveTimeout(t *testing.T) { // Verify connection is established time.Sleep(50 * time.Millisecond) - if tp.Server.ConnectedPeers() != 1 { - t.Fatalf("server should have 1 peer initially, got %d", tp.Server.ConnectedPeers()) + if tp.Server.ConnectedPeerCount() != 1 { + t.Fatalf("server should have 1 peer initially, got %d", tp.Server.ConnectedPeerCount()) } // Close the underlying WebSocket on the client side to simulate network failure. @@ -494,16 +562,16 @@ func TestTransport_KeepaliveTimeout(t *testing.T) { if clientConn == nil { t.Fatal("client should have connection to server") } - clientConn.Conn.Close() + clientConn.WebSocketConnection.Close() // Wait for server to detect and clean up deadline := time.After(2 * time.Second) for { select { case <-deadline: - t.Fatalf("server did not clean up connection: still has %d peers", tp.Server.ConnectedPeers()) + t.Fatalf("server did not clean up connection: still has %d peers", tp.Server.ConnectedPeerCount()) default: - if tp.Server.ConnectedPeers() == 0 { + if tp.Server.ConnectedPeerCount() == 0 { // Verify registry updated peer := tp.ServerReg.GetPeer(clientID) if peer != nil && peer.Connected { @@ -516,7 +584,7 @@ func TestTransport_KeepaliveTimeout(t *testing.T) { } } -func TestTransport_GracefulClose(t *testing.T) { +func TestTransport_GracefulClose_Ugly(t *testing.T) { tp := setupTestTransportPair(t) received := make(chan *Message, 10) @@ -529,13 +597,13 @@ func TestTransport_GracefulClose(t *testing.T) { // Allow connection to fully establish time.Sleep(50 * time.Millisecond) - // Graceful close should send a MsgDisconnect before closing + // Graceful close should send a MessageDisconnect before closing pc.GracefulClose("test shutdown", DisconnectNormal) // Check if disconnect message was received select { case msg := <-received: - if msg.Type != MsgDisconnect { + if msg.Type != MessageDisconnect { t.Errorf("expected disconnect message, got %s", msg.Type) } var payload DisconnectPayload @@ -551,7 +619,38 @@ func TestTransport_GracefulClose(t *testing.T) { } } -func TestTransport_ConcurrentSends(t *testing.T) { +func TestTransport_PeerConnectionClose_ReleasesState_Good(t *testing.T) { + tp := setupTestTransportPair(t) + + pc := tp.connectClient(t) + clientID := tp.ClientNode.GetIdentity().ID + if tp.Client.GetConnection(tp.ServerNode.GetIdentity().ID) == nil { + t.Fatal("client should have an active connection before close") + } + + if err := pc.Close(); err != nil { + t.Fatalf("close peer connection: %v", err) + } + + deadline := time.After(2 * time.Second) + for { + select { + case <-deadline: + t.Fatal("connection state was not released after close") + default: + if tp.Client.GetConnection(tp.ServerNode.GetIdentity().ID) == nil && tp.Client.ConnectedPeerCount() == 0 { + peer := tp.ClientReg.GetPeer(clientID) + if peer != nil && peer.Connected { + t.Fatal("registry should show peer as disconnected after close") + } + return + } + time.Sleep(20 * time.Millisecond) + } + } +} + +func TestTransport_ConcurrentSends_Ugly(t *testing.T) { tp := setupTestTransportPair(t) var count atomic.Int32 @@ -572,7 +671,7 @@ func TestTransport_ConcurrentSends(t *testing.T) { for range goroutines { wg.Go(func() { for range msgsPerGoroutine { - msg, _ := NewMessage(MsgPing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()}) + msg, _ := NewMessage(MessagePing, clientID, serverID, PingPayload{SentAt: time.Now().UnixMilli()}) pc.Send(msg) } }) @@ -591,10 +690,10 @@ func TestTransport_ConcurrentSends(t *testing.T) { // --- Additional coverage tests --- -func TestTransport_Broadcast(t *testing.T) { +func TestTransport_Broadcast_Good(t *testing.T) { // Set up a controller with two worker peers connected. - controllerNM := testNode(t, "broadcast-controller", RoleController) - controllerReg := testRegistry(t) + controllerNM := newTestNodeManager(t, "broadcast-controller", RoleController) + controllerReg := newTestPeerRegistry(t) controllerTransport := NewTransport(controllerNM, controllerReg, DefaultTransportConfig()) t.Cleanup(func() { controllerTransport.Stop() }) @@ -629,7 +728,7 @@ func TestTransport_Broadcast(t *testing.T) { // Broadcast a message from the controller controllerID := controllerNM.GetIdentity().ID - msg, _ := NewMessage(MsgPing, controllerID, "", PingPayload{ + msg, _ := NewMessage(MessagePing, controllerID, "", PingPayload{ SentAt: time.Now().UnixMilli(), }) @@ -648,7 +747,7 @@ func TestTransport_Broadcast(t *testing.T) { } } -func TestTransport_BroadcastExcludesSender(t *testing.T) { +func TestTransport_BroadcastExcludesSender_Good(t *testing.T) { // Verify that Broadcast excludes the sender. tp := setupTestTransportPair(t) @@ -665,7 +764,7 @@ func TestTransport_BroadcastExcludesSender(t *testing.T) { // connection peer ID check, not the server's own ID. Let's verify sender exclusion // by broadcasting from the server with its own ID. serverID := tp.ServerNode.GetIdentity().ID - msg, _ := NewMessage(MsgPing, serverID, "", PingPayload{SentAt: time.Now().UnixMilli()}) + msg, _ := NewMessage(MessagePing, serverID, "", PingPayload{SentAt: time.Now().UnixMilli()}) // This broadcasts from server to all connected peers (the client). // The server itself won't receive it back because it's not connected to itself. @@ -675,9 +774,9 @@ func TestTransport_BroadcastExcludesSender(t *testing.T) { } } -func TestTransport_NewTransport_DefaultMaxMessageSize(t *testing.T) { - nm := testNode(t, "defaults", RoleWorker) - reg := testRegistry(t) +func TestTransport_NewTransport_DefaultMaxMessageSize_Good(t *testing.T) { + nm := newTestNodeManager(t, "defaults", RoleWorker) + reg := newTestPeerRegistry(t) cfg := TransportConfig{ MaxMessageSize: 0, // should use default } @@ -689,29 +788,29 @@ func TestTransport_NewTransport_DefaultMaxMessageSize(t *testing.T) { if tr.config.MaxMessageSize != 0 { t.Errorf("config should preserve 0 value, got %d", tr.config.MaxMessageSize) } - // The actual default is applied at usage time (readLoop, handleWSUpgrade) + // The actual default is applied at usage time (readLoop, handleWebSocketUpgrade) } -func TestTransport_ConnectedPeers(t *testing.T) { +func TestTransport_ConnectedPeerCount_Good(t *testing.T) { tp := setupTestTransportPair(t) - if tp.Server.ConnectedPeers() != 0 { - t.Errorf("expected 0 connected peers initially, got %d", tp.Server.ConnectedPeers()) + if tp.Server.ConnectedPeerCount() != 0 { + t.Errorf("expected 0 connected peers initially, got %d", tp.Server.ConnectedPeerCount()) } tp.connectClient(t) time.Sleep(50 * time.Millisecond) - if tp.Server.ConnectedPeers() != 1 { - t.Errorf("expected 1 connected peer after connect, got %d", tp.Server.ConnectedPeers()) + if tp.Server.ConnectedPeerCount() != 1 { + t.Errorf("expected 1 connected peer after connect, got %d", tp.Server.ConnectedPeerCount()) } } -func TestTransport_StartAndStop(t *testing.T) { - nm := testNode(t, "start-test", RoleWorker) - reg := testRegistry(t) +func TestTransport_StartAndStop_Good(t *testing.T) { + nm := newTestNodeManager(t, "start-test", RoleWorker) + reg := newTestPeerRegistry(t) cfg := DefaultTransportConfig() - cfg.ListenAddr = ":0" // Let OS pick a free port + cfg.ListenAddress = ":0" // Let OS pick a free port tr := NewTransport(nm, reg, cfg) @@ -729,9 +828,9 @@ func TestTransport_StartAndStop(t *testing.T) { } } -func TestTransport_CheckOrigin(t *testing.T) { - nm := testNode(t, "origin-test", RoleWorker) - reg := testRegistry(t) +func TestTransport_CheckOrigin_Good(t *testing.T) { + nm := newTestNodeManager(t, "origin-test", RoleWorker) + reg := newTestPeerRegistry(t) cfg := DefaultTransportConfig() tr := NewTransport(nm, reg, cfg) diff --git a/node/worker.go b/node/worker.go index af917d4..03875f2 100644 --- a/node/worker.go +++ b/node/worker.go @@ -2,18 +2,15 @@ package node import ( "encoding/base64" - "encoding/json" - "path/filepath" "time" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" "dappco.re/go/core/p2p/logging" "github.com/adrg/xdg" ) -// MinerManager interface for the mining package integration. -// This allows the node package to interact with mining.Manager without import cycles. +// var minerManager MinerManager type MinerManager interface { StartMiner(minerType string, config any) (MinerInstance, error) StopMiner(name string) error @@ -21,7 +18,7 @@ type MinerManager interface { GetMiner(name string) (MinerInstance, error) } -// MinerInstance represents a running miner for stats collection. +// var miner MinerInstance type MinerInstance interface { GetName() string GetType() string @@ -29,61 +26,60 @@ type MinerInstance interface { GetConsoleHistory(lines int) []string } -// ProfileManager interface for profile operations. +// var profileManager ProfileManager type ProfileManager interface { GetProfile(id string) (any, error) SaveProfile(profile any) error } -// Worker handles incoming messages on a worker node. +// worker := NewWorker(nodeManager, transport) type Worker struct { - node *NodeManager - transport *Transport - minerManager MinerManager - profileManager ProfileManager - startTime time.Time - DataDir string // Base directory for deployments (defaults to xdg.DataHome) + nodeManager *NodeManager + transport *Transport + minerManager MinerManager + profileManager ProfileManager + startedAt time.Time + DeploymentDirectory string // worker.DeploymentDirectory = "/srv/p2p/deployments" } -// NewWorker creates a new Worker instance. -func NewWorker(node *NodeManager, transport *Transport) *Worker { +// worker := NewWorker(nodeManager, transport) +func NewWorker(nodeManager *NodeManager, transport *Transport) *Worker { return &Worker{ - node: node, - transport: transport, - startTime: time.Now(), - DataDir: xdg.DataHome, + nodeManager: nodeManager, + transport: transport, + startedAt: time.Now(), + DeploymentDirectory: xdg.DataHome, } } - -// SetMinerManager sets the miner manager for handling miner operations. +// worker.SetMinerManager(minerManager) func (w *Worker) SetMinerManager(manager MinerManager) { w.minerManager = manager } -// SetProfileManager sets the profile manager for handling profile operations. +// worker.SetProfileManager(profileManager) func (w *Worker) SetProfileManager(manager ProfileManager) { w.profileManager = manager } -// HandleMessage processes incoming messages and returns a response. -func (w *Worker) HandleMessage(conn *PeerConnection, msg *Message) { +// worker.HandleMessage(peerConnection, message) +func (w *Worker) HandleMessage(peerConnection *PeerConnection, message *Message) { var response *Message var err error - switch msg.Type { - case MsgPing: - response, err = w.handlePing(msg) - case MsgGetStats: - response, err = w.handleGetStats(msg) - case MsgStartMiner: - response, err = w.handleStartMiner(msg) - case MsgStopMiner: - response, err = w.handleStopMiner(msg) - case MsgGetLogs: - response, err = w.handleGetLogs(msg) - case MsgDeploy: - response, err = w.handleDeploy(conn, msg) + switch message.Type { + case MessagePing: + response, err = w.handlePing(message) + case MessageGetStats: + response, err = w.handleStats(message) + case MessageStartMiner: + response, err = w.handleStartMiner(message) + case MessageStopMiner: + response, err = w.handleStopMiner(message) + case MessageGetLogs: + response, err = w.handleLogs(message) + case MessageDeploy: + response, err = w.handleDeploy(peerConnection, message) default: // Unknown message type - ignore or send error return @@ -91,23 +87,23 @@ func (w *Worker) HandleMessage(conn *PeerConnection, msg *Message) { if err != nil { // Send error response - identity := w.node.GetIdentity() + identity := w.nodeManager.GetIdentity() if identity != nil { errMsg, _ := NewErrorMessage( identity.ID, - msg.From, - ErrCodeOperationFailed, + message.From, + ErrorCodeOperationFailed, err.Error(), - msg.ID, + message.ID, ) - conn.Send(errMsg) + peerConnection.Send(errMsg) } return } if response != nil { - logging.Debug("sending response", logging.Fields{"type": response.Type, "to": msg.From}) - if err := conn.Send(response); err != nil { + logging.Debug("sending response", logging.Fields{"type": response.Type, "to": message.From}) + if err := peerConnection.Send(response); err != nil { logging.Error("failed to send response", logging.Fields{"error": err}) } else { logging.Debug("response sent successfully") @@ -115,11 +111,10 @@ func (w *Worker) HandleMessage(conn *PeerConnection, msg *Message) { } } -// handlePing responds to ping requests. -func (w *Worker) handlePing(msg *Message) (*Message, error) { +func (w *Worker) handlePing(message *Message) (*Message, error) { var ping PingPayload - if err := msg.ParsePayload(&ping); err != nil { - return nil, coreerr.E("Worker.handlePing", "invalid ping payload", err) + if err := message.ParsePayload(&ping); err != nil { + return nil, core.E("Worker.handlePing", "invalid ping payload", err) } pong := PongPayload{ @@ -127,21 +122,20 @@ func (w *Worker) handlePing(msg *Message) (*Message, error) { ReceivedAt: time.Now().UnixMilli(), } - return msg.Reply(MsgPong, pong) + return message.Reply(MessagePong, pong) } -// handleGetStats responds with current miner statistics. -func (w *Worker) handleGetStats(msg *Message) (*Message, error) { - identity := w.node.GetIdentity() +func (w *Worker) handleStats(message *Message) (*Message, error) { + identity := w.nodeManager.GetIdentity() if identity == nil { - return nil, ErrIdentityNotInitialized + return nil, ErrorIdentityNotInitialized } stats := StatsPayload{ NodeID: identity.ID, NodeName: identity.Name, Miners: []MinerStatsItem{}, - Uptime: int64(time.Since(w.startTime).Seconds()), + Uptime: int64(time.Since(w.startedAt).Seconds()), } if w.minerManager != nil { @@ -152,24 +146,20 @@ func (w *Worker) handleGetStats(msg *Message) (*Message, error) { continue } - // Convert to MinerStatsItem - this is a simplified conversion - // The actual implementation would need to match the mining package's stats structure item := convertMinerStats(miner, minerStats) stats.Miners = append(stats.Miners, item) } } - return msg.Reply(MsgStats, stats) + return message.Reply(MessageStats, stats) } -// convertMinerStats converts miner stats to the protocol format. func convertMinerStats(miner MinerInstance, rawStats any) MinerStatsItem { item := MinerStatsItem{ Name: miner.GetName(), Type: miner.GetType(), } - // Try to extract common fields from the stats if statsMap, ok := rawStats.(map[string]any); ok { if hashrate, ok := statsMap["hashrate"].(float64); ok { item.Hashrate = hashrate @@ -194,62 +184,57 @@ func convertMinerStats(miner MinerInstance, rawStats any) MinerStatsItem { return item } -// handleStartMiner starts a miner with the given profile. -func (w *Worker) handleStartMiner(msg *Message) (*Message, error) { +func (w *Worker) handleStartMiner(message *Message) (*Message, error) { if w.minerManager == nil { - return nil, ErrMinerManagerNotConfigured + return nil, ErrorMinerManagerNotConfigured } var payload StartMinerPayload - if err := msg.ParsePayload(&payload); err != nil { - return nil, coreerr.E("Worker.handleStartMiner", "invalid start miner payload", err) + if err := message.ParsePayload(&payload); err != nil { + return nil, core.E("Worker.handleStartMiner", "invalid start miner payload", err) } - // Validate miner type is provided if payload.MinerType == "" { - return nil, coreerr.E("Worker.handleStartMiner", "miner type is required", nil) + return nil, core.E("Worker.handleStartMiner", "miner type is required", nil) } - // Get the config from the profile or use the override var config any if payload.Config != nil { config = payload.Config } else if w.profileManager != nil { profile, err := w.profileManager.GetProfile(payload.ProfileID) if err != nil { - return nil, coreerr.E("Worker.handleStartMiner", "profile not found: "+payload.ProfileID, nil) + return nil, core.E("Worker.handleStartMiner", "profile not found: "+payload.ProfileID, nil) } config = profile } else { - return nil, coreerr.E("Worker.handleStartMiner", "no config provided and no profile manager configured", nil) + return nil, core.E("Worker.handleStartMiner", "no config provided and no profile manager configured", nil) } - // Start the miner miner, err := w.minerManager.StartMiner(payload.MinerType, config) if err != nil { ack := MinerAckPayload{ Success: false, Error: err.Error(), } - return msg.Reply(MsgMinerAck, ack) + return message.Reply(MessageMinerAck, ack) } ack := MinerAckPayload{ Success: true, MinerName: miner.GetName(), } - return msg.Reply(MsgMinerAck, ack) + return message.Reply(MessageMinerAck, ack) } -// handleStopMiner stops a running miner. -func (w *Worker) handleStopMiner(msg *Message) (*Message, error) { +func (w *Worker) handleStopMiner(message *Message) (*Message, error) { if w.minerManager == nil { - return nil, ErrMinerManagerNotConfigured + return nil, ErrorMinerManagerNotConfigured } var payload StopMinerPayload - if err := msg.ParsePayload(&payload); err != nil { - return nil, coreerr.E("Worker.handleStopMiner", "invalid stop miner payload", err) + if err := message.ParsePayload(&payload); err != nil { + return nil, core.E("Worker.handleStopMiner", "invalid stop miner payload", err) } err := w.minerManager.StopMiner(payload.MinerName) @@ -261,21 +246,19 @@ func (w *Worker) handleStopMiner(msg *Message) (*Message, error) { ack.Error = err.Error() } - return msg.Reply(MsgMinerAck, ack) + return message.Reply(MessageMinerAck, ack) } -// handleGetLogs returns console logs from a miner. -func (w *Worker) handleGetLogs(msg *Message) (*Message, error) { +func (w *Worker) handleLogs(message *Message) (*Message, error) { if w.minerManager == nil { - return nil, ErrMinerManagerNotConfigured + return nil, ErrorMinerManagerNotConfigured } - var payload GetLogsPayload - if err := msg.ParsePayload(&payload); err != nil { - return nil, coreerr.E("Worker.handleGetLogs", "invalid get logs payload", err) + var payload LogsRequestPayload + if err := message.ParsePayload(&payload); err != nil { + return nil, core.E("Worker.handleLogs", "invalid logs payload", err) } - // Validate and limit the Lines parameter to prevent resource exhaustion const maxLogLines = 10000 if payload.Lines <= 0 || payload.Lines > maxLogLines { payload.Lines = maxLogLines @@ -283,7 +266,7 @@ func (w *Worker) handleGetLogs(msg *Message) (*Message, error) { miner, err := w.minerManager.GetMiner(payload.MinerName) if err != nil { - return nil, coreerr.E("Worker.handleGetLogs", "miner not found: "+payload.MinerName, nil) + return nil, core.E("Worker.handleLogs", "miner not found: "+payload.MinerName, nil) } lines := miner.GetConsoleHistory(payload.Lines) @@ -294,14 +277,13 @@ func (w *Worker) handleGetLogs(msg *Message) (*Message, error) { HasMore: len(lines) >= payload.Lines, } - return msg.Reply(MsgLogs, logs) + return message.Reply(MessageLogs, logs) } -// handleDeploy handles deployment of profiles or miner bundles. -func (w *Worker) handleDeploy(conn *PeerConnection, msg *Message) (*Message, error) { +func (w *Worker) handleDeploy(peerConnection *PeerConnection, message *Message) (*Message, error) { var payload DeployPayload - if err := msg.ParsePayload(&payload); err != nil { - return nil, coreerr.E("Worker.handleDeploy", "invalid deploy payload", err) + if err := message.ParsePayload(&payload); err != nil { + return nil, core.E("Worker.handleDeploy", "invalid deploy payload", err) } // Reconstruct Bundle object from payload @@ -314,26 +296,24 @@ func (w *Worker) handleDeploy(conn *PeerConnection, msg *Message) (*Message, err // Use shared secret as password (base64 encoded) password := "" - if conn != nil && len(conn.SharedSecret) > 0 { - password = base64.StdEncoding.EncodeToString(conn.SharedSecret) + if peerConnection != nil && len(peerConnection.SharedSecret) > 0 { + password = base64.StdEncoding.EncodeToString(peerConnection.SharedSecret) } switch bundle.Type { case BundleProfile: if w.profileManager == nil { - return nil, coreerr.E("Worker.handleDeploy", "profile manager not configured", nil) + return nil, core.E("Worker.handleDeploy", "profile manager not configured", nil) } - // Decrypt and extract profile data profileData, err := ExtractProfileBundle(bundle, password) if err != nil { - return nil, coreerr.E("Worker.handleDeploy", "failed to extract profile bundle", err) + return nil, core.E("Worker.handleDeploy", "failed to extract profile bundle", err) } - // Unmarshal into interface{} to pass to ProfileManager var profile any - if err := json.Unmarshal(profileData, &profile); err != nil { - return nil, coreerr.E("Worker.handleDeploy", "invalid profile data JSON", err) + if result := core.JSONUnmarshal(profileData, &profile); !result.OK { + return nil, core.E("Worker.handleDeploy", "invalid profile data JSON", result.Value.(error)) } if err := w.profileManager.SaveProfile(profile); err != nil { @@ -342,20 +322,18 @@ func (w *Worker) handleDeploy(conn *PeerConnection, msg *Message) (*Message, err Name: payload.Name, Error: err.Error(), } - return msg.Reply(MsgDeployAck, ack) + return message.Reply(MessageDeployAck, ack) } ack := DeployAckPayload{ Success: true, Name: payload.Name, } - return msg.Reply(MsgDeployAck, ack) + return message.Reply(MessageDeployAck, ack) case BundleMiner, BundleFull: - // Determine installation directory - // We use w.DataDir/lethean-desktop/miners/ - minersDir := filepath.Join(w.DataDir, "lethean-desktop", "miners") - installDir := filepath.Join(minersDir, payload.Name) + minersDir := core.JoinPath(w.deploymentDirectory(), "lethean-desktop", "miners") + installDir := core.JoinPath(minersDir, payload.Name) logging.Info("deploying miner bundle", logging.Fields{ "name": payload.Name, @@ -363,17 +341,15 @@ func (w *Worker) handleDeploy(conn *PeerConnection, msg *Message) (*Message, err "type": payload.BundleType, }) - // Extract miner bundle minerPath, profileData, err := ExtractMinerBundle(bundle, password, installDir) if err != nil { - return nil, coreerr.E("Worker.handleDeploy", "failed to extract miner bundle", err) + return nil, core.E("Worker.handleDeploy", "failed to extract miner bundle", err) } - // If the bundle contained a profile config, save it if len(profileData) > 0 && w.profileManager != nil { var profile any - if err := json.Unmarshal(profileData, &profile); err != nil { - logging.Warn("failed to parse profile from miner bundle", logging.Fields{"error": err}) + if result := core.JSONUnmarshal(profileData, &profile); !result.OK { + logging.Warn("failed to parse profile from miner bundle", logging.Fields{"error": result.Value.(error)}) } else { if err := w.profileManager.SaveProfile(profile); err != nil { logging.Warn("failed to save profile from miner bundle", logging.Fields{"error": err}) @@ -381,26 +357,30 @@ func (w *Worker) handleDeploy(conn *PeerConnection, msg *Message) (*Message, err } } - // Success response ack := DeployAckPayload{ Success: true, Name: payload.Name, } - // Log the installation logging.Info("miner bundle installed successfully", logging.Fields{ "name": payload.Name, "miner_path": minerPath, }) - return msg.Reply(MsgDeployAck, ack) + return message.Reply(MessageDeployAck, ack) default: - return nil, coreerr.E("Worker.handleDeploy", "unknown bundle type: "+payload.BundleType, nil) + return nil, core.E("Worker.handleDeploy", "unknown bundle type: "+payload.BundleType, nil) } } -// RegisterWithTransport registers the worker's message handler with the transport. -func (w *Worker) RegisterWithTransport() { +func (w *Worker) RegisterOnTransport() { w.transport.OnMessage(w.HandleMessage) } + +func (w *Worker) deploymentDirectory() string { + if w.DeploymentDirectory != "" { + return w.DeploymentDirectory + } + return xdg.DataHome +} diff --git a/node/worker_test.go b/node/worker_test.go index ee3ed31..5915b33 100644 --- a/node/worker_test.go +++ b/node/worker_test.go @@ -2,34 +2,25 @@ package node import ( "encoding/base64" - "encoding/json" - "fmt" - "os" - "path/filepath" "testing" "time" + + core "dappco.re/go/core" ) -// setupTestEnv sets up a temporary environment for testing and returns cleanup function -func setupTestEnv(t *testing.T) func() { +func setupTestEnvironment(t *testing.T) func() { tmpDir := t.TempDir() - os.Setenv("XDG_CONFIG_HOME", filepath.Join(tmpDir, "config")) - os.Setenv("XDG_DATA_HOME", filepath.Join(tmpDir, "data")) - return func() { - os.Unsetenv("XDG_CONFIG_HOME") - os.Unsetenv("XDG_DATA_HOME") - } + t.Setenv("XDG_CONFIG_HOME", testJoinPath(tmpDir, "config")) + t.Setenv("XDG_DATA_HOME", testJoinPath(tmpDir, "data")) + return func() {} } -func TestNewWorker(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_NewWorker_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -37,35 +28,32 @@ func TestNewWorker(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() if worker == nil { t.Fatal("NewWorker returned nil") } - if worker.node != nm { - t.Error("worker.node not set correctly") + if worker.nodeManager != nm { + t.Error("worker.nodeManager not set correctly") } if worker.transport != transport { t.Error("worker.transport not set correctly") } } -func TestWorker_SetMinerManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_SetMinerManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -73,14 +61,14 @@ func TestWorker_SetMinerManager(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() mockManager := &mockMinerManager{} worker.SetMinerManager(mockManager) @@ -90,15 +78,12 @@ func TestWorker_SetMinerManager(t *testing.T) { } } -func TestWorker_SetProfileManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_SetProfileManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -106,14 +91,14 @@ func TestWorker_SetProfileManager(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() mockProfile := &mockProfileManager{} worker.SetProfileManager(mockProfile) @@ -123,15 +108,12 @@ func TestWorker_SetProfileManager(t *testing.T) { } } -func TestWorker_HandlePing(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandlePing_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -139,14 +121,14 @@ func TestWorker_HandlePing(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() // Create a ping message identity := nm.GetIdentity() @@ -154,7 +136,7 @@ func TestWorker_HandlePing(t *testing.T) { t.Fatal("expected identity to be generated") } pingPayload := PingPayload{SentAt: time.Now().UnixMilli()} - pingMsg, err := NewMessage(MsgPing, "sender-id", identity.ID, pingPayload) + pingMsg, err := NewMessage(MessagePing, "sender-id", identity.ID, pingPayload) if err != nil { t.Fatalf("failed to create ping message: %v", err) } @@ -169,8 +151,8 @@ func TestWorker_HandlePing(t *testing.T) { t.Fatal("handlePing returned nil response") } - if response.Type != MsgPong { - t.Errorf("expected response type %s, got %s", MsgPong, response.Type) + if response.Type != MessagePong { + t.Errorf("expected response type %s, got %s", MessagePong, response.Type) } var pong PongPayload @@ -187,15 +169,12 @@ func TestWorker_HandlePing(t *testing.T) { } } -func TestWorker_HandleGetStats(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleStats_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -203,37 +182,37 @@ func TestWorker_HandleGetStats(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() - // Create a get_stats message + // Create a stats request message. identity := nm.GetIdentity() if identity == nil { t.Fatal("expected identity to be generated") } - msg, err := NewMessage(MsgGetStats, "sender-id", identity.ID, nil) + msg, err := NewMessage(MessageGetStats, "sender-id", identity.ID, nil) if err != nil { - t.Fatalf("failed to create get_stats message: %v", err) + t.Fatalf("failed to create stats request message: %v", err) } - // Call handleGetStats directly (without miner manager) - response, err := worker.handleGetStats(msg) + // Call handleStats directly (without miner manager). + response, err := worker.handleStats(msg) if err != nil { - t.Fatalf("handleGetStats returned error: %v", err) + t.Fatalf("handleStats returned error: %v", err) } if response == nil { - t.Fatal("handleGetStats returned nil response") + t.Fatal("handleStats returned nil response") } - if response.Type != MsgStats { - t.Errorf("expected response type %s, got %s", MsgStats, response.Type) + if response.Type != MessageStats { + t.Errorf("expected response type %s, got %s", MessageStats, response.Type) } var stats StatsPayload @@ -250,15 +229,12 @@ func TestWorker_HandleGetStats(t *testing.T) { } } -func TestWorker_HandleStartMiner_NoManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleStartMiner_NoManager_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -266,14 +242,14 @@ func TestWorker_HandleStartMiner_NoManager(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() // Create a start_miner message identity := nm.GetIdentity() @@ -281,7 +257,7 @@ func TestWorker_HandleStartMiner_NoManager(t *testing.T) { t.Fatal("expected identity to be generated") } payload := StartMinerPayload{MinerType: "xmrig", ProfileID: "test-profile"} - msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageStartMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create start_miner message: %v", err) } @@ -293,15 +269,12 @@ func TestWorker_HandleStartMiner_NoManager(t *testing.T) { } } -func TestWorker_HandleStopMiner_NoManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleStopMiner_NoManager_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -309,14 +282,14 @@ func TestWorker_HandleStopMiner_NoManager(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() // Create a stop_miner message identity := nm.GetIdentity() @@ -324,7 +297,7 @@ func TestWorker_HandleStopMiner_NoManager(t *testing.T) { t.Fatal("expected identity to be generated") } payload := StopMinerPayload{MinerName: "test-miner"} - msg, err := NewMessage(MsgStopMiner, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageStopMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create stop_miner message: %v", err) } @@ -336,15 +309,12 @@ func TestWorker_HandleStopMiner_NoManager(t *testing.T) { } } -func TestWorker_HandleGetLogs_NoManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleLogs_NoManager_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -352,42 +322,39 @@ func TestWorker_HandleGetLogs_NoManager(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() - // Create a get_logs message + // Create a logs request message. identity := nm.GetIdentity() if identity == nil { t.Fatal("expected identity to be generated") } - payload := GetLogsPayload{MinerName: "test-miner", Lines: 100} - msg, err := NewMessage(MsgGetLogs, "sender-id", identity.ID, payload) + payload := LogsRequestPayload{MinerName: "test-miner", Lines: 100} + msg, err := NewMessage(MessageGetLogs, "sender-id", identity.ID, payload) if err != nil { - t.Fatalf("failed to create get_logs message: %v", err) + t.Fatalf("failed to create logs request message: %v", err) } // Without miner manager, should return error - _, err = worker.handleGetLogs(msg) + _, err = worker.handleLogs(msg) if err == nil { t.Error("expected error when miner manager is nil") } } -func TestWorker_HandleDeploy_Profile(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_Profile_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -395,14 +362,14 @@ func TestWorker_HandleDeploy_Profile(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() // Create a deploy message for profile identity := nm.GetIdentity() @@ -414,7 +381,7 @@ func TestWorker_HandleDeploy_Profile(t *testing.T) { Data: []byte(`{"id": "test", "name": "Test Profile"}`), Name: "test-profile", } - msg, err := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageDeploy, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create deploy message: %v", err) } @@ -426,15 +393,12 @@ func TestWorker_HandleDeploy_Profile(t *testing.T) { } } -func TestWorker_HandleDeploy_UnknownType(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_UnknownType_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -442,14 +406,14 @@ func TestWorker_HandleDeploy_UnknownType(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() // Create a deploy message with unknown type identity := nm.GetIdentity() @@ -461,7 +425,7 @@ func TestWorker_HandleDeploy_UnknownType(t *testing.T) { Data: []byte(`{}`), Name: "test", } - msg, err := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageDeploy, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create deploy message: %v", err) } @@ -472,7 +436,7 @@ func TestWorker_HandleDeploy_UnknownType(t *testing.T) { } } -func TestConvertMinerStats(t *testing.T) { +func TestWorker_ConvertMinerStats_Good(t *testing.T) { tests := []struct { name string rawStats any @@ -544,6 +508,10 @@ func (m *mockMinerManager) GetMiner(name string) (MinerInstance, error) { return nil, nil } +func (m *mockMinerManager) Miner(name string) (MinerInstance, error) { + return m.GetMiner(name) +} + type mockMinerInstance struct { name string minerType string @@ -554,6 +522,10 @@ func (m *mockMinerInstance) GetName() string { return m.nam func (m *mockMinerInstance) GetType() string { return m.minerType } func (m *mockMinerInstance) GetStats() (any, error) { return m.stats, nil } func (m *mockMinerInstance) GetConsoleHistory(lines int) []string { return []string{} } +func (m *mockMinerInstance) Name() string { return m.GetName() } +func (m *mockMinerInstance) Type() string { return m.GetType() } +func (m *mockMinerInstance) Stats() (any, error) { return m.GetStats() } +func (m *mockMinerInstance) ConsoleHistory(lines int) []string { return m.GetConsoleHistory(lines) } type mockProfileManager struct{} @@ -561,6 +533,10 @@ func (m *mockProfileManager) GetProfile(id string) (any, error) { return nil, nil } +func (m *mockProfileManager) Profile(id string) (any, error) { + return m.GetProfile(id) +} + func (m *mockProfileManager) SaveProfile(profile any) error { return nil } @@ -573,15 +549,19 @@ type mockMinerManagerFailing struct { } func (m *mockMinerManagerFailing) StartMiner(minerType string, config any) (MinerInstance, error) { - return nil, fmt.Errorf("mining hardware not available") + return nil, core.E("mockMinerManagerFailing.StartMiner", "mining hardware not available", nil) } func (m *mockMinerManagerFailing) StopMiner(name string) error { - return fmt.Errorf("miner %s not found", name) + return core.E("mockMinerManagerFailing.StopMiner", "miner "+name+" not found", nil) } func (m *mockMinerManagerFailing) GetMiner(name string) (MinerInstance, error) { - return nil, fmt.Errorf("miner %s not found", name) + return nil, core.E("mockMinerManagerFailing.Miner", "miner "+name+" not found", nil) +} + +func (m *mockMinerManagerFailing) Miner(name string) (MinerInstance, error) { + return m.GetMiner(name) } // mockProfileManagerFull implements ProfileManager that returns real data. @@ -592,7 +572,7 @@ type mockProfileManagerFull struct { func (m *mockProfileManagerFull) GetProfile(id string) (any, error) { p, ok := m.profiles[id] if !ok { - return nil, fmt.Errorf("profile %s not found", id) + return nil, core.E("mockProfileManagerFull.GetProfile", "profile "+id+" not found", nil) } return p, nil } @@ -601,26 +581,31 @@ func (m *mockProfileManagerFull) SaveProfile(profile any) error { return nil } +func (m *mockProfileManagerFull) Profile(id string) (any, error) { + return m.GetProfile(id) +} + // mockProfileManagerFailing always returns errors. type mockProfileManagerFailing struct{} func (m *mockProfileManagerFailing) GetProfile(id string) (any, error) { - return nil, fmt.Errorf("profile %s not found", id) + return nil, core.E("mockProfileManagerFailing.GetProfile", "profile "+id+" not found", nil) } func (m *mockProfileManagerFailing) SaveProfile(profile any) error { - return fmt.Errorf("save failed") + return core.E("mockProfileManagerFailing.SaveProfile", "save failed", nil) } -func TestWorker_HandleStartMiner_WithManager(t *testing.T) { - cleanup := setupTestEnv(t) +func (m *mockProfileManagerFailing) Profile(id string) (any, error) { + return m.GetProfile(id) +} + +func TestWorker_HandleStartMiner_WithManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } @@ -628,14 +613,14 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() mm := &mockMinerManager{ miners: []MinerInstance{}, @@ -646,12 +631,12 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { identity := nm.GetIdentity() - t.Run("WithConfigOverride", func(t *testing.T) { + t.Run("ConfigOverride", func(t *testing.T) { payload := StartMinerPayload{ MinerType: "xmrig", - Config: json.RawMessage(`{"pool":"test:3333"}`), + Config: RawMessage(`{"pool":"test:3333"}`), } - msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageStartMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create message: %v", err) } @@ -661,8 +646,8 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { t.Fatalf("handleStartMiner returned error: %v", err) } - if response.Type != MsgMinerAck { - t.Errorf("expected type %s, got %s", MsgMinerAck, response.Type) + if response.Type != MessageMinerAck { + t.Errorf("expected type %s, got %s", MessageMinerAck, response.Type) } var ack MinerAckPayload @@ -680,9 +665,9 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { t.Run("EmptyMinerType", func(t *testing.T) { payload := StartMinerPayload{ MinerType: "", - Config: json.RawMessage(`{}`), + Config: RawMessage(`{}`), } - msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageStartMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create message: %v", err) } @@ -693,7 +678,7 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { } }) - t.Run("WithProfileManager", func(t *testing.T) { + t.Run("ProfileManagerConfigured", func(t *testing.T) { pm := &mockProfileManagerFull{ profiles: map[string]any{ "test-profile": map[string]any{"pool": "pool.test:3333"}, @@ -705,7 +690,7 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { MinerType: "xmrig", ProfileID: "test-profile", } - msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageStartMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create message: %v", err) } @@ -730,7 +715,7 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { MinerType: "xmrig", ProfileID: "missing-profile", } - msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageStartMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create message: %v", err) } @@ -747,9 +732,9 @@ func TestWorker_HandleStartMiner_WithManager(t *testing.T) { payload := StartMinerPayload{ MinerType: "xmrig", - Config: json.RawMessage(`{}`), + Config: RawMessage(`{}`), } - msg, err := NewMessage(MsgStartMiner, "sender-id", identity.ID, payload) + msg, err := NewMessage(MessageStartMiner, "sender-id", identity.ID, payload) if err != nil { t.Fatalf("failed to create message: %v", err) } @@ -780,39 +765,36 @@ type mockMinerManagerWithStart struct { func (m *mockMinerManagerWithStart) StartMiner(minerType string, config any) (MinerInstance, error) { m.counter++ - name := fmt.Sprintf("%s-%d", minerType, m.counter) + name := core.Sprintf("%s-%d", minerType, m.counter) return &mockMinerInstance{name: name, minerType: minerType}, nil } -func TestWorker_HandleStopMiner_WithManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleStopMiner_WithManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() identity := nm.GetIdentity() t.Run("Success", func(t *testing.T) { worker.SetMinerManager(&mockMinerManager{}) payload := StopMinerPayload{MinerName: "test-miner"} - msg, _ := NewMessage(MsgStopMiner, "sender-id", identity.ID, payload) + msg, _ := NewMessage(MessageStopMiner, "sender-id", identity.ID, payload) response, err := worker.handleStopMiner(msg) if err != nil { @@ -833,7 +815,7 @@ func TestWorker_HandleStopMiner_WithManager(t *testing.T) { worker.SetMinerManager(&mockMinerManagerFailing{}) payload := StopMinerPayload{MinerName: "missing-miner"} - msg, _ := NewMessage(MsgStopMiner, "sender-id", identity.ID, payload) + msg, _ := NewMessage(MessageStopMiner, "sender-id", identity.ID, payload) response, err := worker.handleStopMiner(msg) if err != nil { @@ -851,28 +833,25 @@ func TestWorker_HandleStopMiner_WithManager(t *testing.T) { }) } -func TestWorker_HandleGetLogs_WithManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleLogs_WithManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() identity := nm.GetIdentity() t.Run("Success", func(t *testing.T) { @@ -886,16 +865,16 @@ func TestWorker_HandleGetLogs_WithManager(t *testing.T) { } worker.SetMinerManager(mm) - payload := GetLogsPayload{MinerName: "test-miner", Lines: 100} - msg, _ := NewMessage(MsgGetLogs, "sender-id", identity.ID, payload) + payload := LogsRequestPayload{MinerName: "test-miner", Lines: 100} + msg, _ := NewMessage(MessageGetLogs, "sender-id", identity.ID, payload) - response, err := worker.handleGetLogs(msg) + response, err := worker.handleLogs(msg) if err != nil { - t.Fatalf("handleGetLogs returned error: %v", err) + t.Fatalf("handleLogs returned error: %v", err) } - if response.Type != MsgLogs { - t.Errorf("expected type %s, got %s", MsgLogs, response.Type) + if response.Type != MessageLogs { + t.Errorf("expected type %s, got %s", MessageLogs, response.Type) } var logs LogsPayload @@ -906,14 +885,14 @@ func TestWorker_HandleGetLogs_WithManager(t *testing.T) { }) t.Run("MinerNotFound", func(t *testing.T) { - // Use a manager that returns error for GetMiner + // Use a manager that returns an error when the miner is missing. mm := &mockMinerManagerFailing{} worker.SetMinerManager(mm) - payload := GetLogsPayload{MinerName: "non-existent", Lines: 50} - msg, _ := NewMessage(MsgGetLogs, "sender-id", identity.ID, payload) + payload := LogsRequestPayload{MinerName: "non-existent", Lines: 50} + msg, _ := NewMessage(MessageGetLogs, "sender-id", identity.ID, payload) - _, err := worker.handleGetLogs(msg) + _, err := worker.handleLogs(msg) if err == nil { t.Error("expected error for non-existent miner") } @@ -927,16 +906,16 @@ func TestWorker_HandleGetLogs_WithManager(t *testing.T) { } worker.SetMinerManager(mm) - payload := GetLogsPayload{MinerName: "test-miner", Lines: -1} - msg, _ := NewMessage(MsgGetLogs, "sender-id", identity.ID, payload) + payload := LogsRequestPayload{MinerName: "test-miner", Lines: -1} + msg, _ := NewMessage(MessageGetLogs, "sender-id", identity.ID, payload) - response, err := worker.handleGetLogs(msg) + response, err := worker.handleLogs(msg) if err != nil { - t.Fatalf("handleGetLogs returned error: %v", err) + t.Fatalf("handleLogs returned error: %v", err) } // Lines <= 0 should be clamped to maxLogLines - if response.Type != MsgLogs { - t.Errorf("expected %s, got %s", MsgLogs, response.Type) + if response.Type != MessageLogs { + t.Errorf("expected %s, got %s", MessageLogs, response.Type) } }) @@ -948,41 +927,38 @@ func TestWorker_HandleGetLogs_WithManager(t *testing.T) { } worker.SetMinerManager(mm) - payload := GetLogsPayload{MinerName: "test-miner", Lines: 999999} - msg, _ := NewMessage(MsgGetLogs, "sender-id", identity.ID, payload) + payload := LogsRequestPayload{MinerName: "test-miner", Lines: 999999} + msg, _ := NewMessage(MessageGetLogs, "sender-id", identity.ID, payload) - response, err := worker.handleGetLogs(msg) + response, err := worker.handleLogs(msg) if err != nil { - t.Fatalf("handleGetLogs returned error: %v", err) + t.Fatalf("handleLogs returned error: %v", err) } - if response.Type != MsgLogs { - t.Errorf("expected %s, got %s", MsgLogs, response.Type) + if response.Type != MessageLogs { + t.Errorf("expected %s, got %s", MessageLogs, response.Type) } }) } -func TestWorker_HandleGetStats_WithMinerManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleStats_WithMinerManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() identity := nm.GetIdentity() // Set miner manager with miners that have real stats @@ -1011,10 +987,10 @@ func TestWorker_HandleGetStats_WithMinerManager(t *testing.T) { } worker.SetMinerManager(mm) - msg, _ := NewMessage(MsgGetStats, "sender-id", identity.ID, nil) - response, err := worker.handleGetStats(msg) + msg, _ := NewMessage(MessageGetStats, "sender-id", identity.ID, nil) + response, err := worker.handleStats(msg) if err != nil { - t.Fatalf("handleGetStats returned error: %v", err) + t.Fatalf("handleStats returned error: %v", err) } var stats StatsPayload @@ -1025,28 +1001,25 @@ func TestWorker_HandleGetStats_WithMinerManager(t *testing.T) { } } -func TestWorker_HandleMessage_UnknownType(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleMessage_UnknownType_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() identity := nm.GetIdentity() msg, _ := NewMessage("unknown_type", "sender-id", identity.ID, nil) @@ -1055,28 +1028,25 @@ func TestWorker_HandleMessage_UnknownType(t *testing.T) { worker.HandleMessage(nil, msg) } -func TestWorker_HandleDeploy_ProfileWithManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_ProfileWithManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() pm := &mockProfileManagerFull{profiles: make(map[string]any)} worker.SetProfileManager(pm) @@ -1096,7 +1066,7 @@ func TestWorker_HandleDeploy_ProfileWithManager(t *testing.T) { Checksum: bundle.Checksum, Name: "deploy-test", } - msg, _ := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + msg, _ := NewMessage(MessageDeploy, "sender-id", identity.ID, payload) response, err := worker.handleDeploy(nil, msg) if err != nil { @@ -1113,28 +1083,25 @@ func TestWorker_HandleDeploy_ProfileWithManager(t *testing.T) { } } -func TestWorker_HandleDeploy_ProfileSaveFails(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_ProfileSaveFails_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() worker.SetProfileManager(&mockProfileManagerFailing{}) identity := nm.GetIdentity() @@ -1148,7 +1115,7 @@ func TestWorker_HandleDeploy_ProfileSaveFails(t *testing.T) { Checksum: bundle.Checksum, Name: "fail-test", } - msg, _ := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + msg, _ := NewMessage(MessageDeploy, "sender-id", identity.ID, payload) response, err := worker.handleDeploy(nil, msg) if err != nil { @@ -1162,36 +1129,33 @@ func TestWorker_HandleDeploy_ProfileSaveFails(t *testing.T) { } } -func TestWorker_HandleDeploy_MinerBundle(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_MinerBundle_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() pm := &mockProfileManagerFull{profiles: make(map[string]any)} worker.SetProfileManager(pm) identity := nm.GetIdentity() tmpDir := t.TempDir() - minerPath := filepath.Join(tmpDir, "test-miner") - os.WriteFile(minerPath, []byte("fake miner binary"), 0755) + minerPath := testJoinPath(tmpDir, "test-miner") + testWriteFile(t, minerPath, []byte("fake miner binary"), 0o755) profileJSON := []byte(`{"pool":"test:3333"}`) @@ -1211,7 +1175,7 @@ func TestWorker_HandleDeploy_MinerBundle(t *testing.T) { Checksum: bundle.Checksum, Name: "deploy-miner", } - msg, _ := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + msg, _ := NewMessage(MessageDeploy, "sender-id", identity.ID, payload) conn := &PeerConnection{ SharedSecret: sharedSecret, @@ -1229,34 +1193,31 @@ func TestWorker_HandleDeploy_MinerBundle(t *testing.T) { } } -func TestWorker_HandleDeploy_FullBundle(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_FullBundle_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() identity := nm.GetIdentity() tmpDir := t.TempDir() - minerPath := filepath.Join(tmpDir, "test-miner") - os.WriteFile(minerPath, []byte("miner binary"), 0755) + minerPath := testJoinPath(tmpDir, "test-miner") + testWriteFile(t, minerPath, []byte("miner binary"), 0o755) sharedSecret := []byte("full-secret-key!") bundlePassword := base64.StdEncoding.EncodeToString(sharedSecret) @@ -1272,7 +1233,7 @@ func TestWorker_HandleDeploy_FullBundle(t *testing.T) { Checksum: bundle.Checksum, Name: "full-deploy", } - msg, _ := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + msg, _ := NewMessage(MessageDeploy, "sender-id", identity.ID, payload) conn := &PeerConnection{SharedSecret: sharedSecret} @@ -1288,28 +1249,25 @@ func TestWorker_HandleDeploy_FullBundle(t *testing.T) { } } -func TestWorker_HandleDeploy_MinerBundle_WithProfileManager(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_MinerBundle_ProfileManager_Good(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, err := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, err := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) if err != nil { t.Fatalf("failed to create node manager: %v", err) } if err := nm.GenerateIdentity("test-worker", RoleWorker); err != nil { t.Fatalf("failed to generate identity: %v", err) } - pr, err := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, err := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) if err != nil { t.Fatalf("failed to create peer registry: %v", err) } transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() // Set a failing profile manager to exercise the warn-and-continue path worker.SetProfileManager(&mockProfileManagerFailing{}) @@ -1317,8 +1275,8 @@ func TestWorker_HandleDeploy_MinerBundle_WithProfileManager(t *testing.T) { identity := nm.GetIdentity() tmpDir := t.TempDir() - minerPath := filepath.Join(tmpDir, "test-miner") - os.WriteFile(minerPath, []byte("miner binary"), 0755) + minerPath := testJoinPath(tmpDir, "test-miner") + testWriteFile(t, minerPath, []byte("miner binary"), 0o755) profileJSON := []byte(`{"pool":"test:3333"}`) sharedSecret := []byte("profile-secret!!") @@ -1335,7 +1293,7 @@ func TestWorker_HandleDeploy_MinerBundle_WithProfileManager(t *testing.T) { Checksum: bundle.Checksum, Name: "deploy-with-profile", } - msg, _ := NewMessage(MsgDeploy, "sender-id", identity.ID, payload) + msg, _ := NewMessage(MessageDeploy, "sender-id", identity.ID, payload) conn := &PeerConnection{SharedSecret: sharedSecret} @@ -1352,24 +1310,21 @@ func TestWorker_HandleDeploy_MinerBundle_WithProfileManager(t *testing.T) { } } -func TestWorker_HandleDeploy_InvalidPayload(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleDeploy_InvalidPayload_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() dir := t.TempDir() - nm, _ := NewNodeManagerWithPaths( - filepath.Join(dir, "private.key"), - filepath.Join(dir, "node.json"), - ) + nm, _ := NewNodeManagerFromPaths(testNodeManagerPaths(dir)) nm.GenerateIdentity("test", RoleWorker) - pr, _ := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, _ := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() identity := nm.GetIdentity() // Create a message with invalid payload - msg, _ := NewMessage(MsgDeploy, "sender-id", identity.ID, "invalid-payload-not-struct") + msg, _ := NewMessage(MessageDeploy, "sender-id", identity.ID, "invalid-payload-not-struct") _, err := worker.handleDeploy(nil, msg) if err == nil { @@ -1377,34 +1332,35 @@ func TestWorker_HandleDeploy_InvalidPayload(t *testing.T) { } } -func TestWorker_HandleGetStats_NoIdentity(t *testing.T) { - cleanup := setupTestEnv(t) +func TestWorker_HandleStats_NoIdentity_Bad(t *testing.T) { + cleanup := setupTestEnvironment(t) defer cleanup() - nm, _ := NewNodeManagerWithPaths( - filepath.Join(t.TempDir(), "priv.key"), - filepath.Join(t.TempDir(), "node.json"), + tmpDir := t.TempDir() + nm, _ := NewNodeManagerFromPaths( + testJoinPath(tmpDir, "priv.key"), + testJoinPath(tmpDir, "node.json"), ) // Don't generate identity - pr, _ := NewPeerRegistryWithPath(t.TempDir() + "/peers.json") + pr, _ := NewPeerRegistryFromPath(testJoinPath(t.TempDir(), "peers.json")) transport := NewTransport(nm, pr, DefaultTransportConfig()) worker := NewWorker(nm, transport) - worker.DataDir = t.TempDir() + worker.DeploymentDirectory = t.TempDir() - msg, _ := NewMessage(MsgGetStats, "sender-id", "target-id", nil) - _, err := worker.handleGetStats(msg) + msg, _ := NewMessage(MessageGetStats, "sender-id", "target-id", nil) + _, err := worker.handleStats(msg) if err == nil { t.Error("expected error when identity is not initialized") } } -func TestWorker_HandleMessage_IntegrationViaWebSocket(t *testing.T) { +func TestWorker_HandleMessage_IntegrationViaWebSocket_Good(t *testing.T) { // Test HandleMessage through real WebSocket -- exercises error response sending path tp := setupTestTransportPair(t) worker := NewWorker(tp.ServerNode, tp.Server) // No miner manager set -- start_miner will fail and send error response - worker.RegisterWithTransport() + worker.RegisterOnTransport() controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) tp.connectClient(t) @@ -1414,15 +1370,15 @@ func TestWorker_HandleMessage_IntegrationViaWebSocket(t *testing.T) { // Send start_miner which will fail because no manager is set. // The worker should send an error response via the connection. - err := controller.StartRemoteMiner(serverID, "xmrig", "", json.RawMessage(`{}`)) + err := controller.StartRemoteMiner(serverID, "xmrig", "", RawMessage(`{}`)) // Should get an error back (either protocol error or operation failed) if err == nil { t.Error("expected error when worker has no miner manager") } } -func TestWorker_HandleMessage_GetStats_IntegrationViaWebSocket(t *testing.T) { - // HandleMessage dispatch for get_stats through real WebSocket +func TestWorker_HandleMessage_Stats_IntegrationViaWebSocket_Good(t *testing.T) { + // HandleMessage dispatch for stats through real WebSocket. tp := setupTestTransportPair(t) worker := NewWorker(tp.ServerNode, tp.Server) @@ -1443,7 +1399,7 @@ func TestWorker_HandleMessage_GetStats_IntegrationViaWebSocket(t *testing.T) { }, } worker.SetMinerManager(mm) - worker.RegisterWithTransport() + worker.RegisterOnTransport() controller := NewController(tp.ClientNode, tp.ClientReg, tp.Client) tp.connectClient(t) @@ -1453,7 +1409,7 @@ func TestWorker_HandleMessage_GetStats_IntegrationViaWebSocket(t *testing.T) { stats, err := controller.GetRemoteStats(serverID) if err != nil { - t.Fatalf("GetRemoteStats failed: %v", err) + t.Fatalf("RemoteStats failed: %v", err) } if len(stats.Miners) != 1 { t.Errorf("expected 1 miner, got %d", len(stats.Miners)) diff --git a/specs/logging.md b/specs/logging.md new file mode 100644 index 0000000..b3ba1ef --- /dev/null +++ b/specs/logging.md @@ -0,0 +1,80 @@ +# logging + +**Import:** `dappco.re/go/core/p2p/logging` + +**Files:** 1 + +## Types + +### `Level` +`type Level int` + +Log severity used by `Logger`. `String` renders the level name in upper case, and `ParseLevel` accepts `debug`, `info`, `warn` or `warning`, and `error`. + +### `Config` +```go +type Config struct { + Output io.Writer + Level Level + Component string +} +``` + +Configuration passed to `New`. + +- `Output`: destination for log lines. `New` falls back to stderr when this is `nil`. +- `Level`: minimum severity that will be emitted. +- `Component`: optional component label added to each line. + +### `Fields` +`type Fields map[string]any` + +Structured key/value fields passed to logging calls. When multiple `Fields` values are supplied, they are merged from left to right, so later maps override earlier keys. + +### `Logger` +`type Logger struct { /* unexported fields */ }` + +Structured logger with configurable output, severity filtering, and component scoping. Log writes are serialised by a mutex and are formatted as timestamped single-line records. + +## Functions + +### Top-level + +| Name | Signature | Description | +| --- | --- | --- | +| `DefaultConfig` | `func DefaultConfig() Config` | Returns the default configuration: stderr output, `LevelInfo`, and no component label. | +| `New` | `func New(config Config) *Logger` | Creates a `Logger` from `config`, substituting the default stderr writer when `config.Output` is `nil`. | +| `SetGlobal` | `func SetGlobal(l *Logger)` | Replaces the package-level global logger instance. | +| `GetGlobal` | `func GetGlobal() *Logger` | Returns the current package-level global logger. | +| `SetGlobalLevel` | `func SetGlobalLevel(level Level)` | Updates the minimum severity on the current global logger. | +| `Debug` | `func Debug(msg string, fields ...Fields)` | Logs a debug message through the global logger. | +| `Info` | `func Info(msg string, fields ...Fields)` | Logs an informational message through the global logger. | +| `Warn` | `func Warn(msg string, fields ...Fields)` | Logs a warning message through the global logger. | +| `Error` | `func Error(msg string, fields ...Fields)` | Logs an error message through the global logger. | +| `Debugf` | `func Debugf(format string, args ...any)` | Formats and logs a debug message through the global logger. | +| `Infof` | `func Infof(format string, args ...any)` | Formats and logs an informational message through the global logger. | +| `Warnf` | `func Warnf(format string, args ...any)` | Formats and logs a warning message through the global logger. | +| `Errorf` | `func Errorf(format string, args ...any)` | Formats and logs an error message through the global logger. | +| `ParseLevel` | `func ParseLevel(s string) (Level, error)` | Parses a text level into `Level`. Unknown strings return `LevelInfo` plus an error. | + +### `Level` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `String` | `func (l Level) String() string` | Returns `DEBUG`, `INFO`, `WARN`, `ERROR`, or `UNKNOWN` for out-of-range values. | + +### `*Logger` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `ComponentLogger` | `func (l *Logger) ComponentLogger(component string) *Logger` | Returns a new logger scoped to `component`. | +| `SetLevel` | `func (l *Logger) SetLevel(level Level)` | Sets the minimum severity that the logger will emit. | +| `GetLevel` | `func (l *Logger) GetLevel() Level` | Returns the current minimum severity. | +| `Debug` | `func (l *Logger) Debug(msg string, fields ...Fields)` | Logs `msg` at debug level after merging any supplied field maps. | +| `Info` | `func (l *Logger) Info(msg string, fields ...Fields)` | Logs `msg` at info level after merging any supplied field maps. | +| `Warn` | `func (l *Logger) Warn(msg string, fields ...Fields)` | Logs `msg` at warning level after merging any supplied field maps. | +| `Error` | `func (l *Logger) Error(msg string, fields ...Fields)` | Logs `msg` at error level after merging any supplied field maps. | +| `Debugf` | `func (l *Logger) Debugf(format string, args ...any)` | Formats and logs a debug message. | +| `Infof` | `func (l *Logger) Infof(format string, args ...any)` | Formats and logs an informational message. | +| `Warnf` | `func (l *Logger) Warnf(format string, args ...any)` | Formats and logs a warning message. | +| `Errorf` | `func (l *Logger) Errorf(format string, args ...any)` | Formats and logs an error message. | diff --git a/specs/node-levin.md b/specs/node-levin.md new file mode 100644 index 0000000..ca6b93c --- /dev/null +++ b/specs/node-levin.md @@ -0,0 +1,117 @@ +# levin + +**Import:** `dappco.re/go/core/p2p/node/levin` + +**Files:** 4 + +## Types + +### `Connection` +```go +type Connection struct { + MaxPayloadSize uint64 + ReadTimeout time.Duration + WriteTimeout time.Duration +} +``` + +Wrapper around `net.Conn` that reads and writes framed Levin packets. + +- `MaxPayloadSize`: per-connection payload ceiling enforced by `ReadPacket`. `NewConnection` starts with the package `MaxPayloadSize` default. +- `ReadTimeout`: deadline applied before each `ReadPacket` call. `NewConnection` sets this to `DefaultReadTimeout`. +- `WriteTimeout`: deadline applied before each write. `NewConnection` sets this to `DefaultWriteTimeout`. + +### `Header` +```go +type Header struct { + Signature uint64 + PayloadSize uint64 + ExpectResponse bool + Command uint32 + ReturnCode int32 + Flags uint32 + ProtocolVersion uint32 +} +``` + +Packed 33-byte Levin frame header. `EncodeHeader` writes these fields little-endian, and `DecodeHeader` validates the `Signature` and package-level `MaxPayloadSize`. + +### `Section` +`type Section map[string]Value` + +Portable-storage object used by the Levin encoder and decoder. `EncodeStorage` sorts keys alphabetically for deterministic output. + +### `Value` +```go +type Value struct { + Type uint8 +} +``` + +Tagged portable-storage value. The exported `Type` field identifies which internal scalar or array slot is populated; constructors such as `Uint64Value`, `StringValue`, and `ObjectArrayValue` create correctly-typed instances. + +## Functions + +### Top-level framing and storage functions + +| Name | Signature | Description | +| --- | --- | --- | +| `NewConnection` | `func NewConnection(conn net.Conn) *Connection` | Wraps `conn` with Levin defaults: 100 MB payload limit, 120 s read timeout, and 30 s write timeout. | +| `EncodeHeader` | `func EncodeHeader(h *Header) [HeaderSize]byte` | Serialises `h` into the fixed 33-byte Levin header format. | +| `DecodeHeader` | `func DecodeHeader(buf [HeaderSize]byte) (Header, error)` | Parses a 33-byte header, rejecting bad magic signatures and payload sizes above the package-level limit. | +| `PackVarint` | `func PackVarint(v uint64) []byte` | Encodes `v` using the epee portable-storage varint scheme where the low two bits of the first byte encode the width. | +| `UnpackVarint` | `func UnpackVarint(buf []byte) (value uint64, bytesConsumed int, err error)` | Decodes one portable-storage varint and returns the value, consumed width, and any truncation or overflow error. | +| `EncodeStorage` | `func EncodeStorage(s Section) ([]byte, error)` | Serialises a `Section` into portable-storage binary form, including the 9-byte storage header. | +| `DecodeStorage` | `func DecodeStorage(data []byte) (Section, error)` | Deserialises portable-storage binary data, validates the storage signatures and version, and reconstructs a `Section`. | + +### `Value` constructors + +| Name | Signature | Description | +| --- | --- | --- | +| `Uint64Value` | `func Uint64Value(v uint64) Value` | Creates a scalar `Value` with `TypeUint64`. | +| `Uint32Value` | `func Uint32Value(v uint32) Value` | Creates a scalar `Value` with `TypeUint32`. | +| `Uint16Value` | `func Uint16Value(v uint16) Value` | Creates a scalar `Value` with `TypeUint16`. | +| `Uint8Value` | `func Uint8Value(v uint8) Value` | Creates a scalar `Value` with `TypeUint8`. | +| `Int64Value` | `func Int64Value(v int64) Value` | Creates a scalar `Value` with `TypeInt64`. | +| `Int32Value` | `func Int32Value(v int32) Value` | Creates a scalar `Value` with `TypeInt32`. | +| `Int16Value` | `func Int16Value(v int16) Value` | Creates a scalar `Value` with `TypeInt16`. | +| `Int8Value` | `func Int8Value(v int8) Value` | Creates a scalar `Value` with `TypeInt8`. | +| `BoolValue` | `func BoolValue(v bool) Value` | Creates a scalar `Value` with `TypeBool`. | +| `DoubleValue` | `func DoubleValue(v float64) Value` | Creates a scalar `Value` with `TypeDouble`. | +| `StringValue` | `func StringValue(v []byte) Value` | Creates a scalar `Value` with `TypeString`. The byte slice is stored without copying. | +| `ObjectValue` | `func ObjectValue(s Section) Value` | Creates a scalar `Value` with `TypeObject` that wraps a nested `Section`. | +| `Uint64ArrayValue` | `func Uint64ArrayValue(vs []uint64) Value` | Creates an array `Value` tagged as `ArrayFlag | TypeUint64`. | +| `Uint32ArrayValue` | `func Uint32ArrayValue(vs []uint32) Value` | Creates an array `Value` tagged as `ArrayFlag | TypeUint32`. | +| `StringArrayValue` | `func StringArrayValue(vs [][]byte) Value` | Creates an array `Value` tagged as `ArrayFlag | TypeString`. | +| `ObjectArrayValue` | `func ObjectArrayValue(vs []Section) Value` | Creates an array `Value` tagged as `ArrayFlag | TypeObject`. | + +### `*Connection` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `WritePacket` | `func (c *Connection) WritePacket(cmd uint32, payload []byte, expectResponse bool) error` | Sends a Levin request or notification with `FlagRequest`, `ReturnOK`, and the current protocol version. Header and payload writes are serialised by an internal mutex. | +| `WriteResponse` | `func (c *Connection) WriteResponse(cmd uint32, payload []byte, returnCode int32) error` | Sends a Levin response with `FlagResponse` and the supplied return code. | +| `ReadPacket` | `func (c *Connection) ReadPacket() (Header, []byte, error)` | Applies the read deadline, reads exactly one header and payload, validates the frame, and enforces the connection-specific `MaxPayloadSize`. Empty payloads are returned as `nil` without allocation. | +| `Close` | `func (c *Connection) Close() error` | Closes the wrapped network connection. | +| `RemoteAddr` | `func (c *Connection) RemoteAddr() string` | Returns the wrapped connection's remote address string. | + +### `Value` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `AsUint64` | `func (v Value) AsUint64() (uint64, error)` | Returns the scalar `uint64` value or `ErrStorageTypeMismatch`. | +| `AsUint32` | `func (v Value) AsUint32() (uint32, error)` | Returns the scalar `uint32` value or `ErrStorageTypeMismatch`. | +| `AsUint16` | `func (v Value) AsUint16() (uint16, error)` | Returns the scalar `uint16` value or `ErrStorageTypeMismatch`. | +| `AsUint8` | `func (v Value) AsUint8() (uint8, error)` | Returns the scalar `uint8` value or `ErrStorageTypeMismatch`. | +| `AsInt64` | `func (v Value) AsInt64() (int64, error)` | Returns the scalar `int64` value or `ErrStorageTypeMismatch`. | +| `AsInt32` | `func (v Value) AsInt32() (int32, error)` | Returns the scalar `int32` value or `ErrStorageTypeMismatch`. | +| `AsInt16` | `func (v Value) AsInt16() (int16, error)` | Returns the scalar `int16` value or `ErrStorageTypeMismatch`. | +| `AsInt8` | `func (v Value) AsInt8() (int8, error)` | Returns the scalar `int8` value or `ErrStorageTypeMismatch`. | +| `AsBool` | `func (v Value) AsBool() (bool, error)` | Returns the scalar `bool` value or `ErrStorageTypeMismatch`. | +| `AsDouble` | `func (v Value) AsDouble() (float64, error)` | Returns the scalar `float64` value or `ErrStorageTypeMismatch`. | +| `AsString` | `func (v Value) AsString() ([]byte, error)` | Returns the scalar byte-string or `ErrStorageTypeMismatch`. | +| `AsSection` | `func (v Value) AsSection() (Section, error)` | Returns the nested `Section` or `ErrStorageTypeMismatch`. | +| `AsUint64Array` | `func (v Value) AsUint64Array() ([]uint64, error)` | Returns the `[]uint64` array or `ErrStorageTypeMismatch`. | +| `AsUint32Array` | `func (v Value) AsUint32Array() ([]uint32, error)` | Returns the `[]uint32` array or `ErrStorageTypeMismatch`. | +| `AsStringArray` | `func (v Value) AsStringArray() ([][]byte, error)` | Returns the `[][]byte` array or `ErrStorageTypeMismatch`. | +| `AsSectionArray` | `func (v Value) AsSectionArray() ([]Section, error)` | Returns the `[]Section` array or `ErrStorageTypeMismatch`. | diff --git a/specs/node.md b/specs/node.md new file mode 100644 index 0000000..d00f942 --- /dev/null +++ b/specs/node.md @@ -0,0 +1,237 @@ +# node + +**Import:** `dappco.re/go/core/p2p/node` + +**Files:** 12 + +## Types + +### Core types + +| Type | Definition | Description | +| --- | --- | --- | +| `BundleType` | `type BundleType string` | Deployment bundle kind used by `Bundle` and `BundleManifest`. | +| `Bundle` | `struct{ Type BundleType; Name string; Data []byte; Checksum string }` | Transferable deployment bundle. `Data` contains STIM-encrypted bytes or raw JSON, and `Checksum` is the SHA-256 hex digest of `Data`. | +| `BundleManifest` | `struct{ Type BundleType; Name string; Version string; MinerType string; ProfileIDs []string; CreatedAt string }` | Metadata describing the logical contents of a bundle payload. | +| `Controller` | `struct{ /* unexported fields */ }` | High-level controller client for remote peer operations. It keeps a pending-response map keyed by request ID and registers its internal response handler with the transport in `NewController`. | +| `Dispatcher` | `struct{ /* unexported fields */ }` | Concurrent-safe UEPS router. It applies the threat-score circuit breaker before dispatching to a handler map keyed by `IntentID`. | +| `IntentHandler` | `type IntentHandler func(pkt *ueps.ParsedPacket) error` | Callback signature used by `Dispatcher` for verified UEPS packets. | +| `Message` | `struct{ ID string; Type MessageType; From string; To string; Timestamp time.Time; Payload RawMessage; ReplyTo string }` | Generic P2P message envelope. `Payload` stores raw JSON, and `ReplyTo` links responses back to the originating request. | +| `MessageDeduplicator` | `struct{ /* unexported fields */ }` | TTL cache of recently seen message IDs used to suppress duplicates. | +| `MessageHandler` | `type MessageHandler func(conn *PeerConnection, msg *Message)` | Callback signature for decrypted inbound transport messages. | +| `MessageType` | `type MessageType string` | String message discriminator stored in `Message.Type`. | +| `NodeIdentity` | `struct{ ID string; Name string; PublicKey string; CreatedAt time.Time; Role NodeRole }` | Public node identity. `ID` is derived from the first 16 bytes of the SHA-256 hash of the public key. | +| `NodeManager` | `struct{ /* unexported fields */ }` | Identity and key manager that loads, generates, persists, and deletes X25519 node credentials. | +| `NodeRole` | `type NodeRole string` | Operational mode string for controller, worker, or dual-role nodes. | +| `Peer` | `struct{ ID string; Name string; PublicKey string; Address string; Role NodeRole; AddedAt time.Time; LastSeen time.Time; PingMS float64; Hops int; GeoKM float64; Score float64; Connected bool }` | Registry record for a remote node, including addressing, role, scoring metrics, and transient connection state. | +| `PeerAuthMode` | `type PeerAuthMode int` | Peer admission policy used by `PeerRegistry` when unknown peers attempt to connect. | +| `PeerConnection` | `struct{ Peer *Peer; WebSocketConnection *websocket.Conn; SharedSecret []byte; LastActivity time.Time }` | Active WebSocket session to a peer, including the negotiated shared secret and transport-owned write/close coordination. | +| `PeerRateLimiter` | `struct{ /* unexported fields */ }` | Per-peer token bucket limiter used by the transport hot path. | +| `PeerRegistry` | `struct{ /* unexported fields */ }` | Concurrent peer store with KD-tree selection, allowlist state, and debounced persistence to disk. | +| `ProtocolError` | `struct{ Code int; Message string }` | Structured remote error returned by protocol response helpers when a peer replies with `MsgError`. | +| `RawMessage` | `type RawMessage []byte` | Raw JSON payload bytes preserved without eager decoding. | +| `ResponseHandler` | `struct{}` | Helper for validating message envelopes and decoding typed responses. | +| `Transport` | `struct{ /* unexported fields */ }` | WebSocket transport that manages listeners, connections, encryption, deduplication, and shutdown coordination. | +| `TransportConfig` | `struct{ ListenAddress string; ListenAddr string; WebSocketPath string; TLSCertPath string; TLSKeyPath string; MaxConnections int; MaxMessageSize int64; PingInterval time.Duration; PongTimeout time.Duration }` | Listener, TLS, sizing, and keepalive settings for `Transport`. | +| `Worker` | `struct{ DataDir string /* plus unexported fields */ }` | Inbound command handler for worker nodes. It tracks uptime, optional miner/profile integrations, and the base directory used for deployments. | + +### Payload and integration types + +| Type | Definition | Description | +| --- | --- | --- | +| `DeployAckPayload` | `struct{ Success bool; Name string; Error string }` | Deployment acknowledgement with success state, optional deployed name, and optional error text. | +| `DeployPayload` | `struct{ BundleType string; Data []byte; Checksum string; Name string }` | Deployment request carrying STIM-encrypted bundle bytes (or other bundle data), checksum, and logical name. | +| `DisconnectPayload` | `struct{ Reason string; Code int }` | Disconnect notice with human-readable reason and optional disconnect code. | +| `ErrorPayload` | `struct{ Code int; Message string; Details string }` | Payload used by `MsgError` responses. | +| `LogsRequestPayload` | `struct{ MinerName string; Lines int; Since int64 }` | Request for miner console output, optionally bounded by line count and a Unix timestamp. | +| `HandshakeAckPayload` | `struct{ Identity NodeIdentity; ChallengeResponse []byte; Accepted bool; Reason string }` | Handshake reply containing the responder identity, optional challenge response, acceptance flag, and optional rejection reason. | +| `HandshakePayload` | `struct{ Identity NodeIdentity; Challenge []byte; Version string }` | Handshake request containing node identity, optional authentication challenge, and protocol version. | +| `LogsPayload` | `struct{ MinerName string; Lines []string; HasMore bool }` | Returned miner log lines plus an indicator that more lines are available. | +| `MinerAckPayload` | `struct{ Success bool; MinerName string; Error string }` | Acknowledgement for remote miner start and stop operations. | +| `MinerInstance` | `interface{ GetName() string; GetType() string; GetStats() (any, error); GetConsoleHistory(lines int) []string }` | Minimal runtime miner contract used by the worker to collect stats and logs without importing the mining package. | +| `MinerManager` | `interface{ StartMiner(minerType string, config any) (MinerInstance, error); StopMiner(name string) error; ListMiners() []MinerInstance; GetMiner(name string) (MinerInstance, error) }` | Worker-facing miner control contract. | +| `MinerStatsItem` | `struct{ Name string; Type string; Hashrate float64; Shares int; Rejected int; Uptime int; Pool string; Algorithm string; CPUThreads int }` | Protocol-facing summary of one miner's runtime statistics. | +| `PingPayload` | `struct{ SentAt int64 }` | Ping payload carrying the sender's millisecond timestamp. | +| `PongPayload` | `struct{ SentAt int64; ReceivedAt int64 }` | Ping response carrying the echoed send time and the receiver's millisecond timestamp. | +| `ProfileManager` | `interface{ GetProfile(id string) (any, error); SaveProfile(profile any) error }` | Worker-facing profile storage contract. | +| `StartMinerPayload` | `struct{ MinerType string; ProfileID string; Config RawMessage }` | Request to start a miner with an optional profile ID and raw JSON config override. | +| `StatsPayload` | `struct{ NodeID string; NodeName string; Miners []MinerStatsItem; Uptime int64 }` | Node-wide stats response with node identity fields, miner summaries, and uptime in seconds. | +| `StopMinerPayload` | `struct{ MinerName string }` | Request to stop a miner by name. | + +## Functions + +### Bundle, protocol, and utility functions + +| Name | Signature | Description | +| --- | --- | --- | +| `CreateProfileBundle` | `func CreateProfileBundle(profileJSON []byte, name string, password string) (*Bundle, error)` | Builds a TIM containing `profileJSON`, encrypts it to STIM with `password`, and returns a `BundleProfile` bundle with a SHA-256 checksum. | +| `CreateProfileBundleUnencrypted` | `func CreateProfileBundleUnencrypted(profileJSON []byte, name string) (*Bundle, error)` | Returns a `BundleProfile` bundle whose `Data` is the raw JSON payload and whose checksum is computed over that JSON. | +| `CreateMinerBundle` | `func CreateMinerBundle(minerPath string, profileJSON []byte, name string, password string) (*Bundle, error)` | Reads a miner binary, tars it, loads it into a TIM, optionally attaches `profileJSON`, encrypts the result to STIM, and returns a `BundleMiner` bundle. | +| `ExtractProfileBundle` | `func ExtractProfileBundle(bundle *Bundle, password string) ([]byte, error)` | Verifies `bundle.Checksum`, returns raw JSON directly when `bundle.Data` already looks like JSON, otherwise decrypts STIM and returns the embedded config bytes. | +| `ExtractMinerBundle` | `func ExtractMinerBundle(bundle *Bundle, password string, destDir string) (string, []byte, error)` | Verifies checksum, decrypts STIM, extracts the root filesystem tarball into `destDir`, and returns the first executable path plus the embedded config bytes. | +| `VerifyBundle` | `func VerifyBundle(bundle *Bundle) bool` | Returns whether `bundle.Checksum` matches the SHA-256 checksum of `bundle.Data`. | +| `StreamBundle` | `func StreamBundle(bundle *Bundle, w io.Writer) error` | JSON-encodes `bundle` and writes it to `w`. | +| `ReadBundle` | `func ReadBundle(r io.Reader) (*Bundle, error)` | Reads all bytes from `r`, JSON-decodes them into a `Bundle`, and returns the result. | +| `GenerateChallenge` | `func GenerateChallenge() ([]byte, error)` | Returns a new 32-byte random authentication challenge. | +| `SignChallenge` | `func SignChallenge(challenge []byte, sharedSecret []byte) []byte` | Computes the HMAC-SHA256 signature of `challenge` using `sharedSecret`. | +| `VerifyChallenge` | `func VerifyChallenge(challenge, response, sharedSecret []byte) bool` | Recomputes the expected challenge signature and compares it to `response` with `hmac.Equal`. | +| `IsProtocolVersionSupported` | `func IsProtocolVersionSupported(version string) bool` | Returns whether `version` is present in `SupportedProtocolVersions`. | +| `MarshalJSON` | `func MarshalJSON(v any) ([]byte, error)` | Encodes `v` with the core JSON helper, restores the package's historical no-EscapeHTML behaviour, and returns a caller-owned copy of the bytes. | +| `NewMessage` | `func NewMessage(msgType MessageType, from, to string, payload any) (*Message, error)` | Creates a message with a generated UUID, current timestamp, and JSON-encoded payload. A `nil` payload leaves `Payload` empty. | +| `NewErrorMessage` | `func NewErrorMessage(from, to string, code int, message string, replyTo string) (*Message, error)` | Creates a `MsgError` response containing an `ErrorPayload` and sets `ReplyTo` to the supplied request ID. | +| `ValidateResponse` | `func ValidateResponse(resp *Message, expectedType MessageType) error` | Convenience wrapper that delegates to `DefaultResponseHandler.ValidateResponse`. | +| `ParseResponse` | `func ParseResponse(resp *Message, expectedType MessageType, target any) error` | Convenience wrapper that delegates to `DefaultResponseHandler.ParseResponse`. | +| `IsProtocolError` | `func IsProtocolError(err error) bool` | Returns whether `err` is a `*ProtocolError`. | +| `GetProtocolErrorCode` | `func GetProtocolErrorCode(err error) int` | Returns `err.(*ProtocolError).Code` when `err` is a `*ProtocolError`, otherwise `0`. | + +### Constructors + +| Name | Signature | Description | +| --- | --- | --- | +| `DefaultTransportConfig` | `func DefaultTransportConfig() TransportConfig` | Returns the transport defaults: `ListenAddress=:9091`, `ListenAddr=:9091`, `WebSocketPath=/ws`, `MaxConnections=100`, `MaxMessageSize=1<<20`, `PingInterval=30s`, and `PongTimeout=10s`. | +| `NewController` | `func NewController(node *NodeManager, peers *PeerRegistry, transport *Transport) *Controller` | Creates a controller, initialises its pending-response map, and installs its response handler on `transport`. | +| `NewDispatcher` | `func NewDispatcher() *Dispatcher` | Creates an empty dispatcher with a debug-level component logger named `dispatcher`. | +| `NewMessageDeduplicator` | `func NewMessageDeduplicator(ttl time.Duration) *MessageDeduplicator` | Creates a deduplicator that retains message IDs for the supplied TTL. | +| `NewNodeManager` | `func NewNodeManager() (*NodeManager, error)` | Resolves XDG key and config paths, then loads an existing identity if present. | +| `NewNodeManagerFromPaths` | `func NewNodeManagerFromPaths(keyPath, configPath string) (*NodeManager, error)` | Creates a node manager from explicit key and config paths. | +| `NewPeerRateLimiter` | `func NewPeerRateLimiter(maxTokens, refillRate int) *PeerRateLimiter` | Creates a token bucket seeded with `maxTokens` and refilled at `refillRate` tokens per second. | +| `NewPeerRegistry` | `func NewPeerRegistry() (*PeerRegistry, error)` | Resolves the XDG peers path, loads any persisted peers, and builds the selection KD-tree. | +| `NewPeerRegistryFromPath` | `func NewPeerRegistryFromPath(peersPath string) (*PeerRegistry, error)` | Creates a peer registry bound to `peersPath` with open authentication mode and an empty public-key allowlist. | +| `NewTransport` | `func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportConfig) *Transport` | Creates a transport with lifecycle context, a 5-minute message deduplicator, and a WebSocket upgrader that only accepts local origins. | +| `NewWorker` | `func NewWorker(node *NodeManager, transport *Transport) *Worker` | Creates a worker, records its start time for uptime reporting, and defaults `DataDir` to `xdg.DataHome`. | + +### `RawMessage` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `MarshalJSON` | `func (m RawMessage) MarshalJSON() ([]byte, error)` | Emits raw payload bytes unchanged, or `null` when the receiver is `nil`. | +| `UnmarshalJSON` | `func (m *RawMessage) UnmarshalJSON(data []byte) error` | Copies `data` into the receiver without decoding it. Passing a `nil` receiver returns an error. | + +### `*Message` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `Reply` | `func (m *Message) Reply(msgType MessageType, payload any) (*Message, error)` | Creates a reply message that swaps `From` and `To` and sets `ReplyTo` to `m.ID`. | +| `ParsePayload` | `func (m *Message) ParsePayload(v any) error` | JSON-decodes `Payload` into `v`. A `nil` payload is treated as a no-op. | + +### `*NodeManager` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `HasIdentity` | `func (n *NodeManager) HasIdentity() bool` | Returns whether an identity is currently loaded in memory. | +| `GetIdentity` | `func (n *NodeManager) GetIdentity() *NodeIdentity` | Returns a copy of the loaded public identity, or `nil` when no identity is initialised. | +| `GenerateIdentity` | `func (n *NodeManager) GenerateIdentity(name string, role NodeRole) error` | Generates a new X25519 keypair, derives the node ID from the public key hash, stores the public identity, and persists both key and config to disk. | +| `DeriveSharedSecret` | `func (n *NodeManager) DeriveSharedSecret(peerPubKeyBase64 string) ([]byte, error)` | Decodes the peer public key, performs X25519 ECDH with the node private key, hashes the result with SHA-256, and returns the symmetric key material. | +| `Delete` | `func (n *NodeManager) Delete() error` | Removes persisted key/config files when they exist and clears the in-memory identity and key state. | + +### `*Controller` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `GetRemoteStats` | `func (c *Controller) GetRemoteStats(peerID string) (*StatsPayload, error)` | Sends `MsgGetStats` to `peerID`, waits for a response, and decodes the resulting `MsgStats` payload. | +| `StartRemoteMiner` | `func (c *Controller) StartRemoteMiner(peerID, minerType, profileID string, configOverride RawMessage) error` | Validates `minerType`, sends `MsgStartMiner`, waits for `MsgMinerAck`, and returns an error when the remote ack reports failure. | +| `StopRemoteMiner` | `func (c *Controller) StopRemoteMiner(peerID, minerName string) error` | Sends `MsgStopMiner`, waits for `MsgMinerAck`, and returns an error when the remote ack reports failure. | +| `GetRemoteLogs` | `func (c *Controller) GetRemoteLogs(peerID, minerName string, lines int) ([]string, error)` | Requests `MsgLogs` from a remote miner and returns the decoded log lines. | +| `GetAllStats` | `func (c *Controller) GetAllStats() map[string]*StatsPayload` | Requests stats from every currently connected peer and returns the successful responses keyed by peer ID. | +| `PingPeer` | `func (c *Controller) PingPeer(peerID string) (float64, error)` | Sends a ping, measures round-trip time in milliseconds, and updates the peer registry metrics for that peer. | +| `ConnectToPeer` | `func (c *Controller) ConnectToPeer(peerID string) error` | Looks up `peerID` in the registry and establishes a transport connection. | +| `DisconnectFromPeer` | `func (c *Controller) DisconnectFromPeer(peerID string) error` | Gracefully closes an active transport connection for `peerID`. | + +### `*Dispatcher` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `RegisterHandler` | `func (d *Dispatcher) RegisterHandler(intentID byte, handler IntentHandler)` | Associates `handler` with `intentID`, replacing any existing handler for that intent. | +| `Handlers` | `func (d *Dispatcher) Handlers() iter.Seq2[byte, IntentHandler]` | Returns an iterator over the currently registered intent handlers. | +| `Dispatch` | `func (d *Dispatcher) Dispatch(pkt *ueps.ParsedPacket) error` | Rejects `nil` packets, drops packets whose `ThreatScore` exceeds `ThreatScoreThreshold`, rejects unknown intents, and otherwise invokes the matching handler. | + +### `*MessageDeduplicator` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `IsDuplicate` | `func (d *MessageDeduplicator) IsDuplicate(msgID string) bool` | Returns whether `msgID` is still present in the deduplicator's TTL window. | +| `Mark` | `func (d *MessageDeduplicator) Mark(msgID string)` | Records `msgID` with the current time. | +| `Cleanup` | `func (d *MessageDeduplicator) Cleanup()` | Removes expired message IDs whose age exceeds the configured TTL. | + +### `*PeerRateLimiter` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `Allow` | `func (r *PeerRateLimiter) Allow() bool` | Refills tokens according to elapsed whole seconds and returns whether one token could be consumed for the current message. | + +### `*PeerRegistry` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `SetAuthMode` | `func (r *PeerRegistry) SetAuthMode(mode PeerAuthMode)` | Replaces the current peer admission mode. | +| `GetAuthMode` | `func (r *PeerRegistry) GetAuthMode() PeerAuthMode` | Returns the current peer admission mode. | +| `AllowPublicKey` | `func (r *PeerRegistry) AllowPublicKey(publicKey string)` | Adds `publicKey` to the explicit allowlist. | +| `RevokePublicKey` | `func (r *PeerRegistry) RevokePublicKey(publicKey string)` | Removes `publicKey` from the explicit allowlist. | +| `IsPublicKeyAllowed` | `func (r *PeerRegistry) IsPublicKeyAllowed(publicKey string) bool` | Returns whether `publicKey` is currently allowlisted. | +| `IsPeerAllowed` | `func (r *PeerRegistry) IsPeerAllowed(peerID string, publicKey string) bool` | Returns `true` in open mode, or in allowlist mode when the peer is already registered or the supplied public key is allowlisted. | +| `ListAllowedPublicKeys` | `func (r *PeerRegistry) ListAllowedPublicKeys() []string` | Returns a slice snapshot of allowlisted public keys. | +| `AllowedPublicKeys` | `func (r *PeerRegistry) AllowedPublicKeys() iter.Seq[string]` | Returns an iterator over allowlisted public keys. | +| `AddPeer` | `func (r *PeerRegistry) AddPeer(peer *Peer) error` | Validates the peer, sets `AddedAt` when zero, defaults `Score` to `50`, adds it to the registry, rebuilds the KD-tree, and schedules a debounced save. | +| `UpdatePeer` | `func (r *PeerRegistry) UpdatePeer(peer *Peer) error` | Replaces an existing peer entry, rebuilds the KD-tree, and schedules a debounced save. | +| `RemovePeer` | `func (r *PeerRegistry) RemovePeer(id string) error` | Deletes an existing peer, rebuilds the KD-tree, and schedules a debounced save. | +| `GetPeer` | `func (r *PeerRegistry) GetPeer(id string) *Peer` | Returns a copy of the peer identified by `id`, or `nil` when absent. | +| `ListPeers` | `func (r *PeerRegistry) ListPeers() []*Peer` | Returns a slice of peer copies. | +| `Peers` | `func (r *PeerRegistry) Peers() iter.Seq[*Peer]` | Returns an iterator over peer copies so callers cannot mutate registry state directly. | +| `UpdateMetrics` | `func (r *PeerRegistry) UpdateMetrics(id string, pingMS, geoKM float64, hops int) error` | Updates latency, distance, hop count, and `LastSeen`, rebuilds the KD-tree, and schedules a debounced save. | +| `UpdateScore` | `func (r *PeerRegistry) UpdateScore(id string, score float64) error` | Clamps `score` into `[0,100]`, updates the peer, rebuilds the KD-tree, and schedules a debounced save. | +| `SetConnected` | `func (r *PeerRegistry) SetConnected(id string, connected bool)` | Updates the connection flag for a peer and refreshes `LastSeen` when marking the peer connected. | +| `RecordSuccess` | `func (r *PeerRegistry) RecordSuccess(id string)` | Increases the peer score by `ScoreSuccessIncrement` up to `ScoreMaximum`, updates `LastSeen`, and schedules a save. | +| `RecordFailure` | `func (r *PeerRegistry) RecordFailure(id string)` | Decreases the peer score by `ScoreFailureDecrement` down to `ScoreMinimum` and schedules a save. | +| `RecordTimeout` | `func (r *PeerRegistry) RecordTimeout(id string)` | Decreases the peer score by `ScoreTimeoutDecrement` down to `ScoreMinimum` and schedules a save. | +| `GetPeersByScore` | `func (r *PeerRegistry) GetPeersByScore() []*Peer` | Returns peers sorted by descending score. | +| `PeersByScore` | `func (r *PeerRegistry) PeersByScore() iter.Seq[*Peer]` | Returns an iterator over peers sorted by descending score. | +| `SelectOptimalPeer` | `func (r *PeerRegistry) SelectOptimalPeer() *Peer` | Uses the KD-tree to find the peer closest to the ideal metrics vector and returns a copy of that peer. | +| `SelectNearestPeers` | `func (r *PeerRegistry) SelectNearestPeers(n int) []*Peer` | Returns copies of the `n` nearest peers from the KD-tree according to the weighted metrics. | +| `GetConnectedPeers` | `func (r *PeerRegistry) GetConnectedPeers() []*Peer` | Returns a slice of copies for peers whose `Connected` flag is true. | +| `ConnectedPeers` | `func (r *PeerRegistry) ConnectedPeers() iter.Seq[*Peer]` | Returns an iterator over connected peer copies. | +| `Count` | `func (r *PeerRegistry) Count() int` | Returns the number of registered peers. | +| `Close` | `func (r *PeerRegistry) Close() error` | Stops any pending save timer and immediately flushes dirty peer data to disk when needed. | + +### `*ResponseHandler` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `ValidateResponse` | `func (h *ResponseHandler) ValidateResponse(resp *Message, expectedType MessageType) error` | Rejects `nil` responses, unwraps `MsgError` into a `ProtocolError`, and checks that `resp.Type` matches `expectedType`. | +| `ParseResponse` | `func (h *ResponseHandler) ParseResponse(resp *Message, expectedType MessageType, target any) error` | Runs `ValidateResponse` and then decodes the payload into `target` when `target` is not `nil`. | + +### `*Transport` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `Start` | `func (t *Transport) Start() error` | Starts the WebSocket listener and begins accepting inbound peer connections. | +| `Stop` | `func (t *Transport) Stop() error` | Cancels transport context, closes active connections, and shuts down the listener. | +| `OnMessage` | `func (t *Transport) OnMessage(handler MessageHandler)` | Installs the inbound message callback used after decryption. It must be set before `Start` to avoid races. | +| `Connect` | `func (t *Transport) Connect(peer *Peer) (*PeerConnection, error)` | Dials `peer`, performs the handshake, derives the shared secret, and returns the active peer connection. | +| `Send` | `func (t *Transport) Send(peerID string, msg *Message) error` | Looks up the active connection for `peerID` and sends `msg` over it. | +| `Connections` | `func (t *Transport) Connections() iter.Seq[*PeerConnection]` | Returns an iterator over active peer connections. | +| `Broadcast` | `func (t *Transport) Broadcast(msg *Message) error` | Sends `msg` to every connected peer except the sender identified by `msg.From`. | +| `GetConnection` | `func (t *Transport) GetConnection(peerID string) *PeerConnection` | Returns the active connection for `peerID`, or `nil` when not connected. | +| `ConnectedPeerCount` | `func (t *Transport) ConnectedPeerCount() int` | Returns the number of active peer connections. | + +### `*PeerConnection` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `Send` | `func (pc *PeerConnection) Send(msg *Message) error` | Encrypts and writes a message over the WebSocket connection. | +| `Close` | `func (pc *PeerConnection) Close() error` | Closes the underlying connection once and releases transport state for that peer. | +| `GracefulClose` | `func (pc *PeerConnection) GracefulClose(reason string, code int) error` | Sends a `MsgDisconnect` notification before closing the connection. | + +### `*Worker` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `SetMinerManager` | `func (w *Worker) SetMinerManager(manager MinerManager)` | Installs the miner manager used for start, stop, stats, and log requests. | +| `SetProfileManager` | `func (w *Worker) SetProfileManager(manager ProfileManager)` | Installs the profile manager used during deployment handling. | +| `HandleMessage` | `func (w *Worker) HandleMessage(conn *PeerConnection, msg *Message)` | Dispatches supported message types, sends normal replies on success, and emits `MsgError` responses when a handled command fails. | +| `RegisterOnTransport` | `func (w *Worker) RegisterOnTransport()` | Registers `HandleMessage` as the transport's inbound message callback. | + +### `*ProtocolError` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `Error` | `func (e *ProtocolError) Error() string` | Formats the remote error as `remote error (): `. | diff --git a/specs/ueps.md b/specs/ueps.md new file mode 100644 index 0000000..0be6353 --- /dev/null +++ b/specs/ueps.md @@ -0,0 +1,67 @@ +# ueps + +**Import:** `dappco.re/go/core/p2p/ueps` + +**Files:** 2 + +## Types + +### `UEPSHeader` +```go +type UEPSHeader struct { + Version uint8 + CurrentLayer uint8 + TargetLayer uint8 + IntentID uint8 + ThreatScore uint16 +} +``` + +Routing and integrity metadata carried in UEPS frames. + +- `Version`: protocol version byte. `NewBuilder` initialises this to `0x09`. +- `CurrentLayer`: source layer byte. `NewBuilder` initialises this to `5`. +- `TargetLayer`: destination layer byte. `NewBuilder` initialises this to `5`. +- `IntentID`: semantic intent token. +- `ThreatScore`: unsigned 16-bit risk score. + +### `PacketBuilder` +```go +type PacketBuilder struct { + Header UEPSHeader + Payload []byte +} +``` + +Mutable packet assembly state used to produce a signed UEPS frame. + +- `Header`: TLV metadata written before the payload. +- `Payload`: raw payload bytes appended as the terminal TLV. + +### `ParsedPacket` +```go +type ParsedPacket struct { + Header UEPSHeader + Payload []byte +} +``` + +Verified packet returned by `ReadAndVerify`. + +- `Header`: decoded UEPS header values reconstructed from the stream. +- `Payload`: payload bytes from the `TagPayload` TLV. + +## Functions + +### Top-level + +| Name | Signature | Description | +| --- | --- | --- | +| `NewBuilder` | `func NewBuilder(intentID uint8, payload []byte) *PacketBuilder` | Creates a packet builder with default header values (`Version=0x09`, `CurrentLayer=5`, `TargetLayer=5`, `ThreatScore=0`) and the supplied intent and payload. | +| `ReadAndVerify` | `func ReadAndVerify(r *bufio.Reader, sharedSecret []byte) (*ParsedPacket, error)` | Reads TLVs from `r` until `TagPayload`, reconstructs the signed header bytes, and verifies the HMAC-SHA256 over headers plus payload using `sharedSecret`. Missing signatures, truncated data, and HMAC mismatches return errors. | + +### `*PacketBuilder` methods + +| Name | Signature | Description | +| --- | --- | --- | +| `MarshalAndSign` | `func (p *PacketBuilder) MarshalAndSign(sharedSecret []byte) ([]byte, error)` | Serialises header TLVs `0x01` through `0x05`, signs those bytes plus `Payload` with HMAC-SHA256, appends the `TagHMAC` TLV, then writes the terminal `TagPayload` TLV. All TLV lengths are encoded as 2-byte big-endian unsigned integers. | diff --git a/ueps/packet.go b/ueps/packet.go index 0fb590c..ad50f9c 100644 --- a/ueps/packet.go +++ b/ueps/packet.go @@ -7,120 +7,105 @@ import ( "encoding/binary" "io" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" ) -// TLV Types const ( - TagVersion = 0x01 - TagCurrentLay = 0x02 - TagTargetLay = 0x03 - TagIntent = 0x04 - TagThreatScore = 0x05 - TagHMAC = 0x06 // The Signature - TagPayload = 0xFF // The Data + TagVersion = 0x01 + TagCurrentLayer = 0x02 + TagTargetLayer = 0x03 + TagIntent = 0x04 + TagThreatScore = 0x05 + TagHMAC = 0x06 + TagPayload = 0xFF ) -// UEPSHeader represents the conscious routing metadata +// header := UEPSHeader{Version: 0x09, CurrentLayer: 5, TargetLayer: 5, IntentID: 0x01} type UEPSHeader struct { - Version uint8 // Default 0x09 + Version uint8 CurrentLayer uint8 TargetLayer uint8 - IntentID uint8 // Semantic Token - ThreatScore uint16 // 0-65535 + IntentID uint8 + ThreatScore uint16 } -// PacketBuilder helps construct a signed UEPS frame +// builder := NewBuilder(0x20, []byte("hello")) type PacketBuilder struct { Header UEPSHeader Payload []byte } -// NewBuilder creates a packet context for a specific intent +// builder := NewBuilder(0x20, []byte("hello")) func NewBuilder(intentID uint8, payload []byte) *PacketBuilder { return &PacketBuilder{ Header: UEPSHeader{ - Version: 0x09, // IPv9 - CurrentLayer: 5, // Application - TargetLayer: 5, // Application + Version: 0x09, + CurrentLayer: 5, + TargetLayer: 5, IntentID: intentID, - ThreatScore: 0, // Assumed innocent until proven guilty + ThreatScore: 0, }, Payload: payload, } } -// MarshalAndSign generates the final byte stream using the shared secret +// frame, err := builder.MarshalAndSign(sharedSecret) func (p *PacketBuilder) MarshalAndSign(sharedSecret []byte) ([]byte, error) { - buf := new(bytes.Buffer) + buffer := new(bytes.Buffer) - // 1. Write Standard Header Tags (0x01 - 0x05) - // We write these first because they are part of what we sign. - if err := writeTLV(buf, TagVersion, []byte{p.Header.Version}); err != nil { + if err := writeTLV(buffer, TagVersion, []byte{p.Header.Version}); err != nil { return nil, err } - if err := writeTLV(buf, TagCurrentLay, []byte{p.Header.CurrentLayer}); err != nil { + if err := writeTLV(buffer, TagCurrentLayer, []byte{p.Header.CurrentLayer}); err != nil { return nil, err } - if err := writeTLV(buf, TagTargetLay, []byte{p.Header.TargetLayer}); err != nil { + if err := writeTLV(buffer, TagTargetLayer, []byte{p.Header.TargetLayer}); err != nil { return nil, err } - if err := writeTLV(buf, TagIntent, []byte{p.Header.IntentID}); err != nil { - return nil, err - } - - // Threat Score is uint16, needs binary packing - tsBuf := make([]byte, 2) - binary.BigEndian.PutUint16(tsBuf, p.Header.ThreatScore) - if err := writeTLV(buf, TagThreatScore, tsBuf); err != nil { + if err := writeTLV(buffer, TagIntent, []byte{p.Header.IntentID}); err != nil { + return nil, err + } + + threatScoreBytes := make([]byte, 2) + binary.BigEndian.PutUint16(threatScoreBytes, p.Header.ThreatScore) + if err := writeTLV(buffer, TagThreatScore, threatScoreBytes); err != nil { return nil, err } - // 2. Calculate HMAC - // The signature covers: Existing Header TLVs + The Payload - // It does NOT cover the HMAC TLV tag itself (obviously) mac := hmac.New(sha256.New, sharedSecret) - mac.Write(buf.Bytes()) // The headers so far - mac.Write(p.Payload) // The data + mac.Write(buffer.Bytes()) + mac.Write(p.Payload) signature := mac.Sum(nil) - // 3. Write HMAC TLV (0x06) - // Length is 32 bytes for SHA256 - if err := writeTLV(buf, TagHMAC, signature); err != nil { + if err := writeTLV(buffer, TagHMAC, signature); err != nil { return nil, err } - // 4. Write Payload TLV (0xFF) - // Fixed: Now uses writeTLV which provides a 2-byte length prefix. - // This prevents the io.ReadAll DoS and allows multiple packets in a stream. - if err := writeTLV(buf, TagPayload, p.Payload); err != nil { + if err := writeTLV(buffer, TagPayload, p.Payload); err != nil { return nil, err } - return buf.Bytes(), nil + return buffer.Bytes(), nil } -// Helper to write a simple TLV. -// Now uses 2-byte big-endian length (uint16) to support up to 64KB payloads. -func writeTLV(w io.Writer, tag uint8, value []byte) error { - // Check length constraint (2 byte length = max 65535 bytes) +// writeTLV(&buffer, TagPayload, []byte("hello")) +func writeTLV(writer io.Writer, tag uint8, value []byte) error { if len(value) > 65535 { - return coreerr.E("ueps.writeTLV", "TLV value too large for 2-byte length header", nil) + return core.E("ueps.writeTLV", "TLV value too large for 2-byte length header", nil) } - if _, err := w.Write([]byte{tag}); err != nil { + if _, err := writer.Write([]byte{tag}); err != nil { return err } - + lenBuf := make([]byte, 2) binary.BigEndian.PutUint16(lenBuf, uint16(len(value))) - if _, err := w.Write(lenBuf); err != nil { + if _, err := writer.Write(lenBuf); err != nil { return err } - - if _, err := w.Write(value); err != nil { + + if _, err := writer.Write(value); err != nil { return err } return nil } - diff --git a/ueps/packet_coverage_test.go b/ueps/packet_coverage_test.go index 6e1595c..ffd2572 100644 --- a/ueps/packet_coverage_test.go +++ b/ueps/packet_coverage_test.go @@ -6,10 +6,10 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/binary" - "errors" "io" "testing" + core "dappco.re/go/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -22,7 +22,7 @@ type failWriter struct { func (f *failWriter) Write(p []byte) (int, error) { if f.remaining <= 0 { - return 0, errors.New("write failed") + return 0, core.NewError("write failed") } f.remaining-- return len(p), nil @@ -30,7 +30,7 @@ func (f *failWriter) Write(p []byte) (int, error) { // TestWriteTLV_TagWriteFails verifies writeTLV returns an error // when the very first Write (the tag byte) fails. -func TestWriteTLV_TagWriteFails(t *testing.T) { +func TestPacketCoverage_WriteTLV_TagWriteFails_Bad(t *testing.T) { w := &failWriter{remaining: 0} err := writeTLV(w, TagVersion, []byte{0x09}) @@ -40,7 +40,7 @@ func TestWriteTLV_TagWriteFails(t *testing.T) { // TestWriteTLV_LengthWriteFails verifies writeTLV returns an error // when the second Write (the length byte) fails. -func TestWriteTLV_LengthWriteFails(t *testing.T) { +func TestPacketCoverage_WriteTLV_LengthWriteFails_Bad(t *testing.T) { w := &failWriter{remaining: 1} err := writeTLV(w, TagVersion, []byte{0x09}) @@ -50,7 +50,7 @@ func TestWriteTLV_LengthWriteFails(t *testing.T) { // TestWriteTLV_ValueWriteFails verifies writeTLV returns an error // when the third Write (the value bytes) fails. -func TestWriteTLV_ValueWriteFails(t *testing.T) { +func TestPacketCoverage_WriteTLV_ValueWriteFails_Bad(t *testing.T) { w := &failWriter{remaining: 2} err := writeTLV(w, TagVersion, []byte{0x09}) @@ -81,7 +81,7 @@ func (r *errorAfterNReader) Read(p []byte) (int, error) { // TestReadAndVerify_PayloadReadError exercises the error branch at // reader.go:51-53 where io.ReadAll fails after the 0xFF tag byte // has been successfully read. -func TestReadAndVerify_PayloadReadError(t *testing.T) { +func TestPacketCoverage_ReadAndVerify_PayloadReadError_Bad(t *testing.T) { // Build a valid packet so we have genuine TLV headers + HMAC. payload := []byte("coverage test") builder := NewBuilder(0x20, payload) @@ -104,7 +104,7 @@ func TestReadAndVerify_PayloadReadError(t *testing.T) { prefix := frame[:payloadTagIdx+1] r := &errorAfterNReader{ data: prefix, - err: errors.New("connection reset"), + err: core.NewError("connection reset"), } _, err = ReadAndVerify(bufio.NewReader(r), testSecret) @@ -115,7 +115,7 @@ func TestReadAndVerify_PayloadReadError(t *testing.T) { // TestReadAndVerify_PayloadReadError_EOF ensures that a truncated payload // (missing bytes after TagPayload) is handled as an I/O error (UnexpectedEOF) // because ReadAndVerify now uses io.ReadFull with the expected length prefix. -func TestReadAndVerify_PayloadReadError_EOF(t *testing.T) { +func TestPacketCoverage_ReadAndVerify_PayloadReadError_EOF_Bad(t *testing.T) { payload := []byte("eof test") builder := NewBuilder(0x20, payload) frame, err := builder.MarshalAndSign(testSecret) @@ -141,7 +141,7 @@ func TestReadAndVerify_PayloadReadError_EOF(t *testing.T) { // TestWriteTLV_AllWritesSucceed confirms the happy path still works // after exercising all error branches — a simple sanity check using // failWriter with enough remaining writes. -func TestWriteTLV_AllWritesSucceed(t *testing.T) { +func TestPacketCoverage_WriteTLV_AllWritesSucceed_Good(t *testing.T) { var buf bytes.Buffer err := writeTLV(&buf, TagVersion, []byte{0x09}) require.NoError(t, err) @@ -149,10 +149,9 @@ func TestWriteTLV_AllWritesSucceed(t *testing.T) { assert.Equal(t, []byte{TagVersion, 0x00, 0x01, 0x09}, buf.Bytes()) } - // TestWriteTLV_FailWriterTable runs the three failure scenarios in // a table-driven fashion for completeness. -func TestWriteTLV_FailWriterTable(t *testing.T) { +func TestPacketCoverage_WriteTLV_FailWriterTable_Bad(t *testing.T) { tests := []struct { name string remaining int @@ -177,14 +176,14 @@ func TestWriteTLV_FailWriterTable(t *testing.T) { // HMAC computation independently of the builder. This also serves as // a cross-check that our errorAfterNReader is not accidentally // corrupting the prefix bytes. -func TestReadAndVerify_ManualPacket_PayloadReadError(t *testing.T) { +func TestPacketCoverage_ReadAndVerify_ManualPacket_PayloadReadError_Bad(t *testing.T) { payload := []byte("manual test") // Build header TLVs var hdr bytes.Buffer require.NoError(t, writeTLV(&hdr, TagVersion, []byte{0x09})) - require.NoError(t, writeTLV(&hdr, TagCurrentLay, []byte{5})) - require.NoError(t, writeTLV(&hdr, TagTargetLay, []byte{5})) + require.NoError(t, writeTLV(&hdr, TagCurrentLayer, []byte{5})) + require.NoError(t, writeTLV(&hdr, TagTargetLayer, []byte{5})) require.NoError(t, writeTLV(&hdr, TagIntent, []byte{0x20})) tsBuf := make([]byte, 2) binary.BigEndian.PutUint16(tsBuf, 0) @@ -212,3 +211,32 @@ func TestReadAndVerify_ManualPacket_PayloadReadError(t *testing.T) { require.Error(t, err) assert.Equal(t, io.ErrUnexpectedEOF, err) } + +// TestReadAndVerify_MalformedHeaderTLV_Bad verifies malformed header values +// return an error instead of panicking during TLV reconstruction. +func TestPacketCoverage_ReadAndVerify_MalformedHeaderTLV_Bad(t *testing.T) { + tests := []struct { + name string + frame []byte + wantErr string + }{ + { + name: "ZeroLengthVersion", + frame: []byte{TagVersion, 0x00, 0x00}, + wantErr: "malformed version TLV", + }, + { + name: "ShortThreatScore", + frame: []byte{TagThreatScore, 0x00, 0x01, 0xFF}, + wantErr: "malformed threat score TLV", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := ReadAndVerify(bufio.NewReader(bytes.NewReader(tc.frame)), testSecret) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + }) + } +} diff --git a/ueps/packet_test.go b/ueps/packet_test.go index cff2f39..449819e 100644 --- a/ueps/packet_test.go +++ b/ueps/packet_test.go @@ -7,14 +7,15 @@ import ( "crypto/sha256" "encoding/binary" "io" - "strings" "testing" + + core "dappco.re/go/core" ) // testSecret is a deterministic shared secret for reproducible tests. var testSecret = []byte("test-shared-secret-32-bytes!!!!!") -func TestPacketBuilder_RoundTrip(t *testing.T) { +func TestPacket_Builder_RoundTrip_Ugly(t *testing.T) { tests := []struct { name string intentID uint8 @@ -84,7 +85,7 @@ func TestPacketBuilder_RoundTrip(t *testing.T) { } } -func TestHMACVerification_TamperedPayload(t *testing.T) { +func TestPacket_HMACVerification_TamperedPayload_Bad(t *testing.T) { builder := NewBuilder(0x20, []byte("original payload")) frame, err := builder.MarshalAndSign(testSecret) if err != nil { @@ -100,12 +101,12 @@ func TestHMACVerification_TamperedPayload(t *testing.T) { if err == nil { t.Fatal("Expected HMAC mismatch error for tampered payload") } - if !strings.Contains(err.Error(), "integrity violation") { + if !core.Contains(err.Error(), "integrity violation") { t.Errorf("Expected integrity violation error, got: %v", err) } } -func TestHMACVerification_TamperedHeader(t *testing.T) { +func TestPacket_HMACVerification_TamperedHeader_Bad(t *testing.T) { builder := NewBuilder(0x20, []byte("test payload")) frame, err := builder.MarshalAndSign(testSecret) if err != nil { @@ -122,12 +123,12 @@ func TestHMACVerification_TamperedHeader(t *testing.T) { if err == nil { t.Fatal("Expected HMAC mismatch error for tampered header") } - if !strings.Contains(err.Error(), "integrity violation") { + if !core.Contains(err.Error(), "integrity violation") { t.Errorf("Expected integrity violation error, got: %v", err) } } -func TestHMACVerification_WrongSharedSecret(t *testing.T) { +func TestPacket_HMACVerification_WrongSharedSecret_Bad(t *testing.T) { builder := NewBuilder(0x20, []byte("secret data")) frame, err := builder.MarshalAndSign([]byte("key-A-used-for-signing!!!!!!!!!!")) if err != nil { @@ -138,12 +139,12 @@ func TestHMACVerification_WrongSharedSecret(t *testing.T) { if err == nil { t.Fatal("Expected HMAC mismatch error for wrong shared secret") } - if !strings.Contains(err.Error(), "integrity violation") { + if !core.Contains(err.Error(), "integrity violation") { t.Errorf("Expected integrity violation error, got: %v", err) } } -func TestEmptyPayload(t *testing.T) { +func TestPacket_EmptyPayload_Ugly(t *testing.T) { tests := []struct { name string payload []byte @@ -175,7 +176,7 @@ func TestEmptyPayload(t *testing.T) { } } -func TestMaxThreatScoreBoundary(t *testing.T) { +func TestPacket_MaxThreatScoreBoundary_Ugly(t *testing.T) { builder := NewBuilder(0x20, []byte("threat boundary")) builder.Header.ThreatScore = 65535 // uint16 max @@ -194,14 +195,14 @@ func TestMaxThreatScoreBoundary(t *testing.T) { } } -func TestMissingHMACTag(t *testing.T) { +func TestPacket_MissingHMACTag_Bad(t *testing.T) { // Craft a packet manually: header TLVs + payload tag, but no HMAC (0x06) var buf bytes.Buffer // Write header TLVs writeTLV(&buf, TagVersion, []byte{0x09}) - writeTLV(&buf, TagCurrentLay, []byte{5}) - writeTLV(&buf, TagTargetLay, []byte{5}) + writeTLV(&buf, TagCurrentLayer, []byte{5}) + writeTLV(&buf, TagTargetLayer, []byte{5}) writeTLV(&buf, TagIntent, []byte{0x20}) tsBuf := make([]byte, 2) binary.BigEndian.PutUint16(tsBuf, 0) @@ -214,24 +215,24 @@ func TestMissingHMACTag(t *testing.T) { if err == nil { t.Fatal("Expected 'missing HMAC' error") } - if !strings.Contains(err.Error(), "missing HMAC") { + if !core.Contains(err.Error(), "missing HMAC") { t.Errorf("Expected 'missing HMAC' error, got: %v", err) } } -func TestWriteTLV_ValueTooLarge(t *testing.T) { +func TestPacket_WriteTLV_ValueTooLarge_Bad(t *testing.T) { var buf bytes.Buffer oversized := make([]byte, 65536) // 1 byte over the 65535 limit err := writeTLV(&buf, TagVersion, oversized) if err == nil { t.Fatal("Expected error for TLV value > 65535 bytes") } - if !strings.Contains(err.Error(), "TLV value too large") { + if !core.Contains(err.Error(), "TLV value too large") { t.Errorf("Expected 'TLV value too large' error, got: %v", err) } } -func TestTruncatedPacket(t *testing.T) { +func TestPacket_TruncatedPacket_Bad(t *testing.T) { builder := NewBuilder(0x20, []byte("full payload")) frame, err := builder.MarshalAndSign(testSecret) if err != nil { @@ -256,7 +257,7 @@ func TestTruncatedPacket(t *testing.T) { { name: "CutMidHMAC", cutAt: 20, // Somewhere inside the header TLVs or HMAC - wantErr: "", // Any io error + wantErr: "", // Any io error }, } @@ -267,14 +268,14 @@ func TestTruncatedPacket(t *testing.T) { if err == nil { t.Fatal("Expected error for truncated packet") } - if tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { + if tc.wantErr != "" && !core.Contains(err.Error(), tc.wantErr) { t.Errorf("Expected error containing %q, got: %v", tc.wantErr, err) } }) } } -func TestUnknownTLVTag(t *testing.T) { +func TestPacket_UnknownTLVTag_Bad(t *testing.T) { // Build a valid packet, then inject an unknown tag before the HMAC. // The unknown tag must be included in signedData for HMAC to pass. payload := []byte("tagged payload") @@ -284,8 +285,8 @@ func TestUnknownTLVTag(t *testing.T) { // Standard header TLVs writeTLV(&headerBuf, TagVersion, []byte{0x09}) - writeTLV(&headerBuf, TagCurrentLay, []byte{5}) - writeTLV(&headerBuf, TagTargetLay, []byte{5}) + writeTLV(&headerBuf, TagCurrentLayer, []byte{5}) + writeTLV(&headerBuf, TagTargetLayer, []byte{5}) writeTLV(&headerBuf, TagIntent, []byte{0x20}) tsBuf := make([]byte, 2) binary.BigEndian.PutUint16(tsBuf, 0) @@ -324,7 +325,7 @@ func TestUnknownTLVTag(t *testing.T) { } } -func TestNewBuilder_Defaults(t *testing.T) { +func TestPacket_NewBuilder_Defaults_Good(t *testing.T) { builder := NewBuilder(0x20, []byte("data")) if builder.Header.Version != 0x09 { @@ -344,7 +345,7 @@ func TestNewBuilder_Defaults(t *testing.T) { } } -func TestThreatScoreBoundaries(t *testing.T) { +func TestPacket_ThreatScoreBoundaries_Good(t *testing.T) { tests := []struct { name string score uint16 @@ -378,7 +379,7 @@ func TestThreatScoreBoundaries(t *testing.T) { } } -func TestWriteTLV_BoundaryLengths(t *testing.T) { +func TestPacket_WriteTLV_BoundaryLengths_Ugly(t *testing.T) { tests := []struct { name string length int @@ -407,9 +408,8 @@ func TestWriteTLV_BoundaryLengths(t *testing.T) { } } - // TestReadAndVerify_EmptyReader verifies behaviour on completely empty input. -func TestReadAndVerify_EmptyReader(t *testing.T) { +func TestPacket_ReadAndVerify_EmptyReader_Ugly(t *testing.T) { _, err := ReadAndVerify(bufio.NewReader(bytes.NewReader(nil)), testSecret) if err == nil { t.Fatal("Expected error for empty reader") diff --git a/ueps/reader.go b/ueps/reader.go index dcd1fe7..a024c9c 100644 --- a/ueps/reader.go +++ b/ueps/reader.go @@ -8,83 +8,86 @@ import ( "encoding/binary" "io" - coreerr "dappco.re/go/core/log" + core "dappco.re/go/core" ) -// ParsedPacket holds the verified data +// packet := &ParsedPacket{Header: UEPSHeader{IntentID: 0x01}} type ParsedPacket struct { Header UEPSHeader Payload []byte } -// ReadAndVerify reads a UEPS frame from the stream and validates the HMAC. -// It consumes the stream up to the end of the packet. +// packet, err := ReadAndVerify(bufio.NewReader(bytes.NewReader(frame)), sharedSecret) func ReadAndVerify(r *bufio.Reader, sharedSecret []byte) (*ParsedPacket, error) { - // Buffer to reconstruct the data for HMAC verification var signedData bytes.Buffer header := UEPSHeader{} var signature []byte var payload []byte - // Loop through TLVs for { - // 1. Read Tag tag, err := r.ReadByte() if err != nil { return nil, err } - // 2. Read Length (2-byte big-endian uint16) lenBuf := make([]byte, 2) if _, err := io.ReadFull(r, lenBuf); err != nil { return nil, err } length := int(binary.BigEndian.Uint16(lenBuf)) - // 3. Read Value value := make([]byte, length) if _, err := io.ReadFull(r, value); err != nil { return nil, err } - // 4. Handle Tag switch tag { case TagVersion: + if len(value) != 1 { + return nil, core.E("ueps.ReadAndVerify", "malformed version TLV", nil) + } header.Version = value[0] signedData.WriteByte(tag) signedData.Write(lenBuf) signedData.Write(value) - case TagCurrentLay: + case TagCurrentLayer: + if len(value) != 1 { + return nil, core.E("ueps.ReadAndVerify", "malformed current layer TLV", nil) + } header.CurrentLayer = value[0] signedData.WriteByte(tag) signedData.Write(lenBuf) signedData.Write(value) - case TagTargetLay: + case TagTargetLayer: + if len(value) != 1 { + return nil, core.E("ueps.ReadAndVerify", "malformed target layer TLV", nil) + } header.TargetLayer = value[0] signedData.WriteByte(tag) signedData.Write(lenBuf) signedData.Write(value) case TagIntent: + if len(value) != 1 { + return nil, core.E("ueps.ReadAndVerify", "malformed intent TLV", nil) + } header.IntentID = value[0] signedData.WriteByte(tag) signedData.Write(lenBuf) signedData.Write(value) case TagThreatScore: + if len(value) != 2 { + return nil, core.E("ueps.ReadAndVerify", "malformed threat score TLV", nil) + } header.ThreatScore = binary.BigEndian.Uint16(value) signedData.WriteByte(tag) signedData.Write(lenBuf) signedData.Write(value) case TagHMAC: signature = value - // HMAC tag itself is not part of the signed data case TagPayload: payload = value - // Exit loop after payload (last tag in UEPS frame) - // Note: The HMAC covers the Payload but NOT the TagPayload/Length bytes - // to match the PacketBuilder.MarshalAndSign logic. goto verify default: - // Unknown tag (future proofing), verify it but ignore semantics signedData.WriteByte(tag) signedData.Write(lenBuf) signedData.Write(value) @@ -93,18 +96,16 @@ func ReadAndVerify(r *bufio.Reader, sharedSecret []byte) (*ParsedPacket, error) verify: if len(signature) == 0 { - return nil, coreerr.E("ueps.ReadAndVerify", "UEPS packet missing HMAC signature", nil) + return nil, core.E("ueps.ReadAndVerify", "UEPS packet missing HMAC signature", nil) } - // 5. Verify HMAC - // Reconstruct: Headers (signedData) + Payload mac := hmac.New(sha256.New, sharedSecret) mac.Write(signedData.Bytes()) mac.Write(payload) expectedMAC := mac.Sum(nil) if !hmac.Equal(signature, expectedMAC) { - return nil, coreerr.E("ueps.ReadAndVerify", "integrity violation: HMAC mismatch (ThreatScore +100)", nil) + return nil, core.E("ueps.ReadAndVerify", "integrity violation: HMAC mismatch (ThreatScore +100)", nil) } return &ParsedPacket{ @@ -112,4 +113,3 @@ verify: Payload: payload, }, nil } -