Merge pull request '[agent/codex:gpt-5.3-codex-spark] Fix ALL security findings from issue #6. Read CLAUDE.md. Com...' (#10) from agent/fix-all-security-findings-in-issue--6--r into dev
This commit is contained in:
commit
7bdbd1301f
15 changed files with 669 additions and 72 deletions
|
|
@ -2,16 +2,18 @@ package forge
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"dappco.re/go/core/scm/agentci"
|
||||
fg "dappco.re/go/core/scm/forge"
|
||||
|
||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
)
|
||||
|
||||
// Sync command flags.
|
||||
|
|
@ -95,11 +97,14 @@ func buildSyncRepoList(client *fg.Client, args []string, basePath string) ([]syn
|
|||
|
||||
if len(args) > 0 {
|
||||
for _, arg := range args {
|
||||
name := arg
|
||||
if parts := strings.SplitN(arg, "/", 2); len(parts) == 2 {
|
||||
name = parts[1]
|
||||
name, err := syncRepoNameFromArg(arg)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("forge.buildSyncRepoList", "invalid repo argument", err)
|
||||
}
|
||||
_, localPath, err := agentci.ResolvePathWithinRoot(basePath, name)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("forge.buildSyncRepoList", "resolve local path", err)
|
||||
}
|
||||
localPath := filepath.Join(basePath, name)
|
||||
branch := syncDetectDefaultBranch(localPath)
|
||||
repos = append(repos, syncRepoEntry{
|
||||
name: name,
|
||||
|
|
@ -113,10 +118,17 @@ func buildSyncRepoList(client *fg.Client, args []string, basePath string) ([]syn
|
|||
return nil, err
|
||||
}
|
||||
for _, r := range orgRepos {
|
||||
localPath := filepath.Join(basePath, r.Name)
|
||||
name, err := agentci.ValidatePathElement(r.Name)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("forge.buildSyncRepoList", "invalid repo name from org list", err)
|
||||
}
|
||||
_, localPath, err := agentci.ResolvePathWithinRoot(basePath, name)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("forge.buildSyncRepoList", "resolve local path", err)
|
||||
}
|
||||
branch := syncDetectDefaultBranch(localPath)
|
||||
repos = append(repos, syncRepoEntry{
|
||||
name: r.Name,
|
||||
name: name,
|
||||
localPath: localPath,
|
||||
defaultBranch: branch,
|
||||
})
|
||||
|
|
@ -333,3 +345,27 @@ func syncCreateMainFromUpstream(client *fg.Client, org, repo string) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncRepoNameFromArg(arg string) (string, error) {
|
||||
decoded, err := url.PathUnescape(arg)
|
||||
if err != nil {
|
||||
return "", coreerr.E("forge.syncRepoNameFromArg", "decode repo argument", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(decoded, "/")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
return agentci.ValidatePathElement(parts[0])
|
||||
case 2:
|
||||
if _, err := agentci.ValidatePathElement(parts[0]); err != nil {
|
||||
return "", coreerr.E("forge.syncRepoNameFromArg", "invalid repo owner", err)
|
||||
}
|
||||
name, err := agentci.ValidatePathElement(parts[1])
|
||||
if err != nil {
|
||||
return "", coreerr.E("forge.syncRepoNameFromArg", "invalid repo name", err)
|
||||
}
|
||||
return name, nil
|
||||
default:
|
||||
return "", coreerr.E("forge.syncRepoNameFromArg", "repo argument must be repo or owner/repo", nil)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
53
cmd/forge/cmd_sync_test.go
Normal file
53
cmd/forge/cmd_sync_test.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
package forge
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildSyncRepoList_Good(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
repos, err := buildSyncRepoList(nil, []string{"host-uk/core"}, basePath)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, repos, 1)
|
||||
assert.Equal(t, "core", repos[0].name)
|
||||
assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath)
|
||||
}
|
||||
|
||||
func TestBuildSyncRepoList_Bad_PathTraversal(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
_, err := buildSyncRepoList(nil, []string{"../escape"}, basePath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid repo argument")
|
||||
}
|
||||
|
||||
func TestBuildSyncRepoList_Good_OwnerRepo(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
repos, err := buildSyncRepoList(nil, []string{"Host-UK/core"}, basePath)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, repos, 1)
|
||||
assert.Equal(t, "core", repos[0].name)
|
||||
assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath)
|
||||
}
|
||||
|
||||
func TestBuildSyncRepoList_Bad_PathTraversal_OwnerRepo(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
_, err := buildSyncRepoList(nil, []string{"host-uk/../escape"}, basePath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid repo argument")
|
||||
}
|
||||
|
||||
func TestBuildSyncRepoList_Bad_PathTraversal_OwnerRepoEncoded(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
_, err := buildSyncRepoList(nil, []string{"host-uk%2F..%2Fescape"}, basePath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid repo argument")
|
||||
}
|
||||
|
|
@ -2,16 +2,18 @@ package gitea
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/sdk/gitea"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"dappco.re/go/core/scm/agentci"
|
||||
gt "dappco.re/go/core/scm/gitea"
|
||||
|
||||
"code.gitea.io/sdk/gitea"
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
)
|
||||
|
||||
// Sync command flags.
|
||||
|
|
@ -96,12 +98,14 @@ func buildRepoList(client *gt.Client, args []string, basePath string) ([]repoEnt
|
|||
if len(args) > 0 {
|
||||
// Specific repos from args
|
||||
for _, arg := range args {
|
||||
name := arg
|
||||
// Strip owner/ prefix if given
|
||||
if parts := strings.SplitN(arg, "/", 2); len(parts) == 2 {
|
||||
name = parts[1]
|
||||
name, err := repoNameFromArg(arg)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("gitea.buildRepoList", "invalid repo argument", err)
|
||||
}
|
||||
_, localPath, err := agentci.ResolvePathWithinRoot(basePath, name)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("gitea.buildRepoList", "resolve local path", err)
|
||||
}
|
||||
localPath := filepath.Join(basePath, name)
|
||||
branch := detectDefaultBranch(localPath)
|
||||
repos = append(repos, repoEntry{
|
||||
name: name,
|
||||
|
|
@ -116,10 +120,17 @@ func buildRepoList(client *gt.Client, args []string, basePath string) ([]repoEnt
|
|||
return nil, err
|
||||
}
|
||||
for _, r := range orgRepos {
|
||||
localPath := filepath.Join(basePath, r.Name)
|
||||
name, err := agentci.ValidatePathElement(r.Name)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("gitea.buildRepoList", "invalid repo name from org list", err)
|
||||
}
|
||||
_, localPath, err := agentci.ResolvePathWithinRoot(basePath, name)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("gitea.buildRepoList", "resolve local path", err)
|
||||
}
|
||||
branch := detectDefaultBranch(localPath)
|
||||
repos = append(repos, repoEntry{
|
||||
name: r.Name,
|
||||
name: name,
|
||||
localPath: localPath,
|
||||
defaultBranch: branch,
|
||||
})
|
||||
|
|
@ -352,3 +363,27 @@ func createMainFromUpstream(client *gt.Client, org, repo string) error {
|
|||
}
|
||||
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
func repoNameFromArg(arg string) (string, error) {
|
||||
decoded, err := url.PathUnescape(arg)
|
||||
if err != nil {
|
||||
return "", coreerr.E("gitea.repoNameFromArg", "decode repo argument", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(decoded, "/")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
return agentci.ValidatePathElement(parts[0])
|
||||
case 2:
|
||||
if _, err := agentci.ValidatePathElement(parts[0]); err != nil {
|
||||
return "", coreerr.E("gitea.repoNameFromArg", "invalid repo owner", err)
|
||||
}
|
||||
name, err := agentci.ValidatePathElement(parts[1])
|
||||
if err != nil {
|
||||
return "", coreerr.E("gitea.repoNameFromArg", "invalid repo name", err)
|
||||
}
|
||||
return name, nil
|
||||
default:
|
||||
return "", coreerr.E("gitea.repoNameFromArg", "repo argument must be repo or owner/repo", nil)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
53
cmd/gitea/cmd_sync_test.go
Normal file
53
cmd/gitea/cmd_sync_test.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
package gitea
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildRepoList_Good(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
repos, err := buildRepoList(nil, []string{"host-uk/core"}, basePath)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, repos, 1)
|
||||
assert.Equal(t, "core", repos[0].name)
|
||||
assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath)
|
||||
}
|
||||
|
||||
func TestBuildRepoList_Bad_PathTraversal(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
_, err := buildRepoList(nil, []string{"../escape"}, basePath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid repo argument")
|
||||
}
|
||||
|
||||
func TestBuildRepoList_Good_OwnerRepo(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
repos, err := buildRepoList(nil, []string{"Host-UK/core"}, basePath)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, repos, 1)
|
||||
assert.Equal(t, "core", repos[0].name)
|
||||
assert.Equal(t, filepath.Join(basePath, "core"), repos[0].localPath)
|
||||
}
|
||||
|
||||
func TestBuildRepoList_Bad_PathTraversal_OwnerRepo(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
_, err := buildRepoList(nil, []string{"host-uk/../escape"}, basePath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid repo argument")
|
||||
}
|
||||
|
||||
func TestBuildRepoList_Bad_PathTraversal_OwnerRepoEncoded(t *testing.T) {
|
||||
basePath := filepath.Join(t.TempDir(), "repos")
|
||||
|
||||
_, err := buildRepoList(nil, []string{"host-uk%2F..%2Fescape"}, basePath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid repo argument")
|
||||
}
|
||||
24
forge/prs.go
24
forge/prs.go
|
|
@ -5,10 +5,13 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"dappco.re/go/core/log"
|
||||
"dappco.re/go/core/scm/agentci"
|
||||
|
||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||
)
|
||||
|
||||
// MergePullRequest merges a pull request with the given method ("squash", "rebase", "merge").
|
||||
|
|
@ -38,14 +41,27 @@ func (c *Client) MergePullRequest(owner, repo string, index int64, method string
|
|||
// The Forgejo SDK v2.2.0 doesn't expose the draft field on EditPullRequestOption,
|
||||
// so we use a raw HTTP PATCH request.
|
||||
func (c *Client) SetPRDraft(owner, repo string, index int64, draft bool) error {
|
||||
safeOwner, err := agentci.ValidatePathElement(owner)
|
||||
if err != nil {
|
||||
return log.E("forge.SetPRDraft", "invalid owner", err)
|
||||
}
|
||||
safeRepo, err := agentci.ValidatePathElement(repo)
|
||||
if err != nil {
|
||||
return log.E("forge.SetPRDraft", "invalid repo", err)
|
||||
}
|
||||
|
||||
payload := map[string]bool{"draft": draft}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return log.E("forge.SetPRDraft", "marshal payload", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d", c.url, owner, repo, index)
|
||||
req, err := http.NewRequest(http.MethodPatch, url, bytes.NewReader(body))
|
||||
path, err := url.JoinPath(c.url, "api", "v1", "repos", safeOwner, safeRepo, "pulls", strconv.FormatInt(index, 10))
|
||||
if err != nil {
|
||||
return log.E("forge.SetPRDraft", "failed to build request path", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPatch, path, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return log.E("forge.SetPRDraft", "create request", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
package forge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
|
@ -98,3 +101,49 @@ func TestClient_DismissReview_Bad_ServerError(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to dismiss review")
|
||||
}
|
||||
|
||||
func TestClient_SetPRDraft_Good_Request(t *testing.T) {
|
||||
var method, path string
|
||||
var payload map[string]any
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/v1/version", func(w http.ResponseWriter, r *http.Request) {
|
||||
jsonResponse(w, map[string]string{"version": "1.21.0"})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/repos/test-org/org-repo/pulls/3", func(w http.ResponseWriter, r *http.Request) {
|
||||
method = r.Method
|
||||
path = r.URL.Path
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&payload))
|
||||
jsonResponse(w, map[string]any{"number": 3})
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
client, err := New(srv.URL, "test-token")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.SetPRDraft("test-org", "org-repo", 3, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.MethodPatch, method)
|
||||
assert.Equal(t, "/api/v1/repos/test-org/org-repo/pulls/3", path)
|
||||
assert.Equal(t, false, payload["draft"])
|
||||
}
|
||||
|
||||
func TestClient_SetPRDraft_Bad_PathTraversalOwner(t *testing.T) {
|
||||
client, srv := newTestClient(t)
|
||||
defer srv.Close()
|
||||
|
||||
err := client.SetPRDraft("../owner", "org-repo", 3, true)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid owner")
|
||||
}
|
||||
|
||||
func TestClient_SetPRDraft_Bad_PathTraversalRepo(t *testing.T) {
|
||||
client, srv := newTestClient(t)
|
||||
defer srv.Close()
|
||||
|
||||
err := client.SetPRDraft("test-org", "..", 3, true)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid repo")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
coreerr "dappco.re/go/core/log"
|
||||
|
|
@ -85,6 +86,10 @@ func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.Pipelin
|
|||
if !ok {
|
||||
return nil, coreerr.E("dispatch.Execute", "unknown agent: "+signal.Assignee, nil)
|
||||
}
|
||||
queueDir, err := agentci.ValidateRemoteDir(agent.QueueDir)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("dispatch.Execute", "invalid agent queue dir", err)
|
||||
}
|
||||
|
||||
// Sanitize inputs to prevent path traversal.
|
||||
safeOwner, err := agentci.SanitizePath(signal.RepoOwner)
|
||||
|
|
@ -184,7 +189,10 @@ func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.Pipelin
|
|||
}
|
||||
|
||||
// Transfer ticket JSON.
|
||||
remoteTicketPath := filepath.Join(agent.QueueDir, ticketName)
|
||||
remoteTicketPath, err := agentci.JoinRemotePath(queueDir, ticketName)
|
||||
if err != nil {
|
||||
return nil, coreerr.E("dispatch.Execute", "ticket path", err)
|
||||
}
|
||||
if err := h.secureTransfer(ctx, agent, remoteTicketPath, ticketJSON, 0644); err != nil {
|
||||
h.failDispatch(signal, fmt.Sprintf("Ticket transfer failed: %v", err))
|
||||
return &jobrunner.ActionResult{
|
||||
|
|
@ -202,10 +210,13 @@ func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.Pipelin
|
|||
|
||||
// Transfer token via separate .env file with 0600 permissions.
|
||||
envContent := fmt.Sprintf("FORGE_TOKEN=%s\n", h.token)
|
||||
remoteEnvPath := filepath.Join(agent.QueueDir, fmt.Sprintf(".env.%s", ticketID))
|
||||
remoteEnvPath, err := agentci.JoinRemotePath(queueDir, fmt.Sprintf(".env.%s", ticketID))
|
||||
if err != nil {
|
||||
return nil, coreerr.E("dispatch.Execute", "env path", err)
|
||||
}
|
||||
if err := h.secureTransfer(ctx, agent, remoteEnvPath, []byte(envContent), 0600); err != nil {
|
||||
// Clean up the ticket if env transfer fails.
|
||||
_ = h.runRemote(ctx, agent, fmt.Sprintf("rm -f %s", agentci.EscapeShellArg(remoteTicketPath)))
|
||||
_ = h.runRemote(ctx, agent, "rm", "-f", remoteTicketPath)
|
||||
h.failDispatch(signal, fmt.Sprintf("Token transfer failed: %v", err))
|
||||
return &jobrunner.ActionResult{
|
||||
Action: "dispatch",
|
||||
|
|
@ -255,8 +266,8 @@ func (h *DispatchHandler) failDispatch(signal *jobrunner.PipelineSignal, reason
|
|||
|
||||
// secureTransfer writes data to a remote path via SSH stdin, preventing command injection.
|
||||
func (h *DispatchHandler) secureTransfer(ctx context.Context, agent agentci.AgentConfig, remotePath string, data []byte, mode int) error {
|
||||
safeRemotePath := agentci.EscapeShellArg(remotePath)
|
||||
remoteCmd := fmt.Sprintf("cat > %s && chmod %o %s", safeRemotePath, mode, safeRemotePath)
|
||||
safePath := agentci.EscapeShellArg(remotePath)
|
||||
remoteCmd := fmt.Sprintf("cat > %s && chmod %o %s", safePath, mode, safePath)
|
||||
|
||||
cmd := agentci.SecureSSHCommand(agent.Host, remoteCmd)
|
||||
cmd.Stdin = bytes.NewReader(data)
|
||||
|
|
@ -269,21 +280,55 @@ func (h *DispatchHandler) secureTransfer(ctx context.Context, agent agentci.Agen
|
|||
}
|
||||
|
||||
// runRemote executes a command on the agent via SSH.
|
||||
func (h *DispatchHandler) runRemote(ctx context.Context, agent agentci.AgentConfig, cmdStr string) error {
|
||||
cmd := agentci.SecureSSHCommand(agent.Host, cmdStr)
|
||||
func (h *DispatchHandler) runRemote(ctx context.Context, agent agentci.AgentConfig, command string, args ...string) error {
|
||||
remoteCmd := command
|
||||
if len(args) > 0 {
|
||||
escaped := make([]string, 0, 1+len(args))
|
||||
escaped = append(escaped, command)
|
||||
for _, arg := range args {
|
||||
escaped = append(escaped, agentci.EscapeShellArg(arg))
|
||||
}
|
||||
remoteCmd = strings.Join(escaped, " ")
|
||||
}
|
||||
|
||||
cmd := agentci.SecureSSHCommand(agent.Host, remoteCmd)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ticketExists checks if a ticket file already exists in queue, active, or done.
|
||||
func (h *DispatchHandler) ticketExists(ctx context.Context, agent agentci.AgentConfig, ticketName string) bool {
|
||||
safeTicket, err := agentci.SanitizePath(ticketName)
|
||||
queueDir, err := agentci.ValidateRemoteDir(agent.QueueDir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
qDir := agent.QueueDir
|
||||
safeTicket, err := agentci.ValidatePathElement(ticketName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
queuePath, err := agentci.JoinRemotePath(queueDir, safeTicket)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
parentDir := queueDir
|
||||
if queueDir != "/" && queueDir != "~" {
|
||||
parentDir = path.Dir(queueDir)
|
||||
}
|
||||
activePath, err := agentci.JoinRemotePath(parentDir, "active", safeTicket)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
donePath, err := agentci.JoinRemotePath(parentDir, "done", safeTicket)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
queuePath = agentci.EscapeShellArg(queuePath)
|
||||
activePath = agentci.EscapeShellArg(activePath)
|
||||
donePath = agentci.EscapeShellArg(donePath)
|
||||
checkCmd := fmt.Sprintf(
|
||||
"test -f %s/%s || test -f %s/../active/%s || test -f %s/../done/%s",
|
||||
qDir, safeTicket, qDir, safeTicket, qDir, safeTicket,
|
||||
"test -f %s || test -f %s || test -f %s",
|
||||
queuePath, activePath, donePath,
|
||||
)
|
||||
cmd := agentci.SecureSSHCommand(agent.Host, checkCmd)
|
||||
return cmd.Run() == nil
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@ import (
|
|||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"dappco.re/go/core/scm/agentci"
|
||||
|
|
@ -13,6 +16,18 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func writeFakeSSHCommand(t *testing.T, outputPath string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "ssh")
|
||||
scriptContent := "#!/bin/sh\n" +
|
||||
"OUT=" + strconv.Quote(outputPath) + "\n" +
|
||||
"printf '%s\n' \"$@\" >> \"$OUT\"\n" +
|
||||
"cat >> \"${OUT}.stdin\"\n"
|
||||
require.NoError(t, os.WriteFile(script, []byte(scriptContent), 0o755))
|
||||
return dir
|
||||
}
|
||||
|
||||
// newTestSpinner creates a Spinner with the given agents for testing.
|
||||
func newTestSpinner(agents map[string]agentci.AgentConfig) *agentci.Spinner {
|
||||
return agentci.NewSpinner(agentci.ClothoConfig{Strategy: "direct"}, agents)
|
||||
|
|
@ -127,6 +142,29 @@ func TestDispatch_Execute_Bad_UnknownAgent(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "unknown agent")
|
||||
}
|
||||
|
||||
func TestDispatch_Execute_Bad_InvalidQueueDir(t *testing.T) {
|
||||
spinner := newTestSpinner(map[string]agentci.AgentConfig{
|
||||
"darbs-claude": {
|
||||
Host: "localhost",
|
||||
QueueDir: "/tmp/queue; touch /tmp/pwned",
|
||||
Active: true,
|
||||
},
|
||||
})
|
||||
h := NewDispatchHandler(nil, "", "", spinner)
|
||||
|
||||
sig := &jobrunner.PipelineSignal{
|
||||
NeedsCoding: true,
|
||||
Assignee: "darbs-claude",
|
||||
RepoOwner: "host-uk",
|
||||
RepoName: "core",
|
||||
ChildNumber: 1,
|
||||
}
|
||||
|
||||
_, err := h.Execute(context.Background(), sig)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid agent queue dir")
|
||||
}
|
||||
|
||||
func TestDispatch_TicketJSON_Good(t *testing.T) {
|
||||
ticket := DispatchTicket{
|
||||
ID: "host-uk-core-5-1234567890",
|
||||
|
|
@ -214,6 +252,53 @@ func TestDispatch_TicketJSON_Good_OmitsEmptyModelRunner(t *testing.T) {
|
|||
assert.False(t, hasRunner, "runner should be omitted when empty")
|
||||
}
|
||||
|
||||
func TestDispatch_runRemote_Good_EscapesPath(t *testing.T) {
|
||||
outputPath := filepath.Join(t.TempDir(), "ssh-output.txt")
|
||||
toolPath := writeFakeSSHCommand(t, outputPath)
|
||||
t.Setenv("PATH", toolPath+":"+os.Getenv("PATH"))
|
||||
|
||||
h := NewDispatchHandler(nil, "", "", newTestSpinner(nil))
|
||||
dangerousPath := "/tmp/queue with spaces; touch /tmp/pwned"
|
||||
err := h.runRemote(
|
||||
context.Background(),
|
||||
agentci.AgentConfig{Host: "localhost"},
|
||||
"rm",
|
||||
"-f",
|
||||
dangerousPath,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := os.ReadFile(outputPath)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(output), "rm '-f' '"+dangerousPath+"'\n")
|
||||
}
|
||||
|
||||
func TestDispatch_secureTransfer_Good_EscapesPath(t *testing.T) {
|
||||
outputPath := filepath.Join(t.TempDir(), "ssh-output.txt")
|
||||
toolPath := writeFakeSSHCommand(t, outputPath)
|
||||
t.Setenv("PATH", toolPath+":"+os.Getenv("PATH"))
|
||||
|
||||
h := NewDispatchHandler(nil, "", "", newTestSpinner(nil))
|
||||
dangerousPath := "/tmp/queue with spaces; touch /tmp/pwned"
|
||||
err := h.secureTransfer(
|
||||
context.Background(),
|
||||
agentci.AgentConfig{Host: "localhost"},
|
||||
dangerousPath,
|
||||
[]byte("hello"),
|
||||
0644,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := os.ReadFile(outputPath)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(output), "cat > '"+dangerousPath+"' && chmod 644 '"+dangerousPath+"'")
|
||||
|
||||
inputPath := outputPath + ".stdin"
|
||||
input, err := os.ReadFile(inputPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", string(input))
|
||||
}
|
||||
|
||||
func TestDispatch_TicketJSON_Good_ModelRunnerVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"dappco.re/go/core/io"
|
||||
"dappco.re/go/core/scm/manifest"
|
||||
"dappco.re/go/core/io/store"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"dappco.re/go/core/scm/agentci"
|
||||
"dappco.re/go/core/scm/manifest"
|
||||
)
|
||||
|
||||
const storeGroup = "_modules"
|
||||
|
|
@ -47,12 +48,16 @@ type InstalledModule struct {
|
|||
|
||||
// Install clones a module repo, verifies its manifest signature, and registers it.
|
||||
func (i *Installer) Install(ctx context.Context, mod Module) error {
|
||||
// Check if already installed
|
||||
if _, err := i.store.Get(storeGroup, mod.Code); err == nil {
|
||||
return coreerr.E("marketplace.Installer.Install", "module already installed: "+mod.Code, nil)
|
||||
safeCode, dest, err := i.resolveModulePath(mod.Code)
|
||||
if err != nil {
|
||||
return coreerr.E("marketplace.Installer.Install", "invalid module code", err)
|
||||
}
|
||||
|
||||
// Check if already installed
|
||||
if _, err := i.store.Get(storeGroup, safeCode); err == nil {
|
||||
return coreerr.E("marketplace.Installer.Install", "module already installed: "+safeCode, nil)
|
||||
}
|
||||
|
||||
dest := filepath.Join(i.modulesDir, mod.Code)
|
||||
if err := i.medium.EnsureDir(i.modulesDir); err != nil {
|
||||
return coreerr.E("marketplace.Installer.Install", "mkdir", err)
|
||||
}
|
||||
|
|
@ -80,7 +85,7 @@ func (i *Installer) Install(ctx context.Context, mod Module) error {
|
|||
|
||||
entryPoint := filepath.Join(dest, "main.ts")
|
||||
installed := InstalledModule{
|
||||
Code: mod.Code,
|
||||
Code: safeCode,
|
||||
Name: m.Name,
|
||||
Version: m.Version,
|
||||
Repo: mod.Repo,
|
||||
|
|
@ -95,7 +100,7 @@ func (i *Installer) Install(ctx context.Context, mod Module) error {
|
|||
return coreerr.E("marketplace.Installer.Install", "marshal", err)
|
||||
}
|
||||
|
||||
if err := i.store.Set(storeGroup, mod.Code, string(data)); err != nil {
|
||||
if err := i.store.Set(storeGroup, safeCode, string(data)); err != nil {
|
||||
return coreerr.E("marketplace.Installer.Install", "store", err)
|
||||
}
|
||||
|
||||
|
|
@ -105,21 +110,32 @@ func (i *Installer) Install(ctx context.Context, mod Module) error {
|
|||
|
||||
// Remove uninstalls a module by deleting its files and store entry.
|
||||
func (i *Installer) Remove(code string) error {
|
||||
if _, err := i.store.Get(storeGroup, code); err != nil {
|
||||
return coreerr.E("marketplace.Installer.Remove", "module not installed: "+code, nil)
|
||||
safeCode, dest, err := i.resolveModulePath(code)
|
||||
if err != nil {
|
||||
return coreerr.E("marketplace.Installer.Remove", "invalid module code", err)
|
||||
}
|
||||
|
||||
dest := filepath.Join(i.modulesDir, code)
|
||||
_ = i.medium.DeleteAll(dest)
|
||||
if _, err := i.store.Get(storeGroup, safeCode); err != nil {
|
||||
return coreerr.E("marketplace.Installer.Remove", "module not installed: "+safeCode, nil)
|
||||
}
|
||||
|
||||
return i.store.Delete(storeGroup, code)
|
||||
if err := i.medium.DeleteAll(dest); err != nil {
|
||||
return coreerr.E("marketplace.Installer.Remove", "delete module files", err)
|
||||
}
|
||||
|
||||
return i.store.Delete(storeGroup, safeCode)
|
||||
}
|
||||
|
||||
// Update pulls latest changes and re-verifies the manifest.
|
||||
func (i *Installer) Update(ctx context.Context, code string) error {
|
||||
raw, err := i.store.Get(storeGroup, code)
|
||||
safeCode, dest, err := i.resolveModulePath(code)
|
||||
if err != nil {
|
||||
return coreerr.E("marketplace.Installer.Update", "module not installed: "+code, nil)
|
||||
return coreerr.E("marketplace.Installer.Update", "invalid module code", err)
|
||||
}
|
||||
|
||||
raw, err := i.store.Get(storeGroup, safeCode)
|
||||
if err != nil {
|
||||
return coreerr.E("marketplace.Installer.Update", "module not installed: "+safeCode, nil)
|
||||
}
|
||||
|
||||
var installed InstalledModule
|
||||
|
|
@ -127,8 +143,6 @@ func (i *Installer) Update(ctx context.Context, code string) error {
|
|||
return coreerr.E("marketplace.Installer.Update", "unmarshal", err)
|
||||
}
|
||||
|
||||
dest := filepath.Join(i.modulesDir, code)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "git", "-C", dest, "pull", "--ff-only")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return coreerr.E("marketplace.Installer.Update", "pull: "+strings.TrimSpace(string(output)), err)
|
||||
|
|
@ -145,6 +159,7 @@ func (i *Installer) Update(ctx context.Context, code string) error {
|
|||
}
|
||||
|
||||
// Update stored metadata
|
||||
installed.Code = safeCode
|
||||
installed.Name = m.Name
|
||||
installed.Version = m.Version
|
||||
installed.Permissions = m.Permissions
|
||||
|
|
@ -154,7 +169,7 @@ func (i *Installer) Update(ctx context.Context, code string) error {
|
|||
return coreerr.E("marketplace.Installer.Update", "marshal", err)
|
||||
}
|
||||
|
||||
return i.store.Set(storeGroup, code, string(data))
|
||||
return i.store.Set(storeGroup, safeCode, string(data))
|
||||
}
|
||||
|
||||
// Installed returns all installed module metadata.
|
||||
|
|
@ -195,3 +210,11 @@ func gitClone(ctx context.Context, repo, dest string) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) resolveModulePath(code string) (string, string, error) {
|
||||
safeCode, dest, err := agentci.ResolvePathWithinRoot(i.modulesDir, code)
|
||||
if err != nil {
|
||||
return "", "", coreerr.E("marketplace.Installer.resolveModulePath", "resolve module path", err)
|
||||
}
|
||||
return safeCode, dest, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"dappco.re/go/core/io"
|
||||
"dappco.re/go/core/scm/manifest"
|
||||
"dappco.re/go/core/io/store"
|
||||
"dappco.re/go/core/scm/manifest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -163,6 +163,29 @@ func TestInstall_Bad_InvalidSignature(t *testing.T) {
|
|||
assert.True(t, os.IsNotExist(statErr), "directory should be cleaned up on failure")
|
||||
}
|
||||
|
||||
func TestInstall_Bad_PathTraversalCode(t *testing.T) {
|
||||
repo := createTestRepo(t, "safe-mod", "1.0")
|
||||
modulesDir := filepath.Join(t.TempDir(), "modules")
|
||||
|
||||
st, err := store.New(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
inst := NewInstaller(io.Local, modulesDir, st)
|
||||
err = inst.Install(context.Background(), Module{
|
||||
Code: "../escape",
|
||||
Repo: repo,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid module code")
|
||||
|
||||
_, err = st.Get("_modules", "escape")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = os.Stat(filepath.Join(filepath.Dir(modulesDir), "escape"))
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func TestRemove_Good(t *testing.T) {
|
||||
repo := createTestRepo(t, "rm-mod", "1.0")
|
||||
modulesDir := filepath.Join(t.TempDir(), "modules")
|
||||
|
|
@ -197,6 +220,26 @@ func TestRemove_Bad_NotInstalled(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "not installed")
|
||||
}
|
||||
|
||||
func TestRemove_Bad_PathTraversalCode(t *testing.T) {
|
||||
baseDir := t.TempDir()
|
||||
modulesDir := filepath.Join(baseDir, "modules")
|
||||
escapeDir := filepath.Join(baseDir, "escape")
|
||||
require.NoError(t, os.MkdirAll(escapeDir, 0755))
|
||||
|
||||
st, err := store.New(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
inst := NewInstaller(io.Local, modulesDir, st)
|
||||
err = inst.Remove("../escape")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid module code")
|
||||
|
||||
info, statErr := os.Stat(escapeDir)
|
||||
require.NoError(t, statErr)
|
||||
assert.True(t, info.IsDir())
|
||||
}
|
||||
|
||||
func TestInstalled_Good(t *testing.T) {
|
||||
modulesDir := filepath.Join(t.TempDir(), "modules")
|
||||
|
||||
|
|
@ -262,3 +305,16 @@ func TestUpdate_Good(t *testing.T) {
|
|||
assert.Equal(t, "2.0", installed[0].Version)
|
||||
assert.Equal(t, "Updated Module", installed[0].Name)
|
||||
}
|
||||
|
||||
func TestUpdate_Bad_PathTraversalCode(t *testing.T) {
|
||||
modulesDir := filepath.Join(t.TempDir(), "modules")
|
||||
|
||||
st, err := store.New(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer st.Close()
|
||||
|
||||
inst := NewInstaller(io.Local, modulesDir, st)
|
||||
err = inst.Update(context.Background(), "../escape")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid module code")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,12 @@ import (
|
|||
"crypto/ed25519"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"dappco.re/go/core/api"
|
||||
"dappco.re/go/core/api/pkg/provider"
|
||||
"dappco.re/go/core/io"
|
||||
"dappco.re/go/core/scm/agentci"
|
||||
"dappco.re/go/core/scm/manifest"
|
||||
"dappco.re/go/core/scm/marketplace"
|
||||
"dappco.re/go/core/scm/repos"
|
||||
|
|
@ -228,7 +230,10 @@ func (p *ScmProvider) getMarketplaceItem(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
code := c.Param("code")
|
||||
code, ok := marketplaceCodeParam(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mod, ok := p.index.Find(code)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, api.Fail("not_found", "provider not found in marketplace"))
|
||||
|
|
@ -243,7 +248,10 @@ func (p *ScmProvider) installItem(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
code := c.Param("code")
|
||||
code, ok := marketplaceCodeParam(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mod, ok := p.index.Find(code)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, api.Fail("not_found", "provider not found in marketplace"))
|
||||
|
|
@ -269,7 +277,10 @@ func (p *ScmProvider) removeItem(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
code := c.Param("code")
|
||||
code, ok := marketplaceCodeParam(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.installer.Remove(code); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, api.Fail("remove_failed", err.Error()))
|
||||
return
|
||||
|
|
@ -393,7 +404,10 @@ func (p *ScmProvider) updateInstalled(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
code := c.Param("code")
|
||||
code, ok := marketplaceCodeParam(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.installer.Update(context.Background(), code); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, api.Fail("update_failed", err.Error()))
|
||||
return
|
||||
|
|
@ -448,3 +462,21 @@ func (p *ScmProvider) emitEvent(channel string, data any) {
|
|||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func marketplaceCodeParam(c *gin.Context) (string, bool) {
|
||||
code, err := normaliseMarketplaceCode(c.Param("code"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, api.Fail("invalid_code", "invalid marketplace code"))
|
||||
return "", false
|
||||
}
|
||||
return code, true
|
||||
}
|
||||
|
||||
func normaliseMarketplaceCode(raw string) (string, error) {
|
||||
decoded, err := url.PathUnescape(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return agentci.ValidatePathElement(decoded)
|
||||
}
|
||||
|
|
|
|||
26
pkg/api/provider_security_test.go
Normal file
26
pkg/api/provider_security_test.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// SPDX-Licence-Identifier: EUPL-1.2
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormaliseMarketplaceCode_Good(t *testing.T) {
|
||||
code, err := normaliseMarketplaceCode("analytics")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "analytics", code)
|
||||
}
|
||||
|
||||
func TestNormaliseMarketplaceCode_Bad(t *testing.T) {
|
||||
_, err := normaliseMarketplaceCode("analytics;rm")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNormaliseMarketplaceCode_Bad_EncodedTraversal(t *testing.T) {
|
||||
_, err := normaliseMarketplaceCode("analytics%2f..%2Fescape")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
@ -164,6 +164,18 @@ func TestScmProvider_GetMarketplaceItem_Bad(t *testing.T) {
|
|||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestScmProvider_GetMarketplaceItem_Bad_PathTraversal(t *testing.T) {
|
||||
idx := &marketplace.Index{Version: 1}
|
||||
p := scmapi.NewProvider(idx, nil, nil, nil)
|
||||
|
||||
r := setupRouter(p)
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/api/v1/scm/marketplace/%2e%2e", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
// -- Installed Endpoints ------------------------------------------------------
|
||||
|
||||
func TestScmProvider_ListInstalled_NilInstaller_Good(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@ package plugin
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"dappco.re/go/core/io"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"dappco.re/go/core/scm/agentci"
|
||||
)
|
||||
|
||||
// Installer handles plugin installation from GitHub.
|
||||
|
|
@ -40,7 +42,10 @@ func (i *Installer) Install(ctx context.Context, source string) error {
|
|||
}
|
||||
|
||||
// Clone the repository
|
||||
pluginDir := filepath.Join(i.registry.basePath, repo)
|
||||
_, pluginDir, err := i.resolvePluginPath(repo)
|
||||
if err != nil {
|
||||
return coreerr.E("plugin.Installer.Install", "invalid plugin path", err)
|
||||
}
|
||||
if err := i.medium.EnsureDir(pluginDir); err != nil {
|
||||
return coreerr.E("plugin.Installer.Install", "failed to create plugin directory", err)
|
||||
}
|
||||
|
|
@ -90,14 +95,15 @@ func (i *Installer) Install(ctx context.Context, source string) error {
|
|||
|
||||
// Update updates a plugin to the latest version.
|
||||
func (i *Installer) Update(ctx context.Context, name string) error {
|
||||
cfg, ok := i.registry.Get(name)
|
||||
if !ok {
|
||||
return coreerr.E("plugin.Installer.Update", "plugin not found: "+name, nil)
|
||||
safeName, pluginDir, err := i.resolvePluginPath(name)
|
||||
if err != nil {
|
||||
return coreerr.E("plugin.Installer.Update", "invalid plugin name", err)
|
||||
}
|
||||
|
||||
// Parse the source to get org/repo
|
||||
source := strings.TrimPrefix(cfg.Source, "github:")
|
||||
pluginDir := filepath.Join(i.registry.basePath, name)
|
||||
cfg, ok := i.registry.Get(safeName)
|
||||
if !ok {
|
||||
return coreerr.E("plugin.Installer.Update", "plugin not found: "+safeName, nil)
|
||||
}
|
||||
|
||||
// Pull latest changes
|
||||
cmd := exec.CommandContext(ctx, "git", "-C", pluginDir, "pull", "--ff-only")
|
||||
|
|
@ -118,18 +124,21 @@ func (i *Installer) Update(ctx context.Context, name string) error {
|
|||
return coreerr.E("plugin.Installer.Update", "failed to save registry", err)
|
||||
}
|
||||
|
||||
_ = source // used for context
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove uninstalls a plugin by removing its files and registry entry.
|
||||
func (i *Installer) Remove(name string) error {
|
||||
if _, ok := i.registry.Get(name); !ok {
|
||||
return coreerr.E("plugin.Installer.Remove", "plugin not found: "+name, nil)
|
||||
safeName, pluginDir, err := i.resolvePluginPath(name)
|
||||
if err != nil {
|
||||
return coreerr.E("plugin.Installer.Remove", "invalid plugin name", err)
|
||||
}
|
||||
|
||||
if _, ok := i.registry.Get(safeName); !ok {
|
||||
return coreerr.E("plugin.Installer.Remove", "plugin not found: "+safeName, nil)
|
||||
}
|
||||
|
||||
// Delete plugin directory
|
||||
pluginDir := filepath.Join(i.registry.basePath, name)
|
||||
if i.medium.Exists(pluginDir) {
|
||||
if err := i.medium.DeleteAll(pluginDir); err != nil {
|
||||
return coreerr.E("plugin.Installer.Remove", "failed to delete plugin files", err)
|
||||
|
|
@ -137,7 +146,7 @@ func (i *Installer) Remove(name string) error {
|
|||
}
|
||||
|
||||
// Remove from registry
|
||||
if err := i.registry.Remove(name); err != nil {
|
||||
if err := i.registry.Remove(safeName); err != nil {
|
||||
return coreerr.E("plugin.Installer.Remove", "failed to unregister plugin", err)
|
||||
}
|
||||
|
||||
|
|
@ -170,6 +179,10 @@ func (i *Installer) cloneRepo(ctx context.Context, org, repo, version, dest stri
|
|||
// - "org/repo" -> org="org", repo="repo", version=""
|
||||
// - "org/repo@v1.0" -> org="org", repo="repo", version="v1.0"
|
||||
func ParseSource(source string) (org, repo, version string, err error) {
|
||||
source, err = url.PathUnescape(source)
|
||||
if err != nil {
|
||||
return "", "", "", coreerr.E("plugin.ParseSource", "invalid source path", err)
|
||||
}
|
||||
if source == "" {
|
||||
return "", "", "", coreerr.E("plugin.ParseSource", "source is empty", nil)
|
||||
}
|
||||
|
|
@ -191,5 +204,22 @@ func ParseSource(source string) (org, repo, version string, err error) {
|
|||
return "", "", "", coreerr.E("plugin.ParseSource", "source must be in format org/repo[@version]", nil)
|
||||
}
|
||||
|
||||
return parts[0], parts[1], version, nil
|
||||
org, err = agentci.ValidatePathElement(parts[0])
|
||||
if err != nil {
|
||||
return "", "", "", coreerr.E("plugin.ParseSource", "invalid org", err)
|
||||
}
|
||||
repo, err = agentci.ValidatePathElement(parts[1])
|
||||
if err != nil {
|
||||
return "", "", "", coreerr.E("plugin.ParseSource", "invalid repo", err)
|
||||
}
|
||||
|
||||
return org, repo, version, nil
|
||||
}
|
||||
|
||||
func (i *Installer) resolvePluginPath(name string) (string, string, error) {
|
||||
safeName, path, err := agentci.ResolvePathWithinRoot(i.registry.basePath, name)
|
||||
if err != nil {
|
||||
return "", "", coreerr.E("plugin.Installer.resolvePluginPath", "resolve plugin path", err)
|
||||
}
|
||||
return safeName, path, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,17 @@ func TestInstall_Bad_AlreadyInstalled(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "already installed")
|
||||
}
|
||||
|
||||
func TestInstall_Bad_PathTraversalSource(t *testing.T) {
|
||||
m := io.NewMockMedium()
|
||||
reg := NewRegistry(m, "/plugins")
|
||||
inst := NewInstaller(m, reg)
|
||||
|
||||
err := inst.Install(context.Background(), "../repo")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid source")
|
||||
assert.False(t, m.Exists("/repo"))
|
||||
}
|
||||
|
||||
// ── Remove ─────────────────────────────────────────────────────────
|
||||
|
||||
func TestRemove_Good(t *testing.T) {
|
||||
|
|
@ -91,6 +102,19 @@ func TestRemove_Bad_NotFound(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "plugin not found")
|
||||
}
|
||||
|
||||
func TestRemove_Bad_PathTraversalName(t *testing.T) {
|
||||
m := io.NewMockMedium()
|
||||
reg := NewRegistry(m, "/plugins")
|
||||
_ = reg.Add(&PluginConfig{Name: "safe", Version: "1.0.0"})
|
||||
_ = m.EnsureDir("/escape")
|
||||
|
||||
inst := NewInstaller(m, reg)
|
||||
err := inst.Remove("../escape")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid plugin name")
|
||||
assert.True(t, m.Exists("/escape"))
|
||||
}
|
||||
|
||||
// ── Update error paths ─────────────────────────────────────────────
|
||||
|
||||
func TestUpdate_Bad_NotFound(t *testing.T) {
|
||||
|
|
@ -164,3 +188,25 @@ func TestParseSource_Bad_EmptyVersion(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "version is empty")
|
||||
}
|
||||
|
||||
func TestParseSource_Bad_PathTraversal(t *testing.T) {
|
||||
_, _, _, err := ParseSource("org/../repo")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "org/repo")
|
||||
}
|
||||
|
||||
func TestParseSource_Bad_PathTraversalEncoded(t *testing.T) {
|
||||
_, _, _, err := ParseSource("org%2f..%2frepo")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "org/repo")
|
||||
}
|
||||
|
||||
func TestInstall_Bad_EncodedPathTraversal(t *testing.T) {
|
||||
m := io.NewMockMedium()
|
||||
reg := NewRegistry(m, "/plugins")
|
||||
inst := NewInstaller(m, reg)
|
||||
|
||||
err := inst.Install(context.Background(), "org%2f..%2frepo")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid source")
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue