From 8e640e2d42c635eba12844ae2b850bb637305c0d Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 30 Mar 2026 22:03:34 +0100 Subject: [PATCH] feat(node): advertise agent identity in transport Co-Authored-By: Virgil --- node/transport.go | 82 ++++++++++++++++++++++++++++++++++++++++-- node/transport_test.go | 66 ++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/node/transport.go b/node/transport.go index 2cb5aed..2d94b20 100644 --- a/node/transport.go +++ b/node/transport.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "slices" + "strings" "sync" "sync/atomic" "time" @@ -29,6 +30,9 @@ const debugLogInterval = 100 // DefaultMaxMessageSize is the default maximum message size (1MB) const DefaultMaxMessageSize int64 = 1 << 20 // 1MB +// agentUserAgentPrefix identifies this tool in request headers. +const agentUserAgentPrefix = "agent-go-p2p" + // TransportConfig configures the WebSocket transport. // // cfg := DefaultTransportConfig() @@ -181,6 +185,7 @@ type PeerConnection struct { Conn *websocket.Conn SharedSecret []byte // Derived via X25519 ECDH, used for SMSG LastActivity time.Time + UserAgent string // Request identity advertised by the peer writeMu sync.Mutex // Serialize WebSocket writes transport *Transport closeOnce sync.Once // Ensure Close() is only called once @@ -222,6 +227,57 @@ func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportCon } } +// agentHeaderToken converts free-form identity data into a stable header token. +func agentHeaderToken(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "unknown" + } + + var sb strings.Builder + sb.Grow(len(value)) + for _, r := range value { + switch { + case r >= 'a' && r <= 'z': + sb.WriteRune(r) + case r >= 'A' && r <= 'Z': + sb.WriteRune(r) + case r >= '0' && r <= '9': + sb.WriteRune(r) + case r == '-' || r == '_' || r == '.': + sb.WriteRune(r) + case r == ' ': + sb.WriteByte('_') + default: + sb.WriteByte('_') + } + } + + token := sb.String() + if token == "" { + return "unknown" + } + + return token +} + +// agentUserAgent returns a transparent identity string for request headers. +func (t *Transport) agentUserAgent() string { + identity := t.node.Identity() + if identity == nil { + return core.Sprintf("%s proto=%s", agentUserAgentPrefix, ProtocolVersion) + } + + return core.Sprintf( + "%s id=%s name=%s role=%s proto=%s", + agentUserAgentPrefix, + identity.ID, + agentHeaderToken(identity.Name), + identity.Role, + ProtocolVersion, + ) +} + // Start opens the WebSocket listener and background maintenance loops. func (t *Transport) Start() error { mux := http.NewServeMux() @@ -333,12 +389,15 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) { scheme = "wss" } peerURL := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.WSPath} + userAgent := t.agentUserAgent() // Dial the peer with timeout to prevent hanging on unresponsive peers dialer := websocket.Dialer{ HandshakeTimeout: 10 * time.Second, } - conn, _, err := dialer.Dial(peerURL.String(), nil) + conn, _, err := dialer.Dial(peerURL.String(), http.Header{ + "User-Agent": []string{userAgent}, + }) if err != nil { return nil, core.E("Transport.Connect", "failed to connect to peer", err) } @@ -347,6 +406,7 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) { Peer: peer, Conn: conn, LastActivity: time.Now(), + UserAgent: userAgent, transport: t, rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill } @@ -364,6 +424,10 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) { t.mu.Unlock() logging.Debug("connected to peer", logging.Fields{"peer_id": pc.Peer.ID, "secret_len": len(pc.SharedSecret)}) + logging.Debug("connected peer metadata", logging.Fields{ + "peer_id": pc.Peer.ID, + "user_agent": pc.UserAgent, + }) // Update registry t.registry.MarkConnected(pc.Peer.ID, true) @@ -442,6 +506,8 @@ func (t *Transport) GetConnection(peerID string) *PeerConnection { // handleWSUpgrade handles incoming WebSocket connections. func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { + userAgent := r.Header.Get("User-Agent") + // Enforce MaxConns limit (including pending connections during handshake) t.mu.RLock() currentConns := len(t.conns) @@ -505,6 +571,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { "peer_version": payload.Version, "supported_versions": SupportedProtocolVersions, "peer_id": payload.Identity.ID, + "user_agent": userAgent, }) identity := t.node.Identity() if identity != nil { @@ -535,6 +602,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { "peer_id": payload.Identity.ID, "peer_name": payload.Identity.Name, "public_key": safeKeyPrefix(payload.Identity.PublicKey), + "user_agent": userAgent, }) // Send rejection before closing identity := t.node.Identity() @@ -577,6 +645,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { Conn: conn, SharedSecret: sharedSecret, LastActivity: time.Now(), + UserAgent: userAgent, transport: t, rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill } @@ -626,6 +695,12 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) { // Update registry t.registry.MarkConnected(peer.ID, true) + logging.Debug("accepted peer connection", logging.Fields{ + "peer_id": peer.ID, + "peer_name": peer.Name, + "user_agent": userAgent, + }) + // Start read loop t.wg.Add(1) go t.readLoop(pc) @@ -733,8 +808,9 @@ func (t *Transport) performHandshake(pc *PeerConnection) error { } logging.Debug("handshake completed with challenge-response verification", logging.Fields{ - "peer_id": pc.Peer.ID, - "peer_name": pc.Peer.Name, + "peer_id": pc.Peer.ID, + "peer_name": pc.Peer.Name, + "user_agent": pc.UserAgent, }) return nil diff --git a/node/transport_test.go b/node/transport_test.go index 4fe6e1f..56ad9e7 100644 --- a/node/transport_test.go +++ b/node/transport_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "sync" "sync/atomic" "testing" @@ -236,6 +237,71 @@ func TestTransport_FullHandshake_Good(t *testing.T) { } } +func TestTransport_ConnectSendsAgentUserAgent_Good(t *testing.T) { + serverNM := testNode(t, "ua-server", RoleWorker) + clientNM := testNode(t, "ua-client", RoleController) + serverReg := testRegistry(t) + clientReg := testRegistry(t) + + serverCfg := DefaultTransportConfig() + clientCfg := DefaultTransportConfig() + serverTransport := NewTransport(serverNM, serverReg, serverCfg) + clientTransport := NewTransport(clientNM, clientReg, clientCfg) + + var capturedUserAgent atomic.Value + + mux := http.NewServeMux() + mux.HandleFunc(serverCfg.WSPath, func(w http.ResponseWriter, r *http.Request) { + capturedUserAgent.Store(r.Header.Get("User-Agent")) + serverTransport.handleWSUpgrade(w, r) + }) + + ts := httptest.NewServer(mux) + t.Cleanup(func() { + clientTransport.Stop() + serverTransport.Stop() + ts.Close() + }) + + u, _ := url.Parse(ts.URL) + serverAddr := u.Host + + peer := &Peer{ + ID: serverNM.GetIdentity().ID, + Name: "server", + Address: serverAddr, + Role: RoleWorker, + } + clientReg.AddPeer(peer) + + pc, err := clientTransport.Connect(peer) + if err != nil { + t.Fatalf("client connect failed: %v", err) + } + + ua, ok := capturedUserAgent.Load().(string) + if !ok || ua == "" { + t.Fatal("expected user-agent to be captured during websocket upgrade") + } + if !strings.HasPrefix(ua, agentUserAgentPrefix) { + t.Fatalf("user-agent prefix: got %q, want prefix %q", ua, agentUserAgentPrefix) + } + if !strings.Contains(ua, "id="+clientNM.GetIdentity().ID) { + t.Fatalf("user-agent should include client identity, got %q", ua) + } + if pc.UserAgent != ua { + t.Fatalf("client connection user-agent: got %q, want %q", pc.UserAgent, ua) + } + + serverConn := serverTransport.Connection(clientNM.GetIdentity().ID) + if serverConn == nil { + t.Fatal("server should retain the accepted connection") + } + if serverConn.UserAgent != ua { + t.Fatalf("server connection user-agent: got %q, want %q", serverConn.UserAgent, ua) + } +} + func TestTransport_HandshakeRejectWrongVersion_Bad(t *testing.T) { tp := setupTestTransportPair(t)