diff --git a/.core/TODO.md b/.core/TODO.md index b1a10b78..de32e97e 100644 --- a/.core/TODO.md +++ b/.core/TODO.md @@ -1 +1 @@ -- @hardening pkg/display/p2p.go:15 — `attachP2PBridge` ignores `router.Subscribe` failure and never unregisters the bridge, so display events can silently stop without cleanup. +- @hardening pkg/display/display.go:195 — The P2P bridge is attached before `WSEventManager` is initialized, so early bridged events can be dropped during startup. diff --git a/pkg/display/display.go b/pkg/display/display.go index 214bf10d..ab6d1a76 100644 --- a/pkg/display/display.go +++ b/pkg/display/display.go @@ -45,18 +45,19 @@ type WindowInfo = window.WindowInfo // Bridges IPC actions to WebSocket events for TypeScript apps. type Service struct { *core.ServiceRuntime[Options] - wailsApp *application.App - app App - configData map[string]map[string]any - configFile *config.Config // config instance for file persistence - mode container.AppMode - events *WSEventManager - schemeHandlers map[string]SchemeHandler - manifestCache map[string]*loadedManifest - manifestMu sync.Mutex - storage *StorageRegistry - background *BackgroundRegistry - sidecar *deno.Manager + wailsApp *application.App + app App + configData map[string]map[string]any + configFile *config.Config // config instance for file persistence + mode container.AppMode + events *WSEventManager + p2pBridgeCancel context.CancelFunc + schemeHandlers map[string]SchemeHandler + manifestCache map[string]*loadedManifest + manifestMu sync.Mutex + storage *StorageRegistry + background *BackgroundRegistry + sidecar *deno.Manager } // New returns a display Service with empty config sections. @@ -203,9 +204,14 @@ func (s *Service) OnStartup(_ context.Context) core.Result { } func (s *Service) OnShutdown(ctx context.Context) core.Result { - if s.events != nil { - s.events.Close() - s.events = nil + events := s.events + s.events = nil + if s.p2pBridgeCancel != nil { + s.p2pBridgeCancel() + s.p2pBridgeCancel = nil + } + if events != nil { + events.Close() } var shutdownErr error if s.storage != nil { diff --git a/pkg/display/p2p.go b/pkg/display/p2p.go index bcd46fd1..67218d1f 100644 --- a/pkg/display/p2p.go +++ b/pkg/display/p2p.go @@ -12,7 +12,12 @@ func (s *Service) attachP2PBridge() { if !ok || router == nil { return } - _ = router.Subscribe(context.Background(), "display", func(envelope p2p.Envelope) { + if s.p2pBridgeCancel != nil { + s.p2pBridgeCancel() + s.p2pBridgeCancel = nil + } + ctx, cancel := context.WithCancel(context.Background()) + if err := router.Subscribe(ctx, "display", func(envelope p2p.Envelope) { if s.events == nil { return } @@ -26,5 +31,12 @@ func (s *Service) attachP2PBridge() { "payload": envelope.Payload, }, }) - }) + }); err != nil { + cancel() + if s.app != nil { + s.app.Logger().Info("p2p bridge subscribe failed", "err", err) + } + return + } + s.p2pBridgeCancel = cancel } diff --git a/pkg/p2p/tcp.go b/pkg/p2p/tcp.go index 6c824440..7e38c6a7 100644 --- a/pkg/p2p/tcp.go +++ b/pkg/p2p/tcp.go @@ -20,7 +20,7 @@ type TCPDriver struct { options TCPOptions mu sync.RWMutex listener net.Listener - subscriptions map[string][]func(Envelope) + subscriptions map[string][]*subscription } func NewTCPDriver(options TCPOptions) *TCPDriver { @@ -30,7 +30,7 @@ func NewTCPDriver(options TCPOptions) *TCPDriver { PeerAddrs: append([]string(nil), options.PeerAddrs...), NodeID: strings.TrimSpace(options.NodeID), }, - subscriptions: make(map[string][]func(Envelope)), + subscriptions: make(map[string][]*subscription), } } @@ -43,7 +43,13 @@ func (d *TCPDriver) ListenAddr() string { return d.options.ListenAddr } -func (d *TCPDriver) Subscribe(_ context.Context, topic string, handler func(Envelope)) error { +func (d *TCPDriver) Subscribe(ctx context.Context, topic string, handler func(Envelope)) error { + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return err + } topic = strings.TrimSpace(topic) if topic == "" { return errors.New("topic is required") @@ -51,10 +57,30 @@ func (d *TCPDriver) Subscribe(_ context.Context, topic string, handler func(Enve if handler == nil { return errors.New("handler is required") } + if err := d.ensureListener(); err != nil { + return err + } + + sub := &subscription{ + handler: func(envelope Envelope) { + select { + case <-ctx.Done(): + return + default: + handler(envelope) + } + }, + } + d.mu.Lock() - d.subscriptions[topic] = append(d.subscriptions[topic], handler) + d.subscriptions[topic] = append(d.subscriptions[topic], sub) d.mu.Unlock() - return d.ensureListener() + + context.AfterFunc(ctx, func() { + d.removeSubscription(topic, sub) + }) + + return nil } func (d *TCPDriver) Publish(ctx context.Context, envelope Envelope) error { @@ -94,10 +120,12 @@ func (d *TCPDriver) Close() error { d.mu.Lock() defer d.mu.Unlock() if d.listener == nil { + d.subscriptions = make(map[string][]*subscription) return nil } err := d.listener.Close() d.listener = nil + d.subscriptions = make(map[string][]*subscription) return err } @@ -141,10 +169,37 @@ func (d *TCPDriver) readConn(conn net.Conn) { func (d *TCPDriver) dispatch(envelope Envelope) { d.mu.RLock() - handlers := append([]func(Envelope){}, d.subscriptions[envelope.Topic]...) + handlers := append([]*subscription{}, d.subscriptions[envelope.Topic]...) handlers = append(handlers, d.subscriptions["*"]...) d.mu.RUnlock() - for _, handler := range handlers { - handler(envelope) + for _, sub := range handlers { + if sub == nil { + continue + } + sub.handler(envelope) } } + +func (d *TCPDriver) removeSubscription(topic string, target *subscription) { + d.mu.Lock() + defer d.mu.Unlock() + + subs := d.subscriptions[topic] + for i, sub := range subs { + if sub == target { + copy(subs[i:], subs[i+1:]) + subs[len(subs)-1] = nil + subs = subs[:len(subs)-1] + break + } + } + if len(subs) == 0 { + delete(d.subscriptions, topic) + return + } + d.subscriptions[topic] = subs +} + +type subscription struct { + handler func(Envelope) +} diff --git a/pkg/p2p/tcp_test.go b/pkg/p2p/tcp_test.go index 71a2532d..9d069aa2 100644 --- a/pkg/p2p/tcp_test.go +++ b/pkg/p2p/tcp_test.go @@ -12,6 +12,22 @@ import ( "github.com/stretchr/testify/require" ) +func TestTCPDriver_Subscribe_CancelRemovesHandler(t *testing.T) { + driver := NewTCPDriver(TCPOptions{}) + ctx, cancel := context.WithCancel(context.Background()) + + calls := 0 + err := driver.Subscribe(ctx, "updates", func(Envelope) { + calls++ + }) + require.NoError(t, err) + + cancel() + err = driver.Publish(context.Background(), Envelope{Topic: "updates"}) + require.NoError(t, err) + assert.Zero(t, calls) +} + func TestTCPDriver_Publish_ContinuesAfterPeerFailure(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err)