From fa4168e380d235c17f40d21a5c8a9e414c07ef58 Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 24 Apr 2026 06:17:33 +0100 Subject: [PATCH] =?UTF-8?q?feat(gui):=20InjectPreload=20=E2=80=94=20storag?= =?UTF-8?q?e=20polyfills=20+=20Electron=20shim=20+=20app=20preloads?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New pkg/preload package: - preload.go — InjectPreload(webview, origin) entry point; builds three-step preload: storage polyfills, Electron shim (origin- filtered), app preloads from .core/view.yaml manifest.preloads. - assets/storage_polyfills.js — localStorage/sessionStorage/ IndexedDB bridges. - assets/electron_shim.js — minimal ipcRenderer.send/invoke mapping to core.QUERY/ACTION. - Adds a minimal window.core.ml.generate shim — gates the AI-native browser path (RFC §11a). pkg/window/wails.go wires into Wails OnPageLoad via reflection when the runtime exposes the hook, with a clean fallback for the stubbed/test runtime. Legacy display-preload code detected and skipped when the new package is in play. Good/Bad/Ugly tests in pkg/preload/preload_test.go. go vet + go test clean. Closes tasks.lthn.sh/view.php?id=16 Co-authored-by: Codex Co-Authored-By: Virgil --- pkg/preload/assets/electron_shim.js | 101 ++++++ pkg/preload/assets/storage_polyfills.js | 282 +++++++++++++++ pkg/preload/electron_shim.go | 20 ++ pkg/preload/preload.go | 436 ++++++++++++++++++++++++ pkg/preload/preload_test.go | 64 ++++ pkg/preload/storage_polyfills.go | 22 ++ pkg/window/wails.go | 117 +++++++ 7 files changed, 1042 insertions(+) create mode 100644 pkg/preload/assets/electron_shim.js create mode 100644 pkg/preload/assets/storage_polyfills.js create mode 100644 pkg/preload/electron_shim.go create mode 100644 pkg/preload/preload.go create mode 100644 pkg/preload/preload_test.go create mode 100644 pkg/preload/storage_polyfills.go diff --git a/pkg/preload/assets/electron_shim.js b/pkg/preload/assets/electron_shim.js new file mode 100644 index 00000000..803ef349 --- /dev/null +++ b/pkg/preload/assets/electron_shim.js @@ -0,0 +1,101 @@ +(function () { + if (globalThis.__corePreloadElectronInstalled) { + return; + } + globalThis.__corePreloadElectronInstalled = true; + + const meta = __CORE_PRELOAD_META__; + if (!meta.allow) { + return; + } + + const bridge = globalThis.__corePreloadBridge; + if (!bridge) { + return; + } + + const listeners = new Map(); + const eventName = (channel) => "__core_preload_electron__:" + String(channel ?? ""); + + const ipcRenderer = { + send(channel, ...args) { + const normalized = String(channel ?? ""); + globalThis.dispatchEvent(new CustomEvent(eventName(normalized), { detail: args })); + return bridge.action(normalized, { channel: normalized, args }).then(() => undefined); + }, + invoke(channel, ...args) { + const normalized = String(channel ?? ""); + return bridge.query(normalized, { channel: normalized, args }); + }, + on(channel, listener) { + const normalized = String(channel ?? ""); + const handler = (event) => listener(event, ...(event.detail || [])); + listeners.set(listener, handler); + globalThis.addEventListener(eventName(normalized), handler); + return this; + }, + once(channel, listener) { + const normalized = String(channel ?? ""); + const onceListener = (event, ...args) => { + ipcRenderer.removeListener(normalized, listener); + listener(event, ...args); + }; + return ipcRenderer.on(normalized, onceListener); + }, + removeListener(channel, listener) { + const normalized = String(channel ?? ""); + const handler = listeners.get(listener); + if (handler) { + globalThis.removeEventListener(eventName(normalized), handler); + listeners.delete(listener); + } + return this; + } + }; + + const remote = { + getGlobal(name) { + return bridge.query("electron.remote.getGlobal", { name: String(name ?? "") }); + }, + app: { + getPath(name) { + return bridge.query("electron.app.getPath", { name: String(name ?? "") }); + } + } + }; + + const shell = { + openExternal(url) { + return bridge.action("browser.openURL", { url: String(url ?? "") }).then(() => undefined); + }, + openPath(path) { + return bridge.action("browser.openFile", { path: String(path ?? "") }).then(() => ""); + } + }; + + const contextBridge = { + exposeInMainWorld(name, api) { + globalThis[name] = api; + } + }; + + const processShim = globalThis.process || { + env: {}, + platform: "wails", + type: "renderer", + versions: {} + }; + processShim.versions = processShim.versions || {}; + processShim.versions.electron = processShim.versions.electron || "wails-shim"; + + const electron = { + ipcRenderer, + remote, + shell, + contextBridge + }; + + globalThis.process = processShim; + globalThis.electron = electron; + globalThis.require = globalThis.require || ((name) => (name === "electron" ? electron : undefined)); +})(); diff --git a/pkg/preload/assets/storage_polyfills.js b/pkg/preload/assets/storage_polyfills.js new file mode 100644 index 00000000..ed36305b --- /dev/null +++ b/pkg/preload/assets/storage_polyfills.js @@ -0,0 +1,282 @@ +(function () { + if (globalThis.__corePreloadStorageInstalled) { + return; + } + globalThis.__corePreloadStorageInstalled = true; + + const meta = __CORE_PRELOAD_META__; + const pageURL = String(meta.pageURL || ""); + const storageOrigin = String(meta.storageOrigin || pageURL || ""); + const storeGroup = String(meta.storeGroup || "gui.preload.storage"); + const canPersist = !!meta.canPersist; + + const asPromise = (value) => ( + value && typeof value.then === "function" ? value : Promise.resolve(value) + ); + + const runCoreCall = (target, methodNames, name, payload) => { + if (!target || typeof target !== "object") { + return undefined; + } + for (const methodName of methodNames) { + const method = target[methodName]; + if (typeof method !== "function") { + continue; + } + try { + const direct = method.call(target, name, payload); + if (direct && typeof direct.Run === "function") { + try { + return direct.Run(payload); + } catch (_) { + return direct.Run(); + } + } + return direct; + } catch (_) { + try { + const deferred = method.call(target, name); + if (deferred && typeof deferred.Run === "function") { + try { + return deferred.Run(payload); + } catch (_) { + return deferred.Run(); + } + } + return deferred; + } catch (_) {} + } + } + return undefined; + }; + + const bridge = globalThis.__corePreloadBridge || (globalThis.__corePreloadBridge = { + action(name, payload) { + const candidates = [globalThis.c, globalThis.Core, globalThis.core]; + for (const candidate of candidates) { + const result = runCoreCall(candidate, ["Action", "ACTION", "action"], name, payload); + if (result !== undefined) { + return asPromise(result); + } + } + if (typeof globalThis.__CORE_GUI_INVOKE__ === "function") { + return asPromise(globalThis.__CORE_GUI_INVOKE__(name, payload, { mode: "action" })); + } + return Promise.resolve(undefined); + }, + query(name, payload) { + const candidates = [globalThis.c, globalThis.Core, globalThis.core]; + for (const candidate of candidates) { + const result = runCoreCall(candidate, ["QUERY", "Query", "query"], name, payload); + if (result !== undefined) { + return asPromise(result); + } + } + if (typeof globalThis.__CORE_GUI_INVOKE__ === "function") { + return asPromise(globalThis.__CORE_GUI_INVOKE__(name, payload, { mode: "query" })); + } + return Promise.resolve(undefined); + } + }); + + const storageScopes = globalThis.__corePreloadStorageScopes || (globalThis.__corePreloadStorageScopes = {}); + const scopeKey = storageOrigin || "__core_default__"; + const scope = storageScopes[scopeKey] || (storageScopes[scopeKey] = { + localStorage: Object.create(null), + sessionStorage: Object.create(null), + indexedDB: Object.create(null) + }); + + const persistKey = (bucket, key) => [storageOrigin, bucket, String(key ?? "")].join(":"); + + const persistSet = (bucket, key, value) => { + if (!canPersist) { + return; + } + bridge.action("store.set", { + group: storeGroup, + key: persistKey(bucket, key), + value: String(value ?? "") + }).catch(() => undefined); + }; + + const persistDelete = (bucket, key) => { + if (!canPersist) { + return; + } + bridge.action("store.delete", { + group: storeGroup, + key: persistKey(bucket, key) + }).catch(() => undefined); + }; + + const createStorage = (bucketName, bucket) => ({ + getItem(key) { + const normalized = String(key ?? ""); + return Object.prototype.hasOwnProperty.call(bucket, normalized) ? String(bucket[normalized]) : null; + }, + setItem(key, value) { + const normalized = String(key ?? ""); + bucket[normalized] = String(value ?? ""); + persistSet(bucketName, normalized, bucket[normalized]); + }, + removeItem(key) { + const normalized = String(key ?? ""); + delete bucket[normalized]; + persistDelete(bucketName, normalized); + }, + clear() { + for (const key of Object.keys(bucket)) { + delete bucket[key]; + persistDelete(bucketName, key); + } + }, + key(index) { + return Object.keys(bucket)[Number(index)] ?? null; + }, + get length() { + return Object.keys(bucket).length; + } + }); + + const queueTask = (callback) => { + if (typeof queueMicrotask === "function") { + queueMicrotask(callback); + return; + } + Promise.resolve().then(callback).catch(() => undefined); + }; + + const createRequest = (result, upgrade) => { + const request = { result, error: null, onsuccess: null, onerror: null, onupgradeneeded: null }; + queueTask(() => { + if (upgrade) { + request.onupgradeneeded?.({ target: request }); + } + request.onsuccess?.({ target: request }); + }); + return request; + }; + + const serializeRecord = (value) => { + if (typeof value === "string") { + return value; + } + try { + return JSON.stringify(value); + } catch (_) { + return String(value ?? ""); + } + }; + + const clearObjectStore = (databaseName, storeName, records) => { + for (const key of Object.keys(records)) { + persistDelete("indexeddb:" + databaseName + ":" + storeName, key); + } + }; + + const createObjectStore = (databaseName, database, storeName) => ({ + put(value, key) { + const resolvedKey = String(key ?? value?.id ?? Date.now()); + database.stores[storeName] = database.stores[storeName] || Object.create(null); + database.stores[storeName][resolvedKey] = value; + persistSet("indexeddb:" + databaseName + ":" + storeName, resolvedKey, serializeRecord(value)); + return createRequest(resolvedKey, false); + }, + get(key) { + const resolvedKey = String(key ?? ""); + return createRequest(database.stores?.[storeName]?.[resolvedKey], false); + }, + getAll() { + return createRequest(Object.values(database.stores?.[storeName] || {}), false); + }, + delete(key) { + const resolvedKey = String(key ?? ""); + if (database.stores?.[storeName]) { + delete database.stores[storeName][resolvedKey]; + } + persistDelete("indexeddb:" + databaseName + ":" + storeName, resolvedKey); + return createRequest(undefined, false); + }, + clear() { + const records = database.stores?.[storeName] || Object.create(null); + clearObjectStore(databaseName, storeName, records); + database.stores[storeName] = Object.create(null); + return createRequest(undefined, false); + }, + createIndex() { + return this; + } + }); + + const createDatabase = (name, upgrade) => { + const database = scope.indexedDB[name] || (scope.indexedDB[name] = { stores: Object.create(null) }); + return { + name, + createObjectStore(storeName) { + const normalized = String(storeName ?? "default"); + database.stores[normalized] = database.stores[normalized] || Object.create(null); + return createObjectStore(name, database, normalized); + }, + transaction(storeNames) { + const names = Array.isArray(storeNames) ? storeNames : [storeNames]; + return { + objectStore(storeName) { + const normalized = String(storeName ?? names[0] ?? "default"); + database.stores[normalized] = database.stores[normalized] || Object.create(null); + return createObjectStore(name, database, normalized); + } + }; + }, + close() {} + }; + }; + + globalThis.core = globalThis.core || {}; + globalThis.core.storage = globalThis.core.storage || {}; + globalThis.core.storage.local = createStorage("localStorage", scope.localStorage); + globalThis.core.storage.session = createStorage("sessionStorage", scope.sessionStorage); + + try { + Object.defineProperty(globalThis, "localStorage", { + configurable: true, + enumerable: true, + get() { + return globalThis.core.storage.local; + } + }); + } catch (_) {} + + try { + Object.defineProperty(globalThis, "sessionStorage", { + configurable: true, + enumerable: true, + get() { + return globalThis.core.storage.session; + } + }); + } catch (_) {} + + try { + if (!globalThis.indexedDB) { + globalThis.indexedDB = { + open(name) { + const normalized = String(name ?? "default"); + const upgrade = !scope.indexedDB[normalized]; + return createRequest(createDatabase(normalized, upgrade), upgrade); + }, + deleteDatabase(name) { + const normalized = String(name ?? "default"); + const database = scope.indexedDB[normalized]; + if (database && database.stores) { + for (const [storeName, records] of Object.entries(database.stores)) { + clearObjectStore(normalized, storeName, records); + } + } + delete scope.indexedDB[normalized]; + return createRequest(undefined, false); + } + }; + } + } catch (_) {} +})(); diff --git a/pkg/preload/electron_shim.go b/pkg/preload/electron_shim.go new file mode 100644 index 00000000..e3df798d --- /dev/null +++ b/pkg/preload/electron_shim.go @@ -0,0 +1,20 @@ +package preload + +import ( + "strings" + + core "dappco.re/go/core" +) + +func renderElectronShim(pageURL string) string { + meta := map[string]any{ + "allow": true, + "pageURL": pageURL, + } + + return strings.ReplaceAll( + electronShimAsset, + "__CORE_PRELOAD_META__", + core.JSONMarshalString(meta), + ) +} diff --git a/pkg/preload/preload.go b/pkg/preload/preload.go new file mode 100644 index 00000000..add59f8b --- /dev/null +++ b/pkg/preload/preload.go @@ -0,0 +1,436 @@ +package preload + +import ( + "embed" + "errors" + "io" + "net" + "net/url" + "os" + "path/filepath" + "reflect" + "strings" + + core "dappco.re/go/core" + "gopkg.in/yaml.v3" +) + +const maxViewManifestBytes = 1 << 20 + +var errViewManifestNotFound = errors.New("view manifest not found") + +type Webview interface { + ExecJS(string) +} + +type ManifestPreload struct { + Path string `yaml:"path"` + Inline string `yaml:"inline"` + Enabled *bool `yaml:"enabled,omitempty"` +} + +type viewManifest struct { + Preloads []ManifestPreload `yaml:"preloads"` + Manifest struct { + Preloads []ManifestPreload `yaml:"preloads"` + } `yaml:"manifest"` +} + +type loadedManifest struct { + Path string + BaseDir string + Preloads []ManifestPreload +} + +//go:embed assets/*.js +var assetFS embed.FS + +var ( + storagePolyfillsAsset = mustReadAsset("assets/storage_polyfills.js") + electronShimAsset = mustReadAsset("assets/electron_shim.js") +) + +func InjectPreload(webview Webview, origin string) error { + if isNilWebview(webview) { + return errors.New("preload target is required") + } + + script, err := buildScript(origin) + if err != nil { + return err + } + if strings.TrimSpace(script) == "" { + return nil + } + + webview.ExecJS(script) + return nil +} + +func buildScript(pageURL string) (string, error) { + loaded, manifestErr := loadManifestForOrigin(pageURL) + switch { + case manifestErr == nil: + case errors.Is(manifestErr, errViewManifestNotFound): + loaded = nil + default: + return "", manifestErr + } + + allowPrivileged := trustedOrigin(pageURL) || loaded != nil + parts := []string{ + renderStoragePolyfills(pageURL, allowPrivileged), + renderCoreMLShim(), + } + if allowPrivileged { + parts = append(parts, renderElectronShim(pageURL)) + } + if appPreloads, err := renderAppPreloads(loaded); err != nil { + return "", err + } else if strings.TrimSpace(appPreloads) != "" { + parts = append(parts, appPreloads) + } + + return strings.Join(filterEmpty(parts), "\n"), nil +} + +func mustReadAsset(name string) string { + body, err := assetFS.ReadFile(name) + if err != nil { + panic(err) + } + return string(body) +} + +func renderAppPreloads(loaded *loadedManifest) (string, error) { + if loaded == nil || len(loaded.Preloads) == 0 { + return "", nil + } + + scripts := make([]string, 0, len(loaded.Preloads)) + for _, preload := range loaded.Preloads { + if preload.Enabled != nil && !*preload.Enabled { + continue + } + if inline := strings.TrimSpace(preload.Inline); inline != "" { + scripts = append(scripts, inline) + continue + } + if path := strings.TrimSpace(preload.Path); path != "" { + body, err := readManifestPreload(loaded.BaseDir, path) + if err != nil { + return "", err + } + scripts = append(scripts, string(body)) + } + } + + return strings.Join(scripts, "\n"), nil +} + +func loadManifestForOrigin(pageURL string) (*loadedManifest, error) { + path, err := discoverManifestPath(pageURL) + if err != nil { + return nil, err + } + + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + body, err := io.ReadAll(io.LimitReader(file, maxViewManifestBytes+1)) + if err != nil { + return nil, err + } + if len(body) > maxViewManifestBytes { + return nil, errors.New("view manifest exceeds 1048576 bytes") + } + + var manifest viewManifest + if err := yaml.Unmarshal(body, &manifest); err != nil { + return nil, err + } + + return &loadedManifest{ + Path: path, + BaseDir: manifestBaseDir(path), + Preloads: collectManifestPreloads(manifest), + }, nil +} + +func collectManifestPreloads(manifest viewManifest) []ManifestPreload { + out := make([]ManifestPreload, 0, len(manifest.Preloads)+len(manifest.Manifest.Preloads)) + out = append(out, manifest.Preloads...) + out = append(out, manifest.Manifest.Preloads...) + return out +} + +func discoverManifestPath(pageURL string) (string, error) { + trimmed := strings.TrimSpace(pageURL) + if trimmed == "" { + return "", errViewManifestNotFound + } + + parsed, err := url.Parse(trimmed) + if err != nil { + return "", err + } + + candidates := make([]string, 0, 4) + switch parsed.Scheme { + case "", "file": + path := parsed.Path + if path == "" { + path = trimmed + } + path = filepath.FromSlash(path) + if info, err := os.Stat(path); err == nil { + if info.IsDir() { + candidates = append(candidates, filepath.Join(path, ".core", "view.yaml")) + } else { + dir := filepath.Dir(path) + candidates = append(candidates, filepath.Join(dir, ".core", "view.yaml")) + candidates = append(candidates, filepath.Join(filepath.Dir(dir), ".core", "view.yaml")) + } + } + default: + if host := strings.TrimSpace(parsed.Host); host != "" { + home := strings.TrimSpace(os.Getenv("DIR_HOME")) + if home == "" { + home = strings.TrimSpace(core.Env("DIR_HOME")) + } + if home != "" { + candidates = append(candidates, filepath.Join(home, ".core", "apps", host, ".core", "view.yaml")) + } + } + } + + for _, candidate := range candidates { + if _, err := os.Stat(candidate); err == nil { + return candidate, nil + } + } + + return "", errViewManifestNotFound +} + +func manifestBaseDir(manifestPath string) string { + baseDir := filepath.Dir(manifestPath) + if filepath.Base(baseDir) == ".core" { + return filepath.Dir(baseDir) + } + return baseDir +} + +func readManifestPreload(baseDir, preloadPath string) ([]byte, error) { + resolvedPath, err := safeManifestRelativePath(baseDir, preloadPath) + if err != nil { + return nil, err + } + return os.ReadFile(resolvedPath) +} + +func safeManifestRelativePath(baseDir, relativePath string) (string, error) { + trimmed := strings.TrimSpace(relativePath) + if trimmed == "" { + return "", errors.New("preload path is empty") + } + if filepath.IsAbs(trimmed) { + return "", errors.New("preload path must be relative") + } + + baseAbs, err := filepath.Abs(baseDir) + if err != nil { + return "", err + } + baseResolved, err := filepath.EvalSymlinks(baseAbs) + if err != nil { + return "", err + } + + candidateAbs, err := filepath.Abs(filepath.Join(baseAbs, trimmed)) + if err != nil { + return "", err + } + if rel, err := filepath.Rel(baseAbs, candidateAbs); err != nil { + return "", err + } else if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", errors.New("preload path escapes manifest directory") + } + + if _, err := os.Lstat(candidateAbs); err != nil { + return "", err + } + candidateResolved, err := filepath.EvalSymlinks(candidateAbs) + if err != nil { + return "", err + } + if rel, err := filepath.Rel(baseResolved, candidateResolved); err != nil { + return "", err + } else if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", errors.New("preload path escapes manifest directory") + } + + return candidateResolved, nil +} + +func trustedOrigin(pageURL string) bool { + trimmed := strings.TrimSpace(pageURL) + if trimmed == "" { + return false + } + + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + + switch strings.ToLower(parsed.Scheme) { + case "core", "wails", "app": + return true + case "http", "https": + host := strings.TrimSpace(parsed.Host) + if host == "" { + return false + } + name := host + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + name = parsedHost + } + name = strings.Trim(strings.ToLower(name), "[]") + switch name { + case "localhost", "127.0.0.1", "::1": + return true + default: + return false + } + default: + return false + } +} + +func storageOriginForPageURL(pageURL string) string { + trimmed := strings.TrimSpace(pageURL) + if trimmed == "" { + return "" + } + + parsed, err := url.Parse(trimmed) + if err != nil || strings.TrimSpace(parsed.Scheme) == "" { + return "" + } + + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + if parsed.Host == "" { + return "" + } + return parsed.Scheme + "://" + parsed.Host + case "core": + if parsed.Host == "" { + return "core://" + } + return "core://" + parsed.Host + case "file": + if parsed.Path == "" { + return "" + } + return "file://" + parsed.Path + default: + if parsed.Host == "" { + return "" + } + origin := parsed.Scheme + "://" + parsed.Host + if parsed.Path != "" { + origin += parsed.Path + } + return strings.TrimRight(origin, "/") + } +} + +func validatedLocalMLAPIURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "http://localhost:8090" + } + + parsed, err := url.Parse(trimmed) + if err != nil { + return "http://localhost:8090" + } + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + default: + return "http://localhost:8090" + } + + host := strings.TrimSpace(parsed.Host) + if host == "" { + return "http://localhost:8090" + } + name := host + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + name = parsedHost + } + name = strings.Trim(strings.ToLower(name), "[]") + switch name { + case "localhost", "127.0.0.1", "::1": + return strings.TrimRight(parsed.String(), "/") + default: + return "http://localhost:8090" + } +} + +func renderCoreMLShim() string { + return `(function() { + const apiURL = ` + core.JSONMarshalString(validatedLocalMLAPIURL(core.Env("CORE_ML_API_URL"))) + ` || "http://localhost:8090"; + globalThis.core = globalThis.core || {}; + globalThis.core.ml = globalThis.core.ml || { + async generate(input) { + const payload = typeof input === "string" + ? { messages: [{ role: "user", content: input }], stream: false } + : { ...input, stream: false }; + const response = await fetch(apiURL + "/v1/chat/completions", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload) + }); + if (!response.ok) { + throw new Error("Core ML request failed: " + response.status + " " + response.statusText); + } + const body = await response.text(); + try { + const parsed = JSON.parse(body); + return parsed?.choices?.[0]?.message?.content ?? parsed?.content ?? body; + } catch (_) { + return body; + } + } + }; +})();` +} + +func filterEmpty(parts []string) []string { + out := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) != "" { + out = append(out, part) + } + } + return out +} + +func isNilWebview(webview Webview) bool { + if webview == nil { + return true + } + value := reflect.ValueOf(webview) + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return value.IsNil() + default: + return false + } +} diff --git a/pkg/preload/preload_test.go b/pkg/preload/preload_test.go new file mode 100644 index 00000000..99c1ec0a --- /dev/null +++ b/pkg/preload/preload_test.go @@ -0,0 +1,64 @@ +package preload + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type captureWebview struct { + scripts []string +} + +func (c *captureWebview) ExecJS(script string) { + c.scripts = append(c.scripts, script) +} + +func TestInjectPreload_Good(t *testing.T) { + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, ".core"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(root, "index.html"), []byte(""), 0o644)) + require.NoError(t, os.WriteFile( + filepath.Join(root, ".core", "view.yaml"), + []byte("manifest:\n preloads:\n - path: preload.js\n"), + 0o644, + )) + require.NoError(t, os.WriteFile( + filepath.Join(root, "preload.js"), + []byte("globalThis.__manifestPreloadLoaded = true;"), + 0o644, + )) + + target := &captureWebview{} + err := InjectPreload(target, "file://"+filepath.ToSlash(filepath.Join(root, "index.html"))) + require.NoError(t, err) + require.Len(t, target.scripts, 1) + + script := target.scripts[0] + assert.Contains(t, script, "globalThis.core.storage.local") + assert.Contains(t, script, "globalThis.core.ml = globalThis.core.ml ||") + assert.Contains(t, script, "globalThis.electron = electron") + assert.Contains(t, script, "globalThis.__manifestPreloadLoaded = true;") +} + +func TestInjectPreload_Bad(t *testing.T) { + err := InjectPreload(nil, "http://localhost:3000") + require.Error(t, err) + assert.Contains(t, err.Error(), "preload target is required") +} + +func TestInjectPreload_Ugly(t *testing.T) { + target := &captureWebview{} + err := InjectPreload(target, "https://example.com/app") + require.NoError(t, err) + require.Len(t, target.scripts, 1) + + script := target.scripts[0] + assert.Contains(t, script, "globalThis.core.storage.local") + assert.Contains(t, script, "globalThis.core.ml = globalThis.core.ml ||") + assert.NotContains(t, script, "globalThis.electron = electron") + assert.NotContains(t, script, "ipcRenderer") +} diff --git a/pkg/preload/storage_polyfills.go b/pkg/preload/storage_polyfills.go new file mode 100644 index 00000000..5b545929 --- /dev/null +++ b/pkg/preload/storage_polyfills.go @@ -0,0 +1,22 @@ +package preload + +import ( + "strings" + + core "dappco.re/go/core" +) + +func renderStoragePolyfills(pageURL string, canPersist bool) string { + meta := map[string]any{ + "pageURL": pageURL, + "storageOrigin": storageOriginForPageURL(pageURL), + "storeGroup": "gui.preload.storage", + "canPersist": canPersist, + } + + return strings.ReplaceAll( + storagePolyfillsAsset, + "__CORE_PRELOAD_META__", + core.JSONMarshalString(meta), + ) +} diff --git a/pkg/window/wails.go b/pkg/window/wails.go index a0ca89ab..09b9e428 100644 --- a/pkg/window/wails.go +++ b/pkg/window/wails.go @@ -2,6 +2,10 @@ package window import ( + "reflect" + "strings" + + "forge.lthn.ai/core/gui/pkg/preload" "github.com/wailsapp/wails/v3/pkg/application" "github.com/wailsapp/wails/v3/pkg/events" ) @@ -38,10 +42,123 @@ func (wp *WailsPlatform) CreateWindow(options PlatformWindowOptions) PlatformWin EnableFileDrop: options.EnableFileDrop, BackgroundColour: application.NewRGBA(options.BackgroundColour[0], options.BackgroundColour[1], options.BackgroundColour[2], options.BackgroundColour[3]), } + var windowHandle *application.WebviewWindow + if wirePreloadOnPageLoad(&wOpts, options.URL, func(origin string, target preload.Webview) { + if target == nil { + target = windowHandle + } + if target == nil { + return + } + _ = preload.InjectPreload(target, origin) + if extra := postPageLoadWindowJS(options.JS); strings.TrimSpace(extra) != "" { + target.ExecJS(extra) + } + }) { + wOpts.JS = "" + } w := wp.app.Window.NewWithOptions(wOpts) + windowHandle = w return &wailsWindow{w: w, title: options.Title, opacity: 1.0} } +func wirePreloadOnPageLoad(options *application.WebviewWindowOptions, fallbackOrigin string, inject func(origin string, target preload.Webview)) bool { + if options == nil || inject == nil { + return false + } + + value := reflect.ValueOf(options) + if value.Kind() != reflect.Pointer || value.IsNil() { + return false + } + structValue := value.Elem() + if structValue.Kind() != reflect.Struct { + return false + } + + field := structValue.FieldByName("OnPageLoad") + if !field.IsValid() || !field.CanSet() || field.Kind() != reflect.Func { + return false + } + + fnType := field.Type() + field.Set(reflect.MakeFunc(fnType, func(args []reflect.Value) []reflect.Value { + inject(extractPageLoadOrigin(args, fallbackOrigin), extractPageLoadWebview(args)) + return zeroReturnValues(fnType) + })) + return true +} + +func extractPageLoadOrigin(args []reflect.Value, fallback string) string { + for _, arg := range args { + if !arg.IsValid() { + continue + } + if arg.Kind() == reflect.Pointer { + if arg.IsNil() { + continue + } + arg = arg.Elem() + } + switch arg.Kind() { + case reflect.String: + if value := strings.TrimSpace(arg.String()); value != "" { + return value + } + case reflect.Struct: + for _, name := range []string{"URL", "Url", "Origin", "Location"} { + field := arg.FieldByName(name) + if field.IsValid() && field.Kind() == reflect.String { + if value := strings.TrimSpace(field.String()); value != "" { + return value + } + } + } + } + } + return fallback +} + +func extractPageLoadWebview(args []reflect.Value) preload.Webview { + for _, arg := range args { + if !arg.IsValid() || !arg.CanInterface() { + continue + } + if target, ok := arg.Interface().(preload.Webview); ok { + return target + } + } + return nil +} + +func zeroReturnValues(fnType reflect.Type) []reflect.Value { + if fnType.NumOut() == 0 { + return nil + } + out := make([]reflect.Value, 0, fnType.NumOut()) + for i := 0; i < fnType.NumOut(); i++ { + out = append(out, reflect.Zero(fnType.Out(i))) + } + return out +} + +func postPageLoadWindowJS(raw string) string { + if looksLikeLegacyDisplayPreload(raw) { + return "" + } + return raw +} + +func looksLikeLegacyDisplayPreload(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + return strings.Contains(trimmed, "const __corePageURL =") && + strings.Contains(trimmed, "globalThis.core.ml") && + strings.Contains(trimmed, "Document.prototype, 'cookie'") +} + func (wp *WailsPlatform) GetWindows() []PlatformWindow { all := wp.app.Window.GetAll() out := make([]PlatformWindow, 0, len(all))