gui/pkg/preload/preload.go
Snider fa4168e380 feat(gui): InjectPreload — storage polyfills + Electron shim + app preloads
New pkg/preload package:
- preload.go — InjectPreload(webview, origin) entry point; builds
  three-step preload: storage polyfills, Electron shim (origin-
  filtered), app preloads from .core/view.yaml manifest.preloads.
- assets/storage_polyfills.js — localStorage/sessionStorage/
  IndexedDB bridges.
- assets/electron_shim.js — minimal ipcRenderer.send/invoke
  mapping to core.QUERY/ACTION.
- Adds a minimal window.core.ml.generate shim — gates the
  AI-native browser path (RFC §11a).

pkg/window/wails.go wires into Wails OnPageLoad via reflection when
the runtime exposes the hook, with a clean fallback for the
stubbed/test runtime. Legacy display-preload code detected and
skipped when the new package is in play.

Good/Bad/Ugly tests in pkg/preload/preload_test.go. go vet + go
test clean.

Closes tasks.lthn.sh/view.php?id=16

Co-authored-by: Codex <noreply@openai.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-04-24 06:17:34 +01:00

436 lines
10 KiB
Go

package preload
import (
"embed"
"errors"
"io"
"net"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
core "dappco.re/go/core"
"gopkg.in/yaml.v3"
)
const maxViewManifestBytes = 1 << 20
var errViewManifestNotFound = errors.New("view manifest not found")
type Webview interface {
ExecJS(string)
}
type ManifestPreload struct {
Path string `yaml:"path"`
Inline string `yaml:"inline"`
Enabled *bool `yaml:"enabled,omitempty"`
}
type viewManifest struct {
Preloads []ManifestPreload `yaml:"preloads"`
Manifest struct {
Preloads []ManifestPreload `yaml:"preloads"`
} `yaml:"manifest"`
}
type loadedManifest struct {
Path string
BaseDir string
Preloads []ManifestPreload
}
//go:embed assets/*.js
var assetFS embed.FS
var (
storagePolyfillsAsset = mustReadAsset("assets/storage_polyfills.js")
electronShimAsset = mustReadAsset("assets/electron_shim.js")
)
func InjectPreload(webview Webview, origin string) error {
if isNilWebview(webview) {
return errors.New("preload target is required")
}
script, err := buildScript(origin)
if err != nil {
return err
}
if strings.TrimSpace(script) == "" {
return nil
}
webview.ExecJS(script)
return nil
}
func buildScript(pageURL string) (string, error) {
loaded, manifestErr := loadManifestForOrigin(pageURL)
switch {
case manifestErr == nil:
case errors.Is(manifestErr, errViewManifestNotFound):
loaded = nil
default:
return "", manifestErr
}
allowPrivileged := trustedOrigin(pageURL) || loaded != nil
parts := []string{
renderStoragePolyfills(pageURL, allowPrivileged),
renderCoreMLShim(),
}
if allowPrivileged {
parts = append(parts, renderElectronShim(pageURL))
}
if appPreloads, err := renderAppPreloads(loaded); err != nil {
return "", err
} else if strings.TrimSpace(appPreloads) != "" {
parts = append(parts, appPreloads)
}
return strings.Join(filterEmpty(parts), "\n"), nil
}
func mustReadAsset(name string) string {
body, err := assetFS.ReadFile(name)
if err != nil {
panic(err)
}
return string(body)
}
func renderAppPreloads(loaded *loadedManifest) (string, error) {
if loaded == nil || len(loaded.Preloads) == 0 {
return "", nil
}
scripts := make([]string, 0, len(loaded.Preloads))
for _, preload := range loaded.Preloads {
if preload.Enabled != nil && !*preload.Enabled {
continue
}
if inline := strings.TrimSpace(preload.Inline); inline != "" {
scripts = append(scripts, inline)
continue
}
if path := strings.TrimSpace(preload.Path); path != "" {
body, err := readManifestPreload(loaded.BaseDir, path)
if err != nil {
return "", err
}
scripts = append(scripts, string(body))
}
}
return strings.Join(scripts, "\n"), nil
}
func loadManifestForOrigin(pageURL string) (*loadedManifest, error) {
path, err := discoverManifestPath(pageURL)
if err != nil {
return nil, err
}
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
body, err := io.ReadAll(io.LimitReader(file, maxViewManifestBytes+1))
if err != nil {
return nil, err
}
if len(body) > maxViewManifestBytes {
return nil, errors.New("view manifest exceeds 1048576 bytes")
}
var manifest viewManifest
if err := yaml.Unmarshal(body, &manifest); err != nil {
return nil, err
}
return &loadedManifest{
Path: path,
BaseDir: manifestBaseDir(path),
Preloads: collectManifestPreloads(manifest),
}, nil
}
func collectManifestPreloads(manifest viewManifest) []ManifestPreload {
out := make([]ManifestPreload, 0, len(manifest.Preloads)+len(manifest.Manifest.Preloads))
out = append(out, manifest.Preloads...)
out = append(out, manifest.Manifest.Preloads...)
return out
}
func discoverManifestPath(pageURL string) (string, error) {
trimmed := strings.TrimSpace(pageURL)
if trimmed == "" {
return "", errViewManifestNotFound
}
parsed, err := url.Parse(trimmed)
if err != nil {
return "", err
}
candidates := make([]string, 0, 4)
switch parsed.Scheme {
case "", "file":
path := parsed.Path
if path == "" {
path = trimmed
}
path = filepath.FromSlash(path)
if info, err := os.Stat(path); err == nil {
if info.IsDir() {
candidates = append(candidates, filepath.Join(path, ".core", "view.yaml"))
} else {
dir := filepath.Dir(path)
candidates = append(candidates, filepath.Join(dir, ".core", "view.yaml"))
candidates = append(candidates, filepath.Join(filepath.Dir(dir), ".core", "view.yaml"))
}
}
default:
if host := strings.TrimSpace(parsed.Host); host != "" {
home := strings.TrimSpace(os.Getenv("DIR_HOME"))
if home == "" {
home = strings.TrimSpace(core.Env("DIR_HOME"))
}
if home != "" {
candidates = append(candidates, filepath.Join(home, ".core", "apps", host, ".core", "view.yaml"))
}
}
}
for _, candidate := range candidates {
if _, err := os.Stat(candidate); err == nil {
return candidate, nil
}
}
return "", errViewManifestNotFound
}
func manifestBaseDir(manifestPath string) string {
baseDir := filepath.Dir(manifestPath)
if filepath.Base(baseDir) == ".core" {
return filepath.Dir(baseDir)
}
return baseDir
}
func readManifestPreload(baseDir, preloadPath string) ([]byte, error) {
resolvedPath, err := safeManifestRelativePath(baseDir, preloadPath)
if err != nil {
return nil, err
}
return os.ReadFile(resolvedPath)
}
func safeManifestRelativePath(baseDir, relativePath string) (string, error) {
trimmed := strings.TrimSpace(relativePath)
if trimmed == "" {
return "", errors.New("preload path is empty")
}
if filepath.IsAbs(trimmed) {
return "", errors.New("preload path must be relative")
}
baseAbs, err := filepath.Abs(baseDir)
if err != nil {
return "", err
}
baseResolved, err := filepath.EvalSymlinks(baseAbs)
if err != nil {
return "", err
}
candidateAbs, err := filepath.Abs(filepath.Join(baseAbs, trimmed))
if err != nil {
return "", err
}
if rel, err := filepath.Rel(baseAbs, candidateAbs); err != nil {
return "", err
} else if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return "", errors.New("preload path escapes manifest directory")
}
if _, err := os.Lstat(candidateAbs); err != nil {
return "", err
}
candidateResolved, err := filepath.EvalSymlinks(candidateAbs)
if err != nil {
return "", err
}
if rel, err := filepath.Rel(baseResolved, candidateResolved); err != nil {
return "", err
} else if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
return "", errors.New("preload path escapes manifest directory")
}
return candidateResolved, nil
}
func trustedOrigin(pageURL string) bool {
trimmed := strings.TrimSpace(pageURL)
if trimmed == "" {
return false
}
parsed, err := url.Parse(trimmed)
if err != nil {
return false
}
switch strings.ToLower(parsed.Scheme) {
case "core", "wails", "app":
return true
case "http", "https":
host := strings.TrimSpace(parsed.Host)
if host == "" {
return false
}
name := host
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
name = parsedHost
}
name = strings.Trim(strings.ToLower(name), "[]")
switch name {
case "localhost", "127.0.0.1", "::1":
return true
default:
return false
}
default:
return false
}
}
func storageOriginForPageURL(pageURL string) string {
trimmed := strings.TrimSpace(pageURL)
if trimmed == "" {
return ""
}
parsed, err := url.Parse(trimmed)
if err != nil || strings.TrimSpace(parsed.Scheme) == "" {
return ""
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
if parsed.Host == "" {
return ""
}
return parsed.Scheme + "://" + parsed.Host
case "core":
if parsed.Host == "" {
return "core://"
}
return "core://" + parsed.Host
case "file":
if parsed.Path == "" {
return ""
}
return "file://" + parsed.Path
default:
if parsed.Host == "" {
return ""
}
origin := parsed.Scheme + "://" + parsed.Host
if parsed.Path != "" {
origin += parsed.Path
}
return strings.TrimRight(origin, "/")
}
}
func validatedLocalMLAPIURL(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "http://localhost:8090"
}
parsed, err := url.Parse(trimmed)
if err != nil {
return "http://localhost:8090"
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
default:
return "http://localhost:8090"
}
host := strings.TrimSpace(parsed.Host)
if host == "" {
return "http://localhost:8090"
}
name := host
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
name = parsedHost
}
name = strings.Trim(strings.ToLower(name), "[]")
switch name {
case "localhost", "127.0.0.1", "::1":
return strings.TrimRight(parsed.String(), "/")
default:
return "http://localhost:8090"
}
}
func renderCoreMLShim() string {
return `(function() {
const apiURL = ` + core.JSONMarshalString(validatedLocalMLAPIURL(core.Env("CORE_ML_API_URL"))) + ` || "http://localhost:8090";
globalThis.core = globalThis.core || {};
globalThis.core.ml = globalThis.core.ml || {
async generate(input) {
const payload = typeof input === "string"
? { messages: [{ role: "user", content: input }], stream: false }
: { ...input, stream: false };
const response = await fetch(apiURL + "/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload)
});
if (!response.ok) {
throw new Error("Core ML request failed: " + response.status + " " + response.statusText);
}
const body = await response.text();
try {
const parsed = JSON.parse(body);
return parsed?.choices?.[0]?.message?.content ?? parsed?.content ?? body;
} catch (_) {
return body;
}
}
};
})();`
}
func filterEmpty(parts []string) []string {
out := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) != "" {
out = append(out, part)
}
}
return out
}
func isNilWebview(webview Webview) bool {
if webview == nil {
return true
}
value := reflect.ValueOf(webview)
switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return value.IsNil()
default:
return false
}
}