From dff3d576fae2ae88dc934dc15d2450ad80dd4eb4 Mon Sep 17 00:00:00 2001 From: Virgil Date: Mon, 23 Mar 2026 07:34:16 +0000 Subject: [PATCH] fix(cdp): resolve issue 2 audit findings Co-Authored-By: Virgil --- actions.go | 1 + angular.go | 247 ++++++++-------- audit_issue2_test.go | 673 +++++++++++++++++++++++++++++++++++++++++++ cdp.go | 485 +++++++++++++++++++++++-------- console.go | 109 +++++-- webview.go | 13 +- webview_test.go | 15 +- 7 files changed, 1255 insertions(+), 288 deletions(-) create mode 100644 audit_issue2_test.go diff --git a/actions.go b/actions.go index f1fe510..284297c 100644 --- a/actions.go +++ b/actions.go @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: EUPL-1.2 package webview import ( diff --git a/angular.go b/angular.go index 6028a13..aceb235 100644 --- a/angular.go +++ b/angular.go @@ -1,7 +1,9 @@ +// SPDX-License-Identifier: EUPL-1.2 package webview import ( "context" + "encoding/json" "fmt" "strings" "time" @@ -93,6 +95,21 @@ func (ah *AngularHelper) isAngularApp(ctx context.Context) (bool, error) { func (ah *AngularHelper) waitForZoneStability(ctx context.Context) error { script := ` new Promise((resolve, reject) => { + const pollZone = () => { + if (!window.Zone || !window.Zone.current) { + resolve(true); + return; + } + + const inner = window.Zone.current._inner || window.Zone.current; + if (!inner._hasPendingMicrotasks && !inner._hasPendingMacrotasks) { + resolve(true); + return; + } + + setTimeout(pollZone, 50); + }; + // Get the root elements const roots = window.getAllAngularRootElements ? window.getAllAngularRootElements() : []; if (roots.length === 0) { @@ -121,28 +138,7 @@ func (ah *AngularHelper) waitForZoneStability(ctx context.Context) error { } if (!zone) { - // Fallback: check window.Zone - if (window.Zone && window.Zone.current && window.Zone.current._inner) { - const isStable = !window.Zone.current._inner._hasPendingMicrotasks && - !window.Zone.current._inner._hasPendingMacrotasks; - if (isStable) { - resolve(true); - } else { - // Poll for stability - let attempts = 0; - const poll = setInterval(() => { - attempts++; - const stable = !window.Zone.current._inner._hasPendingMicrotasks && - !window.Zone.current._inner._hasPendingMacrotasks; - if (stable || attempts > 100) { - clearInterval(poll); - resolve(stable); - } - }, 50); - } - } else { - resolve(true); - } + pollZone(); return; } @@ -153,30 +149,28 @@ func (ah *AngularHelper) waitForZoneStability(ctx context.Context) error { } // Wait for stability - const sub = zone.onStable.subscribe(() => { - sub.unsubscribe(); - resolve(true); - }); - - // Timeout fallback - setTimeout(() => { - sub.unsubscribe(); - resolve(zone.isStable); - }, 5000); + try { + const sub = zone.onStable.subscribe(() => { + sub.unsubscribe(); + resolve(true); + }); + } catch (e) { + pollZone(); + } }) ` - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - // First evaluate the promise - _, err := ah.wv.evaluate(ctx, script) + result, err := ah.wv.evaluate(ctx, script) if err != nil { // If the script fails, fall back to simple polling return ah.pollForStability(ctx) } - return nil + if stable, ok := result.(bool); ok && stable { + return nil + } + + return ah.pollForStability(ctx) } // pollForStability polls for Angular stability as a fallback. @@ -333,18 +327,20 @@ func (ah *AngularHelper) GetComponentProperty(selector, propertyName string) (an defer cancel() script := fmt.Sprintf(` - (function() { - const element = document.querySelector(%q); - if (!element) { - throw new Error('Element not found: %s'); - } - const component = window.ng.probe(element).componentInstance; - if (!component) { - throw new Error('No Angular component found on element'); - } - return component[%q]; - })() - `, selector, selector, propertyName) + (function() { + const selector = %s; + const propertyName = %s; + const element = document.querySelector(selector); + if (!element) { + throw new Error('Element not found: ' + selector); + } + const component = window.ng.probe(element).componentInstance; + if (!component) { + throw new Error('No Angular component found on element'); + } + return component[propertyName]; + })() + `, formatJSValue(selector), formatJSValue(propertyName)) return ah.wv.evaluate(ctx, script) } @@ -355,26 +351,28 @@ func (ah *AngularHelper) SetComponentProperty(selector, propertyName string, val defer cancel() script := fmt.Sprintf(` - (function() { - const element = document.querySelector(%q); - if (!element) { - throw new Error('Element not found: %s'); - } - const component = window.ng.probe(element).componentInstance; - if (!component) { - throw new Error('No Angular component found on element'); - } - component[%q] = %v; + (function() { + const selector = %s; + const propertyName = %s; + const element = document.querySelector(selector); + if (!element) { + throw new Error('Element not found: ' + selector); + } + const component = window.ng.probe(element).componentInstance; + if (!component) { + throw new Error('No Angular component found on element'); + } + component[propertyName] = %s; - // Trigger change detection - const injector = window.ng.probe(element).injector; - const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef'); - if (appRef) { + // Trigger change detection + const injector = window.ng.probe(element).injector; + const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef'); + if (appRef) { appRef.tick(); - } - return true; - })() - `, selector, selector, propertyName, formatJSValue(value)) + } + return true; + })() + `, formatJSValue(selector), formatJSValue(propertyName), formatJSValue(value)) _, err := ah.wv.evaluate(ctx, script) return err @@ -394,29 +392,31 @@ func (ah *AngularHelper) CallComponentMethod(selector, methodName string, args . } script := fmt.Sprintf(` - (function() { - const element = document.querySelector(%q); - if (!element) { - throw new Error('Element not found: %s'); - } - const component = window.ng.probe(element).componentInstance; - if (!component) { - throw new Error('No Angular component found on element'); - } - if (typeof component[%q] !== 'function') { - throw new Error('Method not found: %s'); - } - const result = component[%q](%s); + (function() { + const selector = %s; + const methodName = %s; + const element = document.querySelector(selector); + if (!element) { + throw new Error('Element not found: ' + selector); + } + const component = window.ng.probe(element).componentInstance; + if (!component) { + throw new Error('No Angular component found on element'); + } + if (typeof component[methodName] !== 'function') { + throw new Error('Method not found: ' + methodName); + } + const result = component[methodName](%s); - // Trigger change detection - const injector = window.ng.probe(element).injector; - const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef'); - if (appRef) { + // Trigger change detection + const injector = window.ng.probe(element).injector; + const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef'); + if (appRef) { appRef.tick(); - } - return result; - })() - `, selector, selector, methodName, methodName, methodName, argsStr.String()) + } + return result; + })() + `, formatJSValue(selector), formatJSValue(methodName), argsStr.String()) return ah.wv.evaluate(ctx, script) } @@ -524,16 +524,18 @@ func (ah *AngularHelper) DispatchEvent(selector, eventName string, detail any) e } script := fmt.Sprintf(` - (function() { - const element = document.querySelector(%q); - if (!element) { - throw new Error('Element not found: %s'); - } - const event = new CustomEvent(%q, { bubbles: true, detail: %s }); - element.dispatchEvent(event); - return true; - })() - `, selector, selector, eventName, detailStr) + (function() { + const selector = %s; + const eventName = %s; + const element = document.querySelector(selector); + if (!element) { + throw new Error('Element not found: ' + selector); + } + const event = new CustomEvent(eventName, { bubbles: true, detail: %s }); + element.dispatchEvent(event); + return true; + })() + `, formatJSValue(selector), formatJSValue(eventName), detailStr) _, err := ah.wv.evaluate(ctx, script) return err @@ -572,17 +574,18 @@ func (ah *AngularHelper) SetNgModel(selector string, value any) error { defer cancel() script := fmt.Sprintf(` - (function() { - const element = document.querySelector(%q); - if (!element) { - throw new Error('Element not found: %s'); - } + (function() { + const selector = %s; + const element = document.querySelector(selector); + if (!element) { + throw new Error('Element not found: ' + selector); + } - element.value = %v; - element.dispatchEvent(new Event('input', { bubbles: true })); - element.dispatchEvent(new Event('change', { bubbles: true })); + element.value = %s; + element.dispatchEvent(new Event('input', { bubbles: true })); + element.dispatchEvent(new Event('change', { bubbles: true })); - // Trigger change detection + // Trigger change detection const roots = window.getAllAngularRootElements ? window.getAllAngularRootElements() : []; for (const root of roots) { try { @@ -595,9 +598,9 @@ func (ah *AngularHelper) SetNgModel(selector string, value any) error { } catch (e) {} } - return true; - })() - `, selector, selector, formatJSValue(value)) + return true; + })() + `, formatJSValue(selector), formatJSValue(value)) _, err := ah.wv.evaluate(ctx, script) return err @@ -613,17 +616,15 @@ func getString(m map[string]any, key string) string { } func formatJSValue(v any) string { - switch val := v.(type) { - case string: - return fmt.Sprintf("%q", val) - case bool: - if val { - return "true" - } - return "false" - case nil: - return "null" - default: - return fmt.Sprintf("%v", val) + data, err := json.Marshal(v) + if err == nil { + return string(data) } + + fallback, fallbackErr := json.Marshal(fmt.Sprint(v)) + if fallbackErr == nil { + return string(fallback) + } + + return "null" } diff --git a/audit_issue2_test.go b/audit_issue2_test.go new file mode 100644 index 0000000..ab31c33 --- /dev/null +++ b/audit_issue2_test.go @@ -0,0 +1,673 @@ +// SPDX-License-Identifier: EUPL-1.2 +package webview + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +type fakeCDPServer struct { + t *testing.T + server *httptest.Server + mu sync.Mutex + nextTarget int + targets map[string]*fakeCDPTarget +} + +type fakeCDPTarget struct { + server *fakeCDPServer + id string + onConnect func(*fakeCDPTarget) + onMessage func(*fakeCDPTarget, cdpMessage) + connMu sync.Mutex + conn *websocket.Conn + received chan cdpMessage + connected chan struct{} + closed chan struct{} + connectedOnce sync.Once + closedOnce sync.Once +} + +func newFakeCDPServer(t *testing.T) *fakeCDPServer { + t.Helper() + + server := &fakeCDPServer{ + t: t, + targets: make(map[string]*fakeCDPTarget), + } + server.server = httptest.NewServer(http.HandlerFunc(server.handle)) + server.addTarget("target-1") + t.Cleanup(server.Close) + + return server +} + +func (s *fakeCDPServer) Close() { + s.server.Close() +} + +func (s *fakeCDPServer) DebugURL() string { + return s.server.URL +} + +func (s *fakeCDPServer) addTarget(id string) *fakeCDPTarget { + s.mu.Lock() + defer s.mu.Unlock() + + target := &fakeCDPTarget{ + server: s, + id: id, + received: make(chan cdpMessage, 16), + connected: make(chan struct{}), + closed: make(chan struct{}), + } + s.targets[id] = target + return target +} + +func (s *fakeCDPServer) newTarget() *fakeCDPTarget { + s.mu.Lock() + s.nextTarget++ + id := fmt.Sprintf("target-%d", s.nextTarget+1) + s.mu.Unlock() + + return s.addTarget(id) +} + +func (s *fakeCDPServer) primaryTarget() *fakeCDPTarget { + s.mu.Lock() + defer s.mu.Unlock() + return s.targets["target-1"] +} + +func (s *fakeCDPServer) handle(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/json": + s.handleListTargets(w) + case r.URL.Path == "/json/new": + s.handleNewTarget(w) + case r.URL.Path == "/json/version": + s.writeJSON(w, map[string]string{ + "Browser": "Chrome/123.0", + }) + case strings.HasPrefix(r.URL.Path, "/devtools/page/"): + s.handleWebSocket(w, r, strings.TrimPrefix(r.URL.Path, "/devtools/page/")) + default: + http.NotFound(w, r) + } +} + +func (s *fakeCDPServer) handleListTargets(w http.ResponseWriter) { + s.mu.Lock() + targets := make([]TargetInfo, 0, len(s.targets)) + for id := range s.targets { + targets = append(targets, TargetInfo{ + ID: id, + Type: "page", + Title: id, + URL: "about:blank", + WebSocketDebuggerURL: s.webSocketURL(id), + }) + } + s.mu.Unlock() + + s.writeJSON(w, targets) +} + +func (s *fakeCDPServer) handleNewTarget(w http.ResponseWriter) { + target := s.newTarget() + s.writeJSON(w, TargetInfo{ + ID: target.id, + Type: "page", + Title: target.id, + URL: "about:blank", + WebSocketDebuggerURL: s.webSocketURL(target.id), + }) +} + +func (s *fakeCDPServer) handleWebSocket(w http.ResponseWriter, r *http.Request, id string) { + s.mu.Lock() + target := s.targets[id] + s.mu.Unlock() + if target == nil { + http.NotFound(w, r) + return + } + + upgrader := websocket.Upgrader{ + CheckOrigin: func(*http.Request) bool { return true }, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + s.t.Fatalf("failed to upgrade test WebSocket: %v", err) + } + + target.attach(conn) +} + +func (s *fakeCDPServer) writeJSON(w http.ResponseWriter, value any) { + s.t.Helper() + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(value); err != nil { + s.t.Fatalf("failed to encode JSON: %v", err) + } +} + +func (s *fakeCDPServer) webSocketURL(id string) string { + wsURL, err := url.Parse(s.server.URL) + if err != nil { + s.t.Fatalf("failed to parse test server URL: %v", err) + } + if wsURL.Scheme == "http" { + wsURL.Scheme = "ws" + } else { + wsURL.Scheme = "wss" + } + wsURL.Path = "/devtools/page/" + id + wsURL.RawQuery = "" + wsURL.Fragment = "" + + return wsURL.String() +} + +func (tgt *fakeCDPTarget) attach(conn *websocket.Conn) { + tgt.connMu.Lock() + tgt.conn = conn + tgt.connMu.Unlock() + + tgt.connectedOnce.Do(func() { + close(tgt.connected) + }) + + go tgt.readLoop() + + if tgt.onConnect != nil { + go tgt.onConnect(tgt) + } +} + +func (tgt *fakeCDPTarget) readLoop() { + defer tgt.closedOnce.Do(func() { + close(tgt.closed) + }) + + for { + _, data, err := tgt.conn.ReadMessage() + if err != nil { + return + } + + var msg cdpMessage + if err := json.Unmarshal(data, &msg); err != nil { + continue + } + + select { + case tgt.received <- msg: + default: + } + + if tgt.onMessage != nil { + tgt.onMessage(tgt, msg) + } + } +} + +func (tgt *fakeCDPTarget) reply(id int64, result map[string]any) { + tgt.writeJSON(cdpResponse{ + ID: id, + Result: result, + }) +} + +func (tgt *fakeCDPTarget) replyError(id int64, message string) { + tgt.writeJSON(cdpResponse{ + ID: id, + Error: &cdpError{ + Message: message, + }, + }) +} + +func (tgt *fakeCDPTarget) replyValue(id int64, value any) { + tgt.reply(id, map[string]any{ + "result": map[string]any{ + "value": value, + }, + }) +} + +func (tgt *fakeCDPTarget) writeJSON(value any) { + tgt.server.t.Helper() + + tgt.connMu.Lock() + defer tgt.connMu.Unlock() + if tgt.conn == nil { + tgt.server.t.Fatal("test WebSocket connection was not established") + } + if err := tgt.conn.WriteJSON(value); err != nil { + tgt.server.t.Fatalf("failed to write test WebSocket message: %v", err) + } +} + +func (tgt *fakeCDPTarget) closeWebSocket() { + tgt.connMu.Lock() + defer tgt.connMu.Unlock() + if tgt.conn != nil { + _ = tgt.conn.Close() + } +} + +func (tgt *fakeCDPTarget) waitForMessage(tb testing.TB) cdpMessage { + tb.Helper() + + select { + case msg := <-tgt.received: + return msg + case <-time.After(time.Second): + tb.Fatal("timed out waiting for CDP message") + return cdpMessage{} + } +} + +func (tgt *fakeCDPTarget) waitConnected(tb testing.TB) { + tb.Helper() + + select { + case <-tgt.connected: + case <-time.After(time.Second): + tb.Fatal("timed out waiting for WebSocket connection") + } +} + +func (tgt *fakeCDPTarget) waitClosed(tb testing.TB) { + tb.Helper() + + select { + case <-tgt.closed: + case <-time.After(time.Second): + tb.Fatal("timed out waiting for WebSocket closure") + } +} + +func TestCDPClientClose_Good_UnblocksReadLoop(t *testing.T) { + server := newFakeCDPServer(t) + target := server.primaryTarget() + + client, err := NewCDPClient(server.DebugURL()) + if err != nil { + t.Fatalf("NewCDPClient returned error: %v", err) + } + + target.waitConnected(t) + + done := make(chan error, 1) + go func() { + done <- client.Close() + }() + + select { + case err := <-done: + if err != nil { + t.Fatalf("Close returned error: %v", err) + } + case <-time.After(time.Second): + t.Fatal("Close blocked waiting for readLoop") + } +} + +func TestCDPClientReadLoop_Ugly_StopsOnTerminalReadError(t *testing.T) { + server := newFakeCDPServer(t) + target := server.primaryTarget() + target.onConnect = func(target *fakeCDPTarget) { + target.closeWebSocket() + } + + client, err := NewCDPClient(server.DebugURL()) + if err != nil { + t.Fatalf("NewCDPClient returned error: %v", err) + } + + select { + case <-client.done: + case <-time.After(time.Second): + t.Fatal("readLoop did not stop after terminal read error") + } +} + +func TestCDPClientCloseTab_Good_ClosesTargetOnly(t *testing.T) { + server := newFakeCDPServer(t) + target := server.primaryTarget() + target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) { + if msg.Method != "Target.closeTarget" { + t.Fatalf("CloseTab sent %q, want Target.closeTarget", msg.Method) + } + if got := msg.Params["targetId"]; got != target.id { + t.Fatalf("Target.closeTarget targetId = %v, want %q", got, target.id) + } + target.reply(msg.ID, map[string]any{"success": true}) + go func() { + time.Sleep(10 * time.Millisecond) + target.closeWebSocket() + }() + } + + client, err := NewCDPClient(server.DebugURL()) + if err != nil { + t.Fatalf("NewCDPClient returned error: %v", err) + } + + if err := client.CloseTab(); err != nil { + t.Fatalf("CloseTab returned error: %v", err) + } + + msg := target.waitForMessage(t) + if msg.Method == "Browser.close" { + t.Fatal("CloseTab closed the whole browser") + } +} + +func TestCDPClientDispatchEvent_Good_HandlerParamsAreIsolated(t *testing.T) { + client := &CDPClient{ + handlers: make(map[string][]func(map[string]any)), + } + + firstDone := make(chan map[string]any, 1) + secondDone := make(chan map[string]any, 1) + + client.OnEvent("Runtime.testEvent", func(params map[string]any) { + params["value"] = "mutated" + params["nested"].(map[string]any)["count"] = 1 + params["items"].([]any)[0].(map[string]any)["id"] = "changed" + firstDone <- params + }) + client.OnEvent("Runtime.testEvent", func(params map[string]any) { + secondDone <- params + }) + + original := map[string]any{ + "nested": map[string]any{"count": 0}, + "items": []any{map[string]any{"id": "original"}}, + } + + client.dispatchEvent("Runtime.testEvent", original) + + select { + case <-firstDone: + case <-time.After(time.Second): + t.Fatal("first handler did not run") + } + + var secondParams map[string]any + select { + case secondParams = <-secondDone: + case <-time.After(time.Second): + t.Fatal("second handler did not run") + } + + if _, ok := secondParams["value"]; ok { + t.Fatal("second handler observed first handler mutation") + } + if got := secondParams["nested"].(map[string]any)["count"]; got != 0 { + t.Fatalf("second handler nested count = %v, want 0", got) + } + if got := secondParams["items"].([]any)[0].(map[string]any)["id"]; got != "original" { + t.Fatalf("second handler slice payload = %v, want %q", got, "original") + } + if got := original["nested"].(map[string]any)["count"]; got != 0 { + t.Fatalf("original params were mutated: nested count = %v", got) + } +} + +func TestNewCDPClient_Bad_RejectsCrossHostWebSocket(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/json" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode([]TargetInfo{{ + ID: "target-1", + Type: "page", + WebSocketDebuggerURL: "ws://example.com/devtools/page/target-1", + }}); err != nil { + t.Fatalf("failed to encode targets: %v", err) + } + })) + defer server.Close() + + _, err := NewCDPClient(server.URL) + if err == nil { + t.Fatal("NewCDPClient succeeded with a cross-host WebSocket URL") + } + if !strings.Contains(err.Error(), "invalid target WebSocket URL") { + t.Fatalf("NewCDPClient error = %v, want cross-host WebSocket validation failure", err) + } +} + +func TestWebviewNew_Bad_ClosesClientWhenEnableConsoleFails(t *testing.T) { + server := newFakeCDPServer(t) + target := server.primaryTarget() + target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) { + if msg.Method != "Runtime.enable" { + t.Fatalf("enableConsole sent %q before Runtime.enable failed", msg.Method) + } + target.replyError(msg.ID, "runtime disabled") + } + + _, err := New( + WithTimeout(250*time.Millisecond), + WithDebugURL(server.DebugURL()), + ) + if err == nil { + t.Fatal("New succeeded when Runtime.enable failed") + } + + target.waitClosed(t) +} + +func TestAngularHelperWaitForZoneStability_Good_AwaitsPromise(t *testing.T) { + server := newFakeCDPServer(t) + target := server.primaryTarget() + target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) { + if msg.Method != "Runtime.evaluate" { + t.Fatalf("unexpected method %q", msg.Method) + } + target.replyValue(msg.ID, true) + } + + client, err := NewCDPClient(server.DebugURL()) + if err != nil { + t.Fatalf("NewCDPClient returned error: %v", err) + } + defer func() { _ = client.Close() }() + + wv := &Webview{ + client: client, + ctx: context.Background(), + timeout: time.Second, + } + ah := NewAngularHelper(wv) + + if err := ah.waitForZoneStability(context.Background()); err != nil { + t.Fatalf("waitForZoneStability returned error: %v", err) + } + + msg := target.waitForMessage(t) + if got := msg.Params["awaitPromise"]; got != true { + t.Fatalf("Runtime.evaluate awaitPromise = %v, want true", got) + } + if got := msg.Params["returnByValue"]; got != true { + t.Fatalf("Runtime.evaluate returnByValue = %v, want true", got) + } +} + +func TestAngularHelperSetNgModel_Good_EscapesSelectorAndValue(t *testing.T) { + server := newFakeCDPServer(t) + target := server.primaryTarget() + target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) { + if msg.Method != "Runtime.evaluate" { + t.Fatalf("unexpected method %q", msg.Method) + } + target.replyValue(msg.ID, true) + } + + client, err := NewCDPClient(server.DebugURL()) + if err != nil { + t.Fatalf("NewCDPClient returned error: %v", err) + } + defer func() { _ = client.Close() }() + + wv := &Webview{ + client: client, + ctx: context.Background(), + timeout: time.Second, + } + ah := NewAngularHelper(wv) + + selector := `input[name="x'];window.hacked=true;//"]` + value := `";window.hacked=true;//` + if err := ah.SetNgModel(selector, value); err != nil { + t.Fatalf("SetNgModel returned error: %v", err) + } + + expression, _ := target.waitForMessage(t).Params["expression"].(string) + if !strings.Contains(expression, "const selector = "+formatJSValue(selector)+";") { + t.Fatalf("expression did not contain safely quoted selector: %s", expression) + } + if !strings.Contains(expression, "element.value = "+formatJSValue(value)+";") { + t.Fatalf("expression did not contain safely quoted value: %s", expression) + } + if strings.Contains(expression, "throw new Error('Element not found: "+selector+"')") { + t.Fatalf("expression still embedded selector directly in error text: %s", expression) + } +} + +func TestConsoleWatcherWaitForMessage_Good_IsolatesTemporaryHandlers(t *testing.T) { + cw := &ConsoleWatcher{ + messages: make([]ConsoleMessage, 0), + filters: make([]ConsoleFilter, 0), + limit: 1000, + handlers: make([]consoleHandlerRegistration, 0), + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + results := make(chan string, 2) + errorsCh := make(chan error, 2) + + go func() { + msg, err := cw.WaitForMessage(ctx, ConsoleFilter{Type: "error"}) + if err != nil { + errorsCh <- err + return + } + results <- "error:" + msg.Text + }() + go func() { + msg, err := cw.WaitForMessage(ctx, ConsoleFilter{Type: "log"}) + if err != nil { + errorsCh <- err + return + } + results <- "log:" + msg.Text + }() + + time.Sleep(20 * time.Millisecond) + cw.addMessage(ConsoleMessage{Type: "error", Text: "first"}) + time.Sleep(20 * time.Millisecond) + cw.addMessage(ConsoleMessage{Type: "log", Text: "second"}) + + got := make(map[string]bool, 2) + for range 2 { + select { + case err := <-errorsCh: + t.Fatalf("WaitForMessage returned error: %v", err) + case result := <-results: + got[result] = true + case <-time.After(time.Second): + t.Fatal("timed out waiting for console waiter results") + } + } + + if !got["error:first"] || !got["log:second"] { + t.Fatalf("unexpected console waiter results: %#v", got) + } + if len(cw.handlers) != 0 { + t.Fatalf("temporary handlers leaked: %d", len(cw.handlers)) + } +} + +func TestExceptionWatcherWaitForException_Good_PreservesExistingHandlers(t *testing.T) { + ew := &ExceptionWatcher{ + exceptions: make([]ExceptionInfo, 0), + handlers: make([]exceptionHandlerRegistration, 0), + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + waitDone := make(chan error, 1) + go func() { + _, err := ew.WaitForException(ctx) + waitDone <- err + }() + + time.Sleep(20 * time.Millisecond) + + var mu sync.Mutex + count := 0 + ew.AddHandler(func(ExceptionInfo) { + mu.Lock() + defer mu.Unlock() + count++ + }) + + ew.handleException(map[string]any{ + "exceptionDetails": map[string]any{ + "text": "first", + "lineNumber": float64(1), + "columnNumber": float64(1), + "url": "https://example.com/app.js", + }, + }) + + select { + case err := <-waitDone: + if err != nil { + t.Fatalf("WaitForException returned error: %v", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for exception waiter") + } + + ew.handleException(map[string]any{ + "exceptionDetails": map[string]any{ + "text": "second", + "lineNumber": float64(2), + "columnNumber": float64(1), + "url": "https://example.com/app.js", + }, + }) + + mu.Lock() + defer mu.Unlock() + if count != 2 { + t.Fatalf("persistent handler count = %d, want 2", count) + } + if len(ew.handlers) != 1 { + t.Fatalf("unexpected handler count after waiter removal: %d", len(ew.handlers)) + } +} diff --git a/cdp.go b/cdp.go index 365c198..444a07a 100644 --- a/cdp.go +++ b/cdp.go @@ -1,26 +1,46 @@ +// SPDX-License-Identifier: EUPL-1.2 package webview import ( "context" "encoding/json" + "errors" "io" "iter" + "net" "net/http" + "net/url" + "path" "slices" + "strings" "sync" "sync/atomic" + "time" "github.com/gorilla/websocket" coreerr "dappco.re/go/core/log" ) +const debugEndpointTimeout = 10 * time.Second + +var ( + defaultDebugHTTPClient = &http.Client{ + Timeout: debugEndpointTimeout, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } + errCDPClientClosed = errors.New("cdp client closed") +) + // CDPClient handles communication with Chrome DevTools Protocol via WebSocket. type CDPClient struct { - mu sync.RWMutex - conn *websocket.Conn - debugURL string - wsURL string + mu sync.RWMutex + conn *websocket.Conn + debugURL string + debugBase *url.URL + wsURL string // Message tracking msgID atomic.Int64 @@ -32,9 +52,11 @@ type CDPClient struct { handMu sync.RWMutex // Lifecycle - ctx context.Context - cancel context.CancelFunc - done chan struct{} + ctx context.Context + cancel context.CancelFunc + done chan struct{} + closeOnce sync.Once + closeErr error } // cdpMessage represents a CDP protocol message. @@ -76,51 +98,41 @@ type TargetInfo struct { // NewCDPClient creates a new CDP client connected to the given debug URL. // The debug URL should be the Chrome DevTools HTTP endpoint (e.g., http://localhost:9222). func NewCDPClient(debugURL string) (*CDPClient, error) { - // Get available targets - resp, err := http.Get(debugURL + "/json") + debugBase, err := parseDebugURL(debugURL) + if err != nil { + return nil, coreerr.E("CDPClient.New", "invalid debug URL", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), debugEndpointTimeout) + defer cancel() + + targets, err := listTargetsAt(ctx, debugBase) if err != nil { return nil, coreerr.E("CDPClient.New", "failed to get targets", err) } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, coreerr.E("CDPClient.New", "failed to read targets", err) - } - - var targets []TargetInfo - if err := json.Unmarshal(body, &targets); err != nil { - return nil, coreerr.E("CDPClient.New", "failed to parse targets", err) - } // Find a page target var wsURL string for _, t := range targets { if t.Type == "page" && t.WebSocketDebuggerURL != "" { - wsURL = t.WebSocketDebuggerURL + wsURL, err = validateTargetWebSocketURL(debugBase, t.WebSocketDebuggerURL) + if err != nil { + return nil, coreerr.E("CDPClient.New", "invalid target WebSocket URL", err) + } break } } if wsURL == "" { - // Try to create a new target - resp, err := http.Get(debugURL + "/json/new") + newTarget, err := createTargetAt(ctx, debugBase, "") if err != nil { return nil, coreerr.E("CDPClient.New", "no page targets found and failed to create new", err) } - defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) + wsURL, err = validateTargetWebSocketURL(debugBase, newTarget.WebSocketDebuggerURL) if err != nil { - return nil, coreerr.E("CDPClient.New", "failed to read new target", err) + return nil, coreerr.E("CDPClient.New", "invalid new target WebSocket URL", err) } - - var newTarget TargetInfo - if err := json.Unmarshal(body, &newTarget); err != nil { - return nil, coreerr.E("CDPClient.New", "failed to parse new target", err) - } - - wsURL = newTarget.WebSocketDebuggerURL } if wsURL == "" { @@ -133,30 +145,17 @@ func NewCDPClient(debugURL string) (*CDPClient, error) { return nil, coreerr.E("CDPClient.New", "failed to connect to WebSocket", err) } - ctx, cancel := context.WithCancel(context.Background()) - - client := &CDPClient{ - conn: conn, - debugURL: debugURL, - wsURL: wsURL, - pending: make(map[int64]chan *cdpResponse), - handlers: make(map[string][]func(map[string]any)), - ctx: ctx, - cancel: cancel, - done: make(chan struct{}), - } - - // Start message reader - go client.readLoop() - - return client, nil + return newCDPClient(debugBase, wsURL, conn), nil } // Close closes the CDP connection. func (c *CDPClient) Close() error { - c.cancel() - <-c.done // Wait for read loop to finish - return c.conn.Close() + c.close(errCDPClientClosed) + <-c.done + if c.closeErr != nil { + return coreerr.E("CDPClient.Close", "failed to close WebSocket", c.closeErr) + } + return nil } // Call sends a CDP method call and waits for the response. @@ -166,7 +165,7 @@ func (c *CDPClient) Call(ctx context.Context, method string, params map[string]a msg := cdpMessage{ ID: id, Method: method, - Params: params, + Params: cloneMapAny(params), } // Register response channel @@ -193,6 +192,8 @@ func (c *CDPClient) Call(ctx context.Context, method string, params map[string]a select { case <-ctx.Done(): return nil, ctx.Err() + case <-c.ctx.Done(): + return nil, coreerr.E("CDPClient.Call", "client closed", errCDPClientClosed) case resp := <-respCh: if resp.Error != nil { return nil, coreerr.E("CDPClient.Call", resp.Error.Message, nil) @@ -213,22 +214,23 @@ func (c *CDPClient) readLoop() { defer close(c.done) for { - select { - case <-c.ctx.Done(): - return - default: - } - _, data, err := c.conn.ReadMessage() if err != nil { - // Check if context was cancelled - select { - case <-c.ctx.Done(): + if c.ctx.Err() != nil { return - default: - // Log error but continue (could be temporary) + } + if isTerminalReadError(err) { + c.close(err) + return + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { continue } + + c.close(err) + return } // Try to parse as response @@ -237,7 +239,10 @@ func (c *CDPClient) readLoop() { c.pendMu.Lock() if ch, ok := c.pending[resp.ID]; ok { respCopy := resp - ch <- &respCopy + select { + case ch <- &respCopy: + default: + } } c.pendMu.Unlock() continue @@ -259,7 +264,8 @@ func (c *CDPClient) dispatchEvent(method string, params map[string]any) { for _, handler := range handlers { // Call handler in goroutine to avoid blocking - go handler(params) + handlerParams := cloneMapAny(params) + go handler(handlerParams) } } @@ -267,7 +273,7 @@ func (c *CDPClient) dispatchEvent(method string, params map[string]any) { func (c *CDPClient) Send(method string, params map[string]any) error { msg := cdpMessage{ Method: method, - Params: params, + Params: cloneMapAny(params), } c.mu.Lock() @@ -287,83 +293,70 @@ func (c *CDPClient) WebSocketURL() string { // NewTab creates a new browser tab and returns a new CDPClient connected to it. func (c *CDPClient) NewTab(url string) (*CDPClient, error) { - endpoint := c.debugURL + "/json/new" - if url != "" { - endpoint += "?" + url - } + ctx, cancel := context.WithTimeout(c.ctx, debugEndpointTimeout) + defer cancel() - resp, err := http.Get(endpoint) + target, err := createTargetAt(ctx, c.debugBase, url) if err != nil { return nil, coreerr.E("CDPClient.NewTab", "failed to create new tab", err) } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, coreerr.E("CDPClient.NewTab", "failed to read response", err) - } - - var target TargetInfo - if err := json.Unmarshal(body, &target); err != nil { - return nil, coreerr.E("CDPClient.NewTab", "failed to parse target", err) - } if target.WebSocketDebuggerURL == "" { return nil, coreerr.E("CDPClient.NewTab", "no WebSocket URL for new tab", nil) } + wsURL, err := validateTargetWebSocketURL(c.debugBase, target.WebSocketDebuggerURL) + if err != nil { + return nil, coreerr.E("CDPClient.NewTab", "invalid WebSocket URL for new tab", err) + } + // Connect to new tab - conn, _, err := websocket.DefaultDialer.Dial(target.WebSocketDebuggerURL, nil) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { return nil, coreerr.E("CDPClient.NewTab", "failed to connect to new tab", err) } - ctx, cancel := context.WithCancel(context.Background()) - - client := &CDPClient{ - conn: conn, - debugURL: c.debugURL, - wsURL: target.WebSocketDebuggerURL, - pending: make(map[int64]chan *cdpResponse), - handlers: make(map[string][]func(map[string]any)), - ctx: ctx, - cancel: cancel, - done: make(chan struct{}), - } - - go client.readLoop() - - return client, nil + return newCDPClient(c.debugBase, wsURL, conn), nil } // CloseTab closes the current tab (target). func (c *CDPClient) CloseTab() error { - // Extract target ID from WebSocket URL - // Format: ws://host:port/devtools/page/TARGET_ID - // We'll use the Browser.close target API + targetID, err := targetIDFromWebSocketURL(c.wsURL) + if err != nil { + return coreerr.E("CDPClient.CloseTab", "failed to determine target ID", err) + } - ctx := context.Background() - _, err := c.Call(ctx, "Browser.close", nil) - return err + ctx, cancel := context.WithTimeout(c.ctx, debugEndpointTimeout) + defer cancel() + + result, err := c.Call(ctx, "Target.closeTarget", map[string]any{ + "targetId": targetID, + }) + if err != nil { + return coreerr.E("CDPClient.CloseTab", "failed to close target", err) + } + + if success, ok := result["success"].(bool); ok && !success { + return coreerr.E("CDPClient.CloseTab", "target close was not acknowledged", nil) + } + + return c.Close() } // ListTargets returns all available targets. func ListTargets(debugURL string) ([]TargetInfo, error) { - resp, err := http.Get(debugURL + "/json") + debugBase, err := parseDebugURL(debugURL) + if err != nil { + return nil, coreerr.E("ListTargets", "invalid debug URL", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), debugEndpointTimeout) + defer cancel() + + targets, err := listTargetsAt(ctx, debugBase) if err != nil { return nil, coreerr.E("ListTargets", "failed to get targets", err) } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, coreerr.E("ListTargets", "failed to read targets", err) - } - - var targets []TargetInfo - if err := json.Unmarshal(body, &targets); err != nil { - return nil, coreerr.E("ListTargets", "failed to parse targets", err) - } return targets, nil } @@ -385,16 +378,18 @@ func ListTargetsAll(debugURL string) iter.Seq[TargetInfo] { // GetVersion returns Chrome version information. func GetVersion(debugURL string) (map[string]string, error) { - resp, err := http.Get(debugURL + "/json/version") + debugBase, err := parseDebugURL(debugURL) + if err != nil { + return nil, coreerr.E("GetVersion", "invalid debug URL", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), debugEndpointTimeout) + defer cancel() + + body, err := doDebugRequest(ctx, debugBase, "/json/version", "") if err != nil { return nil, coreerr.E("GetVersion", "failed to get version", err) } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, coreerr.E("GetVersion", "failed to read version", err) - } var version map[string]string if err := json.Unmarshal(body, &version); err != nil { @@ -403,3 +398,241 @@ func GetVersion(debugURL string) (map[string]string, error) { return version, nil } + +func newCDPClient(debugBase *url.URL, wsURL string, conn *websocket.Conn) *CDPClient { + ctx, cancel := context.WithCancel(context.Background()) + baseCopy := *debugBase + + client := &CDPClient{ + conn: conn, + debugURL: canonicalDebugURL(&baseCopy), + debugBase: &baseCopy, + wsURL: wsURL, + pending: make(map[int64]chan *cdpResponse), + handlers: make(map[string][]func(map[string]any)), + ctx: ctx, + cancel: cancel, + done: make(chan struct{}), + } + + go client.readLoop() + + return client +} + +func parseDebugURL(raw string) (*url.URL, error) { + debugURL, err := url.Parse(raw) + if err != nil { + return nil, err + } + if debugURL.Scheme != "http" && debugURL.Scheme != "https" { + return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must use http or https", nil) + } + if debugURL.Host == "" { + return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL host is required", nil) + } + if debugURL.User != nil { + return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must not include credentials", nil) + } + if debugURL.RawQuery != "" || debugURL.Fragment != "" { + return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must not include query or fragment", nil) + } + if debugURL.Path == "" { + debugURL.Path = "/" + } + if debugURL.Path != "/" { + return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must point at the DevTools root", nil) + } + return debugURL, nil +} + +func canonicalDebugURL(debugURL *url.URL) string { + return strings.TrimSuffix(debugURL.String(), "/") +} + +func doDebugRequest(ctx context.Context, debugBase *url.URL, endpoint, rawQuery string) ([]byte, error) { + reqURL := *debugBase + reqURL.Path = endpoint + reqURL.RawPath = "" + reqURL.RawQuery = rawQuery + reqURL.Fragment = "" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) + if err != nil { + return nil, err + } + + resp, err := defaultDebugHTTPClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, coreerr.E("CDPClient.doDebugRequest", "debug endpoint returned "+resp.Status, nil) + } + + return body, nil +} + +func listTargetsAt(ctx context.Context, debugBase *url.URL) ([]TargetInfo, error) { + body, err := doDebugRequest(ctx, debugBase, "/json", "") + if err != nil { + return nil, err + } + + var targets []TargetInfo + if err := json.Unmarshal(body, &targets); err != nil { + return nil, err + } + + return targets, nil +} + +func createTargetAt(ctx context.Context, debugBase *url.URL, pageURL string) (*TargetInfo, error) { + rawQuery := "" + if pageURL != "" { + rawQuery = url.QueryEscape(pageURL) + } + + body, err := doDebugRequest(ctx, debugBase, "/json/new", rawQuery) + if err != nil { + return nil, err + } + + var target TargetInfo + if err := json.Unmarshal(body, &target); err != nil { + return nil, err + } + + return &target, nil +} + +func validateTargetWebSocketURL(debugBase *url.URL, raw string) (string, error) { + wsURL, err := url.Parse(raw) + if err != nil { + return "", err + } + if wsURL.Scheme != "ws" && wsURL.Scheme != "wss" { + return "", coreerr.E("CDPClient.validateTargetWebSocketURL", "target WebSocket URL must use ws or wss", nil) + } + if !sameEndpointHost(debugBase, wsURL) { + return "", coreerr.E("CDPClient.validateTargetWebSocketURL", "target WebSocket URL must match debug URL host", nil) + } + return wsURL.String(), nil +} + +func sameEndpointHost(httpURL, wsURL *url.URL) bool { + return strings.EqualFold(httpURL.Hostname(), wsURL.Hostname()) && normalisedPort(httpURL) == normalisedPort(wsURL) +} + +func normalisedPort(u *url.URL) string { + if port := u.Port(); port != "" { + return port + } + + switch u.Scheme { + case "http", "ws": + return "80" + case "https", "wss": + return "443" + default: + return "" + } +} + +func targetIDFromWebSocketURL(raw string) (string, error) { + wsURL, err := url.Parse(raw) + if err != nil { + return "", err + } + + targetID := path.Base(strings.TrimSuffix(wsURL.Path, "/")) + if targetID == "." || targetID == "/" || targetID == "" { + return "", coreerr.E("CDPClient.targetIDFromWebSocketURL", "missing target ID in WebSocket URL", nil) + } + + return targetID, nil +} + +func (c *CDPClient) close(reason error) { + c.closeOnce.Do(func() { + c.cancel() + c.failPending(reason) + + c.mu.Lock() + err := c.conn.Close() + c.mu.Unlock() + if err != nil && !isTerminalReadError(err) { + c.closeErr = err + } + }) +} + +func (c *CDPClient) failPending(err error) { + c.pendMu.Lock() + defer c.pendMu.Unlock() + + for id, ch := range c.pending { + resp := &cdpResponse{ + ID: id, + Error: &cdpError{ + Message: err.Error(), + }, + } + select { + case ch <- resp: + default: + } + } +} + +func isTerminalReadError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, net.ErrClosed) || errors.Is(err, websocket.ErrCloseSent) { + return true + } + var closeErr *websocket.CloseError + return errors.As(err, &closeErr) +} + +func cloneMapAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + + dst := make(map[string]any, len(src)) + for key, value := range src { + dst[key] = cloneAny(value) + } + return dst +} + +func cloneSliceAny(src []any) []any { + if src == nil { + return nil + } + + dst := make([]any, len(src)) + for i, value := range src { + dst[i] = cloneAny(value) + } + return dst +} + +func cloneAny(value any) any { + switch typed := value.(type) { + case map[string]any: + return cloneMapAny(typed) + case []any: + return cloneSliceAny(typed) + default: + return typed + } +} diff --git a/console.go b/console.go index cd5af31..d5d22b4 100644 --- a/console.go +++ b/console.go @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: EUPL-1.2 package webview import ( @@ -7,17 +8,19 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" ) // ConsoleWatcher provides advanced console message watching capabilities. type ConsoleWatcher struct { - mu sync.RWMutex - wv *Webview - messages []ConsoleMessage - filters []ConsoleFilter - limit int - handlers []ConsoleHandler + mu sync.RWMutex + wv *Webview + messages []ConsoleMessage + filters []ConsoleFilter + limit int + handlers []consoleHandlerRegistration + nextHandlerID atomic.Int64 } // ConsoleFilter filters console messages. @@ -29,6 +32,11 @@ type ConsoleFilter struct { // ConsoleHandler is called when a matching console message is received. type ConsoleHandler func(msg ConsoleMessage) +type consoleHandlerRegistration struct { + id int64 + handler ConsoleHandler +} + // NewConsoleWatcher creates a new console watcher for the webview. func NewConsoleWatcher(wv *Webview) *ConsoleWatcher { cw := &ConsoleWatcher{ @@ -36,7 +44,7 @@ func NewConsoleWatcher(wv *Webview) *ConsoleWatcher { messages: make([]ConsoleMessage, 0, 100), filters: make([]ConsoleFilter, 0), limit: 1000, - handlers: make([]ConsoleHandler, 0), + handlers: make([]consoleHandlerRegistration, 0), } // Subscribe to console events from the webview's client @@ -63,9 +71,30 @@ func (cw *ConsoleWatcher) ClearFilters() { // AddHandler adds a handler for console messages. func (cw *ConsoleWatcher) AddHandler(handler ConsoleHandler) { + cw.addHandler(handler) +} + +func (cw *ConsoleWatcher) addHandler(handler ConsoleHandler) int64 { cw.mu.Lock() defer cw.mu.Unlock() - cw.handlers = append(cw.handlers, handler) + id := cw.nextHandlerID.Add(1) + cw.handlers = append(cw.handlers, consoleHandlerRegistration{ + id: id, + handler: handler, + }) + return id +} + +func (cw *ConsoleWatcher) removeHandler(id int64) { + cw.mu.Lock() + defer cw.mu.Unlock() + + for i, registration := range cw.handlers { + if registration.id == id { + cw.handlers = slices.Delete(cw.handlers, i, i+1) + return + } + } } // SetLimit sets the maximum number of messages to retain. @@ -187,13 +216,8 @@ func (cw *ConsoleWatcher) WaitForMessage(ctx context.Context, filter ConsoleFilt } } - cw.AddHandler(handler) - defer func() { - cw.mu.Lock() - // Remove handler (simple implementation - in production you'd want a handle-based removal) - cw.handlers = cw.handlers[:len(cw.handlers)-1] - cw.mu.Unlock() - }() + handlerID := cw.addHandler(handler) + defer cw.removeHandler(handlerID) select { case <-ctx.Done(): @@ -302,8 +326,8 @@ func (cw *ConsoleWatcher) addMessage(msg ConsoleMessage) { cw.mu.Unlock() // Call handlers - for _, handler := range handlers { - handler(msg) + for _, registration := range handlers { + registration.handler(msg) } } @@ -361,10 +385,16 @@ type ExceptionInfo struct { // ExceptionWatcher watches for JavaScript exceptions. type ExceptionWatcher struct { - mu sync.RWMutex - wv *Webview - exceptions []ExceptionInfo - handlers []func(ExceptionInfo) + mu sync.RWMutex + wv *Webview + exceptions []ExceptionInfo + handlers []exceptionHandlerRegistration + nextHandlerID atomic.Int64 +} + +type exceptionHandlerRegistration struct { + id int64 + handler func(ExceptionInfo) } // NewExceptionWatcher creates a new exception watcher. @@ -372,7 +402,7 @@ func NewExceptionWatcher(wv *Webview) *ExceptionWatcher { ew := &ExceptionWatcher{ wv: wv, exceptions: make([]ExceptionInfo, 0), - handlers: make([]func(ExceptionInfo), 0), + handlers: make([]exceptionHandlerRegistration, 0), } // Subscribe to exception events @@ -425,9 +455,30 @@ func (ew *ExceptionWatcher) Count() int { // AddHandler adds a handler for exceptions. func (ew *ExceptionWatcher) AddHandler(handler func(ExceptionInfo)) { + ew.addHandler(handler) +} + +func (ew *ExceptionWatcher) addHandler(handler func(ExceptionInfo)) int64 { ew.mu.Lock() defer ew.mu.Unlock() - ew.handlers = append(ew.handlers, handler) + id := ew.nextHandlerID.Add(1) + ew.handlers = append(ew.handlers, exceptionHandlerRegistration{ + id: id, + handler: handler, + }) + return id +} + +func (ew *ExceptionWatcher) removeHandler(id int64) { + ew.mu.Lock() + defer ew.mu.Unlock() + + for i, registration := range ew.handlers { + if registration.id == id { + ew.handlers = slices.Delete(ew.handlers, i, i+1) + return + } + } } // WaitForException waits for an exception to be thrown. @@ -450,12 +501,8 @@ func (ew *ExceptionWatcher) WaitForException(ctx context.Context) (*ExceptionInf } } - ew.AddHandler(handler) - defer func() { - ew.mu.Lock() - ew.handlers = ew.handlers[:len(ew.handlers)-1] - ew.mu.Unlock() - }() + handlerID := ew.addHandler(handler) + defer ew.removeHandler(handlerID) select { case <-ctx.Done(): @@ -515,8 +562,8 @@ func (ew *ExceptionWatcher) handleException(params map[string]any) { ew.mu.Unlock() // Call handlers - for _, handler := range handlers { - handler(info) + for _, registration := range handlers { + registration.handler(info) } } diff --git a/webview.go b/webview.go index 5305e9c..8ca1a90 100644 --- a/webview.go +++ b/webview.go @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: EUPL-1.2 // Package webview provides browser automation via Chrome DevTools Protocol (CDP). // // The package allows controlling Chrome/Chromium browsers for automated testing, @@ -118,9 +119,16 @@ func New(opts ...Option) (*Webview, error) { consoleLimit: 1000, } + cleanupOnError := func() { + cancel() + if wv.client != nil { + _ = wv.client.Close() + } + } + for _, opt := range opts { if err := opt(wv); err != nil { - cancel() + cleanupOnError() return nil, err } } @@ -132,7 +140,7 @@ func New(opts ...Option) (*Webview, error) { // Enable console capture if err := wv.enableConsole(); err != nil { - cancel() + cleanupOnError() return nil, coreerr.E("Webview.New", "failed to enable console capture", err) } @@ -542,6 +550,7 @@ func (wv *Webview) evaluate(ctx context.Context, script string) (any, error) { result, err := wv.client.Call(ctx, "Runtime.evaluate", map[string]any{ "expression": script, "returnByValue": true, + "awaitPromise": true, }) if err != nil { return nil, coreerr.E("Webview.evaluate", "failed to evaluate script", err) diff --git a/webview_test.go b/webview_test.go index cbecc51..dd32729 100644 --- a/webview_test.go +++ b/webview_test.go @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: EUPL-1.2 package webview import ( @@ -427,6 +428,8 @@ func TestFormatJSValue_Good(t *testing.T) { {nil, "null"}, {42, "42"}, {3.14, "3.14"}, + {map[string]any{"enabled": true}, `{"enabled":true}`}, + {[]any{1, "two"}, `[1,"two"]`}, } for _, tc := range tests { @@ -512,7 +515,7 @@ func TestConsoleWatcherFilter_Good(t *testing.T) { messages: make([]ConsoleMessage, 0), filters: make([]ConsoleFilter, 0), limit: 1000, - handlers: make([]ConsoleHandler, 0), + handlers: make([]consoleHandlerRegistration, 0), } // No filters — everything matches @@ -556,7 +559,7 @@ func TestConsoleWatcherCounts_Good(t *testing.T) { }, filters: make([]ConsoleFilter, 0), limit: 1000, - handlers: make([]ConsoleHandler, 0), + handlers: make([]consoleHandlerRegistration, 0), } if cw.Count() != 5 { @@ -592,7 +595,7 @@ func TestConsoleWatcherCounts_Good(t *testing.T) { func TestExceptionWatcher_Good(t *testing.T) { ew := &ExceptionWatcher{ exceptions: make([]ExceptionInfo, 0), - handlers: make([]func(ExceptionInfo), 0), + handlers: make([]exceptionHandlerRegistration, 0), } if ew.HasExceptions() { @@ -682,7 +685,7 @@ func TestConsoleWatcherAddMessage_Good(t *testing.T) { messages: make([]ConsoleMessage, 0), filters: make([]ConsoleFilter, 0), limit: 5, - handlers: make([]ConsoleHandler, 0), + handlers: make([]consoleHandlerRegistration, 0), } // Add messages past the limit @@ -704,7 +707,7 @@ func TestConsoleWatcherHandler_Good(t *testing.T) { messages: make([]ConsoleMessage, 0), filters: make([]ConsoleFilter, 0), limit: 1000, - handlers: make([]ConsoleHandler, 0), + handlers: make([]consoleHandlerRegistration, 0), } var received ConsoleMessage @@ -729,7 +732,7 @@ func TestConsoleWatcherFilteredMessages_Good(t *testing.T) { }, filters: []ConsoleFilter{{Type: "error"}}, limit: 1000, - handlers: make([]ConsoleHandler, 0), + handlers: make([]consoleHandlerRegistration, 0), } filtered := cw.FilteredMessages()