From ae0677a04618bf359e9ba9bd86f24fe9fc699a1c Mon Sep 17 00:00:00 2001 From: Virgil Date: Mon, 23 Mar 2026 14:32:11 +0000 Subject: [PATCH] fix(security): harden installer, marketplace, and sync path handling Co-Authored-By: Virgil --- cmd/forge/cmd_sync.go | 54 +++++++++++++++--- cmd/forge/cmd_sync_test.go | 53 ++++++++++++++++++ cmd/gitea/cmd_sync.go | 55 +++++++++++++++---- cmd/gitea/cmd_sync_test.go | 53 ++++++++++++++++++ forge/prs.go | 24 ++++++-- forge/prs_test.go | 49 +++++++++++++++++ jobrunner/handlers/dispatch.go | 69 +++++++++++++++++++---- jobrunner/handlers/dispatch_test.go | 85 +++++++++++++++++++++++++++++ marketplace/installer.go | 59 ++++++++++++++------ marketplace/installer_test.go | 58 +++++++++++++++++++- pkg/api/provider.go | 40 ++++++++++++-- pkg/api/provider_security_test.go | 26 +++++++++ pkg/api/provider_test.go | 12 ++++ plugin/installer.go | 58 +++++++++++++++----- plugin/installer_test.go | 46 ++++++++++++++++ 15 files changed, 669 insertions(+), 72 deletions(-) create mode 100644 cmd/forge/cmd_sync_test.go create mode 100644 cmd/gitea/cmd_sync_test.go create mode 100644 pkg/api/provider_security_test.go diff --git a/cmd/forge/cmd_sync.go b/cmd/forge/cmd_sync.go index 390d8c9..7a4176c 100644 --- a/cmd/forge/cmd_sync.go +++ b/cmd/forge/cmd_sync.go @@ -2,16 +2,18 @@ package forge import ( "fmt" + "net/url" "os" "os/exec" "path/filepath" "strings" - forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" - - "forge.lthn.ai/core/cli/pkg/cli" coreerr "dappco.re/go/core/log" + "dappco.re/go/core/scm/agentci" fg "dappco.re/go/core/scm/forge" + + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + "forge.lthn.ai/core/cli/pkg/cli" ) // Sync command flags. @@ -95,11 +97,14 @@ func buildSyncRepoList(client *fg.Client, args []string, basePath string) ([]syn if len(args) > 0 { for _, arg := range args { - name := arg - if parts := strings.SplitN(arg, "/", 2); len(parts) == 2 { - name = parts[1] + name, err := syncRepoNameFromArg(arg) + if err != nil { + return nil, coreerr.E("forge.buildSyncRepoList", "invalid repo argument", err) + } + _, localPath, err := agentci.ResolvePathWithinRoot(basePath, name) + if err != nil { + return nil, coreerr.E("forge.buildSyncRepoList", "resolve local path", err) } - localPath := filepath.Join(basePath, name) branch := syncDetectDefaultBranch(localPath) repos = append(repos, syncRepoEntry{ name: name, @@ -113,10 +118,17 @@ func buildSyncRepoList(client *fg.Client, args []string, basePath string) ([]syn return nil, err } for _, r := range orgRepos { - localPath := filepath.Join(basePath, r.Name) + name, err := agentci.ValidatePathElement(r.Name) + if err != nil { + return nil, coreerr.E("forge.buildSyncRepoList", "invalid repo name from org list", err) + } + _, localPath, err := agentci.ResolvePathWithinRoot(basePath, name) + if err != nil { + return nil, coreerr.E("forge.buildSyncRepoList", "resolve local path", err) + } branch := syncDetectDefaultBranch(localPath) repos = append(repos, syncRepoEntry{ - name: r.Name, + name: name, localPath: localPath, defaultBranch: branch, }) @@ -333,3 +345,27 @@ func syncCreateMainFromUpstream(client *fg.Client, org, repo string) error { return nil } + +func syncRepoNameFromArg(arg string) (string, error) { + decoded, err := url.PathUnescape(arg) + if err != nil { + return "", coreerr.E("forge.syncRepoNameFromArg", "decode repo argument", err) + } + + parts := strings.Split(decoded, "/") + switch len(parts) { + case 1: + return agentci.ValidatePathElement(parts[0]) + case 2: + if _, err := agentci.ValidatePathElement(parts[0]); err != nil { + return "", coreerr.E("forge.syncRepoNameFromArg", "invalid repo owner", err) + } + name, err := agentci.ValidatePathElement(parts[1]) + if err != nil { + return "", coreerr.E("forge.syncRepoNameFromArg", "invalid repo name", err) + } + return name, nil + default: + return "", coreerr.E("forge.syncRepoNameFromArg", "repo argument must be repo or owner/repo", nil) + } +} diff --git a/cmd/forge/cmd_sync_test.go b/cmd/forge/cmd_sync_test.go new file mode 100644 index 0000000..c75d74b --- /dev/null +++ b/cmd/forge/cmd_sync_test.go @@ -0,0 +1,53 @@ +package forge + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildSyncRepoList_Good(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + repos, err := buildSyncRepoList(nil, []string{"host-uk/core"}, basePath) + require.NoError(t, err) + require.Len(t, repos, 1) + assert.Equal(t, "core", repos[0].name) + assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath) +} + +func TestBuildSyncRepoList_Bad_PathTraversal(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + _, err := buildSyncRepoList(nil, []string{"../escape"}, basePath) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid repo argument") +} + +func TestBuildSyncRepoList_Good_OwnerRepo(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + repos, err := buildSyncRepoList(nil, []string{"Host-UK/core"}, basePath) + require.NoError(t, err) + require.Len(t, repos, 1) + assert.Equal(t, "core", repos[0].name) + assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath) +} + +func TestBuildSyncRepoList_Bad_PathTraversal_OwnerRepo(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + _, err := buildSyncRepoList(nil, []string{"host-uk/../escape"}, basePath) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid repo argument") +} + +func TestBuildSyncRepoList_Bad_PathTraversal_OwnerRepoEncoded(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + _, err := buildSyncRepoList(nil, []string{"host-uk%2F..%2Fescape"}, basePath) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid repo argument") +} diff --git a/cmd/gitea/cmd_sync.go b/cmd/gitea/cmd_sync.go index b9b4c8f..de14da3 100644 --- a/cmd/gitea/cmd_sync.go +++ b/cmd/gitea/cmd_sync.go @@ -2,16 +2,18 @@ package gitea import ( "fmt" + "net/url" "os" "os/exec" "path/filepath" "strings" - "code.gitea.io/sdk/gitea" - - "forge.lthn.ai/core/cli/pkg/cli" coreerr "dappco.re/go/core/log" + "dappco.re/go/core/scm/agentci" gt "dappco.re/go/core/scm/gitea" + + "code.gitea.io/sdk/gitea" + "forge.lthn.ai/core/cli/pkg/cli" ) // Sync command flags. @@ -96,12 +98,14 @@ func buildRepoList(client *gt.Client, args []string, basePath string) ([]repoEnt if len(args) > 0 { // Specific repos from args for _, arg := range args { - name := arg - // Strip owner/ prefix if given - if parts := strings.SplitN(arg, "/", 2); len(parts) == 2 { - name = parts[1] + name, err := repoNameFromArg(arg) + if err != nil { + return nil, coreerr.E("gitea.buildRepoList", "invalid repo argument", err) + } + _, localPath, err := agentci.ResolvePathWithinRoot(basePath, name) + if err != nil { + return nil, coreerr.E("gitea.buildRepoList", "resolve local path", err) } - localPath := filepath.Join(basePath, name) branch := detectDefaultBranch(localPath) repos = append(repos, repoEntry{ name: name, @@ -116,10 +120,17 @@ func buildRepoList(client *gt.Client, args []string, basePath string) ([]repoEnt return nil, err } for _, r := range orgRepos { - localPath := filepath.Join(basePath, r.Name) + name, err := agentci.ValidatePathElement(r.Name) + if err != nil { + return nil, coreerr.E("gitea.buildRepoList", "invalid repo name from org list", err) + } + _, localPath, err := agentci.ResolvePathWithinRoot(basePath, name) + if err != nil { + return nil, coreerr.E("gitea.buildRepoList", "resolve local path", err) + } branch := detectDefaultBranch(localPath) repos = append(repos, repoEntry{ - name: r.Name, + name: name, localPath: localPath, defaultBranch: branch, }) @@ -352,3 +363,27 @@ func createMainFromUpstream(client *gt.Client, org, repo string) error { } func strPtr(s string) *string { return &s } + +func repoNameFromArg(arg string) (string, error) { + decoded, err := url.PathUnescape(arg) + if err != nil { + return "", coreerr.E("gitea.repoNameFromArg", "decode repo argument", err) + } + + parts := strings.Split(decoded, "/") + switch len(parts) { + case 1: + return agentci.ValidatePathElement(parts[0]) + case 2: + if _, err := agentci.ValidatePathElement(parts[0]); err != nil { + return "", coreerr.E("gitea.repoNameFromArg", "invalid repo owner", err) + } + name, err := agentci.ValidatePathElement(parts[1]) + if err != nil { + return "", coreerr.E("gitea.repoNameFromArg", "invalid repo name", err) + } + return name, nil + default: + return "", coreerr.E("gitea.repoNameFromArg", "repo argument must be repo or owner/repo", nil) + } +} diff --git a/cmd/gitea/cmd_sync_test.go b/cmd/gitea/cmd_sync_test.go new file mode 100644 index 0000000..e21e712 --- /dev/null +++ b/cmd/gitea/cmd_sync_test.go @@ -0,0 +1,53 @@ +package gitea + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildRepoList_Good(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + repos, err := buildRepoList(nil, []string{"host-uk/core"}, basePath) + require.NoError(t, err) + require.Len(t, repos, 1) + assert.Equal(t, "core", repos[0].name) + assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath) +} + +func TestBuildRepoList_Bad_PathTraversal(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + _, err := buildRepoList(nil, []string{"../escape"}, basePath) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid repo argument") +} + +func TestBuildRepoList_Good_OwnerRepo(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + repos, err := buildRepoList(nil, []string{"Host-UK/core"}, basePath) + require.NoError(t, err) + require.Len(t, repos, 1) + assert.Equal(t, "core", repos[0].name) + assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath) +} + +func TestBuildRepoList_Bad_PathTraversal_OwnerRepo(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + _, err := buildRepoList(nil, []string{"host-uk/../escape"}, basePath) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid repo argument") +} + +func TestBuildRepoList_Bad_PathTraversal_OwnerRepoEncoded(t *testing.T) { + basePath := filepath.Join(t.TempDir(), "repos") + + _, err := buildRepoList(nil, []string{"host-uk%2F..%2Fescape"}, basePath) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid repo argument") +} diff --git a/forge/prs.go b/forge/prs.go index 070662f..d8d92f7 100644 --- a/forge/prs.go +++ b/forge/prs.go @@ -5,10 +5,13 @@ import ( "encoding/json" "fmt" "net/http" - - forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + "net/url" + "strconv" "dappco.re/go/core/log" + "dappco.re/go/core/scm/agentci" + + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" ) // MergePullRequest merges a pull request with the given method ("squash", "rebase", "merge"). @@ -38,14 +41,27 @@ func (c *Client) MergePullRequest(owner, repo string, index int64, method string // The Forgejo SDK v2.2.0 doesn't expose the draft field on EditPullRequestOption, // so we use a raw HTTP PATCH request. func (c *Client) SetPRDraft(owner, repo string, index int64, draft bool) error { + safeOwner, err := agentci.ValidatePathElement(owner) + if err != nil { + return log.E("forge.SetPRDraft", "invalid owner", err) + } + safeRepo, err := agentci.ValidatePathElement(repo) + if err != nil { + return log.E("forge.SetPRDraft", "invalid repo", err) + } + payload := map[string]bool{"draft": draft} body, err := json.Marshal(payload) if err != nil { return log.E("forge.SetPRDraft", "marshal payload", err) } - url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d", c.url, owner, repo, index) - req, err := http.NewRequest(http.MethodPatch, url, bytes.NewReader(body)) + path, err := url.JoinPath(c.url, "api", "v1", "repos", safeOwner, safeRepo, "pulls", strconv.FormatInt(index, 10)) + if err != nil { + return log.E("forge.SetPRDraft", "failed to build request path", err) + } + + req, err := http.NewRequest(http.MethodPatch, path, bytes.NewReader(body)) if err != nil { return log.E("forge.SetPRDraft", "create request", err) } diff --git a/forge/prs_test.go b/forge/prs_test.go index 14f30be..aabe584 100644 --- a/forge/prs_test.go +++ b/forge/prs_test.go @@ -1,6 +1,9 @@ package forge import ( + "encoding/json" + "net/http" + "net/http/httptest" "strings" "testing" @@ -98,3 +101,49 @@ func TestClient_DismissReview_Bad_ServerError(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "failed to dismiss review") } + +func TestClient_SetPRDraft_Good_Request(t *testing.T) { + var method, path string + var payload map[string]any + + mux := http.NewServeMux() + mux.HandleFunc("/api/v1/version", func(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, map[string]string{"version": "1.21.0"}) + }) + mux.HandleFunc("/api/v1/repos/test-org/org-repo/pulls/3", func(w http.ResponseWriter, r *http.Request) { + method = r.Method + path = r.URL.Path + require.NoError(t, json.NewDecoder(r.Body).Decode(&payload)) + jsonResponse(w, map[string]any{"number": 3}) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + client, err := New(srv.URL, "test-token") + require.NoError(t, err) + + err = client.SetPRDraft("test-org", "org-repo", 3, false) + assert.NoError(t, err) + assert.Equal(t, http.MethodPatch, method) + assert.Equal(t, "/api/v1/repos/test-org/org-repo/pulls/3", path) + assert.Equal(t, false, payload["draft"]) +} + +func TestClient_SetPRDraft_Bad_PathTraversalOwner(t *testing.T) { + client, srv := newTestClient(t) + defer srv.Close() + + err := client.SetPRDraft("../owner", "org-repo", 3, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid owner") +} + +func TestClient_SetPRDraft_Bad_PathTraversalRepo(t *testing.T) { + client, srv := newTestClient(t) + defer srv.Close() + + err := client.SetPRDraft("test-org", "..", 3, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid repo") +} diff --git a/jobrunner/handlers/dispatch.go b/jobrunner/handlers/dispatch.go index fbd83e2..961a9d9 100644 --- a/jobrunner/handlers/dispatch.go +++ b/jobrunner/handlers/dispatch.go @@ -5,7 +5,8 @@ import ( "context" "encoding/json" "fmt" - "path/filepath" + "path" + "strings" "time" coreerr "dappco.re/go/core/log" @@ -85,6 +86,10 @@ func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.Pipelin if !ok { return nil, coreerr.E("dispatch.Execute", "unknown agent: "+signal.Assignee, nil) } + queueDir, err := agentci.ValidateRemoteDir(agent.QueueDir) + if err != nil { + return nil, coreerr.E("dispatch.Execute", "invalid agent queue dir", err) + } // Sanitize inputs to prevent path traversal. safeOwner, err := agentci.SanitizePath(signal.RepoOwner) @@ -184,7 +189,10 @@ func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.Pipelin } // Transfer ticket JSON. - remoteTicketPath := filepath.Join(agent.QueueDir, ticketName) + remoteTicketPath, err := agentci.JoinRemotePath(queueDir, ticketName) + if err != nil { + return nil, coreerr.E("dispatch.Execute", "ticket path", err) + } if err := h.secureTransfer(ctx, agent, remoteTicketPath, ticketJSON, 0644); err != nil { h.failDispatch(signal, fmt.Sprintf("Ticket transfer failed: %v", err)) return &jobrunner.ActionResult{ @@ -202,10 +210,13 @@ func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.Pipelin // Transfer token via separate .env file with 0600 permissions. envContent := fmt.Sprintf("FORGE_TOKEN=%s\n", h.token) - remoteEnvPath := filepath.Join(agent.QueueDir, fmt.Sprintf(".env.%s", ticketID)) + remoteEnvPath, err := agentci.JoinRemotePath(queueDir, fmt.Sprintf(".env.%s", ticketID)) + if err != nil { + return nil, coreerr.E("dispatch.Execute", "env path", err) + } if err := h.secureTransfer(ctx, agent, remoteEnvPath, []byte(envContent), 0600); err != nil { // Clean up the ticket if env transfer fails. - _ = h.runRemote(ctx, agent, fmt.Sprintf("rm -f %s", agentci.EscapeShellArg(remoteTicketPath))) + _ = h.runRemote(ctx, agent, "rm", "-f", remoteTicketPath) h.failDispatch(signal, fmt.Sprintf("Token transfer failed: %v", err)) return &jobrunner.ActionResult{ Action: "dispatch", @@ -255,8 +266,8 @@ func (h *DispatchHandler) failDispatch(signal *jobrunner.PipelineSignal, reason // secureTransfer writes data to a remote path via SSH stdin, preventing command injection. func (h *DispatchHandler) secureTransfer(ctx context.Context, agent agentci.AgentConfig, remotePath string, data []byte, mode int) error { - safeRemotePath := agentci.EscapeShellArg(remotePath) - remoteCmd := fmt.Sprintf("cat > %s && chmod %o %s", safeRemotePath, mode, safeRemotePath) + safePath := agentci.EscapeShellArg(remotePath) + remoteCmd := fmt.Sprintf("cat > %s && chmod %o %s", safePath, mode, safePath) cmd := agentci.SecureSSHCommand(agent.Host, remoteCmd) cmd.Stdin = bytes.NewReader(data) @@ -269,21 +280,55 @@ func (h *DispatchHandler) secureTransfer(ctx context.Context, agent agentci.Agen } // runRemote executes a command on the agent via SSH. -func (h *DispatchHandler) runRemote(ctx context.Context, agent agentci.AgentConfig, cmdStr string) error { - cmd := agentci.SecureSSHCommand(agent.Host, cmdStr) +func (h *DispatchHandler) runRemote(ctx context.Context, agent agentci.AgentConfig, command string, args ...string) error { + remoteCmd := command + if len(args) > 0 { + escaped := make([]string, 0, 1+len(args)) + escaped = append(escaped, command) + for _, arg := range args { + escaped = append(escaped, agentci.EscapeShellArg(arg)) + } + remoteCmd = strings.Join(escaped, " ") + } + + cmd := agentci.SecureSSHCommand(agent.Host, remoteCmd) return cmd.Run() } // ticketExists checks if a ticket file already exists in queue, active, or done. func (h *DispatchHandler) ticketExists(ctx context.Context, agent agentci.AgentConfig, ticketName string) bool { - safeTicket, err := agentci.SanitizePath(ticketName) + queueDir, err := agentci.ValidateRemoteDir(agent.QueueDir) if err != nil { return false } - qDir := agent.QueueDir + safeTicket, err := agentci.ValidatePathElement(ticketName) + if err != nil { + return false + } + + queuePath, err := agentci.JoinRemotePath(queueDir, safeTicket) + if err != nil { + return false + } + parentDir := queueDir + if queueDir != "/" && queueDir != "~" { + parentDir = path.Dir(queueDir) + } + activePath, err := agentci.JoinRemotePath(parentDir, "active", safeTicket) + if err != nil { + return false + } + donePath, err := agentci.JoinRemotePath(parentDir, "done", safeTicket) + if err != nil { + return false + } + + queuePath = agentci.EscapeShellArg(queuePath) + activePath = agentci.EscapeShellArg(activePath) + donePath = agentci.EscapeShellArg(donePath) checkCmd := fmt.Sprintf( - "test -f %s/%s || test -f %s/../active/%s || test -f %s/../done/%s", - qDir, safeTicket, qDir, safeTicket, qDir, safeTicket, + "test -f %s || test -f %s || test -f %s", + queuePath, activePath, donePath, ) cmd := agentci.SecureSSHCommand(agent.Host, checkCmd) return cmd.Run() == nil diff --git a/jobrunner/handlers/dispatch_test.go b/jobrunner/handlers/dispatch_test.go index f981207..0f733b3 100644 --- a/jobrunner/handlers/dispatch_test.go +++ b/jobrunner/handlers/dispatch_test.go @@ -5,6 +5,9 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" + "strconv" "testing" "dappco.re/go/core/scm/agentci" @@ -13,6 +16,18 @@ import ( "github.com/stretchr/testify/require" ) +func writeFakeSSHCommand(t *testing.T, outputPath string) string { + t.Helper() + dir := t.TempDir() + script := filepath.Join(dir, "ssh") + scriptContent := "#!/bin/sh\n" + + "OUT=" + strconv.Quote(outputPath) + "\n" + + "printf '%s\n' \"$@\" >> \"$OUT\"\n" + + "cat >> \"${OUT}.stdin\"\n" + require.NoError(t, os.WriteFile(script, []byte(scriptContent), 0o755)) + return dir +} + // newTestSpinner creates a Spinner with the given agents for testing. func newTestSpinner(agents map[string]agentci.AgentConfig) *agentci.Spinner { return agentci.NewSpinner(agentci.ClothoConfig{Strategy: "direct"}, agents) @@ -127,6 +142,29 @@ func TestDispatch_Execute_Bad_UnknownAgent(t *testing.T) { assert.Contains(t, err.Error(), "unknown agent") } +func TestDispatch_Execute_Bad_InvalidQueueDir(t *testing.T) { + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": { + Host: "localhost", + QueueDir: "/tmp/queue; touch /tmp/pwned", + Active: true, + }, + }) + h := NewDispatchHandler(nil, "", "", spinner) + + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "darbs-claude", + RepoOwner: "host-uk", + RepoName: "core", + ChildNumber: 1, + } + + _, err := h.Execute(context.Background(), sig) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid agent queue dir") +} + func TestDispatch_TicketJSON_Good(t *testing.T) { ticket := DispatchTicket{ ID: "host-uk-core-5-1234567890", @@ -214,6 +252,53 @@ func TestDispatch_TicketJSON_Good_OmitsEmptyModelRunner(t *testing.T) { assert.False(t, hasRunner, "runner should be omitted when empty") } +func TestDispatch_runRemote_Good_EscapesPath(t *testing.T) { + outputPath := filepath.Join(t.TempDir(), "ssh-output.txt") + toolPath := writeFakeSSHCommand(t, outputPath) + t.Setenv("PATH", toolPath+":"+os.Getenv("PATH")) + + h := NewDispatchHandler(nil, "", "", newTestSpinner(nil)) + dangerousPath := "/tmp/queue with spaces; touch /tmp/pwned" + err := h.runRemote( + context.Background(), + agentci.AgentConfig{Host: "localhost"}, + "rm", + "-f", + dangerousPath, + ) + require.NoError(t, err) + + output, err := os.ReadFile(outputPath) + require.NoError(t, err) + assert.Contains(t, string(output), "rm '-f' '"+dangerousPath+"'\n") +} + +func TestDispatch_secureTransfer_Good_EscapesPath(t *testing.T) { + outputPath := filepath.Join(t.TempDir(), "ssh-output.txt") + toolPath := writeFakeSSHCommand(t, outputPath) + t.Setenv("PATH", toolPath+":"+os.Getenv("PATH")) + + h := NewDispatchHandler(nil, "", "", newTestSpinner(nil)) + dangerousPath := "/tmp/queue with spaces; touch /tmp/pwned" + err := h.secureTransfer( + context.Background(), + agentci.AgentConfig{Host: "localhost"}, + dangerousPath, + []byte("hello"), + 0644, + ) + require.NoError(t, err) + + output, err := os.ReadFile(outputPath) + require.NoError(t, err) + assert.Contains(t, string(output), "cat > '"+dangerousPath+"' && chmod 644 '"+dangerousPath+"'") + + inputPath := outputPath + ".stdin" + input, err := os.ReadFile(inputPath) + require.NoError(t, err) + assert.Equal(t, "hello", string(input)) +} + func TestDispatch_TicketJSON_Good_ModelRunnerVariants(t *testing.T) { tests := []struct { name string diff --git a/marketplace/installer.go b/marketplace/installer.go index 50a9686..e338ce4 100644 --- a/marketplace/installer.go +++ b/marketplace/installer.go @@ -9,10 +9,11 @@ import ( "strings" "time" - coreerr "dappco.re/go/core/log" "dappco.re/go/core/io" - "dappco.re/go/core/scm/manifest" "dappco.re/go/core/io/store" + coreerr "dappco.re/go/core/log" + "dappco.re/go/core/scm/agentci" + "dappco.re/go/core/scm/manifest" ) const storeGroup = "_modules" @@ -47,12 +48,16 @@ type InstalledModule struct { // Install clones a module repo, verifies its manifest signature, and registers it. func (i *Installer) Install(ctx context.Context, mod Module) error { - // Check if already installed - if _, err := i.store.Get(storeGroup, mod.Code); err == nil { - return coreerr.E("marketplace.Installer.Install", "module already installed: "+mod.Code, nil) + safeCode, dest, err := i.resolveModulePath(mod.Code) + if err != nil { + return coreerr.E("marketplace.Installer.Install", "invalid module code", err) + } + + // Check if already installed + if _, err := i.store.Get(storeGroup, safeCode); err == nil { + return coreerr.E("marketplace.Installer.Install", "module already installed: "+safeCode, nil) } - dest := filepath.Join(i.modulesDir, mod.Code) if err := i.medium.EnsureDir(i.modulesDir); err != nil { return coreerr.E("marketplace.Installer.Install", "mkdir", err) } @@ -80,7 +85,7 @@ func (i *Installer) Install(ctx context.Context, mod Module) error { entryPoint := filepath.Join(dest, "main.ts") installed := InstalledModule{ - Code: mod.Code, + Code: safeCode, Name: m.Name, Version: m.Version, Repo: mod.Repo, @@ -95,7 +100,7 @@ func (i *Installer) Install(ctx context.Context, mod Module) error { return coreerr.E("marketplace.Installer.Install", "marshal", err) } - if err := i.store.Set(storeGroup, mod.Code, string(data)); err != nil { + if err := i.store.Set(storeGroup, safeCode, string(data)); err != nil { return coreerr.E("marketplace.Installer.Install", "store", err) } @@ -105,21 +110,32 @@ func (i *Installer) Install(ctx context.Context, mod Module) error { // Remove uninstalls a module by deleting its files and store entry. func (i *Installer) Remove(code string) error { - if _, err := i.store.Get(storeGroup, code); err != nil { - return coreerr.E("marketplace.Installer.Remove", "module not installed: "+code, nil) + safeCode, dest, err := i.resolveModulePath(code) + if err != nil { + return coreerr.E("marketplace.Installer.Remove", "invalid module code", err) } - dest := filepath.Join(i.modulesDir, code) - _ = i.medium.DeleteAll(dest) + if _, err := i.store.Get(storeGroup, safeCode); err != nil { + return coreerr.E("marketplace.Installer.Remove", "module not installed: "+safeCode, nil) + } - return i.store.Delete(storeGroup, code) + if err := i.medium.DeleteAll(dest); err != nil { + return coreerr.E("marketplace.Installer.Remove", "delete module files", err) + } + + return i.store.Delete(storeGroup, safeCode) } // Update pulls latest changes and re-verifies the manifest. func (i *Installer) Update(ctx context.Context, code string) error { - raw, err := i.store.Get(storeGroup, code) + safeCode, dest, err := i.resolveModulePath(code) if err != nil { - return coreerr.E("marketplace.Installer.Update", "module not installed: "+code, nil) + return coreerr.E("marketplace.Installer.Update", "invalid module code", err) + } + + raw, err := i.store.Get(storeGroup, safeCode) + if err != nil { + return coreerr.E("marketplace.Installer.Update", "module not installed: "+safeCode, nil) } var installed InstalledModule @@ -127,8 +143,6 @@ func (i *Installer) Update(ctx context.Context, code string) error { return coreerr.E("marketplace.Installer.Update", "unmarshal", err) } - dest := filepath.Join(i.modulesDir, code) - cmd := exec.CommandContext(ctx, "git", "-C", dest, "pull", "--ff-only") if output, err := cmd.CombinedOutput(); err != nil { return coreerr.E("marketplace.Installer.Update", "pull: "+strings.TrimSpace(string(output)), err) @@ -145,6 +159,7 @@ func (i *Installer) Update(ctx context.Context, code string) error { } // Update stored metadata + installed.Code = safeCode installed.Name = m.Name installed.Version = m.Version installed.Permissions = m.Permissions @@ -154,7 +169,7 @@ func (i *Installer) Update(ctx context.Context, code string) error { return coreerr.E("marketplace.Installer.Update", "marshal", err) } - return i.store.Set(storeGroup, code, string(data)) + return i.store.Set(storeGroup, safeCode, string(data)) } // Installed returns all installed module metadata. @@ -195,3 +210,11 @@ func gitClone(ctx context.Context, repo, dest string) error { } return nil } + +func (i *Installer) resolveModulePath(code string) (string, string, error) { + safeCode, dest, err := agentci.ResolvePathWithinRoot(i.modulesDir, code) + if err != nil { + return "", "", coreerr.E("marketplace.Installer.resolveModulePath", "resolve module path", err) + } + return safeCode, dest, nil +} diff --git a/marketplace/installer_test.go b/marketplace/installer_test.go index 358e69a..ee992fa 100644 --- a/marketplace/installer_test.go +++ b/marketplace/installer_test.go @@ -10,8 +10,8 @@ import ( "testing" "dappco.re/go/core/io" - "dappco.re/go/core/scm/manifest" "dappco.re/go/core/io/store" + "dappco.re/go/core/scm/manifest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -163,6 +163,29 @@ func TestInstall_Bad_InvalidSignature(t *testing.T) { assert.True(t, os.IsNotExist(statErr), "directory should be cleaned up on failure") } +func TestInstall_Bad_PathTraversalCode(t *testing.T) { + repo := createTestRepo(t, "safe-mod", "1.0") + modulesDir := filepath.Join(t.TempDir(), "modules") + + st, err := store.New(":memory:") + require.NoError(t, err) + defer st.Close() + + inst := NewInstaller(io.Local, modulesDir, st) + err = inst.Install(context.Background(), Module{ + Code: "../escape", + Repo: repo, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid module code") + + _, err = st.Get("_modules", "escape") + assert.Error(t, err) + + _, err = os.Stat(filepath.Join(filepath.Dir(modulesDir), "escape")) + assert.True(t, os.IsNotExist(err)) +} + func TestRemove_Good(t *testing.T) { repo := createTestRepo(t, "rm-mod", "1.0") modulesDir := filepath.Join(t.TempDir(), "modules") @@ -197,6 +220,26 @@ func TestRemove_Bad_NotInstalled(t *testing.T) { assert.Contains(t, err.Error(), "not installed") } +func TestRemove_Bad_PathTraversalCode(t *testing.T) { + baseDir := t.TempDir() + modulesDir := filepath.Join(baseDir, "modules") + escapeDir := filepath.Join(baseDir, "escape") + require.NoError(t, os.MkdirAll(escapeDir, 0755)) + + st, err := store.New(":memory:") + require.NoError(t, err) + defer st.Close() + + inst := NewInstaller(io.Local, modulesDir, st) + err = inst.Remove("../escape") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid module code") + + info, statErr := os.Stat(escapeDir) + require.NoError(t, statErr) + assert.True(t, info.IsDir()) +} + func TestInstalled_Good(t *testing.T) { modulesDir := filepath.Join(t.TempDir(), "modules") @@ -262,3 +305,16 @@ func TestUpdate_Good(t *testing.T) { assert.Equal(t, "2.0", installed[0].Version) assert.Equal(t, "Updated Module", installed[0].Name) } + +func TestUpdate_Bad_PathTraversalCode(t *testing.T) { + modulesDir := filepath.Join(t.TempDir(), "modules") + + st, err := store.New(":memory:") + require.NoError(t, err) + defer st.Close() + + inst := NewInstaller(io.Local, modulesDir, st) + err = inst.Update(context.Background(), "../escape") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid module code") +} diff --git a/pkg/api/provider.go b/pkg/api/provider.go index 77475c9..80641e5 100644 --- a/pkg/api/provider.go +++ b/pkg/api/provider.go @@ -10,10 +10,12 @@ import ( "crypto/ed25519" "encoding/hex" "net/http" + "net/url" "dappco.re/go/core/api" "dappco.re/go/core/api/pkg/provider" "dappco.re/go/core/io" + "dappco.re/go/core/scm/agentci" "dappco.re/go/core/scm/manifest" "dappco.re/go/core/scm/marketplace" "dappco.re/go/core/scm/repos" @@ -228,7 +230,10 @@ func (p *ScmProvider) getMarketplaceItem(c *gin.Context) { return } - code := c.Param("code") + code, ok := marketplaceCodeParam(c) + if !ok { + return + } mod, ok := p.index.Find(code) if !ok { c.JSON(http.StatusNotFound, api.Fail("not_found", "provider not found in marketplace")) @@ -243,7 +248,10 @@ func (p *ScmProvider) installItem(c *gin.Context) { return } - code := c.Param("code") + code, ok := marketplaceCodeParam(c) + if !ok { + return + } mod, ok := p.index.Find(code) if !ok { c.JSON(http.StatusNotFound, api.Fail("not_found", "provider not found in marketplace")) @@ -269,7 +277,10 @@ func (p *ScmProvider) removeItem(c *gin.Context) { return } - code := c.Param("code") + code, ok := marketplaceCodeParam(c) + if !ok { + return + } if err := p.installer.Remove(code); err != nil { c.JSON(http.StatusInternalServerError, api.Fail("remove_failed", err.Error())) return @@ -393,7 +404,10 @@ func (p *ScmProvider) updateInstalled(c *gin.Context) { return } - code := c.Param("code") + code, ok := marketplaceCodeParam(c) + if !ok { + return + } if err := p.installer.Update(context.Background(), code); err != nil { c.JSON(http.StatusInternalServerError, api.Fail("update_failed", err.Error())) return @@ -448,3 +462,21 @@ func (p *ScmProvider) emitEvent(channel string, data any) { Data: data, }) } + +func marketplaceCodeParam(c *gin.Context) (string, bool) { + code, err := normaliseMarketplaceCode(c.Param("code")) + if err != nil { + c.JSON(http.StatusBadRequest, api.Fail("invalid_code", "invalid marketplace code")) + return "", false + } + return code, true +} + +func normaliseMarketplaceCode(raw string) (string, error) { + decoded, err := url.PathUnescape(raw) + if err != nil { + return "", err + } + + return agentci.ValidatePathElement(decoded) +} diff --git a/pkg/api/provider_security_test.go b/pkg/api/provider_security_test.go new file mode 100644 index 0000000..066293b --- /dev/null +++ b/pkg/api/provider_security_test.go @@ -0,0 +1,26 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNormaliseMarketplaceCode_Good(t *testing.T) { + code, err := normaliseMarketplaceCode("analytics") + require.NoError(t, err) + assert.Equal(t, "analytics", code) +} + +func TestNormaliseMarketplaceCode_Bad(t *testing.T) { + _, err := normaliseMarketplaceCode("analytics;rm") + assert.Error(t, err) +} + +func TestNormaliseMarketplaceCode_Bad_EncodedTraversal(t *testing.T) { + _, err := normaliseMarketplaceCode("analytics%2f..%2Fescape") + assert.Error(t, err) +} diff --git a/pkg/api/provider_test.go b/pkg/api/provider_test.go index 7e72509..0674f8a 100644 --- a/pkg/api/provider_test.go +++ b/pkg/api/provider_test.go @@ -164,6 +164,18 @@ func TestScmProvider_GetMarketplaceItem_Bad(t *testing.T) { assert.Equal(t, http.StatusNotFound, w.Code) } +func TestScmProvider_GetMarketplaceItem_Bad_PathTraversal(t *testing.T) { + idx := &marketplace.Index{Version: 1} + p := scmapi.NewProvider(idx, nil, nil, nil) + + r := setupRouter(p) + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/api/v1/scm/marketplace/%2e%2e", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + // -- Installed Endpoints ------------------------------------------------------ func TestScmProvider_ListInstalled_NilInstaller_Good(t *testing.T) { diff --git a/plugin/installer.go b/plugin/installer.go index d98c59c..0be3233 100644 --- a/plugin/installer.go +++ b/plugin/installer.go @@ -3,13 +3,15 @@ package plugin import ( "context" "fmt" + "net/url" "os/exec" "path/filepath" "strings" "time" - coreerr "dappco.re/go/core/log" "dappco.re/go/core/io" + coreerr "dappco.re/go/core/log" + "dappco.re/go/core/scm/agentci" ) // Installer handles plugin installation from GitHub. @@ -40,7 +42,10 @@ func (i *Installer) Install(ctx context.Context, source string) error { } // Clone the repository - pluginDir := filepath.Join(i.registry.basePath, repo) + _, pluginDir, err := i.resolvePluginPath(repo) + if err != nil { + return coreerr.E("plugin.Installer.Install", "invalid plugin path", err) + } if err := i.medium.EnsureDir(pluginDir); err != nil { return coreerr.E("plugin.Installer.Install", "failed to create plugin directory", err) } @@ -90,14 +95,15 @@ func (i *Installer) Install(ctx context.Context, source string) error { // Update updates a plugin to the latest version. func (i *Installer) Update(ctx context.Context, name string) error { - cfg, ok := i.registry.Get(name) - if !ok { - return coreerr.E("plugin.Installer.Update", "plugin not found: "+name, nil) + safeName, pluginDir, err := i.resolvePluginPath(name) + if err != nil { + return coreerr.E("plugin.Installer.Update", "invalid plugin name", err) } - // Parse the source to get org/repo - source := strings.TrimPrefix(cfg.Source, "github:") - pluginDir := filepath.Join(i.registry.basePath, name) + cfg, ok := i.registry.Get(safeName) + if !ok { + return coreerr.E("plugin.Installer.Update", "plugin not found: "+safeName, nil) + } // Pull latest changes cmd := exec.CommandContext(ctx, "git", "-C", pluginDir, "pull", "--ff-only") @@ -118,18 +124,21 @@ func (i *Installer) Update(ctx context.Context, name string) error { return coreerr.E("plugin.Installer.Update", "failed to save registry", err) } - _ = source // used for context return nil } // Remove uninstalls a plugin by removing its files and registry entry. func (i *Installer) Remove(name string) error { - if _, ok := i.registry.Get(name); !ok { - return coreerr.E("plugin.Installer.Remove", "plugin not found: "+name, nil) + safeName, pluginDir, err := i.resolvePluginPath(name) + if err != nil { + return coreerr.E("plugin.Installer.Remove", "invalid plugin name", err) + } + + if _, ok := i.registry.Get(safeName); !ok { + return coreerr.E("plugin.Installer.Remove", "plugin not found: "+safeName, nil) } // Delete plugin directory - pluginDir := filepath.Join(i.registry.basePath, name) if i.medium.Exists(pluginDir) { if err := i.medium.DeleteAll(pluginDir); err != nil { return coreerr.E("plugin.Installer.Remove", "failed to delete plugin files", err) @@ -137,7 +146,7 @@ func (i *Installer) Remove(name string) error { } // Remove from registry - if err := i.registry.Remove(name); err != nil { + if err := i.registry.Remove(safeName); err != nil { return coreerr.E("plugin.Installer.Remove", "failed to unregister plugin", err) } @@ -170,6 +179,10 @@ func (i *Installer) cloneRepo(ctx context.Context, org, repo, version, dest stri // - "org/repo" -> org="org", repo="repo", version="" // - "org/repo@v1.0" -> org="org", repo="repo", version="v1.0" func ParseSource(source string) (org, repo, version string, err error) { + source, err = url.PathUnescape(source) + if err != nil { + return "", "", "", coreerr.E("plugin.ParseSource", "invalid source path", err) + } if source == "" { return "", "", "", coreerr.E("plugin.ParseSource", "source is empty", nil) } @@ -191,5 +204,22 @@ func ParseSource(source string) (org, repo, version string, err error) { return "", "", "", coreerr.E("plugin.ParseSource", "source must be in format org/repo[@version]", nil) } - return parts[0], parts[1], version, nil + org, err = agentci.ValidatePathElement(parts[0]) + if err != nil { + return "", "", "", coreerr.E("plugin.ParseSource", "invalid org", err) + } + repo, err = agentci.ValidatePathElement(parts[1]) + if err != nil { + return "", "", "", coreerr.E("plugin.ParseSource", "invalid repo", err) + } + + return org, repo, version, nil +} + +func (i *Installer) resolvePluginPath(name string) (string, string, error) { + safeName, path, err := agentci.ResolvePathWithinRoot(i.registry.basePath, name) + if err != nil { + return "", "", coreerr.E("plugin.Installer.resolvePluginPath", "resolve plugin path", err) + } + return safeName, path, nil } diff --git a/plugin/installer_test.go b/plugin/installer_test.go index 4b57611..bd87b74 100644 --- a/plugin/installer_test.go +++ b/plugin/installer_test.go @@ -44,6 +44,17 @@ func TestInstall_Bad_AlreadyInstalled(t *testing.T) { assert.Contains(t, err.Error(), "already installed") } +func TestInstall_Bad_PathTraversalSource(t *testing.T) { + m := io.NewMockMedium() + reg := NewRegistry(m, "/plugins") + inst := NewInstaller(m, reg) + + err := inst.Install(context.Background(), "../repo") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid source") + assert.False(t, m.Exists("/repo")) +} + // ── Remove ───────────────────────────────────────────────────────── func TestRemove_Good(t *testing.T) { @@ -91,6 +102,19 @@ func TestRemove_Bad_NotFound(t *testing.T) { assert.Contains(t, err.Error(), "plugin not found") } +func TestRemove_Bad_PathTraversalName(t *testing.T) { + m := io.NewMockMedium() + reg := NewRegistry(m, "/plugins") + _ = reg.Add(&PluginConfig{Name: "safe", Version: "1.0.0"}) + _ = m.EnsureDir("/escape") + + inst := NewInstaller(m, reg) + err := inst.Remove("../escape") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid plugin name") + assert.True(t, m.Exists("/escape")) +} + // ── Update error paths ───────────────────────────────────────────── func TestUpdate_Bad_NotFound(t *testing.T) { @@ -164,3 +188,25 @@ func TestParseSource_Bad_EmptyVersion(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "version is empty") } + +func TestParseSource_Bad_PathTraversal(t *testing.T) { + _, _, _, err := ParseSource("org/../repo") + assert.Error(t, err) + assert.Contains(t, err.Error(), "org/repo") +} + +func TestParseSource_Bad_PathTraversalEncoded(t *testing.T) { + _, _, _, err := ParseSource("org%2f..%2frepo") + assert.Error(t, err) + assert.Contains(t, err.Error(), "org/repo") +} + +func TestInstall_Bad_EncodedPathTraversal(t *testing.T) { + m := io.NewMockMedium() + reg := NewRegistry(m, "/plugins") + inst := NewInstaller(m, reg) + + err := inst.Install(context.Background(), "org%2f..%2frepo") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid source") +} -- 2.45.3