Harden p2p bridge cleanup
Some checks are pending
Security Scan / security (push) Waiting to run
Test / test (push) Waiting to run

This commit is contained in:
Snider 2026-04-17 20:12:18 +01:00
parent 17fd6e9cdf
commit 9eb87cdc68
5 changed files with 115 additions and 26 deletions

View file

@ -1 +1 @@
- @hardening pkg/display/p2p.go:15 — `attachP2PBridge` ignores `router.Subscribe` failure and never unregisters the bridge, so display events can silently stop without cleanup.
- @hardening pkg/display/display.go:195 — The P2P bridge is attached before `WSEventManager` is initialized, so early bridged events can be dropped during startup.

View file

@ -45,18 +45,19 @@ type WindowInfo = window.WindowInfo
// Bridges IPC actions to WebSocket events for TypeScript apps.
type Service struct {
*core.ServiceRuntime[Options]
wailsApp *application.App
app App
configData map[string]map[string]any
configFile *config.Config // config instance for file persistence
mode container.AppMode
events *WSEventManager
schemeHandlers map[string]SchemeHandler
manifestCache map[string]*loadedManifest
manifestMu sync.Mutex
storage *StorageRegistry
background *BackgroundRegistry
sidecar *deno.Manager
wailsApp *application.App
app App
configData map[string]map[string]any
configFile *config.Config // config instance for file persistence
mode container.AppMode
events *WSEventManager
p2pBridgeCancel context.CancelFunc
schemeHandlers map[string]SchemeHandler
manifestCache map[string]*loadedManifest
manifestMu sync.Mutex
storage *StorageRegistry
background *BackgroundRegistry
sidecar *deno.Manager
}
// New returns a display Service with empty config sections.
@ -203,9 +204,14 @@ func (s *Service) OnStartup(_ context.Context) core.Result {
}
func (s *Service) OnShutdown(ctx context.Context) core.Result {
if s.events != nil {
s.events.Close()
s.events = nil
events := s.events
s.events = nil
if s.p2pBridgeCancel != nil {
s.p2pBridgeCancel()
s.p2pBridgeCancel = nil
}
if events != nil {
events.Close()
}
var shutdownErr error
if s.storage != nil {

View file

@ -12,7 +12,12 @@ func (s *Service) attachP2PBridge() {
if !ok || router == nil {
return
}
_ = router.Subscribe(context.Background(), "display", func(envelope p2p.Envelope) {
if s.p2pBridgeCancel != nil {
s.p2pBridgeCancel()
s.p2pBridgeCancel = nil
}
ctx, cancel := context.WithCancel(context.Background())
if err := router.Subscribe(ctx, "display", func(envelope p2p.Envelope) {
if s.events == nil {
return
}
@ -26,5 +31,12 @@ func (s *Service) attachP2PBridge() {
"payload": envelope.Payload,
},
})
})
}); err != nil {
cancel()
if s.app != nil {
s.app.Logger().Info("p2p bridge subscribe failed", "err", err)
}
return
}
s.p2pBridgeCancel = cancel
}

View file

@ -20,7 +20,7 @@ type TCPDriver struct {
options TCPOptions
mu sync.RWMutex
listener net.Listener
subscriptions map[string][]func(Envelope)
subscriptions map[string][]*subscription
}
func NewTCPDriver(options TCPOptions) *TCPDriver {
@ -30,7 +30,7 @@ func NewTCPDriver(options TCPOptions) *TCPDriver {
PeerAddrs: append([]string(nil), options.PeerAddrs...),
NodeID: strings.TrimSpace(options.NodeID),
},
subscriptions: make(map[string][]func(Envelope)),
subscriptions: make(map[string][]*subscription),
}
}
@ -43,7 +43,13 @@ func (d *TCPDriver) ListenAddr() string {
return d.options.ListenAddr
}
func (d *TCPDriver) Subscribe(_ context.Context, topic string, handler func(Envelope)) error {
func (d *TCPDriver) Subscribe(ctx context.Context, topic string, handler func(Envelope)) error {
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return err
}
topic = strings.TrimSpace(topic)
if topic == "" {
return errors.New("topic is required")
@ -51,10 +57,30 @@ func (d *TCPDriver) Subscribe(_ context.Context, topic string, handler func(Enve
if handler == nil {
return errors.New("handler is required")
}
if err := d.ensureListener(); err != nil {
return err
}
sub := &subscription{
handler: func(envelope Envelope) {
select {
case <-ctx.Done():
return
default:
handler(envelope)
}
},
}
d.mu.Lock()
d.subscriptions[topic] = append(d.subscriptions[topic], handler)
d.subscriptions[topic] = append(d.subscriptions[topic], sub)
d.mu.Unlock()
return d.ensureListener()
context.AfterFunc(ctx, func() {
d.removeSubscription(topic, sub)
})
return nil
}
func (d *TCPDriver) Publish(ctx context.Context, envelope Envelope) error {
@ -94,10 +120,12 @@ func (d *TCPDriver) Close() error {
d.mu.Lock()
defer d.mu.Unlock()
if d.listener == nil {
d.subscriptions = make(map[string][]*subscription)
return nil
}
err := d.listener.Close()
d.listener = nil
d.subscriptions = make(map[string][]*subscription)
return err
}
@ -141,10 +169,37 @@ func (d *TCPDriver) readConn(conn net.Conn) {
func (d *TCPDriver) dispatch(envelope Envelope) {
d.mu.RLock()
handlers := append([]func(Envelope){}, d.subscriptions[envelope.Topic]...)
handlers := append([]*subscription{}, d.subscriptions[envelope.Topic]...)
handlers = append(handlers, d.subscriptions["*"]...)
d.mu.RUnlock()
for _, handler := range handlers {
handler(envelope)
for _, sub := range handlers {
if sub == nil {
continue
}
sub.handler(envelope)
}
}
func (d *TCPDriver) removeSubscription(topic string, target *subscription) {
d.mu.Lock()
defer d.mu.Unlock()
subs := d.subscriptions[topic]
for i, sub := range subs {
if sub == target {
copy(subs[i:], subs[i+1:])
subs[len(subs)-1] = nil
subs = subs[:len(subs)-1]
break
}
}
if len(subs) == 0 {
delete(d.subscriptions, topic)
return
}
d.subscriptions[topic] = subs
}
type subscription struct {
handler func(Envelope)
}

View file

@ -12,6 +12,22 @@ import (
"github.com/stretchr/testify/require"
)
func TestTCPDriver_Subscribe_CancelRemovesHandler(t *testing.T) {
driver := NewTCPDriver(TCPOptions{})
ctx, cancel := context.WithCancel(context.Background())
calls := 0
err := driver.Subscribe(ctx, "updates", func(Envelope) {
calls++
})
require.NoError(t, err)
cancel()
err = driver.Publish(context.Background(), Envelope{Topic: "updates"})
require.NoError(t, err)
assert.Zero(t, calls)
}
func TestTCPDriver_Publish_ContinuesAfterPeerFailure(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)