fix(cdp): resolve issue 2 audit findings

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-23 07:34:16 +00:00
parent c6d1ccba7d
commit dff3d576fa
7 changed files with 1255 additions and 288 deletions

View file

@ -1,3 +1,4 @@
// SPDX-License-Identifier: EUPL-1.2
package webview
import (

View file

@ -1,7 +1,9 @@
// SPDX-License-Identifier: EUPL-1.2
package webview
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
@ -93,6 +95,21 @@ func (ah *AngularHelper) isAngularApp(ctx context.Context) (bool, error) {
func (ah *AngularHelper) waitForZoneStability(ctx context.Context) error {
script := `
new Promise((resolve, reject) => {
const pollZone = () => {
if (!window.Zone || !window.Zone.current) {
resolve(true);
return;
}
const inner = window.Zone.current._inner || window.Zone.current;
if (!inner._hasPendingMicrotasks && !inner._hasPendingMacrotasks) {
resolve(true);
return;
}
setTimeout(pollZone, 50);
};
// Get the root elements
const roots = window.getAllAngularRootElements ? window.getAllAngularRootElements() : [];
if (roots.length === 0) {
@ -121,28 +138,7 @@ func (ah *AngularHelper) waitForZoneStability(ctx context.Context) error {
}
if (!zone) {
// Fallback: check window.Zone
if (window.Zone && window.Zone.current && window.Zone.current._inner) {
const isStable = !window.Zone.current._inner._hasPendingMicrotasks &&
!window.Zone.current._inner._hasPendingMacrotasks;
if (isStable) {
resolve(true);
} else {
// Poll for stability
let attempts = 0;
const poll = setInterval(() => {
attempts++;
const stable = !window.Zone.current._inner._hasPendingMicrotasks &&
!window.Zone.current._inner._hasPendingMacrotasks;
if (stable || attempts > 100) {
clearInterval(poll);
resolve(stable);
}
}, 50);
}
} else {
resolve(true);
}
pollZone();
return;
}
@ -153,30 +149,28 @@ func (ah *AngularHelper) waitForZoneStability(ctx context.Context) error {
}
// Wait for stability
const sub = zone.onStable.subscribe(() => {
sub.unsubscribe();
resolve(true);
});
// Timeout fallback
setTimeout(() => {
sub.unsubscribe();
resolve(zone.isStable);
}, 5000);
try {
const sub = zone.onStable.subscribe(() => {
sub.unsubscribe();
resolve(true);
});
} catch (e) {
pollZone();
}
})
`
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
// First evaluate the promise
_, err := ah.wv.evaluate(ctx, script)
result, err := ah.wv.evaluate(ctx, script)
if err != nil {
// If the script fails, fall back to simple polling
return ah.pollForStability(ctx)
}
return nil
if stable, ok := result.(bool); ok && stable {
return nil
}
return ah.pollForStability(ctx)
}
// pollForStability polls for Angular stability as a fallback.
@ -333,18 +327,20 @@ func (ah *AngularHelper) GetComponentProperty(selector, propertyName string) (an
defer cancel()
script := fmt.Sprintf(`
(function() {
const element = document.querySelector(%q);
if (!element) {
throw new Error('Element not found: %s');
}
const component = window.ng.probe(element).componentInstance;
if (!component) {
throw new Error('No Angular component found on element');
}
return component[%q];
})()
`, selector, selector, propertyName)
(function() {
const selector = %s;
const propertyName = %s;
const element = document.querySelector(selector);
if (!element) {
throw new Error('Element not found: ' + selector);
}
const component = window.ng.probe(element).componentInstance;
if (!component) {
throw new Error('No Angular component found on element');
}
return component[propertyName];
})()
`, formatJSValue(selector), formatJSValue(propertyName))
return ah.wv.evaluate(ctx, script)
}
@ -355,26 +351,28 @@ func (ah *AngularHelper) SetComponentProperty(selector, propertyName string, val
defer cancel()
script := fmt.Sprintf(`
(function() {
const element = document.querySelector(%q);
if (!element) {
throw new Error('Element not found: %s');
}
const component = window.ng.probe(element).componentInstance;
if (!component) {
throw new Error('No Angular component found on element');
}
component[%q] = %v;
(function() {
const selector = %s;
const propertyName = %s;
const element = document.querySelector(selector);
if (!element) {
throw new Error('Element not found: ' + selector);
}
const component = window.ng.probe(element).componentInstance;
if (!component) {
throw new Error('No Angular component found on element');
}
component[propertyName] = %s;
// Trigger change detection
const injector = window.ng.probe(element).injector;
const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef');
if (appRef) {
// Trigger change detection
const injector = window.ng.probe(element).injector;
const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef');
if (appRef) {
appRef.tick();
}
return true;
})()
`, selector, selector, propertyName, formatJSValue(value))
}
return true;
})()
`, formatJSValue(selector), formatJSValue(propertyName), formatJSValue(value))
_, err := ah.wv.evaluate(ctx, script)
return err
@ -394,29 +392,31 @@ func (ah *AngularHelper) CallComponentMethod(selector, methodName string, args .
}
script := fmt.Sprintf(`
(function() {
const element = document.querySelector(%q);
if (!element) {
throw new Error('Element not found: %s');
}
const component = window.ng.probe(element).componentInstance;
if (!component) {
throw new Error('No Angular component found on element');
}
if (typeof component[%q] !== 'function') {
throw new Error('Method not found: %s');
}
const result = component[%q](%s);
(function() {
const selector = %s;
const methodName = %s;
const element = document.querySelector(selector);
if (!element) {
throw new Error('Element not found: ' + selector);
}
const component = window.ng.probe(element).componentInstance;
if (!component) {
throw new Error('No Angular component found on element');
}
if (typeof component[methodName] !== 'function') {
throw new Error('Method not found: ' + methodName);
}
const result = component[methodName](%s);
// Trigger change detection
const injector = window.ng.probe(element).injector;
const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef');
if (appRef) {
// Trigger change detection
const injector = window.ng.probe(element).injector;
const appRef = injector.get(window.ng.coreTokens.ApplicationRef || 'ApplicationRef');
if (appRef) {
appRef.tick();
}
return result;
})()
`, selector, selector, methodName, methodName, methodName, argsStr.String())
}
return result;
})()
`, formatJSValue(selector), formatJSValue(methodName), argsStr.String())
return ah.wv.evaluate(ctx, script)
}
@ -524,16 +524,18 @@ func (ah *AngularHelper) DispatchEvent(selector, eventName string, detail any) e
}
script := fmt.Sprintf(`
(function() {
const element = document.querySelector(%q);
if (!element) {
throw new Error('Element not found: %s');
}
const event = new CustomEvent(%q, { bubbles: true, detail: %s });
element.dispatchEvent(event);
return true;
})()
`, selector, selector, eventName, detailStr)
(function() {
const selector = %s;
const eventName = %s;
const element = document.querySelector(selector);
if (!element) {
throw new Error('Element not found: ' + selector);
}
const event = new CustomEvent(eventName, { bubbles: true, detail: %s });
element.dispatchEvent(event);
return true;
})()
`, formatJSValue(selector), formatJSValue(eventName), detailStr)
_, err := ah.wv.evaluate(ctx, script)
return err
@ -572,17 +574,18 @@ func (ah *AngularHelper) SetNgModel(selector string, value any) error {
defer cancel()
script := fmt.Sprintf(`
(function() {
const element = document.querySelector(%q);
if (!element) {
throw new Error('Element not found: %s');
}
(function() {
const selector = %s;
const element = document.querySelector(selector);
if (!element) {
throw new Error('Element not found: ' + selector);
}
element.value = %v;
element.dispatchEvent(new Event('input', { bubbles: true }));
element.dispatchEvent(new Event('change', { bubbles: true }));
element.value = %s;
element.dispatchEvent(new Event('input', { bubbles: true }));
element.dispatchEvent(new Event('change', { bubbles: true }));
// Trigger change detection
// Trigger change detection
const roots = window.getAllAngularRootElements ? window.getAllAngularRootElements() : [];
for (const root of roots) {
try {
@ -595,9 +598,9 @@ func (ah *AngularHelper) SetNgModel(selector string, value any) error {
} catch (e) {}
}
return true;
})()
`, selector, selector, formatJSValue(value))
return true;
})()
`, formatJSValue(selector), formatJSValue(value))
_, err := ah.wv.evaluate(ctx, script)
return err
@ -613,17 +616,15 @@ func getString(m map[string]any, key string) string {
}
func formatJSValue(v any) string {
switch val := v.(type) {
case string:
return fmt.Sprintf("%q", val)
case bool:
if val {
return "true"
}
return "false"
case nil:
return "null"
default:
return fmt.Sprintf("%v", val)
data, err := json.Marshal(v)
if err == nil {
return string(data)
}
fallback, fallbackErr := json.Marshal(fmt.Sprint(v))
if fallbackErr == nil {
return string(fallback)
}
return "null"
}

673
audit_issue2_test.go Normal file
View file

@ -0,0 +1,673 @@
// SPDX-License-Identifier: EUPL-1.2
package webview
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
)
type fakeCDPServer struct {
t *testing.T
server *httptest.Server
mu sync.Mutex
nextTarget int
targets map[string]*fakeCDPTarget
}
type fakeCDPTarget struct {
server *fakeCDPServer
id string
onConnect func(*fakeCDPTarget)
onMessage func(*fakeCDPTarget, cdpMessage)
connMu sync.Mutex
conn *websocket.Conn
received chan cdpMessage
connected chan struct{}
closed chan struct{}
connectedOnce sync.Once
closedOnce sync.Once
}
func newFakeCDPServer(t *testing.T) *fakeCDPServer {
t.Helper()
server := &fakeCDPServer{
t: t,
targets: make(map[string]*fakeCDPTarget),
}
server.server = httptest.NewServer(http.HandlerFunc(server.handle))
server.addTarget("target-1")
t.Cleanup(server.Close)
return server
}
func (s *fakeCDPServer) Close() {
s.server.Close()
}
func (s *fakeCDPServer) DebugURL() string {
return s.server.URL
}
func (s *fakeCDPServer) addTarget(id string) *fakeCDPTarget {
s.mu.Lock()
defer s.mu.Unlock()
target := &fakeCDPTarget{
server: s,
id: id,
received: make(chan cdpMessage, 16),
connected: make(chan struct{}),
closed: make(chan struct{}),
}
s.targets[id] = target
return target
}
func (s *fakeCDPServer) newTarget() *fakeCDPTarget {
s.mu.Lock()
s.nextTarget++
id := fmt.Sprintf("target-%d", s.nextTarget+1)
s.mu.Unlock()
return s.addTarget(id)
}
func (s *fakeCDPServer) primaryTarget() *fakeCDPTarget {
s.mu.Lock()
defer s.mu.Unlock()
return s.targets["target-1"]
}
func (s *fakeCDPServer) handle(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/json":
s.handleListTargets(w)
case r.URL.Path == "/json/new":
s.handleNewTarget(w)
case r.URL.Path == "/json/version":
s.writeJSON(w, map[string]string{
"Browser": "Chrome/123.0",
})
case strings.HasPrefix(r.URL.Path, "/devtools/page/"):
s.handleWebSocket(w, r, strings.TrimPrefix(r.URL.Path, "/devtools/page/"))
default:
http.NotFound(w, r)
}
}
func (s *fakeCDPServer) handleListTargets(w http.ResponseWriter) {
s.mu.Lock()
targets := make([]TargetInfo, 0, len(s.targets))
for id := range s.targets {
targets = append(targets, TargetInfo{
ID: id,
Type: "page",
Title: id,
URL: "about:blank",
WebSocketDebuggerURL: s.webSocketURL(id),
})
}
s.mu.Unlock()
s.writeJSON(w, targets)
}
func (s *fakeCDPServer) handleNewTarget(w http.ResponseWriter) {
target := s.newTarget()
s.writeJSON(w, TargetInfo{
ID: target.id,
Type: "page",
Title: target.id,
URL: "about:blank",
WebSocketDebuggerURL: s.webSocketURL(target.id),
})
}
func (s *fakeCDPServer) handleWebSocket(w http.ResponseWriter, r *http.Request, id string) {
s.mu.Lock()
target := s.targets[id]
s.mu.Unlock()
if target == nil {
http.NotFound(w, r)
return
}
upgrader := websocket.Upgrader{
CheckOrigin: func(*http.Request) bool { return true },
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
s.t.Fatalf("failed to upgrade test WebSocket: %v", err)
}
target.attach(conn)
}
func (s *fakeCDPServer) writeJSON(w http.ResponseWriter, value any) {
s.t.Helper()
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(value); err != nil {
s.t.Fatalf("failed to encode JSON: %v", err)
}
}
func (s *fakeCDPServer) webSocketURL(id string) string {
wsURL, err := url.Parse(s.server.URL)
if err != nil {
s.t.Fatalf("failed to parse test server URL: %v", err)
}
if wsURL.Scheme == "http" {
wsURL.Scheme = "ws"
} else {
wsURL.Scheme = "wss"
}
wsURL.Path = "/devtools/page/" + id
wsURL.RawQuery = ""
wsURL.Fragment = ""
return wsURL.String()
}
func (tgt *fakeCDPTarget) attach(conn *websocket.Conn) {
tgt.connMu.Lock()
tgt.conn = conn
tgt.connMu.Unlock()
tgt.connectedOnce.Do(func() {
close(tgt.connected)
})
go tgt.readLoop()
if tgt.onConnect != nil {
go tgt.onConnect(tgt)
}
}
func (tgt *fakeCDPTarget) readLoop() {
defer tgt.closedOnce.Do(func() {
close(tgt.closed)
})
for {
_, data, err := tgt.conn.ReadMessage()
if err != nil {
return
}
var msg cdpMessage
if err := json.Unmarshal(data, &msg); err != nil {
continue
}
select {
case tgt.received <- msg:
default:
}
if tgt.onMessage != nil {
tgt.onMessage(tgt, msg)
}
}
}
func (tgt *fakeCDPTarget) reply(id int64, result map[string]any) {
tgt.writeJSON(cdpResponse{
ID: id,
Result: result,
})
}
func (tgt *fakeCDPTarget) replyError(id int64, message string) {
tgt.writeJSON(cdpResponse{
ID: id,
Error: &cdpError{
Message: message,
},
})
}
func (tgt *fakeCDPTarget) replyValue(id int64, value any) {
tgt.reply(id, map[string]any{
"result": map[string]any{
"value": value,
},
})
}
func (tgt *fakeCDPTarget) writeJSON(value any) {
tgt.server.t.Helper()
tgt.connMu.Lock()
defer tgt.connMu.Unlock()
if tgt.conn == nil {
tgt.server.t.Fatal("test WebSocket connection was not established")
}
if err := tgt.conn.WriteJSON(value); err != nil {
tgt.server.t.Fatalf("failed to write test WebSocket message: %v", err)
}
}
func (tgt *fakeCDPTarget) closeWebSocket() {
tgt.connMu.Lock()
defer tgt.connMu.Unlock()
if tgt.conn != nil {
_ = tgt.conn.Close()
}
}
func (tgt *fakeCDPTarget) waitForMessage(tb testing.TB) cdpMessage {
tb.Helper()
select {
case msg := <-tgt.received:
return msg
case <-time.After(time.Second):
tb.Fatal("timed out waiting for CDP message")
return cdpMessage{}
}
}
func (tgt *fakeCDPTarget) waitConnected(tb testing.TB) {
tb.Helper()
select {
case <-tgt.connected:
case <-time.After(time.Second):
tb.Fatal("timed out waiting for WebSocket connection")
}
}
func (tgt *fakeCDPTarget) waitClosed(tb testing.TB) {
tb.Helper()
select {
case <-tgt.closed:
case <-time.After(time.Second):
tb.Fatal("timed out waiting for WebSocket closure")
}
}
func TestCDPClientClose_Good_UnblocksReadLoop(t *testing.T) {
server := newFakeCDPServer(t)
target := server.primaryTarget()
client, err := NewCDPClient(server.DebugURL())
if err != nil {
t.Fatalf("NewCDPClient returned error: %v", err)
}
target.waitConnected(t)
done := make(chan error, 1)
go func() {
done <- client.Close()
}()
select {
case err := <-done:
if err != nil {
t.Fatalf("Close returned error: %v", err)
}
case <-time.After(time.Second):
t.Fatal("Close blocked waiting for readLoop")
}
}
func TestCDPClientReadLoop_Ugly_StopsOnTerminalReadError(t *testing.T) {
server := newFakeCDPServer(t)
target := server.primaryTarget()
target.onConnect = func(target *fakeCDPTarget) {
target.closeWebSocket()
}
client, err := NewCDPClient(server.DebugURL())
if err != nil {
t.Fatalf("NewCDPClient returned error: %v", err)
}
select {
case <-client.done:
case <-time.After(time.Second):
t.Fatal("readLoop did not stop after terminal read error")
}
}
func TestCDPClientCloseTab_Good_ClosesTargetOnly(t *testing.T) {
server := newFakeCDPServer(t)
target := server.primaryTarget()
target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) {
if msg.Method != "Target.closeTarget" {
t.Fatalf("CloseTab sent %q, want Target.closeTarget", msg.Method)
}
if got := msg.Params["targetId"]; got != target.id {
t.Fatalf("Target.closeTarget targetId = %v, want %q", got, target.id)
}
target.reply(msg.ID, map[string]any{"success": true})
go func() {
time.Sleep(10 * time.Millisecond)
target.closeWebSocket()
}()
}
client, err := NewCDPClient(server.DebugURL())
if err != nil {
t.Fatalf("NewCDPClient returned error: %v", err)
}
if err := client.CloseTab(); err != nil {
t.Fatalf("CloseTab returned error: %v", err)
}
msg := target.waitForMessage(t)
if msg.Method == "Browser.close" {
t.Fatal("CloseTab closed the whole browser")
}
}
func TestCDPClientDispatchEvent_Good_HandlerParamsAreIsolated(t *testing.T) {
client := &CDPClient{
handlers: make(map[string][]func(map[string]any)),
}
firstDone := make(chan map[string]any, 1)
secondDone := make(chan map[string]any, 1)
client.OnEvent("Runtime.testEvent", func(params map[string]any) {
params["value"] = "mutated"
params["nested"].(map[string]any)["count"] = 1
params["items"].([]any)[0].(map[string]any)["id"] = "changed"
firstDone <- params
})
client.OnEvent("Runtime.testEvent", func(params map[string]any) {
secondDone <- params
})
original := map[string]any{
"nested": map[string]any{"count": 0},
"items": []any{map[string]any{"id": "original"}},
}
client.dispatchEvent("Runtime.testEvent", original)
select {
case <-firstDone:
case <-time.After(time.Second):
t.Fatal("first handler did not run")
}
var secondParams map[string]any
select {
case secondParams = <-secondDone:
case <-time.After(time.Second):
t.Fatal("second handler did not run")
}
if _, ok := secondParams["value"]; ok {
t.Fatal("second handler observed first handler mutation")
}
if got := secondParams["nested"].(map[string]any)["count"]; got != 0 {
t.Fatalf("second handler nested count = %v, want 0", got)
}
if got := secondParams["items"].([]any)[0].(map[string]any)["id"]; got != "original" {
t.Fatalf("second handler slice payload = %v, want %q", got, "original")
}
if got := original["nested"].(map[string]any)["count"]; got != 0 {
t.Fatalf("original params were mutated: nested count = %v", got)
}
}
func TestNewCDPClient_Bad_RejectsCrossHostWebSocket(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/json" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode([]TargetInfo{{
ID: "target-1",
Type: "page",
WebSocketDebuggerURL: "ws://example.com/devtools/page/target-1",
}}); err != nil {
t.Fatalf("failed to encode targets: %v", err)
}
}))
defer server.Close()
_, err := NewCDPClient(server.URL)
if err == nil {
t.Fatal("NewCDPClient succeeded with a cross-host WebSocket URL")
}
if !strings.Contains(err.Error(), "invalid target WebSocket URL") {
t.Fatalf("NewCDPClient error = %v, want cross-host WebSocket validation failure", err)
}
}
func TestWebviewNew_Bad_ClosesClientWhenEnableConsoleFails(t *testing.T) {
server := newFakeCDPServer(t)
target := server.primaryTarget()
target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) {
if msg.Method != "Runtime.enable" {
t.Fatalf("enableConsole sent %q before Runtime.enable failed", msg.Method)
}
target.replyError(msg.ID, "runtime disabled")
}
_, err := New(
WithTimeout(250*time.Millisecond),
WithDebugURL(server.DebugURL()),
)
if err == nil {
t.Fatal("New succeeded when Runtime.enable failed")
}
target.waitClosed(t)
}
func TestAngularHelperWaitForZoneStability_Good_AwaitsPromise(t *testing.T) {
server := newFakeCDPServer(t)
target := server.primaryTarget()
target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) {
if msg.Method != "Runtime.evaluate" {
t.Fatalf("unexpected method %q", msg.Method)
}
target.replyValue(msg.ID, true)
}
client, err := NewCDPClient(server.DebugURL())
if err != nil {
t.Fatalf("NewCDPClient returned error: %v", err)
}
defer func() { _ = client.Close() }()
wv := &Webview{
client: client,
ctx: context.Background(),
timeout: time.Second,
}
ah := NewAngularHelper(wv)
if err := ah.waitForZoneStability(context.Background()); err != nil {
t.Fatalf("waitForZoneStability returned error: %v", err)
}
msg := target.waitForMessage(t)
if got := msg.Params["awaitPromise"]; got != true {
t.Fatalf("Runtime.evaluate awaitPromise = %v, want true", got)
}
if got := msg.Params["returnByValue"]; got != true {
t.Fatalf("Runtime.evaluate returnByValue = %v, want true", got)
}
}
func TestAngularHelperSetNgModel_Good_EscapesSelectorAndValue(t *testing.T) {
server := newFakeCDPServer(t)
target := server.primaryTarget()
target.onMessage = func(target *fakeCDPTarget, msg cdpMessage) {
if msg.Method != "Runtime.evaluate" {
t.Fatalf("unexpected method %q", msg.Method)
}
target.replyValue(msg.ID, true)
}
client, err := NewCDPClient(server.DebugURL())
if err != nil {
t.Fatalf("NewCDPClient returned error: %v", err)
}
defer func() { _ = client.Close() }()
wv := &Webview{
client: client,
ctx: context.Background(),
timeout: time.Second,
}
ah := NewAngularHelper(wv)
selector := `input[name="x'];window.hacked=true;//"]`
value := `";window.hacked=true;//`
if err := ah.SetNgModel(selector, value); err != nil {
t.Fatalf("SetNgModel returned error: %v", err)
}
expression, _ := target.waitForMessage(t).Params["expression"].(string)
if !strings.Contains(expression, "const selector = "+formatJSValue(selector)+";") {
t.Fatalf("expression did not contain safely quoted selector: %s", expression)
}
if !strings.Contains(expression, "element.value = "+formatJSValue(value)+";") {
t.Fatalf("expression did not contain safely quoted value: %s", expression)
}
if strings.Contains(expression, "throw new Error('Element not found: "+selector+"')") {
t.Fatalf("expression still embedded selector directly in error text: %s", expression)
}
}
func TestConsoleWatcherWaitForMessage_Good_IsolatesTemporaryHandlers(t *testing.T) {
cw := &ConsoleWatcher{
messages: make([]ConsoleMessage, 0),
filters: make([]ConsoleFilter, 0),
limit: 1000,
handlers: make([]consoleHandlerRegistration, 0),
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
results := make(chan string, 2)
errorsCh := make(chan error, 2)
go func() {
msg, err := cw.WaitForMessage(ctx, ConsoleFilter{Type: "error"})
if err != nil {
errorsCh <- err
return
}
results <- "error:" + msg.Text
}()
go func() {
msg, err := cw.WaitForMessage(ctx, ConsoleFilter{Type: "log"})
if err != nil {
errorsCh <- err
return
}
results <- "log:" + msg.Text
}()
time.Sleep(20 * time.Millisecond)
cw.addMessage(ConsoleMessage{Type: "error", Text: "first"})
time.Sleep(20 * time.Millisecond)
cw.addMessage(ConsoleMessage{Type: "log", Text: "second"})
got := make(map[string]bool, 2)
for range 2 {
select {
case err := <-errorsCh:
t.Fatalf("WaitForMessage returned error: %v", err)
case result := <-results:
got[result] = true
case <-time.After(time.Second):
t.Fatal("timed out waiting for console waiter results")
}
}
if !got["error:first"] || !got["log:second"] {
t.Fatalf("unexpected console waiter results: %#v", got)
}
if len(cw.handlers) != 0 {
t.Fatalf("temporary handlers leaked: %d", len(cw.handlers))
}
}
func TestExceptionWatcherWaitForException_Good_PreservesExistingHandlers(t *testing.T) {
ew := &ExceptionWatcher{
exceptions: make([]ExceptionInfo, 0),
handlers: make([]exceptionHandlerRegistration, 0),
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
waitDone := make(chan error, 1)
go func() {
_, err := ew.WaitForException(ctx)
waitDone <- err
}()
time.Sleep(20 * time.Millisecond)
var mu sync.Mutex
count := 0
ew.AddHandler(func(ExceptionInfo) {
mu.Lock()
defer mu.Unlock()
count++
})
ew.handleException(map[string]any{
"exceptionDetails": map[string]any{
"text": "first",
"lineNumber": float64(1),
"columnNumber": float64(1),
"url": "https://example.com/app.js",
},
})
select {
case err := <-waitDone:
if err != nil {
t.Fatalf("WaitForException returned error: %v", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for exception waiter")
}
ew.handleException(map[string]any{
"exceptionDetails": map[string]any{
"text": "second",
"lineNumber": float64(2),
"columnNumber": float64(1),
"url": "https://example.com/app.js",
},
})
mu.Lock()
defer mu.Unlock()
if count != 2 {
t.Fatalf("persistent handler count = %d, want 2", count)
}
if len(ew.handlers) != 1 {
t.Fatalf("unexpected handler count after waiter removal: %d", len(ew.handlers))
}
}

485
cdp.go
View file

@ -1,26 +1,46 @@
// SPDX-License-Identifier: EUPL-1.2
package webview
import (
"context"
"encoding/json"
"errors"
"io"
"iter"
"net"
"net/http"
"net/url"
"path"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
coreerr "dappco.re/go/core/log"
)
const debugEndpointTimeout = 10 * time.Second
var (
defaultDebugHTTPClient = &http.Client{
Timeout: debugEndpointTimeout,
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
}
errCDPClientClosed = errors.New("cdp client closed")
)
// CDPClient handles communication with Chrome DevTools Protocol via WebSocket.
type CDPClient struct {
mu sync.RWMutex
conn *websocket.Conn
debugURL string
wsURL string
mu sync.RWMutex
conn *websocket.Conn
debugURL string
debugBase *url.URL
wsURL string
// Message tracking
msgID atomic.Int64
@ -32,9 +52,11 @@ type CDPClient struct {
handMu sync.RWMutex
// Lifecycle
ctx context.Context
cancel context.CancelFunc
done chan struct{}
ctx context.Context
cancel context.CancelFunc
done chan struct{}
closeOnce sync.Once
closeErr error
}
// cdpMessage represents a CDP protocol message.
@ -76,51 +98,41 @@ type TargetInfo struct {
// NewCDPClient creates a new CDP client connected to the given debug URL.
// The debug URL should be the Chrome DevTools HTTP endpoint (e.g., http://localhost:9222).
func NewCDPClient(debugURL string) (*CDPClient, error) {
// Get available targets
resp, err := http.Get(debugURL + "/json")
debugBase, err := parseDebugURL(debugURL)
if err != nil {
return nil, coreerr.E("CDPClient.New", "invalid debug URL", err)
}
ctx, cancel := context.WithTimeout(context.Background(), debugEndpointTimeout)
defer cancel()
targets, err := listTargetsAt(ctx, debugBase)
if err != nil {
return nil, coreerr.E("CDPClient.New", "failed to get targets", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, coreerr.E("CDPClient.New", "failed to read targets", err)
}
var targets []TargetInfo
if err := json.Unmarshal(body, &targets); err != nil {
return nil, coreerr.E("CDPClient.New", "failed to parse targets", err)
}
// Find a page target
var wsURL string
for _, t := range targets {
if t.Type == "page" && t.WebSocketDebuggerURL != "" {
wsURL = t.WebSocketDebuggerURL
wsURL, err = validateTargetWebSocketURL(debugBase, t.WebSocketDebuggerURL)
if err != nil {
return nil, coreerr.E("CDPClient.New", "invalid target WebSocket URL", err)
}
break
}
}
if wsURL == "" {
// Try to create a new target
resp, err := http.Get(debugURL + "/json/new")
newTarget, err := createTargetAt(ctx, debugBase, "")
if err != nil {
return nil, coreerr.E("CDPClient.New", "no page targets found and failed to create new", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
wsURL, err = validateTargetWebSocketURL(debugBase, newTarget.WebSocketDebuggerURL)
if err != nil {
return nil, coreerr.E("CDPClient.New", "failed to read new target", err)
return nil, coreerr.E("CDPClient.New", "invalid new target WebSocket URL", err)
}
var newTarget TargetInfo
if err := json.Unmarshal(body, &newTarget); err != nil {
return nil, coreerr.E("CDPClient.New", "failed to parse new target", err)
}
wsURL = newTarget.WebSocketDebuggerURL
}
if wsURL == "" {
@ -133,30 +145,17 @@ func NewCDPClient(debugURL string) (*CDPClient, error) {
return nil, coreerr.E("CDPClient.New", "failed to connect to WebSocket", err)
}
ctx, cancel := context.WithCancel(context.Background())
client := &CDPClient{
conn: conn,
debugURL: debugURL,
wsURL: wsURL,
pending: make(map[int64]chan *cdpResponse),
handlers: make(map[string][]func(map[string]any)),
ctx: ctx,
cancel: cancel,
done: make(chan struct{}),
}
// Start message reader
go client.readLoop()
return client, nil
return newCDPClient(debugBase, wsURL, conn), nil
}
// Close closes the CDP connection.
func (c *CDPClient) Close() error {
c.cancel()
<-c.done // Wait for read loop to finish
return c.conn.Close()
c.close(errCDPClientClosed)
<-c.done
if c.closeErr != nil {
return coreerr.E("CDPClient.Close", "failed to close WebSocket", c.closeErr)
}
return nil
}
// Call sends a CDP method call and waits for the response.
@ -166,7 +165,7 @@ func (c *CDPClient) Call(ctx context.Context, method string, params map[string]a
msg := cdpMessage{
ID: id,
Method: method,
Params: params,
Params: cloneMapAny(params),
}
// Register response channel
@ -193,6 +192,8 @@ func (c *CDPClient) Call(ctx context.Context, method string, params map[string]a
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-c.ctx.Done():
return nil, coreerr.E("CDPClient.Call", "client closed", errCDPClientClosed)
case resp := <-respCh:
if resp.Error != nil {
return nil, coreerr.E("CDPClient.Call", resp.Error.Message, nil)
@ -213,22 +214,23 @@ func (c *CDPClient) readLoop() {
defer close(c.done)
for {
select {
case <-c.ctx.Done():
return
default:
}
_, data, err := c.conn.ReadMessage()
if err != nil {
// Check if context was cancelled
select {
case <-c.ctx.Done():
if c.ctx.Err() != nil {
return
default:
// Log error but continue (could be temporary)
}
if isTerminalReadError(err) {
c.close(err)
return
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
continue
}
c.close(err)
return
}
// Try to parse as response
@ -237,7 +239,10 @@ func (c *CDPClient) readLoop() {
c.pendMu.Lock()
if ch, ok := c.pending[resp.ID]; ok {
respCopy := resp
ch <- &respCopy
select {
case ch <- &respCopy:
default:
}
}
c.pendMu.Unlock()
continue
@ -259,7 +264,8 @@ func (c *CDPClient) dispatchEvent(method string, params map[string]any) {
for _, handler := range handlers {
// Call handler in goroutine to avoid blocking
go handler(params)
handlerParams := cloneMapAny(params)
go handler(handlerParams)
}
}
@ -267,7 +273,7 @@ func (c *CDPClient) dispatchEvent(method string, params map[string]any) {
func (c *CDPClient) Send(method string, params map[string]any) error {
msg := cdpMessage{
Method: method,
Params: params,
Params: cloneMapAny(params),
}
c.mu.Lock()
@ -287,83 +293,70 @@ func (c *CDPClient) WebSocketURL() string {
// NewTab creates a new browser tab and returns a new CDPClient connected to it.
func (c *CDPClient) NewTab(url string) (*CDPClient, error) {
endpoint := c.debugURL + "/json/new"
if url != "" {
endpoint += "?" + url
}
ctx, cancel := context.WithTimeout(c.ctx, debugEndpointTimeout)
defer cancel()
resp, err := http.Get(endpoint)
target, err := createTargetAt(ctx, c.debugBase, url)
if err != nil {
return nil, coreerr.E("CDPClient.NewTab", "failed to create new tab", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, coreerr.E("CDPClient.NewTab", "failed to read response", err)
}
var target TargetInfo
if err := json.Unmarshal(body, &target); err != nil {
return nil, coreerr.E("CDPClient.NewTab", "failed to parse target", err)
}
if target.WebSocketDebuggerURL == "" {
return nil, coreerr.E("CDPClient.NewTab", "no WebSocket URL for new tab", nil)
}
wsURL, err := validateTargetWebSocketURL(c.debugBase, target.WebSocketDebuggerURL)
if err != nil {
return nil, coreerr.E("CDPClient.NewTab", "invalid WebSocket URL for new tab", err)
}
// Connect to new tab
conn, _, err := websocket.DefaultDialer.Dial(target.WebSocketDebuggerURL, nil)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return nil, coreerr.E("CDPClient.NewTab", "failed to connect to new tab", err)
}
ctx, cancel := context.WithCancel(context.Background())
client := &CDPClient{
conn: conn,
debugURL: c.debugURL,
wsURL: target.WebSocketDebuggerURL,
pending: make(map[int64]chan *cdpResponse),
handlers: make(map[string][]func(map[string]any)),
ctx: ctx,
cancel: cancel,
done: make(chan struct{}),
}
go client.readLoop()
return client, nil
return newCDPClient(c.debugBase, wsURL, conn), nil
}
// CloseTab closes the current tab (target).
func (c *CDPClient) CloseTab() error {
// Extract target ID from WebSocket URL
// Format: ws://host:port/devtools/page/TARGET_ID
// We'll use the Browser.close target API
targetID, err := targetIDFromWebSocketURL(c.wsURL)
if err != nil {
return coreerr.E("CDPClient.CloseTab", "failed to determine target ID", err)
}
ctx := context.Background()
_, err := c.Call(ctx, "Browser.close", nil)
return err
ctx, cancel := context.WithTimeout(c.ctx, debugEndpointTimeout)
defer cancel()
result, err := c.Call(ctx, "Target.closeTarget", map[string]any{
"targetId": targetID,
})
if err != nil {
return coreerr.E("CDPClient.CloseTab", "failed to close target", err)
}
if success, ok := result["success"].(bool); ok && !success {
return coreerr.E("CDPClient.CloseTab", "target close was not acknowledged", nil)
}
return c.Close()
}
// ListTargets returns all available targets.
func ListTargets(debugURL string) ([]TargetInfo, error) {
resp, err := http.Get(debugURL + "/json")
debugBase, err := parseDebugURL(debugURL)
if err != nil {
return nil, coreerr.E("ListTargets", "invalid debug URL", err)
}
ctx, cancel := context.WithTimeout(context.Background(), debugEndpointTimeout)
defer cancel()
targets, err := listTargetsAt(ctx, debugBase)
if err != nil {
return nil, coreerr.E("ListTargets", "failed to get targets", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, coreerr.E("ListTargets", "failed to read targets", err)
}
var targets []TargetInfo
if err := json.Unmarshal(body, &targets); err != nil {
return nil, coreerr.E("ListTargets", "failed to parse targets", err)
}
return targets, nil
}
@ -385,16 +378,18 @@ func ListTargetsAll(debugURL string) iter.Seq[TargetInfo] {
// GetVersion returns Chrome version information.
func GetVersion(debugURL string) (map[string]string, error) {
resp, err := http.Get(debugURL + "/json/version")
debugBase, err := parseDebugURL(debugURL)
if err != nil {
return nil, coreerr.E("GetVersion", "invalid debug URL", err)
}
ctx, cancel := context.WithTimeout(context.Background(), debugEndpointTimeout)
defer cancel()
body, err := doDebugRequest(ctx, debugBase, "/json/version", "")
if err != nil {
return nil, coreerr.E("GetVersion", "failed to get version", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, coreerr.E("GetVersion", "failed to read version", err)
}
var version map[string]string
if err := json.Unmarshal(body, &version); err != nil {
@ -403,3 +398,241 @@ func GetVersion(debugURL string) (map[string]string, error) {
return version, nil
}
func newCDPClient(debugBase *url.URL, wsURL string, conn *websocket.Conn) *CDPClient {
ctx, cancel := context.WithCancel(context.Background())
baseCopy := *debugBase
client := &CDPClient{
conn: conn,
debugURL: canonicalDebugURL(&baseCopy),
debugBase: &baseCopy,
wsURL: wsURL,
pending: make(map[int64]chan *cdpResponse),
handlers: make(map[string][]func(map[string]any)),
ctx: ctx,
cancel: cancel,
done: make(chan struct{}),
}
go client.readLoop()
return client
}
func parseDebugURL(raw string) (*url.URL, error) {
debugURL, err := url.Parse(raw)
if err != nil {
return nil, err
}
if debugURL.Scheme != "http" && debugURL.Scheme != "https" {
return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must use http or https", nil)
}
if debugURL.Host == "" {
return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL host is required", nil)
}
if debugURL.User != nil {
return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must not include credentials", nil)
}
if debugURL.RawQuery != "" || debugURL.Fragment != "" {
return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must not include query or fragment", nil)
}
if debugURL.Path == "" {
debugURL.Path = "/"
}
if debugURL.Path != "/" {
return nil, coreerr.E("CDPClient.parseDebugURL", "debug URL must point at the DevTools root", nil)
}
return debugURL, nil
}
func canonicalDebugURL(debugURL *url.URL) string {
return strings.TrimSuffix(debugURL.String(), "/")
}
func doDebugRequest(ctx context.Context, debugBase *url.URL, endpoint, rawQuery string) ([]byte, error) {
reqURL := *debugBase
reqURL.Path = endpoint
reqURL.RawPath = ""
reqURL.RawQuery = rawQuery
reqURL.Fragment = ""
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return nil, err
}
resp, err := defaultDebugHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, coreerr.E("CDPClient.doDebugRequest", "debug endpoint returned "+resp.Status, nil)
}
return body, nil
}
func listTargetsAt(ctx context.Context, debugBase *url.URL) ([]TargetInfo, error) {
body, err := doDebugRequest(ctx, debugBase, "/json", "")
if err != nil {
return nil, err
}
var targets []TargetInfo
if err := json.Unmarshal(body, &targets); err != nil {
return nil, err
}
return targets, nil
}
func createTargetAt(ctx context.Context, debugBase *url.URL, pageURL string) (*TargetInfo, error) {
rawQuery := ""
if pageURL != "" {
rawQuery = url.QueryEscape(pageURL)
}
body, err := doDebugRequest(ctx, debugBase, "/json/new", rawQuery)
if err != nil {
return nil, err
}
var target TargetInfo
if err := json.Unmarshal(body, &target); err != nil {
return nil, err
}
return &target, nil
}
func validateTargetWebSocketURL(debugBase *url.URL, raw string) (string, error) {
wsURL, err := url.Parse(raw)
if err != nil {
return "", err
}
if wsURL.Scheme != "ws" && wsURL.Scheme != "wss" {
return "", coreerr.E("CDPClient.validateTargetWebSocketURL", "target WebSocket URL must use ws or wss", nil)
}
if !sameEndpointHost(debugBase, wsURL) {
return "", coreerr.E("CDPClient.validateTargetWebSocketURL", "target WebSocket URL must match debug URL host", nil)
}
return wsURL.String(), nil
}
func sameEndpointHost(httpURL, wsURL *url.URL) bool {
return strings.EqualFold(httpURL.Hostname(), wsURL.Hostname()) && normalisedPort(httpURL) == normalisedPort(wsURL)
}
func normalisedPort(u *url.URL) string {
if port := u.Port(); port != "" {
return port
}
switch u.Scheme {
case "http", "ws":
return "80"
case "https", "wss":
return "443"
default:
return ""
}
}
func targetIDFromWebSocketURL(raw string) (string, error) {
wsURL, err := url.Parse(raw)
if err != nil {
return "", err
}
targetID := path.Base(strings.TrimSuffix(wsURL.Path, "/"))
if targetID == "." || targetID == "/" || targetID == "" {
return "", coreerr.E("CDPClient.targetIDFromWebSocketURL", "missing target ID in WebSocket URL", nil)
}
return targetID, nil
}
func (c *CDPClient) close(reason error) {
c.closeOnce.Do(func() {
c.cancel()
c.failPending(reason)
c.mu.Lock()
err := c.conn.Close()
c.mu.Unlock()
if err != nil && !isTerminalReadError(err) {
c.closeErr = err
}
})
}
func (c *CDPClient) failPending(err error) {
c.pendMu.Lock()
defer c.pendMu.Unlock()
for id, ch := range c.pending {
resp := &cdpResponse{
ID: id,
Error: &cdpError{
Message: err.Error(),
},
}
select {
case ch <- resp:
default:
}
}
}
func isTerminalReadError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, net.ErrClosed) || errors.Is(err, websocket.ErrCloseSent) {
return true
}
var closeErr *websocket.CloseError
return errors.As(err, &closeErr)
}
func cloneMapAny(src map[string]any) map[string]any {
if src == nil {
return nil
}
dst := make(map[string]any, len(src))
for key, value := range src {
dst[key] = cloneAny(value)
}
return dst
}
func cloneSliceAny(src []any) []any {
if src == nil {
return nil
}
dst := make([]any, len(src))
for i, value := range src {
dst[i] = cloneAny(value)
}
return dst
}
func cloneAny(value any) any {
switch typed := value.(type) {
case map[string]any:
return cloneMapAny(typed)
case []any:
return cloneSliceAny(typed)
default:
return typed
}
}

View file

@ -1,3 +1,4 @@
// SPDX-License-Identifier: EUPL-1.2
package webview
import (
@ -7,17 +8,19 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
)
// ConsoleWatcher provides advanced console message watching capabilities.
type ConsoleWatcher struct {
mu sync.RWMutex
wv *Webview
messages []ConsoleMessage
filters []ConsoleFilter
limit int
handlers []ConsoleHandler
mu sync.RWMutex
wv *Webview
messages []ConsoleMessage
filters []ConsoleFilter
limit int
handlers []consoleHandlerRegistration
nextHandlerID atomic.Int64
}
// ConsoleFilter filters console messages.
@ -29,6 +32,11 @@ type ConsoleFilter struct {
// ConsoleHandler is called when a matching console message is received.
type ConsoleHandler func(msg ConsoleMessage)
type consoleHandlerRegistration struct {
id int64
handler ConsoleHandler
}
// NewConsoleWatcher creates a new console watcher for the webview.
func NewConsoleWatcher(wv *Webview) *ConsoleWatcher {
cw := &ConsoleWatcher{
@ -36,7 +44,7 @@ func NewConsoleWatcher(wv *Webview) *ConsoleWatcher {
messages: make([]ConsoleMessage, 0, 100),
filters: make([]ConsoleFilter, 0),
limit: 1000,
handlers: make([]ConsoleHandler, 0),
handlers: make([]consoleHandlerRegistration, 0),
}
// Subscribe to console events from the webview's client
@ -63,9 +71,30 @@ func (cw *ConsoleWatcher) ClearFilters() {
// AddHandler adds a handler for console messages.
func (cw *ConsoleWatcher) AddHandler(handler ConsoleHandler) {
cw.addHandler(handler)
}
func (cw *ConsoleWatcher) addHandler(handler ConsoleHandler) int64 {
cw.mu.Lock()
defer cw.mu.Unlock()
cw.handlers = append(cw.handlers, handler)
id := cw.nextHandlerID.Add(1)
cw.handlers = append(cw.handlers, consoleHandlerRegistration{
id: id,
handler: handler,
})
return id
}
func (cw *ConsoleWatcher) removeHandler(id int64) {
cw.mu.Lock()
defer cw.mu.Unlock()
for i, registration := range cw.handlers {
if registration.id == id {
cw.handlers = slices.Delete(cw.handlers, i, i+1)
return
}
}
}
// SetLimit sets the maximum number of messages to retain.
@ -187,13 +216,8 @@ func (cw *ConsoleWatcher) WaitForMessage(ctx context.Context, filter ConsoleFilt
}
}
cw.AddHandler(handler)
defer func() {
cw.mu.Lock()
// Remove handler (simple implementation - in production you'd want a handle-based removal)
cw.handlers = cw.handlers[:len(cw.handlers)-1]
cw.mu.Unlock()
}()
handlerID := cw.addHandler(handler)
defer cw.removeHandler(handlerID)
select {
case <-ctx.Done():
@ -302,8 +326,8 @@ func (cw *ConsoleWatcher) addMessage(msg ConsoleMessage) {
cw.mu.Unlock()
// Call handlers
for _, handler := range handlers {
handler(msg)
for _, registration := range handlers {
registration.handler(msg)
}
}
@ -361,10 +385,16 @@ type ExceptionInfo struct {
// ExceptionWatcher watches for JavaScript exceptions.
type ExceptionWatcher struct {
mu sync.RWMutex
wv *Webview
exceptions []ExceptionInfo
handlers []func(ExceptionInfo)
mu sync.RWMutex
wv *Webview
exceptions []ExceptionInfo
handlers []exceptionHandlerRegistration
nextHandlerID atomic.Int64
}
type exceptionHandlerRegistration struct {
id int64
handler func(ExceptionInfo)
}
// NewExceptionWatcher creates a new exception watcher.
@ -372,7 +402,7 @@ func NewExceptionWatcher(wv *Webview) *ExceptionWatcher {
ew := &ExceptionWatcher{
wv: wv,
exceptions: make([]ExceptionInfo, 0),
handlers: make([]func(ExceptionInfo), 0),
handlers: make([]exceptionHandlerRegistration, 0),
}
// Subscribe to exception events
@ -425,9 +455,30 @@ func (ew *ExceptionWatcher) Count() int {
// AddHandler adds a handler for exceptions.
func (ew *ExceptionWatcher) AddHandler(handler func(ExceptionInfo)) {
ew.addHandler(handler)
}
func (ew *ExceptionWatcher) addHandler(handler func(ExceptionInfo)) int64 {
ew.mu.Lock()
defer ew.mu.Unlock()
ew.handlers = append(ew.handlers, handler)
id := ew.nextHandlerID.Add(1)
ew.handlers = append(ew.handlers, exceptionHandlerRegistration{
id: id,
handler: handler,
})
return id
}
func (ew *ExceptionWatcher) removeHandler(id int64) {
ew.mu.Lock()
defer ew.mu.Unlock()
for i, registration := range ew.handlers {
if registration.id == id {
ew.handlers = slices.Delete(ew.handlers, i, i+1)
return
}
}
}
// WaitForException waits for an exception to be thrown.
@ -450,12 +501,8 @@ func (ew *ExceptionWatcher) WaitForException(ctx context.Context) (*ExceptionInf
}
}
ew.AddHandler(handler)
defer func() {
ew.mu.Lock()
ew.handlers = ew.handlers[:len(ew.handlers)-1]
ew.mu.Unlock()
}()
handlerID := ew.addHandler(handler)
defer ew.removeHandler(handlerID)
select {
case <-ctx.Done():
@ -515,8 +562,8 @@ func (ew *ExceptionWatcher) handleException(params map[string]any) {
ew.mu.Unlock()
// Call handlers
for _, handler := range handlers {
handler(info)
for _, registration := range handlers {
registration.handler(info)
}
}

View file

@ -1,3 +1,4 @@
// SPDX-License-Identifier: EUPL-1.2
// Package webview provides browser automation via Chrome DevTools Protocol (CDP).
//
// The package allows controlling Chrome/Chromium browsers for automated testing,
@ -118,9 +119,16 @@ func New(opts ...Option) (*Webview, error) {
consoleLimit: 1000,
}
cleanupOnError := func() {
cancel()
if wv.client != nil {
_ = wv.client.Close()
}
}
for _, opt := range opts {
if err := opt(wv); err != nil {
cancel()
cleanupOnError()
return nil, err
}
}
@ -132,7 +140,7 @@ func New(opts ...Option) (*Webview, error) {
// Enable console capture
if err := wv.enableConsole(); err != nil {
cancel()
cleanupOnError()
return nil, coreerr.E("Webview.New", "failed to enable console capture", err)
}
@ -542,6 +550,7 @@ func (wv *Webview) evaluate(ctx context.Context, script string) (any, error) {
result, err := wv.client.Call(ctx, "Runtime.evaluate", map[string]any{
"expression": script,
"returnByValue": true,
"awaitPromise": true,
})
if err != nil {
return nil, coreerr.E("Webview.evaluate", "failed to evaluate script", err)

View file

@ -1,3 +1,4 @@
// SPDX-License-Identifier: EUPL-1.2
package webview
import (
@ -427,6 +428,8 @@ func TestFormatJSValue_Good(t *testing.T) {
{nil, "null"},
{42, "42"},
{3.14, "3.14"},
{map[string]any{"enabled": true}, `{"enabled":true}`},
{[]any{1, "two"}, `[1,"two"]`},
}
for _, tc := range tests {
@ -512,7 +515,7 @@ func TestConsoleWatcherFilter_Good(t *testing.T) {
messages: make([]ConsoleMessage, 0),
filters: make([]ConsoleFilter, 0),
limit: 1000,
handlers: make([]ConsoleHandler, 0),
handlers: make([]consoleHandlerRegistration, 0),
}
// No filters — everything matches
@ -556,7 +559,7 @@ func TestConsoleWatcherCounts_Good(t *testing.T) {
},
filters: make([]ConsoleFilter, 0),
limit: 1000,
handlers: make([]ConsoleHandler, 0),
handlers: make([]consoleHandlerRegistration, 0),
}
if cw.Count() != 5 {
@ -592,7 +595,7 @@ func TestConsoleWatcherCounts_Good(t *testing.T) {
func TestExceptionWatcher_Good(t *testing.T) {
ew := &ExceptionWatcher{
exceptions: make([]ExceptionInfo, 0),
handlers: make([]func(ExceptionInfo), 0),
handlers: make([]exceptionHandlerRegistration, 0),
}
if ew.HasExceptions() {
@ -682,7 +685,7 @@ func TestConsoleWatcherAddMessage_Good(t *testing.T) {
messages: make([]ConsoleMessage, 0),
filters: make([]ConsoleFilter, 0),
limit: 5,
handlers: make([]ConsoleHandler, 0),
handlers: make([]consoleHandlerRegistration, 0),
}
// Add messages past the limit
@ -704,7 +707,7 @@ func TestConsoleWatcherHandler_Good(t *testing.T) {
messages: make([]ConsoleMessage, 0),
filters: make([]ConsoleFilter, 0),
limit: 1000,
handlers: make([]ConsoleHandler, 0),
handlers: make([]consoleHandlerRegistration, 0),
}
var received ConsoleMessage
@ -729,7 +732,7 @@ func TestConsoleWatcherFilteredMessages_Good(t *testing.T) {
},
filters: []ConsoleFilter{{Type: "error"}},
limit: 1000,
handlers: make([]ConsoleHandler, 0),
handlers: make([]consoleHandlerRegistration, 0),
}
filtered := cw.FilteredMessages()