feat(node): advertise agent identity in transport
Some checks failed
Security Scan / security (push) Successful in 8s
Test / test (push) Failing after 36s

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-03-30 22:03:34 +01:00
parent ca885ff386
commit 8e640e2d42
2 changed files with 145 additions and 3 deletions

View file

@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@ -29,6 +30,9 @@ const debugLogInterval = 100
// DefaultMaxMessageSize is the default maximum message size (1MB)
const DefaultMaxMessageSize int64 = 1 << 20 // 1MB
// agentUserAgentPrefix identifies this tool in request headers.
const agentUserAgentPrefix = "agent-go-p2p"
// TransportConfig configures the WebSocket transport.
//
// cfg := DefaultTransportConfig()
@ -181,6 +185,7 @@ type PeerConnection struct {
Conn *websocket.Conn
SharedSecret []byte // Derived via X25519 ECDH, used for SMSG
LastActivity time.Time
UserAgent string // Request identity advertised by the peer
writeMu sync.Mutex // Serialize WebSocket writes
transport *Transport
closeOnce sync.Once // Ensure Close() is only called once
@ -222,6 +227,57 @@ func NewTransport(node *NodeManager, registry *PeerRegistry, config TransportCon
}
}
// agentHeaderToken converts free-form identity data into a stable header token.
func agentHeaderToken(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return "unknown"
}
var sb strings.Builder
sb.Grow(len(value))
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
sb.WriteRune(r)
case r >= 'A' && r <= 'Z':
sb.WriteRune(r)
case r >= '0' && r <= '9':
sb.WriteRune(r)
case r == '-' || r == '_' || r == '.':
sb.WriteRune(r)
case r == ' ':
sb.WriteByte('_')
default:
sb.WriteByte('_')
}
}
token := sb.String()
if token == "" {
return "unknown"
}
return token
}
// agentUserAgent returns a transparent identity string for request headers.
func (t *Transport) agentUserAgent() string {
identity := t.node.Identity()
if identity == nil {
return core.Sprintf("%s proto=%s", agentUserAgentPrefix, ProtocolVersion)
}
return core.Sprintf(
"%s id=%s name=%s role=%s proto=%s",
agentUserAgentPrefix,
identity.ID,
agentHeaderToken(identity.Name),
identity.Role,
ProtocolVersion,
)
}
// Start opens the WebSocket listener and background maintenance loops.
func (t *Transport) Start() error {
mux := http.NewServeMux()
@ -333,12 +389,15 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
scheme = "wss"
}
peerURL := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.WSPath}
userAgent := t.agentUserAgent()
// Dial the peer with timeout to prevent hanging on unresponsive peers
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, _, err := dialer.Dial(peerURL.String(), nil)
conn, _, err := dialer.Dial(peerURL.String(), http.Header{
"User-Agent": []string{userAgent},
})
if err != nil {
return nil, core.E("Transport.Connect", "failed to connect to peer", err)
}
@ -347,6 +406,7 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
Peer: peer,
Conn: conn,
LastActivity: time.Now(),
UserAgent: userAgent,
transport: t,
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
}
@ -364,6 +424,10 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
t.mu.Unlock()
logging.Debug("connected to peer", logging.Fields{"peer_id": pc.Peer.ID, "secret_len": len(pc.SharedSecret)})
logging.Debug("connected peer metadata", logging.Fields{
"peer_id": pc.Peer.ID,
"user_agent": pc.UserAgent,
})
// Update registry
t.registry.MarkConnected(pc.Peer.ID, true)
@ -442,6 +506,8 @@ func (t *Transport) GetConnection(peerID string) *PeerConnection {
// handleWSUpgrade handles incoming WebSocket connections.
func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
userAgent := r.Header.Get("User-Agent")
// Enforce MaxConns limit (including pending connections during handshake)
t.mu.RLock()
currentConns := len(t.conns)
@ -505,6 +571,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
"peer_version": payload.Version,
"supported_versions": SupportedProtocolVersions,
"peer_id": payload.Identity.ID,
"user_agent": userAgent,
})
identity := t.node.Identity()
if identity != nil {
@ -535,6 +602,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
"peer_id": payload.Identity.ID,
"peer_name": payload.Identity.Name,
"public_key": safeKeyPrefix(payload.Identity.PublicKey),
"user_agent": userAgent,
})
// Send rejection before closing
identity := t.node.Identity()
@ -577,6 +645,7 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
Conn: conn,
SharedSecret: sharedSecret,
LastActivity: time.Now(),
UserAgent: userAgent,
transport: t,
rateLimiter: NewPeerRateLimiter(100, 50), // 100 burst, 50/sec refill
}
@ -626,6 +695,12 @@ func (t *Transport) handleWSUpgrade(w http.ResponseWriter, r *http.Request) {
// Update registry
t.registry.MarkConnected(peer.ID, true)
logging.Debug("accepted peer connection", logging.Fields{
"peer_id": peer.ID,
"peer_name": peer.Name,
"user_agent": userAgent,
})
// Start read loop
t.wg.Add(1)
go t.readLoop(pc)
@ -733,8 +808,9 @@ func (t *Transport) performHandshake(pc *PeerConnection) error {
}
logging.Debug("handshake completed with challenge-response verification", logging.Fields{
"peer_id": pc.Peer.ID,
"peer_name": pc.Peer.Name,
"peer_id": pc.Peer.ID,
"peer_name": pc.Peer.Name,
"user_agent": pc.UserAgent,
})
return nil

View file

@ -4,6 +4,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
@ -236,6 +237,71 @@ func TestTransport_FullHandshake_Good(t *testing.T) {
}
}
func TestTransport_ConnectSendsAgentUserAgent_Good(t *testing.T) {
serverNM := testNode(t, "ua-server", RoleWorker)
clientNM := testNode(t, "ua-client", RoleController)
serverReg := testRegistry(t)
clientReg := testRegistry(t)
serverCfg := DefaultTransportConfig()
clientCfg := DefaultTransportConfig()
serverTransport := NewTransport(serverNM, serverReg, serverCfg)
clientTransport := NewTransport(clientNM, clientReg, clientCfg)
var capturedUserAgent atomic.Value
mux := http.NewServeMux()
mux.HandleFunc(serverCfg.WSPath, func(w http.ResponseWriter, r *http.Request) {
capturedUserAgent.Store(r.Header.Get("User-Agent"))
serverTransport.handleWSUpgrade(w, r)
})
ts := httptest.NewServer(mux)
t.Cleanup(func() {
clientTransport.Stop()
serverTransport.Stop()
ts.Close()
})
u, _ := url.Parse(ts.URL)
serverAddr := u.Host
peer := &Peer{
ID: serverNM.GetIdentity().ID,
Name: "server",
Address: serverAddr,
Role: RoleWorker,
}
clientReg.AddPeer(peer)
pc, err := clientTransport.Connect(peer)
if err != nil {
t.Fatalf("client connect failed: %v", err)
}
ua, ok := capturedUserAgent.Load().(string)
if !ok || ua == "" {
t.Fatal("expected user-agent to be captured during websocket upgrade")
}
if !strings.HasPrefix(ua, agentUserAgentPrefix) {
t.Fatalf("user-agent prefix: got %q, want prefix %q", ua, agentUserAgentPrefix)
}
if !strings.Contains(ua, "id="+clientNM.GetIdentity().ID) {
t.Fatalf("user-agent should include client identity, got %q", ua)
}
if pc.UserAgent != ua {
t.Fatalf("client connection user-agent: got %q, want %q", pc.UserAgent, ua)
}
serverConn := serverTransport.Connection(clientNM.GetIdentity().ID)
if serverConn == nil {
t.Fatal("server should retain the accepted connection")
}
if serverConn.UserAgent != ua {
t.Fatalf("server connection user-agent: got %q, want %q", serverConn.UserAgent, ua)
}
}
func TestTransport_HandshakeRejectWrongVersion_Bad(t *testing.T) {
tp := setupTestTransportPair(t)