diff --git a/pkg/cli/daemon.go b/pkg/cli/daemon.go index 90b2fd28..e43df9f1 100644 --- a/pkg/cli/daemon.go +++ b/pkg/cli/daemon.go @@ -74,14 +74,13 @@ func IsStderrTTY() bool { // PIDFile manages a process ID file for single-instance enforcement. type PIDFile struct { - medium io.Medium - path string - mu sync.Mutex + path string + mu sync.Mutex } // NewPIDFile creates a PID file manager. -func NewPIDFile(m io.Medium, path string) *PIDFile { - return &PIDFile{medium: m, path: path} +func NewPIDFile(path string) *PIDFile { + return &PIDFile{path: path} } // Acquire writes the current PID to the file. @@ -91,7 +90,7 @@ func (p *PIDFile) Acquire() error { defer p.mu.Unlock() // Check if PID file exists - if data, err := p.medium.Read(p.path); err == nil { + if data, err := io.Local.Read(p.path); err == nil { pid, err := strconv.Atoi(data) if err == nil && pid > 0 { // Check if process is still running @@ -102,19 +101,19 @@ func (p *PIDFile) Acquire() error { } } // Stale PID file, remove it - _ = p.medium.Delete(p.path) + _ = io.Local.Delete(p.path) } // Ensure directory exists if dir := filepath.Dir(p.path); dir != "." { - if err := p.medium.EnsureDir(dir); err != nil { + if err := io.Local.EnsureDir(dir); err != nil { return fmt.Errorf("failed to create PID directory: %w", err) } } // Write current PID pid := os.Getpid() - if err := p.medium.Write(p.path, strconv.Itoa(pid)); err != nil { + if err := io.Local.Write(p.path, strconv.Itoa(pid)); err != nil { return fmt.Errorf("failed to write PID file: %w", err) } @@ -125,7 +124,7 @@ func (p *PIDFile) Acquire() error { func (p *PIDFile) Release() error { p.mu.Lock() defer p.mu.Unlock() - return p.medium.Delete(p.path) + return io.Local.Delete(p.path) } // Path returns the PID file path. @@ -247,9 +246,6 @@ func (h *HealthServer) Addr() string { // DaemonOptions configures daemon mode execution. type DaemonOptions struct { - // Medium is the filesystem abstraction. - Medium io.Medium - // PIDFile path for single-instance enforcement. // Leave empty to skip PID file management. PIDFile string @@ -287,17 +283,13 @@ func NewDaemon(opts DaemonOptions) *Daemon { opts.ShutdownTimeout = 30 * time.Second } - if opts.Medium == nil { - opts.Medium = io.Local - } - d := &Daemon{ opts: opts, reload: make(chan struct{}, 1), } if opts.PIDFile != "" { - d.pid = NewPIDFile(opts.Medium, opts.PIDFile) + d.pid = NewPIDFile(opts.PIDFile) } if opts.HealthAddr != "" { diff --git a/pkg/cli/daemon_test.go b/pkg/cli/daemon_test.go index d128b5e2..5eb51329 100644 --- a/pkg/cli/daemon_test.go +++ b/pkg/cli/daemon_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -32,7 +31,7 @@ func TestPIDFile(t *testing.T) { tmpDir := t.TempDir() pidPath := filepath.Join(tmpDir, "test.pid") - pid := NewPIDFile(io.Local, pidPath) + pid := NewPIDFile(pidPath) // Acquire should succeed err := pid.Acquire() @@ -59,7 +58,7 @@ func TestPIDFile(t *testing.T) { err := os.WriteFile(pidPath, []byte("999999999"), 0644) require.NoError(t, err) - pid := NewPIDFile(io.Local, pidPath) + pid := NewPIDFile(pidPath) // Should acquire successfully (stale PID removed) err = pid.Acquire() @@ -73,7 +72,7 @@ func TestPIDFile(t *testing.T) { tmpDir := t.TempDir() pidPath := filepath.Join(tmpDir, "subdir", "nested", "test.pid") - pid := NewPIDFile(io.Local, pidPath) + pid := NewPIDFile(pidPath) err := pid.Acquire() require.NoError(t, err) @@ -86,26 +85,9 @@ func TestPIDFile(t *testing.T) { }) t.Run("path getter", func(t *testing.T) { - pid := NewPIDFile(io.Local, "/tmp/test.pid") + pid := NewPIDFile("/tmp/test.pid") assert.Equal(t, "/tmp/test.pid", pid.Path()) }) - - t.Run("with mock medium", func(t *testing.T) { - mock := io.NewMockMedium() - pidPath := "/tmp/mock.pid" - pid := NewPIDFile(mock, pidPath) - - err := pid.Acquire() - require.NoError(t, err) - - assert.True(t, mock.Exists(pidPath)) - data, _ := mock.Read(pidPath) - assert.NotEmpty(t, data) - - err = pid.Release() - require.NoError(t, err) - assert.False(t, mock.Exists(pidPath)) - }) } func TestHealthServer(t *testing.T) { @@ -262,26 +244,6 @@ func TestDaemon(t *testing.T) { d := NewDaemon(DaemonOptions{}) assert.Equal(t, 30*time.Second, d.opts.ShutdownTimeout) }) - - t.Run("with mock medium", func(t *testing.T) { - mock := io.NewMockMedium() - pidPath := "/tmp/daemon.pid" - - d := NewDaemon(DaemonOptions{ - Medium: mock, - PIDFile: pidPath, - HealthAddr: "127.0.0.1:0", - }) - - err := d.Start() - require.NoError(t, err) - - assert.True(t, mock.Exists(pidPath)) - - err = d.Stop() - require.NoError(t, err) - assert.False(t, mock.Exists(pidPath)) - }) } func TestRunWithTimeout(t *testing.T) { diff --git a/pkg/container/linuxkit.go b/pkg/container/linuxkit.go index 252b864a..2f2780af 100644 --- a/pkg/container/linuxkit.go +++ b/pkg/container/linuxkit.go @@ -52,10 +52,6 @@ func NewLinuxKitManagerWithHypervisor(state *State, hypervisor Hypervisor) *Linu // Run starts a new LinuxKit VM from the given image. func (m *LinuxKitManager) Run(ctx context.Context, image string, opts RunOptions) (*Container, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - // Validate image exists if !io.Local.IsFile(image) { return nil, fmt.Errorf("image not found: %s", image) @@ -236,10 +232,6 @@ func (m *LinuxKitManager) waitForExit(id string, cmd *exec.Cmd) { // Stop stops a running container by sending SIGTERM. func (m *LinuxKitManager) Stop(ctx context.Context, id string) error { - if err := ctx.Err(); err != nil { - return err - } - container, ok := m.state.Get(id) if !ok { return fmt.Errorf("container not found: %s", id) @@ -298,10 +290,6 @@ func (m *LinuxKitManager) Stop(ctx context.Context, id string) error { // List returns all known containers, verifying process state. func (m *LinuxKitManager) List(ctx context.Context) ([]*Container, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - containers := m.state.All() // Verify each running container's process is still alive @@ -331,10 +319,6 @@ func isProcessRunning(pid int) bool { // Logs returns a reader for the container's log output. func (m *LinuxKitManager) Logs(ctx context.Context, id string, follow bool) (goio.ReadCloser, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - _, ok := m.state.Get(id) if !ok { return nil, fmt.Errorf("container not found: %s", id) @@ -419,10 +403,6 @@ func (f *followReader) Close() error { // Exec executes a command inside the container via SSH. func (m *LinuxKitManager) Exec(ctx context.Context, id string, cmd []string) error { - if err := ctx.Err(); err != nil { - return err - } - container, ok := m.state.Get(id) if !ok { return fmt.Errorf("container not found: %s", id) diff --git a/pkg/io/local/client.go b/pkg/io/local/client.go index 14cb826f..03b9e7ac 100644 --- a/pkg/io/local/client.go +++ b/pkg/io/local/client.go @@ -24,34 +24,70 @@ func New(root string) (*Medium, error) { } // path sanitizes and returns the full path. -// Replaces .. with . to prevent traversal, then joins with root. // Absolute paths are sandboxed under root (unless root is "/"). func (m *Medium) path(p string) string { if p == "" { return m.root } - clean := strings.ReplaceAll(p, "..", ".") - if filepath.IsAbs(clean) { - // If root is "/", allow absolute paths through - if m.root == "/" { - return filepath.Clean(clean) - } - // Otherwise, sandbox absolute paths by stripping volume + leading separators - vol := filepath.VolumeName(clean) - clean = strings.TrimPrefix(clean, vol) - cutset := string(os.PathSeparator) - if os.PathSeparator != '/' { - cutset += "/" - } - clean = strings.TrimLeft(clean, cutset) - return filepath.Join(m.root, clean) + // Use filepath.Clean with a leading slash to resolve all .. and . internally + // before joining with the root. This is a standard way to sandbox paths. + clean := filepath.Clean("/" + p) + + // If root is "/", allow absolute paths through + if m.root == "/" { + return clean } + + // Join cleaned relative path with root return filepath.Join(m.root, clean) } +// validatePath ensures the path is within the sandbox, following symlinks if they exist. +func (m *Medium) validatePath(p string) (string, error) { + if m.root == "/" { + return m.path(p), nil + } + + // Split the cleaned path into components + parts := strings.Split(filepath.Clean("/"+p), string(os.PathSeparator)) + current := m.root + + for _, part := range parts { + if part == "" { + continue + } + + next := filepath.Join(current, part) + realNext, err := filepath.EvalSymlinks(next) + if err != nil { + if os.IsNotExist(err) { + // Part doesn't exist, we can't follow symlinks anymore. + // Since the path is already Cleaned and current is safe, + // appending a component to current will not escape. + current = next + continue + } + return "", err + } + + // Verify the resolved part is still within the root + rel, err := filepath.Rel(m.root, realNext) + if err != nil || strings.HasPrefix(rel, "..") { + return "", os.ErrPermission // Path escapes sandbox + } + current = realNext + } + + return current, nil +} + // Read returns file contents as string. func (m *Medium) Read(p string) (string, error) { - data, err := os.ReadFile(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return "", err + } + data, err := os.ReadFile(full) if err != nil { return "", err } @@ -60,7 +96,10 @@ func (m *Medium) Read(p string) (string, error) { // Write saves content to file, creating parent directories as needed. func (m *Medium) Write(p, content string) error { - full := m.path(p) + full, err := m.validatePath(p) + if err != nil { + return err + } if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { return err } @@ -69,7 +108,11 @@ func (m *Medium) Write(p, content string) error { // EnsureDir creates directory if it doesn't exist. func (m *Medium) EnsureDir(p string) error { - return os.MkdirAll(m.path(p), 0755) + full, err := m.validatePath(p) + if err != nil { + return err + } + return os.MkdirAll(full, 0755) } // IsDir returns true if path is a directory. @@ -77,7 +120,11 @@ func (m *Medium) IsDir(p string) bool { if p == "" { return false } - info, err := os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return false + } + info, err := os.Stat(full) return err == nil && info.IsDir() } @@ -86,29 +133,48 @@ func (m *Medium) IsFile(p string) bool { if p == "" { return false } - info, err := os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return false + } + info, err := os.Stat(full) return err == nil && info.Mode().IsRegular() } // Exists returns true if path exists. func (m *Medium) Exists(p string) bool { - _, err := os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return false + } + _, err = os.Stat(full) return err == nil } // List returns directory entries. func (m *Medium) List(p string) ([]fs.DirEntry, error) { - return os.ReadDir(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return nil, err + } + return os.ReadDir(full) } // Stat returns file info. func (m *Medium) Stat(p string) (fs.FileInfo, error) { - return os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return nil, err + } + return os.Stat(full) } // Delete removes a file or empty directory. func (m *Medium) Delete(p string) error { - full := m.path(p) + full, err := m.validatePath(p) + if err != nil { + return err + } if len(full) < 3 { return nil } @@ -117,7 +183,10 @@ func (m *Medium) Delete(p string) error { // DeleteAll removes a file or directory recursively. func (m *Medium) DeleteAll(p string) error { - full := m.path(p) + full, err := m.validatePath(p) + if err != nil { + return err + } if len(full) < 3 { return nil } @@ -126,7 +195,15 @@ func (m *Medium) DeleteAll(p string) error { // Rename moves a file or directory. func (m *Medium) Rename(oldPath, newPath string) error { - return os.Rename(m.path(oldPath), m.path(newPath)) + oldFull, err := m.validatePath(oldPath) + if err != nil { + return err + } + newFull, err := m.validatePath(newPath) + if err != nil { + return err + } + return os.Rename(oldFull, newFull) } // FileGet is an alias for Read. diff --git a/pkg/io/local/client_test.go b/pkg/io/local/client_test.go index 9e2a1e14..5308cdb1 100644 --- a/pkg/io/local/client_test.go +++ b/pkg/io/local/client_test.go @@ -25,9 +25,9 @@ func TestPath(t *testing.T) { // Empty returns root assert.Equal(t, "/home/user", m.path("")) - // Traversal attempts get sanitized (.. becomes ., then cleaned by Join) + // Traversal attempts get sanitized assert.Equal(t, "/home/user/file.txt", m.path("../file.txt")) - assert.Equal(t, "/home/user/dir/file.txt", m.path("dir/../file.txt")) + assert.Equal(t, "/home/user/file.txt", m.path("dir/../file.txt")) // Absolute paths are constrained to sandbox (no escape) assert.Equal(t, "/home/user/etc/passwd", m.path("/etc/passwd")) diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 9f07dbc8..0d3dba0d 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/host-uk/core/pkg/io" + "github.com/host-uk/core/pkg/io/local" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -40,7 +41,7 @@ func WithWorkspaceRoot(root string) Option { if err != nil { return fmt.Errorf("invalid workspace root: %w", err) } - m, err := io.NewSandboxed(abs) + m, err := local.New(abs) if err != nil { return fmt.Errorf("failed to create workspace medium: %w", err) } @@ -69,7 +70,7 @@ func New(opts ...Option) (*Service, error) { return nil, fmt.Errorf("failed to get working directory: %w", err) } s.workspaceRoot = cwd - m, err := io.NewSandboxed(cwd) + m, err := local.New(cwd) if err != nil { return nil, fmt.Errorf("failed to create sandboxed medium: %w", err) } @@ -310,11 +311,8 @@ func (s *Service) listDirectory(ctx context.Context, req *mcp.CallToolRequest, i size = info.Size() } result = append(result, DirectoryEntry{ - Name: e.Name(), - Path: filepath.Join(input.Path, e.Name()), // Note: This might be relative path, client might expect absolute? - // Issue 103 says "Replace ... with local.Medium sandboxing". - // Previous code returned `filepath.Join(input.Path, e.Name())`. - // If input.Path is relative, this preserves it. + Name: e.Name(), + Path: filepath.Join(input.Path, e.Name()), IsDir: e.IsDir(), Size: size, }) @@ -344,21 +342,18 @@ func (s *Service) renameFile(ctx context.Context, req *mcp.CallToolRequest, inpu } func (s *Service) fileExists(ctx context.Context, req *mcp.CallToolRequest, input FileExistsInput) (*mcp.CallToolResult, FileExistsOutput, error) { - exists := s.medium.IsFile(input.Path) - if exists { - return nil, FileExistsOutput{Exists: true, IsDir: false, Path: input.Path}, nil + info, err := s.medium.Stat(input.Path) + if err != nil { + // Any error from Stat (e.g., not found, permission denied) is treated as "does not exist" + // for the purpose of this tool. + return nil, FileExistsOutput{Exists: false, IsDir: false, Path: input.Path}, nil } - // Check if it's a directory by attempting to list it - // List might fail if it's a file too (but we checked IsFile) or if doesn't exist. - _, err := s.medium.List(input.Path) - isDir := err == nil - // If List failed, it might mean it doesn't exist OR it's a special file or permissions. - // Assuming if List works, it's a directory. - - // Refinement: If it doesn't exist, List returns error. - - return nil, FileExistsOutput{Exists: isDir, IsDir: isDir, Path: input.Path}, nil + return nil, FileExistsOutput{ + Exists: true, + IsDir: info.IsDir(), + Path: input.Path, + }, nil } func (s *Service) detectLanguage(ctx context.Context, req *mcp.CallToolRequest, input DetectLanguageInput) (*mcp.CallToolResult, DetectLanguageOutput, error) { diff --git a/pkg/mcp/mcp_test.go b/pkg/mcp/mcp_test.go index 544d2da2..2172abda 100644 --- a/pkg/mcp/mcp_test.go +++ b/pkg/mcp/mcp_test.go @@ -144,12 +144,15 @@ func TestSandboxing_Traversal_Sanitized(t *testing.T) { t.Error("Expected error (file not found)") } - // Absolute paths are allowed through - they access the real filesystem. - // This is intentional for full filesystem access. Callers wanting sandboxing - // should validate inputs before calling Medium. + // Absolute paths are also sandboxed under the root directory. + // For example, /etc/passwd becomes /etc/passwd. + _, err = s.medium.Read("/etc/passwd") + if err == nil { + t.Error("Expected error (file not found in sandbox)") + } } -func TestSandboxing_Symlinks_Followed(t *testing.T) { +func TestSandboxing_Symlinks_Blocked(t *testing.T) { tmpDir := t.TempDir() outsideDir := t.TempDir() @@ -170,14 +173,15 @@ func TestSandboxing_Symlinks_Followed(t *testing.T) { t.Fatalf("Failed to create service: %v", err) } - // Symlinks are followed - no traversal blocking at Medium level. - // This is intentional for simplicity. Callers wanting to block symlinks - // should validate inputs before calling Medium. - content, err := s.medium.Read("link") - if err != nil { - t.Errorf("Expected symlink to be followed, got error: %v", err) + // Symlinks that escape the sandbox should be blocked. + _, err = s.medium.Read("link") + if err == nil { + t.Error("Expected error for symlink escaping sandbox, got nil") } - if content != "secret" { - t.Errorf("Expected 'secret', got '%s'", content) + + // Symlinks that escape the sandbox should be blocked even if target doesn't exist. + _, err = s.medium.Read("link/nonexistent") + if err == nil { + t.Error("Expected error for symlink/nonexistent escaping sandbox, got nil") } }