Harden GUI security boundaries
This commit is contained in:
parent
65ccf50c2b
commit
723116acb7
6 changed files with 219 additions and 21 deletions
|
|
@ -754,7 +754,7 @@ func (s *Service) handleWSMessage(msg WSMessage) core.Result {
|
|||
}
|
||||
return c.QUERY(webview.QueryTitle{Window: w})
|
||||
default:
|
||||
return core.Result{}
|
||||
return core.Result{Value: coreerr.E("display.handleWSMessage", "unknown websocket action: "+msg.Action, nil), OK: false}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ type WSEventManager struct {
|
|||
// clientState tracks a client's subscriptions.
|
||||
type clientState struct {
|
||||
subscriptions map[string]*Subscription
|
||||
writeMu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
|
|
@ -124,10 +125,13 @@ func trustedWebSocketOrigin(r *http.Request) bool {
|
|||
if !trustedWebSocketHost(r.Host) {
|
||||
return false
|
||||
}
|
||||
if !trustedWSRequestOrigin(r.RemoteAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" || strings.EqualFold(origin, "null") {
|
||||
return trustedWSRequestOrigin(r.RemoteAddr)
|
||||
return true
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(origin)
|
||||
|
|
@ -221,10 +225,10 @@ func (em *WSEventManager) clientSubscribed(state *clientState, eventType EventTy
|
|||
// sendEvent sends an event to a specific client.
|
||||
func (em *WSEventManager) sendEvent(conn *websocket.Conn, event Event) {
|
||||
em.mu.RLock()
|
||||
_, exists := em.clients[conn]
|
||||
state, exists := em.clients[conn]
|
||||
em.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if !exists || state == nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -234,8 +238,11 @@ func (em *WSEventManager) sendEvent(conn *websocket.Conn, event Event) {
|
|||
}
|
||||
data, _ := marshalResult.Value.([]byte)
|
||||
|
||||
state.writeMu.Lock()
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
err := conn.WriteMessage(websocket.TextMessage, data)
|
||||
state.writeMu.Unlock()
|
||||
if err != nil {
|
||||
em.removeClient(conn)
|
||||
}
|
||||
}
|
||||
|
|
@ -258,6 +265,8 @@ func (em *WSEventManager) HandleWebSocket(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
em.mu.Unlock()
|
||||
|
||||
conn.SetReadLimit(64 * 1024)
|
||||
|
||||
// Handle incoming messages
|
||||
go em.handleMessages(conn)
|
||||
}
|
||||
|
|
@ -279,9 +288,11 @@ func (em *WSEventManager) handleMessages(conn *websocket.Conn) {
|
|||
}
|
||||
|
||||
if unmarshalResult := core.JSONUnmarshal(message, &msg); !unmarshalResult.OK {
|
||||
continue
|
||||
em.closeWithPolicyViolation(conn, "invalid websocket message")
|
||||
return
|
||||
}
|
||||
|
||||
handled := true
|
||||
switch msg.Action {
|
||||
case "subscribe":
|
||||
em.subscribe(conn, msg.ID, msg.EventTypes)
|
||||
|
|
@ -289,10 +300,28 @@ func (em *WSEventManager) handleMessages(conn *websocket.Conn) {
|
|||
em.unsubscribe(conn, msg.ID)
|
||||
case "list":
|
||||
em.listSubscriptions(conn)
|
||||
default:
|
||||
handled = false
|
||||
}
|
||||
if !handled {
|
||||
em.closeWithPolicyViolation(conn, "unknown websocket action")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (em *WSEventManager) closeWithPolicyViolation(conn *websocket.Conn, reason string) {
|
||||
em.mu.RLock()
|
||||
state, exists := em.clients[conn]
|
||||
em.mu.RUnlock()
|
||||
if !exists || state == nil {
|
||||
return
|
||||
}
|
||||
state.writeMu.Lock()
|
||||
defer state.writeMu.Unlock()
|
||||
_ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, reason), time.Now().Add(2*time.Second))
|
||||
}
|
||||
|
||||
// subscribe adds a subscription for a client.
|
||||
func (em *WSEventManager) subscribe(conn *websocket.Conn, id string, eventTypes []EventType) {
|
||||
em.mu.RLock()
|
||||
|
|
@ -326,7 +355,7 @@ func (em *WSEventManager) subscribe(conn *websocket.Conn, id string, eventTypes
|
|||
}
|
||||
if marshalResult := core.JSONMarshal(response); marshalResult.OK {
|
||||
responseData, _ := marshalResult.Value.([]byte)
|
||||
conn.WriteMessage(websocket.TextMessage, responseData)
|
||||
em.writeClientMessage(state, conn, responseData)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -351,7 +380,7 @@ func (em *WSEventManager) unsubscribe(conn *websocket.Conn, id string) {
|
|||
}
|
||||
if marshalResult := core.JSONMarshal(response); marshalResult.OK {
|
||||
responseData, _ := marshalResult.Value.([]byte)
|
||||
conn.WriteMessage(websocket.TextMessage, responseData)
|
||||
em.writeClientMessage(state, conn, responseData)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -378,10 +407,17 @@ func (em *WSEventManager) listSubscriptions(conn *websocket.Conn) {
|
|||
}
|
||||
if marshalResult := core.JSONMarshal(response); marshalResult.OK {
|
||||
responseData, _ := marshalResult.Value.([]byte)
|
||||
conn.WriteMessage(websocket.TextMessage, responseData)
|
||||
em.writeClientMessage(state, conn, responseData)
|
||||
}
|
||||
}
|
||||
|
||||
func (em *WSEventManager) writeClientMessage(state *clientState, conn *websocket.Conn, data []byte) {
|
||||
state.writeMu.Lock()
|
||||
defer state.writeMu.Unlock()
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
_ = conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
// removeClient removes a client and its subscriptions.
|
||||
func (em *WSEventManager) removeClient(conn *websocket.Conn) {
|
||||
em.mu.Lock()
|
||||
|
|
|
|||
|
|
@ -153,6 +153,41 @@ func TestWSEventManager_HandleWebSocket_RejectsRemoteOrigin(t *testing.T) {
|
|||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
}
|
||||
|
||||
func TestWSEventManager_HandleWebSocket_RejectsLoopbackSpoofedOrigin(t *testing.T) {
|
||||
em := NewWSEventManager()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/events", nil)
|
||||
req.RemoteAddr = "203.0.113.10:12345"
|
||||
req.Header.Set("Origin", "file://malicious")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
em.HandleWebSocket(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
}
|
||||
|
||||
func TestWSEventManager_HandleWebSocket_ClosesOnMalformedMessage(t *testing.T) {
|
||||
em := NewWSEventManager()
|
||||
conn, cleanup := dialWSEventManager(t, em)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":`)))
|
||||
|
||||
_, _, err := conn.ReadMessage()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWSEventManager_HandleWebSocket_ClosesOnUnknownAction(t *testing.T) {
|
||||
em := NewWSEventManager()
|
||||
conn, cleanup := dialWSEventManager(t, em)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"action":"bogus"}`)))
|
||||
|
||||
_, _, err := conn.ReadMessage()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWSEventManager_Emit_Ugly(t *testing.T) {
|
||||
em := &WSEventManager{
|
||||
clients: map[*websocket.Conn]*clientState{},
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package display
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
|
|
@ -39,12 +41,47 @@ func (s *Service) BuildPreloadScript(pageURL string) (string, error) {
|
|||
s.injectCoreMLShim(),
|
||||
s.buildHLCRFComponents(pageURL),
|
||||
}
|
||||
if appPreloads, err := s.injectAppPreloads(pageURL); err == nil && strings.TrimSpace(appPreloads) != "" {
|
||||
if appPreloads, err := s.injectAppPreloads(pageURL); err != nil {
|
||||
if !strings.Contains(err.Error(), "view manifest not found") {
|
||||
return "", err
|
||||
}
|
||||
} else if strings.TrimSpace(appPreloads) != "" {
|
||||
parts = append(parts, appPreloads)
|
||||
}
|
||||
return strings.Join(parts, "\n"), nil
|
||||
}
|
||||
|
||||
func validatedLocalMLAPIURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "http://localhost:8090"
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return "http://localhost:8090"
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
default:
|
||||
return "http://localhost:8090"
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Host)
|
||||
if host == "" {
|
||||
return "http://localhost:8090"
|
||||
}
|
||||
name := host
|
||||
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||
name = parsedHost
|
||||
}
|
||||
name = strings.Trim(strings.ToLower(name), "[]")
|
||||
switch name {
|
||||
case "localhost", "127.0.0.1", "::1":
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
default:
|
||||
return "http://localhost:8090"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) injectStoragePolyfills(pageOrigin string, bootstrap map[string]map[string]string) string {
|
||||
return `(function() {
|
||||
const __corePageURL = ` + core.JSONMarshalString(pageOrigin) + `;
|
||||
|
|
@ -569,7 +606,7 @@ func (s *Service) injectBackgroundServiceShims() string {
|
|||
|
||||
func (s *Service) injectCoreMLShim() string {
|
||||
return `(function() {
|
||||
const __coreMLApiURL = ` + core.JSONMarshalString(strings.TrimRight(core.Env("CORE_ML_API_URL"), "/")) + ` || "http://localhost:8090";
|
||||
const __coreMLApiURL = ` + core.JSONMarshalString(validatedLocalMLAPIURL(core.Env("CORE_ML_API_URL"))) + ` || "http://localhost:8090";
|
||||
globalThis.core = globalThis.core || {};
|
||||
globalThis.core.ml = globalThis.core.ml || {
|
||||
async generate(input) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
|
@ -40,6 +41,8 @@ type Installer struct {
|
|||
InstallDir string
|
||||
}
|
||||
|
||||
const maxManifestBytes = 1 << 20
|
||||
|
||||
func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manifest, error) {
|
||||
client := i.HTTPClient
|
||||
if client == nil {
|
||||
|
|
@ -57,10 +60,13 @@ func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manif
|
|||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
return Manifest{}, fmt.Errorf("manifest fetch failed: %s", resp.Status)
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxManifestBytes+1))
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
if len(body) > maxManifestBytes {
|
||||
return Manifest{}, fmt.Errorf("manifest fetch failed: manifest exceeds %d bytes", maxManifestBytes)
|
||||
}
|
||||
var manifest Manifest
|
||||
if err := yaml.Unmarshal(body, &manifest); err != nil {
|
||||
return Manifest{}, err
|
||||
|
|
@ -72,8 +78,11 @@ func (i Installer) FetchManifest(ctx context.Context, manifestURL string) (Manif
|
|||
}
|
||||
|
||||
func VerifyManifest(manifest Manifest) error {
|
||||
if strings.ToLower(strings.TrimSpace(manifest.Signature.Algorithm)) != "ed25519" {
|
||||
return errors.New("manifest signature algorithm must be ed25519")
|
||||
}
|
||||
if manifest.Signature.Value == "" || manifest.Signature.PublicKey == "" {
|
||||
return nil
|
||||
return errors.New("manifest signature is required")
|
||||
}
|
||||
payload := manifest.Name + "\n" + manifest.Version + "\n" + manifest.Repository + "\n" + manifest.Ref
|
||||
signature, err := base64.StdEncoding.DecodeString(manifest.Signature.Value)
|
||||
|
|
@ -84,6 +93,12 @@ func VerifyManifest(manifest Manifest) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(signature) != ed25519.SignatureSize {
|
||||
return errors.New("manifest signature has invalid size")
|
||||
}
|
||||
if len(publicKey) != ed25519.PublicKeySize {
|
||||
return errors.New("manifest public key has invalid size")
|
||||
}
|
||||
if !ed25519.Verify(ed25519.PublicKey(publicKey), []byte(payload), signature) {
|
||||
return errors.New("manifest signature verification failed")
|
||||
}
|
||||
|
|
@ -91,16 +106,34 @@ func VerifyManifest(manifest Manifest) error {
|
|||
}
|
||||
|
||||
func (i Installer) Install(ctx context.Context, manifest Manifest) (string, error) {
|
||||
if strings.TrimSpace(i.InstallDir) == "" {
|
||||
return "", errors.New("install dir is required")
|
||||
}
|
||||
if err := VerifyManifest(manifest); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(i.InstallDir) == "" {
|
||||
return "", errors.New("install dir is required")
|
||||
if err := validateManifestName(manifest.Name); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.MkdirAll(i.InstallDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
targetDir := filepath.Join(i.InstallDir, safeName(manifest.Name))
|
||||
rootAbs, err := filepath.Abs(i.InstallDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
targetAbs, err := filepath.Abs(targetDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rel, err := filepath.Rel(rootAbs, targetAbs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
|
||||
return "", errors.New("install path escapes install dir")
|
||||
}
|
||||
_ = os.RemoveAll(targetDir)
|
||||
args := []string{"clone", "--depth", "1"}
|
||||
if manifest.Ref != "" {
|
||||
|
|
@ -118,6 +151,20 @@ func (i Installer) Install(ctx context.Context, manifest Manifest) (string, erro
|
|||
return targetDir, nil
|
||||
}
|
||||
|
||||
func validateManifestName(value string) error {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return errors.New("manifest name is required")
|
||||
}
|
||||
if strings.ContainsAny(trimmed, `/\`) {
|
||||
return errors.New("manifest name must not contain path separators")
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return errors.New("manifest name must not contain path traversal segments")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DigestManifest(manifest Manifest) string {
|
||||
hash := sha256.Sum256([]byte(manifest.Name + ":" + manifest.Version + ":" + manifest.Repository + ":" + manifest.Ref))
|
||||
return hex.EncodeToString(hash[:])
|
||||
|
|
@ -125,9 +172,29 @@ func DigestManifest(manifest Manifest) string {
|
|||
|
||||
func safeName(value string) string {
|
||||
value = strings.TrimSpace(strings.ToLower(value))
|
||||
value = strings.ReplaceAll(value, " ", "-")
|
||||
if value == "" {
|
||||
return "module"
|
||||
}
|
||||
return value
|
||||
var builder strings.Builder
|
||||
lastDash := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case unicode.IsLetter(r), unicode.IsDigit(r):
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
default:
|
||||
if !lastDash {
|
||||
builder.WriteRune('-')
|
||||
lastDash = true
|
||||
}
|
||||
}
|
||||
}
|
||||
cleaned := strings.Trim(builder.String(), "-._")
|
||||
if cleaned == "" {
|
||||
return "module"
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,6 +112,17 @@ func TestMarketplace_VerifyManifest_Ugly(t *testing.T) {
|
|||
require.Error(t, VerifyManifest(manifest))
|
||||
}
|
||||
|
||||
func TestMarketplace_VerifyManifest_RequiresSignature(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
Name: "core-ui",
|
||||
Version: "1.2.3",
|
||||
Repository: "https://example.com/core-ui.git",
|
||||
Ref: "main",
|
||||
}
|
||||
|
||||
require.Error(t, VerifyManifest(manifest))
|
||||
}
|
||||
|
||||
func TestMarketplace_Install_Good(t *testing.T) {
|
||||
scriptDir := t.TempDir()
|
||||
logFile := filepath.Join(scriptDir, "git.log")
|
||||
|
|
@ -125,12 +136,12 @@ func TestMarketplace_Install_Good(t *testing.T) {
|
|||
InstallDir: targetRoot,
|
||||
}
|
||||
|
||||
targetDir, err := installer.Install(context.Background(), Manifest{
|
||||
targetDir, err := installer.Install(context.Background(), signedManifest(t, Manifest{
|
||||
Name: "Core UI",
|
||||
Version: "1.2.3",
|
||||
Repository: "https://example.com/core-ui.git",
|
||||
Ref: "main",
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(targetRoot, "core-ui"), targetDir)
|
||||
_, err = os.Stat(targetDir)
|
||||
|
|
@ -150,6 +161,18 @@ func TestMarketplace_Install_Bad(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "install dir is required")
|
||||
}
|
||||
|
||||
func TestMarketplace_Install_RejectsTraversalName(t *testing.T) {
|
||||
installer := Installer{InstallDir: t.TempDir()}
|
||||
_, err := installer.Install(context.Background(), signedManifest(t, Manifest{
|
||||
Name: "../../escape",
|
||||
Version: "1.2.3",
|
||||
Repository: "https://example.com/core-ui.git",
|
||||
Ref: "main",
|
||||
}))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "path separators")
|
||||
}
|
||||
|
||||
func TestMarketplace_Install_Ugly(t *testing.T) {
|
||||
scriptDir := t.TempDir()
|
||||
scriptPath := filepath.Join(scriptDir, "git")
|
||||
|
|
@ -160,10 +183,10 @@ func TestMarketplace_Install_Ugly(t *testing.T) {
|
|||
InstallDir: t.TempDir(),
|
||||
}
|
||||
|
||||
_, err := installer.Install(context.Background(), Manifest{
|
||||
_, err := installer.Install(context.Background(), signedManifest(t, Manifest{
|
||||
Name: "core-ui",
|
||||
Repository: "https://example.com/core-ui.git",
|
||||
})
|
||||
}))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "git clone failed")
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue