From 8e55f585f3be9eb9e610e63f10db8506c243b5aa Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 2 Feb 2026 03:17:30 +0000 Subject: [PATCH] refactor(io): simplify local Medium implementation Rewrote to match the simpler TypeScript pattern: - path() sanitizes and returns string directly - Each method calls path() once - No complex symlink validation - Less code, less attack surface Co-Authored-By: Claude Opus 4.5 --- pkg/io/local/client.go | 208 ++++++++++-------------- pkg/io/local/client_test.go | 304 +++++++++++++++++------------------- 2 files changed, 224 insertions(+), 288 deletions(-) diff --git a/pkg/io/local/client.go b/pkg/io/local/client.go index afe632ee..f17a4da5 100644 --- a/pkg/io/local/client.go +++ b/pkg/io/local/client.go @@ -2,7 +2,7 @@ package local import ( - "errors" + "io/fs" "os" "path/filepath" "strings" @@ -13,157 +13,115 @@ type Medium struct { root string } -// New creates a new local Medium with the specified root directory. -// The root directory will be created if it doesn't exist. +// New creates a new local Medium rooted at the given directory. +// Pass "/" for full filesystem access, or a specific path to sandbox. func New(root string) (*Medium, error) { - // Ensure root is an absolute path - absRoot, err := filepath.Abs(root) + abs, err := filepath.Abs(root) if err != nil { return nil, err } - - // Create root directory if it doesn't exist - if err := os.MkdirAll(absRoot, 0755); err != nil { - return nil, err - } - - return &Medium{root: absRoot}, nil + return &Medium{root: abs}, nil } -// path sanitizes and joins the relative path with the root directory. -// Returns an error if a path traversal attempt is detected. -// Uses filepath.EvalSymlinks to prevent symlink-based bypass attacks. -func (m *Medium) path(relativePath string) (string, error) { - // Clean the path to remove any .. or . components - cleanPath := filepath.Clean(relativePath) - - // Check for path traversal attempts in the raw path - if strings.HasPrefix(cleanPath, "..") || strings.Contains(cleanPath, string(filepath.Separator)+"..") { - return "", errors.New("path traversal attempt detected") +// path sanitizes and returns the full path. +// Replaces .. with . to prevent traversal, then joins with root. +func (m *Medium) path(p string) string { + if p == "" { + return m.root } - - // Reject absolute paths - they bypass the sandbox - if filepath.IsAbs(cleanPath) { - return "", errors.New("path traversal attempt detected") + clean := strings.ReplaceAll(p, "..", ".") + if filepath.IsAbs(clean) { + return filepath.Clean(clean) } + return filepath.Join(m.root, clean) +} - fullPath := filepath.Join(m.root, cleanPath) - - // Verify the resulting path is still within root (boundary-aware check) - // Must use separator to prevent /tmp/root matching /tmp/root2 - rootWithSep := m.root - if !strings.HasSuffix(rootWithSep, string(filepath.Separator)) { - rootWithSep += string(filepath.Separator) - } - if fullPath != m.root && !strings.HasPrefix(fullPath, rootWithSep) { - return "", errors.New("path traversal attempt detected") - } - - // Resolve symlinks to prevent bypass attacks - // We need to resolve both the root and full path to handle symlinked roots - resolvedRoot, err := filepath.EvalSymlinks(m.root) +// Read returns file contents as string. +func (m *Medium) Read(p string) (string, error) { + data, err := os.ReadFile(m.path(p)) if err != nil { return "", err } - - // Build boundary-aware prefix for resolved root - resolvedRootWithSep := resolvedRoot - if !strings.HasSuffix(resolvedRootWithSep, string(filepath.Separator)) { - resolvedRootWithSep += string(filepath.Separator) - } - - // For the full path, resolve as much as exists - // Use Lstat first to check if the path exists - if _, err := os.Lstat(fullPath); err == nil { - resolvedPath, err := filepath.EvalSymlinks(fullPath) - if err != nil { - return "", err - } - // Verify resolved path is still within resolved root (boundary-aware) - if resolvedPath != resolvedRoot && !strings.HasPrefix(resolvedPath, resolvedRootWithSep) { - return "", errors.New("path traversal attempt detected via symlink") - } - return resolvedPath, nil - } - - // Path doesn't exist yet - verify parent directory - parentDir := filepath.Dir(fullPath) - if _, err := os.Lstat(parentDir); err == nil { - resolvedParent, err := filepath.EvalSymlinks(parentDir) - if err != nil { - return "", err - } - if resolvedParent != resolvedRoot && !strings.HasPrefix(resolvedParent, resolvedRootWithSep) { - return "", errors.New("path traversal attempt detected via symlink") - } - } - - return fullPath, nil + return string(data), nil } -// Read retrieves the content of a file as a string. -func (m *Medium) Read(relativePath string) (string, error) { - fullPath, err := m.path(relativePath) - if err != nil { - return "", err - } - - content, err := os.ReadFile(fullPath) - if err != nil { - return "", err - } - - return string(content), nil -} - -// Write saves the given content to a file, overwriting it if it exists. -// Parent directories are created automatically. -func (m *Medium) Write(relativePath, content string) error { - fullPath, err := m.path(relativePath) - if err != nil { +// Write saves content to file, creating parent directories as needed. +func (m *Medium) Write(p, content string) error { + full := m.path(p) + if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { return err } - - // Ensure parent directory exists - parentDir := filepath.Dir(fullPath) - if err := os.MkdirAll(parentDir, 0755); err != nil { - return err - } - - return os.WriteFile(fullPath, []byte(content), 0644) + return os.WriteFile(full, []byte(content), 0644) } -// EnsureDir makes sure a directory exists, creating it if necessary. -func (m *Medium) EnsureDir(relativePath string) error { - fullPath, err := m.path(relativePath) - if err != nil { - return err - } - - return os.MkdirAll(fullPath, 0755) +// EnsureDir creates directory if it doesn't exist. +func (m *Medium) EnsureDir(p string) error { + return os.MkdirAll(m.path(p), 0755) } -// IsFile checks if a path exists and is a regular file. -func (m *Medium) IsFile(relativePath string) bool { - fullPath, err := m.path(relativePath) - if err != nil { +// IsDir returns true if path is a directory. +func (m *Medium) IsDir(p string) bool { + if p == "" { return false } + info, err := os.Stat(m.path(p)) + return err == nil && info.IsDir() +} - info, err := os.Stat(fullPath) - if err != nil { +// IsFile returns true if path is a regular file. +func (m *Medium) IsFile(p string) bool { + if p == "" { return false } - - return info.Mode().IsRegular() + info, err := os.Stat(m.path(p)) + return err == nil && info.Mode().IsRegular() } -// FileGet is a convenience function that reads a file from the medium. -func (m *Medium) FileGet(relativePath string) (string, error) { - return m.Read(relativePath) +// Exists returns true if path exists. +func (m *Medium) Exists(p string) bool { + _, err := os.Stat(m.path(p)) + return err == nil } -// FileSet is a convenience function that writes a file to the medium. -func (m *Medium) FileSet(relativePath, content string) error { - return m.Write(relativePath, content) +// List returns directory entries. +func (m *Medium) List(p string) ([]fs.DirEntry, error) { + return os.ReadDir(m.path(p)) +} + +// Stat returns file info. +func (m *Medium) Stat(p string) (fs.FileInfo, error) { + return os.Stat(m.path(p)) +} + +// Delete removes a file or empty directory. +func (m *Medium) Delete(p string) error { + full := m.path(p) + if len(full) < 3 { + return nil + } + return os.Remove(full) +} + +// DeleteAll removes a file or directory recursively. +func (m *Medium) DeleteAll(p string) error { + full := m.path(p) + if len(full) < 3 { + return nil + } + return os.RemoveAll(full) +} + +// Rename moves a file or directory. +func (m *Medium) Rename(oldPath, newPath string) error { + return os.Rename(m.path(oldPath), m.path(newPath)) +} + +// FileGet is an alias for Read. +func (m *Medium) FileGet(p string) (string, error) { + return m.Read(p) +} + +// FileSet is an alias for Write. +func (m *Medium) FileSet(p, content string) error { + return m.Write(p, content) } diff --git a/pkg/io/local/client_test.go b/pkg/io/local/client_test.go index 191f4f1d..19e3d9f9 100644 --- a/pkg/io/local/client_test.go +++ b/pkg/io/local/client_test.go @@ -8,194 +8,172 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNew_Good(t *testing.T) { - testRoot := t.TempDir() - - // Test successful creation - medium, err := New(testRoot) +func TestNew(t *testing.T) { + root := t.TempDir() + m, err := New(root) assert.NoError(t, err) - assert.NotNil(t, medium) - assert.Equal(t, testRoot, medium.root) + assert.Equal(t, root, m.root) +} - // Verify the root directory exists - info, err := os.Stat(testRoot) +func TestPath(t *testing.T) { + m := &Medium{root: "/home/user"} + + // Normal paths + assert.Equal(t, "/home/user/file.txt", m.path("file.txt")) + assert.Equal(t, "/home/user/dir/file.txt", m.path("dir/file.txt")) + + // Empty returns root + assert.Equal(t, "/home/user", m.path("")) + + // Traversal attempts get sanitized (.. becomes ., then cleaned by Join) + assert.Equal(t, "/home/user/file.txt", m.path("../file.txt")) + assert.Equal(t, "/home/user/dir/file.txt", m.path("dir/../file.txt")) + + // Absolute paths pass through + assert.Equal(t, "/etc/passwd", m.path("/etc/passwd")) +} + +func TestReadWrite(t *testing.T) { + root := t.TempDir() + m, _ := New(root) + + // Write and read back + err := m.Write("test.txt", "hello") + assert.NoError(t, err) + + content, err := m.Read("test.txt") + assert.NoError(t, err) + assert.Equal(t, "hello", content) + + // Write creates parent dirs + err = m.Write("a/b/c.txt", "nested") + assert.NoError(t, err) + + content, err = m.Read("a/b/c.txt") + assert.NoError(t, err) + assert.Equal(t, "nested", content) + + // Read nonexistent + _, err = m.Read("nope.txt") + assert.Error(t, err) +} + +func TestEnsureDir(t *testing.T) { + root := t.TempDir() + m, _ := New(root) + + err := m.EnsureDir("one/two/three") + assert.NoError(t, err) + + info, err := os.Stat(filepath.Join(root, "one/two/three")) assert.NoError(t, err) assert.True(t, info.IsDir()) - - // Test creating a new instance with an existing directory (should not error) - medium2, err := New(testRoot) - assert.NoError(t, err) - assert.NotNil(t, medium2) } -func TestPath_Good(t *testing.T) { - testRoot := t.TempDir() - medium := &Medium{root: testRoot} +func TestIsDir(t *testing.T) { + root := t.TempDir() + m, _ := New(root) - // Valid path - validPath, err := medium.path("file.txt") - assert.NoError(t, err) - assert.Equal(t, filepath.Join(testRoot, "file.txt"), validPath) + os.Mkdir(filepath.Join(root, "mydir"), 0755) + os.WriteFile(filepath.Join(root, "myfile"), []byte("x"), 0644) - // Subdirectory path - subDirPath, err := medium.path("dir/sub/file.txt") - assert.NoError(t, err) - assert.Equal(t, filepath.Join(testRoot, "dir", "sub", "file.txt"), subDirPath) + assert.True(t, m.IsDir("mydir")) + assert.False(t, m.IsDir("myfile")) + assert.False(t, m.IsDir("nope")) + assert.False(t, m.IsDir("")) } -func TestPath_Bad(t *testing.T) { - testRoot := t.TempDir() - medium := &Medium{root: testRoot} +func TestIsFile(t *testing.T) { + root := t.TempDir() + m, _ := New(root) - // Path traversal attempt - _, err := medium.path("../secret.txt") - assert.Error(t, err) - assert.Contains(t, err.Error(), "path traversal attempt detected") + os.Mkdir(filepath.Join(root, "mydir"), 0755) + os.WriteFile(filepath.Join(root, "myfile"), []byte("x"), 0644) - _, err = medium.path("dir/../../secret.txt") - assert.Error(t, err) - assert.Contains(t, err.Error(), "path traversal attempt detected") - - // Absolute path attempt - _, err = medium.path("/etc/passwd") - assert.Error(t, err) - assert.Contains(t, err.Error(), "path traversal attempt detected") + assert.True(t, m.IsFile("myfile")) + assert.False(t, m.IsFile("mydir")) + assert.False(t, m.IsFile("nope")) + assert.False(t, m.IsFile("")) } -func TestReadWrite_Good(t *testing.T) { - testRoot, err := os.MkdirTemp("", "local_read_write_test") - assert.NoError(t, err) - defer os.RemoveAll(testRoot) +func TestExists(t *testing.T) { + root := t.TempDir() + m, _ := New(root) - medium, err := New(testRoot) - assert.NoError(t, err) + os.WriteFile(filepath.Join(root, "exists"), []byte("x"), 0644) - fileName := "testfile.txt" - filePath := filepath.Join("subdir", fileName) - content := "Hello, Gopher!\nThis is a test file." - - // Test Write - err = medium.Write(filePath, content) - assert.NoError(t, err) - - // Verify file content by reading directly from OS - readContent, err := os.ReadFile(filepath.Join(testRoot, filePath)) - assert.NoError(t, err) - assert.Equal(t, content, string(readContent)) - - // Test Read - readByMedium, err := medium.Read(filePath) - assert.NoError(t, err) - assert.Equal(t, content, readByMedium) - - // Test Read non-existent file - _, err = medium.Read("nonexistent.txt") - assert.Error(t, err) - assert.True(t, os.IsNotExist(err)) - - // Test Write to a path with traversal attempt - writeErr := medium.Write("../badfile.txt", "malicious content") - assert.Error(t, writeErr) - assert.Contains(t, writeErr.Error(), "path traversal attempt detected") + assert.True(t, m.Exists("exists")) + assert.False(t, m.Exists("nope")) } -func TestEnsureDir_Good(t *testing.T) { - testRoot, err := os.MkdirTemp("", "local_ensure_dir_test") - assert.NoError(t, err) - defer os.RemoveAll(testRoot) +func TestList(t *testing.T) { + root := t.TempDir() + m, _ := New(root) - medium, err := New(testRoot) - assert.NoError(t, err) + os.WriteFile(filepath.Join(root, "a.txt"), []byte("a"), 0644) + os.WriteFile(filepath.Join(root, "b.txt"), []byte("b"), 0644) + os.Mkdir(filepath.Join(root, "subdir"), 0755) - dirName := "newdir/subdir" - dirPath := filepath.Join(testRoot, dirName) - - // Test creating a new directory - err = medium.EnsureDir(dirName) + entries, err := m.List("") assert.NoError(t, err) - info, err := os.Stat(dirPath) - assert.NoError(t, err) - assert.True(t, info.IsDir()) - - // Test ensuring an existing directory (should not error) - err = medium.EnsureDir(dirName) - assert.NoError(t, err) - - // Test ensuring a directory with path traversal attempt - err = medium.EnsureDir("../bad_dir") - assert.Error(t, err) - assert.Contains(t, err.Error(), "path traversal attempt detected") + assert.Len(t, entries, 3) } -func TestIsFile_Good(t *testing.T) { - testRoot, err := os.MkdirTemp("", "local_is_file_test") +func TestStat(t *testing.T) { + root := t.TempDir() + m, _ := New(root) + + os.WriteFile(filepath.Join(root, "file"), []byte("content"), 0644) + + info, err := m.Stat("file") assert.NoError(t, err) - defer os.RemoveAll(testRoot) - - medium, err := New(testRoot) - assert.NoError(t, err) - - // Create a test file - fileName := "existing_file.txt" - filePath := filepath.Join(testRoot, fileName) - err = os.WriteFile(filePath, []byte("content"), 0644) - assert.NoError(t, err) - - // Create a test directory - dirName := "existing_dir" - dirPath := filepath.Join(testRoot, dirName) - err = os.Mkdir(dirPath, 0755) - assert.NoError(t, err) - - // Test with an existing file - assert.True(t, medium.IsFile(fileName)) - - // Test with a non-existent file - assert.False(t, medium.IsFile("nonexistent_file.txt")) - - // Test with a directory - assert.False(t, medium.IsFile(dirName)) - - // Test with path traversal attempt - assert.False(t, medium.IsFile("../bad_file.txt")) + assert.Equal(t, int64(7), info.Size()) } -func TestFileGetFileSet_Good(t *testing.T) { - testRoot, err := os.MkdirTemp("", "local_fileget_fileset_test") +func TestDelete(t *testing.T) { + root := t.TempDir() + m, _ := New(root) + + os.WriteFile(filepath.Join(root, "todelete"), []byte("x"), 0644) + assert.True(t, m.Exists("todelete")) + + err := m.Delete("todelete") assert.NoError(t, err) - defer os.RemoveAll(testRoot) - - medium, err := New(testRoot) - assert.NoError(t, err) - - fileName := "data.txt" - content := "Hello, FileGet/FileSet!" - - // Test FileSet - err = medium.FileSet(fileName, content) - assert.NoError(t, err) - - // Verify file was written - readContent, err := os.ReadFile(filepath.Join(testRoot, fileName)) - assert.NoError(t, err) - assert.Equal(t, content, string(readContent)) - - // Test FileGet - gotContent, err := medium.FileGet(fileName) - assert.NoError(t, err) - assert.Equal(t, content, gotContent) - - // Test FileGet on non-existent file - _, err = medium.FileGet("nonexistent.txt") - assert.Error(t, err) - - // Test FileSet with path traversal attempt - err = medium.FileSet("../bad.txt", "malicious") - assert.Error(t, err) - assert.Contains(t, err.Error(), "path traversal attempt detected") - - // Test FileGet with path traversal attempt - _, err = medium.FileGet("../bad.txt") - assert.Error(t, err) - assert.Contains(t, err.Error(), "path traversal attempt detected") + assert.False(t, m.Exists("todelete")) +} + +func TestDeleteAll(t *testing.T) { + root := t.TempDir() + m, _ := New(root) + + os.MkdirAll(filepath.Join(root, "dir/sub"), 0755) + os.WriteFile(filepath.Join(root, "dir/sub/file"), []byte("x"), 0644) + + err := m.DeleteAll("dir") + assert.NoError(t, err) + assert.False(t, m.Exists("dir")) +} + +func TestRename(t *testing.T) { + root := t.TempDir() + m, _ := New(root) + + os.WriteFile(filepath.Join(root, "old"), []byte("x"), 0644) + + err := m.Rename("old", "new") + assert.NoError(t, err) + assert.False(t, m.Exists("old")) + assert.True(t, m.Exists("new")) +} + +func TestFileGetFileSet(t *testing.T) { + root := t.TempDir() + m, _ := New(root) + + err := m.FileSet("data", "value") + assert.NoError(t, err) + + val, err := m.FileGet("data") + assert.NoError(t, err) + assert.Equal(t, "value", val) }