fix(io): address audit issue 4 findings

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-23 07:26:09 +00:00
parent 163692870f
commit 2acfc3d548
11 changed files with 590 additions and 146 deletions

View file

@ -17,14 +17,26 @@ import (
"sync" "sync"
"time" "time"
borgdatanode "forge.lthn.ai/Snider/Borg/pkg/datanode"
coreerr "forge.lthn.ai/core/go-log" coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/Snider/Borg/pkg/datanode" )
var (
dataNodeWalkDir = func(fsys fs.FS, root string, fn fs.WalkDirFunc) error {
return fs.WalkDir(fsys, root, fn)
}
dataNodeOpen = func(dn *borgdatanode.DataNode, name string) (fs.File, error) {
return dn.Open(name)
}
dataNodeReadAll = func(r goio.Reader) ([]byte, error) {
return goio.ReadAll(r)
}
) )
// Medium is an in-memory storage backend backed by a Borg DataNode. // Medium is an in-memory storage backend backed by a Borg DataNode.
// All paths are relative (no leading slash). Thread-safe via RWMutex. // All paths are relative (no leading slash). Thread-safe via RWMutex.
type Medium struct { type Medium struct {
dn *datanode.DataNode dn *borgdatanode.DataNode
dirs map[string]bool // explicit directory tracking dirs map[string]bool // explicit directory tracking
mu sync.RWMutex mu sync.RWMutex
} }
@ -32,14 +44,14 @@ type Medium struct {
// New creates a new empty DataNode Medium. // New creates a new empty DataNode Medium.
func New() *Medium { func New() *Medium {
return &Medium{ return &Medium{
dn: datanode.New(), dn: borgdatanode.New(),
dirs: make(map[string]bool), dirs: make(map[string]bool),
} }
} }
// FromTar creates a Medium from a tarball, restoring all files. // FromTar creates a Medium from a tarball, restoring all files.
func FromTar(data []byte) (*Medium, error) { func FromTar(data []byte) (*Medium, error) {
dn, err := datanode.FromTar(data) dn, err := borgdatanode.FromTar(data)
if err != nil { if err != nil {
return nil, coreerr.E("datanode.FromTar", "failed to restore", err) return nil, coreerr.E("datanode.FromTar", "failed to restore", err)
} }
@ -63,7 +75,7 @@ func (m *Medium) Snapshot() ([]byte, error) {
// Restore replaces the filesystem contents from a tarball. // Restore replaces the filesystem contents from a tarball.
func (m *Medium) Restore(data []byte) error { func (m *Medium) Restore(data []byte) error {
dn, err := datanode.FromTar(data) dn, err := borgdatanode.FromTar(data)
if err != nil { if err != nil {
return coreerr.E("datanode.Restore", "tar failed", err) return coreerr.E("datanode.Restore", "tar failed", err)
} }
@ -76,7 +88,7 @@ func (m *Medium) Restore(data []byte) error {
// DataNode returns the underlying Borg DataNode. // DataNode returns the underlying Borg DataNode.
// Use this to wrap the filesystem in a TIM container. // Use this to wrap the filesystem in a TIM container.
func (m *Medium) DataNode() *datanode.DataNode { func (m *Medium) DataNode() *borgdatanode.DataNode {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
return m.dn return m.dn
@ -195,7 +207,11 @@ func (m *Medium) Delete(p string) error {
// Check explicit dirs // Check explicit dirs
if m.dirs[p] { if m.dirs[p] {
// Check if dir is empty // Check if dir is empty
if m.hasPrefixLocked(p + "/") { hasChildren, err := m.hasPrefixLocked(p + "/")
if err != nil {
return coreerr.E("datanode.Delete", "failed to inspect directory: "+p, err)
}
if hasChildren {
return coreerr.E("datanode.Delete", "directory not empty: "+p, os.ErrExist) return coreerr.E("datanode.Delete", "directory not empty: "+p, os.ErrExist)
} }
delete(m.dirs, p) delete(m.dirs, p)
@ -205,7 +221,11 @@ func (m *Medium) Delete(p string) error {
} }
if info.IsDir() { if info.IsDir() {
if m.hasPrefixLocked(p + "/") { hasChildren, err := m.hasPrefixLocked(p + "/")
if err != nil {
return coreerr.E("datanode.Delete", "failed to inspect directory: "+p, err)
}
if hasChildren {
return coreerr.E("datanode.Delete", "directory not empty: "+p, os.ErrExist) return coreerr.E("datanode.Delete", "directory not empty: "+p, os.ErrExist)
} }
delete(m.dirs, p) delete(m.dirs, p)
@ -213,7 +233,9 @@ func (m *Medium) Delete(p string) error {
} }
// Remove the file by creating a new DataNode without it // Remove the file by creating a new DataNode without it
m.removeFileLocked(p) if err := m.removeFileLocked(p); err != nil {
return coreerr.E("datanode.Delete", "failed to delete file: "+p, err)
}
return nil return nil
} }
@ -232,15 +254,22 @@ func (m *Medium) DeleteAll(p string) error {
// Check if p itself is a file // Check if p itself is a file
info, err := m.dn.Stat(p) info, err := m.dn.Stat(p)
if err == nil && !info.IsDir() { if err == nil && !info.IsDir() {
m.removeFileLocked(p) if err := m.removeFileLocked(p); err != nil {
return coreerr.E("datanode.DeleteAll", "failed to delete file: "+p, err)
}
found = true found = true
} }
// Remove all files under prefix // Remove all files under prefix
entries, _ := m.collectAllLocked() entries, err := m.collectAllLocked()
if err != nil {
return coreerr.E("datanode.DeleteAll", "failed to inspect tree: "+p, err)
}
for _, name := range entries { for _, name := range entries {
if name == p || strings.HasPrefix(name, prefix) { if name == p || strings.HasPrefix(name, prefix) {
m.removeFileLocked(name) if err := m.removeFileLocked(name); err != nil {
return coreerr.E("datanode.DeleteAll", "failed to delete file: "+name, err)
}
found = true found = true
} }
} }
@ -274,18 +303,15 @@ func (m *Medium) Rename(oldPath, newPath string) error {
if !info.IsDir() { if !info.IsDir() {
// Read old, write new, delete old // Read old, write new, delete old
f, err := m.dn.Open(oldPath) data, err := m.readFileLocked(oldPath)
if err != nil { if err != nil {
return coreerr.E("datanode.Rename", "open failed: "+oldPath, err) return coreerr.E("datanode.Rename", "failed to read source file: "+oldPath, err)
}
data, err := goio.ReadAll(f)
f.Close()
if err != nil {
return coreerr.E("datanode.Rename", "read failed: "+oldPath, err)
} }
m.dn.AddData(newPath, data) m.dn.AddData(newPath, data)
m.ensureDirsLocked(path.Dir(newPath)) m.ensureDirsLocked(path.Dir(newPath))
m.removeFileLocked(oldPath) if err := m.removeFileLocked(oldPath); err != nil {
return coreerr.E("datanode.Rename", "failed to remove source file: "+oldPath, err)
}
return nil return nil
} }
@ -293,18 +319,21 @@ func (m *Medium) Rename(oldPath, newPath string) error {
oldPrefix := oldPath + "/" oldPrefix := oldPath + "/"
newPrefix := newPath + "/" newPrefix := newPath + "/"
entries, _ := m.collectAllLocked() entries, err := m.collectAllLocked()
if err != nil {
return coreerr.E("datanode.Rename", "failed to inspect tree: "+oldPath, err)
}
for _, name := range entries { for _, name := range entries {
if strings.HasPrefix(name, oldPrefix) { if strings.HasPrefix(name, oldPrefix) {
newName := newPrefix + strings.TrimPrefix(name, oldPrefix) newName := newPrefix + strings.TrimPrefix(name, oldPrefix)
f, err := m.dn.Open(name) data, err := m.readFileLocked(name)
if err != nil { if err != nil {
continue return coreerr.E("datanode.Rename", "failed to read source file: "+name, err)
} }
data, _ := goio.ReadAll(f)
f.Close()
m.dn.AddData(newName, data) m.dn.AddData(newName, data)
m.removeFileLocked(name) if err := m.removeFileLocked(name); err != nil {
return coreerr.E("datanode.Rename", "failed to remove source file: "+name, err)
}
} }
} }
@ -416,10 +445,13 @@ func (m *Medium) Append(p string) (goio.WriteCloser, error) {
// Read existing content // Read existing content
var existing []byte var existing []byte
m.mu.RLock() m.mu.RLock()
f, err := m.dn.Open(p) if m.IsFile(p) {
if err == nil { data, err := m.readFileLocked(p)
existing, _ = goio.ReadAll(f) if err != nil {
f.Close() m.mu.RUnlock()
return nil, coreerr.E("datanode.Append", "failed to read existing content: "+p, err)
}
existing = data
} }
m.mu.RUnlock() m.mu.RUnlock()
@ -475,27 +507,30 @@ func (m *Medium) IsDir(p string) bool {
// --- internal helpers --- // --- internal helpers ---
// hasPrefixLocked checks if any file path starts with prefix. Caller holds lock. // hasPrefixLocked checks if any file path starts with prefix. Caller holds lock.
func (m *Medium) hasPrefixLocked(prefix string) bool { func (m *Medium) hasPrefixLocked(prefix string) (bool, error) {
entries, _ := m.collectAllLocked() entries, err := m.collectAllLocked()
if err != nil {
return false, err
}
for _, name := range entries { for _, name := range entries {
if strings.HasPrefix(name, prefix) { if strings.HasPrefix(name, prefix) {
return true return true, nil
} }
} }
for d := range m.dirs { for d := range m.dirs {
if strings.HasPrefix(d, prefix) { if strings.HasPrefix(d, prefix) {
return true return true, nil
} }
} }
return false return false, nil
} }
// collectAllLocked returns all file paths in the DataNode. Caller holds lock. // collectAllLocked returns all file paths in the DataNode. Caller holds lock.
func (m *Medium) collectAllLocked() ([]string, error) { func (m *Medium) collectAllLocked() ([]string, error) {
var names []string var names []string
err := fs.WalkDir(m.dn, ".", func(p string, d fs.DirEntry, err error) error { err := dataNodeWalkDir(m.dn, ".", func(p string, d fs.DirEntry, err error) error {
if err != nil { if err != nil {
return nil return err
} }
if !d.IsDir() { if !d.IsDir() {
names = append(names, p) names = append(names, p)
@ -505,28 +540,43 @@ func (m *Medium) collectAllLocked() ([]string, error) {
return names, err return names, err
} }
func (m *Medium) readFileLocked(name string) ([]byte, error) {
f, err := dataNodeOpen(m.dn, name)
if err != nil {
return nil, err
}
data, readErr := dataNodeReadAll(f)
closeErr := f.Close()
if readErr != nil {
return nil, readErr
}
if closeErr != nil {
return nil, closeErr
}
return data, nil
}
// removeFileLocked removes a single file by rebuilding the DataNode. // removeFileLocked removes a single file by rebuilding the DataNode.
// This is necessary because Borg's DataNode doesn't expose a Remove method. // This is necessary because Borg's DataNode doesn't expose a Remove method.
// Caller must hold m.mu write lock. // Caller must hold m.mu write lock.
func (m *Medium) removeFileLocked(target string) { func (m *Medium) removeFileLocked(target string) error {
entries, _ := m.collectAllLocked() entries, err := m.collectAllLocked()
newDN := datanode.New() if err != nil {
return err
}
newDN := borgdatanode.New()
for _, name := range entries { for _, name := range entries {
if name == target { if name == target {
continue continue
} }
f, err := m.dn.Open(name) data, err := m.readFileLocked(name)
if err != nil { if err != nil {
continue return err
}
data, err := goio.ReadAll(f)
f.Close()
if err != nil {
continue
} }
newDN.AddData(name, data) newDN.AddData(name, data)
} }
m.dn = newDN m.dn = newDN
return nil
} }
// --- writeCloser buffers writes and flushes to DataNode on Close --- // --- writeCloser buffers writes and flushes to DataNode on Close ---

View file

@ -1,7 +1,9 @@
package datanode package datanode
import ( import (
"errors"
"io" "io"
"io/fs"
"testing" "testing"
coreio "dappco.re/go/core/io" coreio "dappco.re/go/core/io"
@ -102,6 +104,23 @@ func TestDelete_Bad(t *testing.T) {
assert.Error(t, m.Delete("dir")) assert.Error(t, m.Delete("dir"))
} }
func TestDelete_Bad_DirectoryInspectionFailure(t *testing.T) {
m := New()
require.NoError(t, m.Write("dir/file.txt", "content"))
original := dataNodeWalkDir
dataNodeWalkDir = func(_ fs.FS, _ string, _ fs.WalkDirFunc) error {
return errors.New("walk failed")
}
t.Cleanup(func() {
dataNodeWalkDir = original
})
err := m.Delete("dir")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to inspect directory")
}
func TestDeleteAll_Good(t *testing.T) { func TestDeleteAll_Good(t *testing.T) {
m := New() m := New()
@ -116,6 +135,41 @@ func TestDeleteAll_Good(t *testing.T) {
assert.True(t, m.Exists("keep.txt")) assert.True(t, m.Exists("keep.txt"))
} }
func TestDeleteAll_Bad_WalkFailure(t *testing.T) {
m := New()
require.NoError(t, m.Write("tree/a.txt", "a"))
original := dataNodeWalkDir
dataNodeWalkDir = func(_ fs.FS, _ string, _ fs.WalkDirFunc) error {
return errors.New("walk failed")
}
t.Cleanup(func() {
dataNodeWalkDir = original
})
err := m.DeleteAll("tree")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to inspect tree")
}
func TestDelete_Bad_RemoveFailure(t *testing.T) {
m := New()
require.NoError(t, m.Write("keep.txt", "keep"))
require.NoError(t, m.Write("bad.txt", "bad"))
original := dataNodeReadAll
dataNodeReadAll = func(_ io.Reader) ([]byte, error) {
return nil, errors.New("read failed")
}
t.Cleanup(func() {
dataNodeReadAll = original
})
err := m.Delete("bad.txt")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to delete file")
}
func TestRename_Good(t *testing.T) { func TestRename_Good(t *testing.T) {
m := New() m := New()
@ -147,6 +201,23 @@ func TestRenameDir_Good(t *testing.T) {
assert.Equal(t, "package b", got) assert.Equal(t, "package b", got)
} }
func TestRenameDir_Bad_ReadFailure(t *testing.T) {
m := New()
require.NoError(t, m.Write("src/a.go", "package a"))
original := dataNodeReadAll
dataNodeReadAll = func(_ io.Reader) ([]byte, error) {
return nil, errors.New("read failed")
}
t.Cleanup(func() {
dataNodeReadAll = original
})
err := m.Rename("src", "dst")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to read source file")
}
func TestList_Good(t *testing.T) { func TestList_Good(t *testing.T) {
m := New() m := New()
@ -230,6 +301,23 @@ func TestCreateAppend_Good(t *testing.T) {
assert.Equal(t, "hello world", got) assert.Equal(t, "hello world", got)
} }
func TestAppend_Bad_ReadFailure(t *testing.T) {
m := New()
require.NoError(t, m.Write("new.txt", "hello"))
original := dataNodeReadAll
dataNodeReadAll = func(_ io.Reader) ([]byte, error) {
return nil, errors.New("read failed")
}
t.Cleanup(func() {
dataNodeReadAll = original
})
_, err := m.Append("new.txt")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to read existing content")
}
func TestStreams_Good(t *testing.T) { func TestStreams_Good(t *testing.T) {
m := New() m := New()

6
go.mod
View file

@ -3,9 +3,8 @@ module dappco.re/go/core/io
go 1.26.0 go 1.26.0
require ( require (
dappco.re/go/core v0.4.7 dappco.re/go/core v0.6.0
forge.lthn.ai/Snider/Borg v0.3.1 forge.lthn.ai/Snider/Borg v0.3.1
forge.lthn.ai/core/go-crypt v0.1.6
forge.lthn.ai/core/go-log v0.0.4 forge.lthn.ai/core/go-log v0.0.4
github.com/aws/aws-sdk-go-v2 v1.41.4 github.com/aws/aws-sdk-go-v2 v1.41.4
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1
@ -15,8 +14,6 @@ require (
) )
require ( require (
forge.lthn.ai/core/go v0.3.0 // indirect
github.com/ProtonMail/go-crypto v1.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
@ -26,7 +23,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect
github.com/aws/smithy-go v1.24.2 // indirect github.com/aws/smithy-go v1.24.2 // indirect
github.com/cloudflare/circl v1.6.3 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect

12
go.sum
View file

@ -1,15 +1,9 @@
dappco.re/go/core v0.4.7 h1:KmIA/2lo6rl1NMtLrKqCWfMlUqpDZYH3q0/d10dTtGA= dappco.re/go/core v0.6.0 h1:0wmuO/UmCWXxJkxQ6XvVLnqkAuWitbd49PhxjCsplyk=
dappco.re/go/core v0.4.7/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A= dappco.re/go/core v0.6.0/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
forge.lthn.ai/Snider/Borg v0.3.1 h1:gfC1ZTpLoZai07oOWJiVeQ8+qJYK8A795tgVGJHbVL8= forge.lthn.ai/Snider/Borg v0.3.1 h1:gfC1ZTpLoZai07oOWJiVeQ8+qJYK8A795tgVGJHbVL8=
forge.lthn.ai/Snider/Borg v0.3.1/go.mod h1:Z7DJD0yHXsxSyM7Mjl6/g4gH1NBsIz44Bf5AFlV76Wg= forge.lthn.ai/Snider/Borg v0.3.1/go.mod h1:Z7DJD0yHXsxSyM7Mjl6/g4gH1NBsIz44Bf5AFlV76Wg=
forge.lthn.ai/core/go v0.3.0 h1:mOG97ApMprwx9Ked62FdWVwXTGSF6JO6m0DrVpoH2Q4=
forge.lthn.ai/core/go v0.3.0/go.mod h1:gE6c8h+PJ2287qNhVUJ5SOe1kopEwHEquvinstpuyJc=
forge.lthn.ai/core/go-crypt v0.1.6 h1:jB7L/28S1NR+91u3GcOYuKfBLzPhhBUY1fRe6WkGVns=
forge.lthn.ai/core/go-crypt v0.1.6/go.mod h1:4VZAGqxlbadhSB66sJkdj54/HSJ+bSxVgwWK5kMMYDo=
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0= forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
github.com/ProtonMail/go-crypto v1.4.0 h1:Zq/pbM3F5DFgJiMouxEdSVY44MVoQNEKp5d5QxIQceQ=
github.com/ProtonMail/go-crypto v1.4.0/go.mod h1:e1OaTyu5SYVrO9gKOEhTc+5UcXtTUa+P3uLudwcgPqo=
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k= github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY=
@ -32,8 +26,6 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 h1:csi9NLpFZXb9fxY7rS1xVzgPRGMt7
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1/go.mod h1:qXVal5H0ChqXP63t6jze5LmFalc7+ZE7wOdLtZ0LCP0= github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1/go.mod h1:qXVal5H0ChqXP63t6jze5LmFalc7+ZE7wOdLtZ0LCP0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

10
io.go
View file

@ -4,12 +4,12 @@ import (
goio "io" goio "io"
"io/fs" "io/fs"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
coreerr "forge.lthn.ai/core/go-log" core "dappco.re/go/core"
"dappco.re/go/core/io/local" "dappco.re/go/core/io/local"
coreerr "forge.lthn.ai/core/go-log"
) )
// Medium defines the standard interface for a storage backend. // Medium defines the standard interface for a storage backend.
@ -361,7 +361,7 @@ func (m *MockMedium) Open(path string) (fs.File, error) {
return nil, coreerr.E("io.MockMedium.Open", "file not found: "+path, os.ErrNotExist) return nil, coreerr.E("io.MockMedium.Open", "file not found: "+path, os.ErrNotExist)
} }
return &MockFile{ return &MockFile{
name: filepath.Base(path), name: core.PathBase(path),
content: []byte(content), content: []byte(content),
}, nil }, nil
} }
@ -556,7 +556,7 @@ func (m *MockMedium) Stat(path string) (fs.FileInfo, error) {
modTime = time.Now() modTime = time.Now()
} }
return FileInfo{ return FileInfo{
name: filepath.Base(path), name: core.PathBase(path),
size: int64(len(content)), size: int64(len(content)),
mode: 0644, mode: 0644,
modTime: modTime, modTime: modTime,
@ -564,7 +564,7 @@ func (m *MockMedium) Stat(path string) (fs.FileInfo, error) {
} }
if _, ok := m.Dirs[path]; ok { if _, ok := m.Dirs[path]; ok {
return FileInfo{ return FileInfo{
name: filepath.Base(path), name: core.PathBase(path),
isDir: true, isDir: true,
mode: fs.ModeDir | 0755, mode: fs.ModeDir | 0755,
}, nil }, nil

View file

@ -6,11 +6,10 @@ import (
goio "io" goio "io"
"io/fs" "io/fs"
"os" "os"
"os/user"
"path/filepath"
"strings" "strings"
"time" "time"
core "dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log" coreerr "forge.lthn.ai/core/go-log"
) )
@ -22,20 +21,163 @@ type Medium struct {
// New creates a new local Medium rooted at the given directory. // New creates a new local Medium rooted at the given directory.
// Pass "/" for full filesystem access, or a specific path to sandbox. // Pass "/" for full filesystem access, or a specific path to sandbox.
func New(root string) (*Medium, error) { func New(root string) (*Medium, error) {
abs, err := filepath.Abs(root) abs := absolutePath(root)
if err != nil {
return nil, err
}
// Resolve symlinks so sandbox checks compare like-for-like. // Resolve symlinks so sandbox checks compare like-for-like.
// On macOS, /var is a symlink to /private/var — without this, // On macOS, /var is a symlink to /private/var — without this,
// EvalSymlinks on child paths resolves to /private/var/... while // resolving child paths resolves to /private/var/... while
// root stays /var/..., causing false sandbox escape detections. // root stays /var/..., causing false sandbox escape detections.
if resolved, err := filepath.EvalSymlinks(abs); err == nil { if resolved, err := resolveSymlinksPath(abs); err == nil {
abs = resolved abs = resolved
} }
return &Medium{root: abs}, nil return &Medium{root: abs}, nil
} }
func dirSeparator() string {
if sep := core.Env("DS"); sep != "" {
return sep
}
return string(os.PathSeparator)
}
func normalisePath(p string) string {
sep := dirSeparator()
if sep == "/" {
return strings.ReplaceAll(p, "\\", sep)
}
return strings.ReplaceAll(p, "/", sep)
}
func currentWorkingDir() string {
if cwd, err := os.Getwd(); err == nil && cwd != "" {
return cwd
}
if cwd := core.Env("DIR_CWD"); cwd != "" {
return cwd
}
return "."
}
func absolutePath(p string) string {
p = normalisePath(p)
if core.PathIsAbs(p) {
return core.Path(p)
}
return core.Path(currentWorkingDir(), p)
}
func cleanSandboxPath(p string) string {
return core.Path(dirSeparator() + normalisePath(p))
}
func splitPathParts(p string) []string {
trimmed := strings.TrimPrefix(p, dirSeparator())
if trimmed == "" {
return nil
}
var parts []string
for _, part := range strings.Split(trimmed, dirSeparator()) {
if part == "" {
continue
}
parts = append(parts, part)
}
return parts
}
func resolveSymlinksPath(p string) (string, error) {
return resolveSymlinksRecursive(absolutePath(p), map[string]struct{}{})
}
func resolveSymlinksRecursive(p string, seen map[string]struct{}) (string, error) {
p = core.Path(p)
if p == dirSeparator() {
return p, nil
}
current := dirSeparator()
for _, part := range splitPathParts(p) {
next := core.Path(current, part)
info, err := os.Lstat(next)
if err != nil {
if os.IsNotExist(err) {
current = next
continue
}
return "", err
}
if info.Mode()&os.ModeSymlink == 0 {
current = next
continue
}
target, err := os.Readlink(next)
if err != nil {
return "", err
}
target = normalisePath(target)
if !core.PathIsAbs(target) {
target = core.Path(current, target)
} else {
target = core.Path(target)
}
if _, ok := seen[target]; ok {
return "", coreerr.E("local.resolveSymlinksPath", "symlink cycle: "+target, os.ErrInvalid)
}
seen[target] = struct{}{}
resolved, err := resolveSymlinksRecursive(target, seen)
delete(seen, target)
if err != nil {
return "", err
}
current = resolved
}
return current, nil
}
func isWithinRoot(root, target string) bool {
root = core.Path(root)
target = core.Path(target)
if root == dirSeparator() {
return true
}
return target == root || strings.HasPrefix(target, root+dirSeparator())
}
func canonicalPath(p string) string {
if p == "" {
return ""
}
if resolved, err := resolveSymlinksPath(p); err == nil {
return resolved
}
return absolutePath(p)
}
func isProtectedPath(full string) bool {
full = canonicalPath(full)
protected := map[string]struct{}{
canonicalPath(dirSeparator()): {},
}
for _, home := range []string{core.Env("HOME"), core.Env("DIR_HOME")} {
if home == "" {
continue
}
protected[canonicalPath(home)] = struct{}{}
}
_, ok := protected[full]
return ok
}
func logSandboxEscape(root, path, attempted string) {
username := core.Env("USER")
if username == "" {
username = "unknown"
}
fmt.Fprintf(os.Stderr, "[%s] SECURITY sandbox escape detected root=%s path=%s attempted=%s user=%s\n",
time.Now().Format(time.RFC3339), root, path, attempted, username)
}
// path sanitises and returns the full path. // path sanitises and returns the full path.
// Absolute paths are sandboxed under root (unless root is "/"). // Absolute paths are sandboxed under root (unless root is "/").
func (m *Medium) path(p string) string { func (m *Medium) path(p string) string {
@ -46,41 +188,36 @@ func (m *Medium) path(p string) string {
// If the path is relative and the medium is rooted at "/", // If the path is relative and the medium is rooted at "/",
// treat it as relative to the current working directory. // treat it as relative to the current working directory.
// This makes io.Local behave more like the standard 'os' package. // This makes io.Local behave more like the standard 'os' package.
if m.root == "/" && !filepath.IsAbs(p) { if m.root == dirSeparator() && !core.PathIsAbs(normalisePath(p)) {
cwd, _ := os.Getwd() return core.Path(currentWorkingDir(), normalisePath(p))
return filepath.Join(cwd, p)
} }
// Use filepath.Clean with a leading slash to resolve all .. and . internally // Use a cleaned absolute path to resolve all .. and . internally
// before joining with the root. This is a standard way to sandbox paths. // before joining with the root. This is a standard way to sandbox paths.
clean := filepath.Clean("/" + p) clean := cleanSandboxPath(p)
// If root is "/", allow absolute paths through // If root is "/", allow absolute paths through
if m.root == "/" { if m.root == dirSeparator() {
return clean return clean
} }
// Join cleaned relative path with root // Join cleaned relative path with root
return filepath.Join(m.root, clean) return core.Path(m.root, strings.TrimPrefix(clean, dirSeparator()))
} }
// validatePath ensures the path is within the sandbox, following symlinks if they exist. // validatePath ensures the path is within the sandbox, following symlinks if they exist.
func (m *Medium) validatePath(p string) (string, error) { func (m *Medium) validatePath(p string) (string, error) {
if m.root == "/" { if m.root == dirSeparator() {
return m.path(p), nil return m.path(p), nil
} }
// Split the cleaned path into components // Split the cleaned path into components
parts := strings.Split(filepath.Clean("/"+p), string(os.PathSeparator)) parts := splitPathParts(cleanSandboxPath(p))
current := m.root current := m.root
for _, part := range parts { for _, part := range parts {
if part == "" { next := core.Path(current, part)
continue realNext, err := resolveSymlinksPath(next)
}
next := filepath.Join(current, part)
realNext, err := filepath.EvalSymlinks(next)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
// Part doesn't exist, we can't follow symlinks anymore. // Part doesn't exist, we can't follow symlinks anymore.
@ -93,15 +230,9 @@ func (m *Medium) validatePath(p string) (string, error) {
} }
// Verify the resolved part is still within the root // Verify the resolved part is still within the root
rel, err := filepath.Rel(m.root, realNext) if !isWithinRoot(m.root, realNext) {
if err != nil || strings.HasPrefix(rel, "..") {
// Security event: sandbox escape attempt // Security event: sandbox escape attempt
username := "unknown" logSandboxEscape(m.root, p, realNext)
if u, err := user.Current(); err == nil {
username = u.Username
}
fmt.Fprintf(os.Stderr, "[%s] SECURITY sandbox escape detected root=%s path=%s attempted=%s user=%s\n",
time.Now().Format(time.RFC3339), m.root, p, realNext, username)
return "", os.ErrPermission // Path escapes sandbox return "", os.ErrPermission // Path escapes sandbox
} }
current = realNext current = realNext
@ -137,7 +268,7 @@ func (m *Medium) WriteMode(p, content string, mode os.FileMode) error {
if err != nil { if err != nil {
return err return err
} }
if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { if err := os.MkdirAll(core.PathDir(full), 0755); err != nil {
return err return err
} }
return os.WriteFile(full, []byte(content), mode) return os.WriteFile(full, []byte(content), mode)
@ -221,7 +352,7 @@ func (m *Medium) Create(p string) (goio.WriteCloser, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { if err := os.MkdirAll(core.PathDir(full), 0755); err != nil {
return nil, err return nil, err
} }
return os.Create(full) return os.Create(full)
@ -233,7 +364,7 @@ func (m *Medium) Append(p string) (goio.WriteCloser, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { if err := os.MkdirAll(core.PathDir(full), 0755); err != nil {
return nil, err return nil, err
} }
return os.OpenFile(full, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) return os.OpenFile(full, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
@ -265,7 +396,7 @@ func (m *Medium) Delete(p string) error {
if err != nil { if err != nil {
return err return err
} }
if full == "/" || full == os.Getenv("HOME") { if isProtectedPath(full) {
return coreerr.E("local.Delete", "refusing to delete protected path: "+full, nil) return coreerr.E("local.Delete", "refusing to delete protected path: "+full, nil)
} }
return os.Remove(full) return os.Remove(full)
@ -277,7 +408,7 @@ func (m *Medium) DeleteAll(p string) error {
if err != nil { if err != nil {
return err return err
} }
if full == "/" || full == os.Getenv("HOME") { if isProtectedPath(full) {
return coreerr.E("local.DeleteAll", "refusing to delete protected path: "+full, nil) return coreerr.E("local.DeleteAll", "refusing to delete protected path: "+full, nil)
} }
return os.RemoveAll(full) return os.RemoveAll(full)

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@ -170,6 +171,33 @@ func TestDeleteAll(t *testing.T) {
assert.False(t, m.Exists("dir")) assert.False(t, m.Exists("dir"))
} }
func TestDelete_ProtectedHomeViaSymlinkEnv(t *testing.T) {
realHome := t.TempDir()
linkParent := t.TempDir()
homeLink := filepath.Join(linkParent, "home-link")
require.NoError(t, os.Symlink(realHome, homeLink))
t.Setenv("HOME", homeLink)
m, err := New("/")
require.NoError(t, err)
err = m.Delete(realHome)
require.Error(t, err)
assert.DirExists(t, realHome)
}
func TestDeleteAll_ProtectedHomeViaEnv(t *testing.T) {
tempHome := t.TempDir()
t.Setenv("HOME", tempHome)
m, err := New("/")
require.NoError(t, err)
err = m.DeleteAll(tempHome)
require.Error(t, err)
assert.DirExists(t, tempHome)
}
func TestRename(t *testing.T) { func TestRename(t *testing.T) {
root := t.TempDir() root := t.TempDir()
m, _ := New(root) m, _ := New(root)

View file

@ -37,6 +37,29 @@ type Medium struct {
prefix string prefix string
} }
func deleteObjectsError(prefix string, errs []types.Error) error {
if len(errs) == 0 {
return nil
}
details := make([]string, 0, len(errs))
for _, item := range errs {
key := aws.ToString(item.Key)
code := aws.ToString(item.Code)
msg := aws.ToString(item.Message)
switch {
case code != "" && msg != "":
details = append(details, key+": "+code+" "+msg)
case code != "":
details = append(details, key+": "+code)
case msg != "":
details = append(details, key+": "+msg)
default:
details = append(details, key)
}
}
return coreerr.E("s3.DeleteAll", "partial delete failed under "+prefix+": "+strings.Join(details, "; "), nil)
}
// Option configures a Medium. // Option configures a Medium.
type Option func(*Medium) type Option func(*Medium)
@ -197,10 +220,13 @@ func (m *Medium) DeleteAll(p string) error {
} }
// First, try deleting the exact key // First, try deleting the exact key
_, _ = m.client.DeleteObject(context.Background(), &s3.DeleteObjectInput{ _, err := m.client.DeleteObject(context.Background(), &s3.DeleteObjectInput{
Bucket: aws.String(m.bucket), Bucket: aws.String(m.bucket),
Key: aws.String(key), Key: aws.String(key),
}) })
if err != nil {
return coreerr.E("s3.DeleteAll", "failed to delete object: "+key, err)
}
// Then delete all objects under the prefix // Then delete all objects under the prefix
prefix := key prefix := key
@ -230,13 +256,16 @@ func (m *Medium) DeleteAll(p string) error {
objects[i] = types.ObjectIdentifier{Key: obj.Key} objects[i] = types.ObjectIdentifier{Key: obj.Key}
} }
_, err = m.client.DeleteObjects(context.Background(), &s3.DeleteObjectsInput{ deleteOut, err := m.client.DeleteObjects(context.Background(), &s3.DeleteObjectsInput{
Bucket: aws.String(m.bucket), Bucket: aws.String(m.bucket),
Delete: &types.Delete{Objects: objects, Quiet: aws.Bool(true)}, Delete: &types.Delete{Objects: objects, Quiet: aws.Bool(true)},
}) })
if err != nil { if err != nil {
return coreerr.E("s3.DeleteAll", "failed to delete objects", err) return coreerr.E("s3.DeleteAll", "failed to delete objects", err)
} }
if err := deleteObjectsError(prefix, deleteOut.Errors); err != nil {
return err
}
if listOut.IsTruncated != nil && *listOut.IsTruncated { if listOut.IsTruncated != nil && *listOut.IsTruncated {
continuationToken = listOut.NextContinuationToken continuationToken = listOut.NextContinuationToken

View file

@ -3,6 +3,7 @@ package s3
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
goio "io" goio "io"
"io/fs" "io/fs"
@ -21,15 +22,19 @@ import (
// mockS3 is an in-memory mock implementing the s3API interface. // mockS3 is an in-memory mock implementing the s3API interface.
type mockS3 struct { type mockS3 struct {
mu sync.RWMutex mu sync.RWMutex
objects map[string][]byte objects map[string][]byte
mtimes map[string]time.Time mtimes map[string]time.Time
deleteObjectErrors map[string]error
deleteObjectsErrs map[string]types.Error
} }
func newMockS3() *mockS3 { func newMockS3() *mockS3 {
return &mockS3{ return &mockS3{
objects: make(map[string][]byte), objects: make(map[string][]byte),
mtimes: make(map[string]time.Time), mtimes: make(map[string]time.Time),
deleteObjectErrors: make(map[string]error),
deleteObjectsErrs: make(map[string]types.Error),
} }
} }
@ -69,6 +74,9 @@ func (m *mockS3) DeleteObject(_ context.Context, params *s3.DeleteObjectInput, _
defer m.mu.Unlock() defer m.mu.Unlock()
key := aws.ToString(params.Key) key := aws.ToString(params.Key)
if err, ok := m.deleteObjectErrors[key]; ok {
return nil, err
}
delete(m.objects, key) delete(m.objects, key)
delete(m.mtimes, key) delete(m.mtimes, key)
return &s3.DeleteObjectOutput{}, nil return &s3.DeleteObjectOutput{}, nil
@ -78,12 +86,17 @@ func (m *mockS3) DeleteObjects(_ context.Context, params *s3.DeleteObjectsInput,
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
var outErrs []types.Error
for _, obj := range params.Delete.Objects { for _, obj := range params.Delete.Objects {
key := aws.ToString(obj.Key) key := aws.ToString(obj.Key)
if errInfo, ok := m.deleteObjectsErrs[key]; ok {
outErrs = append(outErrs, errInfo)
continue
}
delete(m.objects, key) delete(m.objects, key)
delete(m.mtimes, key) delete(m.mtimes, key)
} }
return &s3.DeleteObjectsOutput{}, nil return &s3.DeleteObjectsOutput{Errors: outErrs}, nil
} }
func (m *mockS3) HeadObject(_ context.Context, params *s3.HeadObjectInput, _ ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { func (m *mockS3) HeadObject(_ context.Context, params *s3.HeadObjectInput, _ ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
@ -350,6 +363,34 @@ func TestDeleteAll_Bad_EmptyPath(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} }
func TestDeleteAll_Bad_DeleteObjectError(t *testing.T) {
m, mock := newTestMedium(t)
mock.deleteObjectErrors["dir"] = errors.New("boom")
err := m.DeleteAll("dir")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to delete object: dir")
}
func TestDeleteAll_Bad_PartialDelete(t *testing.T) {
m, mock := newTestMedium(t)
require.NoError(t, m.Write("dir/file1.txt", "a"))
require.NoError(t, m.Write("dir/file2.txt", "b"))
mock.deleteObjectsErrs["dir/file2.txt"] = types.Error{
Key: aws.String("dir/file2.txt"),
Code: aws.String("AccessDenied"),
Message: aws.String("blocked"),
}
err := m.DeleteAll("dir")
require.Error(t, err)
assert.Contains(t, err.Error(), "partial delete failed")
assert.Contains(t, err.Error(), "dir/file2.txt")
assert.True(t, m.IsFile("dir/file2.txt"))
assert.False(t, m.IsFile("dir/file1.txt"))
}
func TestRename_Good(t *testing.T) { func TestRename_Good(t *testing.T) {
m, _ := newTestMedium(t) m, _ := newTestMedium(t)

View file

@ -4,7 +4,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"os" "os"
"path/filepath" "strings"
"sync" "sync"
core "dappco.re/go/core" core "dappco.re/go/core"
@ -39,11 +39,11 @@ type Service struct {
// New creates a new Workspace service instance. // New creates a new Workspace service instance.
// An optional cryptProvider can be passed to supply PGP key generation. // An optional cryptProvider can be passed to supply PGP key generation.
func New(c *core.Core, crypt ...cryptProvider) (any, error) { func New(c *core.Core, crypt ...cryptProvider) (any, error) {
home, err := os.UserHomeDir() home := workspaceHome()
if err != nil { if home == "" {
return nil, coreerr.E("workspace.New", "failed to determine home directory", err) return nil, coreerr.E("workspace.New", "failed to determine home directory", os.ErrNotExist)
} }
rootPath := filepath.Join(home, ".core", "workspaces") rootPath := core.Path(home, ".core", "workspaces")
s := &Service{ s := &Service{
core: c, core: c,
@ -75,14 +75,17 @@ func (s *Service) CreateWorkspace(identifier, password string) (string, error) {
hash := sha256.Sum256([]byte(identifier)) hash := sha256.Sum256([]byte(identifier))
wsID := hex.EncodeToString(hash[:]) wsID := hex.EncodeToString(hash[:])
wsPath := filepath.Join(s.rootPath, wsID) wsPath, err := s.workspacePath("workspace.CreateWorkspace", wsID)
if err != nil {
return "", err
}
if s.medium.Exists(wsPath) { if s.medium.Exists(wsPath) {
return "", coreerr.E("workspace.CreateWorkspace", "workspace already exists", nil) return "", coreerr.E("workspace.CreateWorkspace", "workspace already exists", nil)
} }
for _, d := range []string{"config", "log", "data", "files", "keys"} { for _, d := range []string{"config", "log", "data", "files", "keys"} {
if err := s.medium.EnsureDir(filepath.Join(wsPath, d)); err != nil { if err := s.medium.EnsureDir(core.Path(wsPath, d)); err != nil {
return "", coreerr.E("workspace.CreateWorkspace", "failed to create directory: "+d, err) return "", coreerr.E("workspace.CreateWorkspace", "failed to create directory: "+d, err)
} }
} }
@ -92,7 +95,7 @@ func (s *Service) CreateWorkspace(identifier, password string) (string, error) {
return "", coreerr.E("workspace.CreateWorkspace", "failed to generate keys", err) return "", coreerr.E("workspace.CreateWorkspace", "failed to generate keys", err)
} }
if err := s.medium.WriteMode(filepath.Join(wsPath, "keys", "private.key"), privKey, 0600); err != nil { if err := s.medium.WriteMode(core.Path(wsPath, "keys", "private.key"), privKey, 0600); err != nil {
return "", coreerr.E("workspace.CreateWorkspace", "failed to save private key", err) return "", coreerr.E("workspace.CreateWorkspace", "failed to save private key", err)
} }
@ -104,12 +107,15 @@ func (s *Service) SwitchWorkspace(name string) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
wsPath := filepath.Join(s.rootPath, name) wsPath, err := s.workspacePath("workspace.SwitchWorkspace", name)
if err != nil {
return err
}
if !s.medium.IsDir(wsPath) { if !s.medium.IsDir(wsPath) {
return coreerr.E("workspace.SwitchWorkspace", "workspace not found: "+name, nil) return coreerr.E("workspace.SwitchWorkspace", "workspace not found: "+name, nil)
} }
s.activeWorkspace = name s.activeWorkspace = core.PathBase(wsPath)
return nil return nil
} }
@ -119,7 +125,15 @@ func (s *Service) activeFilePath(op, filename string) (string, error) {
if s.activeWorkspace == "" { if s.activeWorkspace == "" {
return "", coreerr.E(op, "no active workspace", nil) return "", coreerr.E(op, "no active workspace", nil)
} }
return filepath.Join(s.rootPath, s.activeWorkspace, "files", filename), nil filesRoot := core.Path(s.rootPath, s.activeWorkspace, "files")
path, err := joinWithinRoot(filesRoot, filename)
if err != nil {
return "", coreerr.E(op, "file path escapes workspace files", os.ErrPermission)
}
if path == filesRoot {
return "", coreerr.E(op, "filename is required", os.ErrInvalid)
}
return path, nil
} }
// WorkspaceFileGet retrieves the content of a file from the active workspace. // WorkspaceFileGet retrieves the content of a file from the active workspace.
@ -171,5 +185,38 @@ func (s *Service) HandleIPCEvents(c *core.Core, msg core.Message) core.Result {
return core.Result{OK: true} return core.Result{OK: true}
} }
func workspaceHome() string {
if home := core.Env("CORE_HOME"); home != "" {
return home
}
if home := core.Env("HOME"); home != "" {
return home
}
return core.Env("DIR_HOME")
}
func joinWithinRoot(root string, parts ...string) (string, error) {
candidate := core.Path(append([]string{root}, parts...)...)
sep := core.Env("DS")
if candidate == root || strings.HasPrefix(candidate, root+sep) {
return candidate, nil
}
return "", os.ErrPermission
}
func (s *Service) workspacePath(op, name string) (string, error) {
if name == "" {
return "", coreerr.E(op, "workspace name is required", os.ErrInvalid)
}
path, err := joinWithinRoot(s.rootPath, name)
if err != nil {
return "", coreerr.E(op, "workspace path escapes root", err)
}
if core.PathDir(path) != s.rootPath {
return "", coreerr.E(op, "invalid workspace name: "+name, os.ErrPermission)
}
return path, nil
}
// Ensure Service implements Workspace. // Ensure Service implements Workspace.
var _ Workspace = (*Service)(nil) var _ Workspace = (*Service)(nil)

View file

@ -1,48 +1,90 @@
package workspace package workspace
import ( import (
"path/filepath" "os"
"testing" "testing"
core "dappco.re/go/core" core "dappco.re/go/core"
"forge.lthn.ai/core/go-crypt/crypt/openpgp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestWorkspace(t *testing.T) { type stubCrypt struct {
c := core.New() key string
pgpSvc, err := openpgp.New(nil) err error
assert.NoError(t, err) }
func (s stubCrypt) CreateKeyPair(_, _ string) (string, error) {
if s.err != nil {
return "", s.err
}
return s.key, nil
}
func newTestService(t *testing.T) (*Service, string) {
t.Helper()
tempHome := t.TempDir() tempHome := t.TempDir()
t.Setenv("HOME", tempHome) t.Setenv("HOME", tempHome)
svc, err := New(c, pgpSvc.(cryptProvider)) svc, err := New(core.New(), stubCrypt{key: "private-key"})
assert.NoError(t, err) require.NoError(t, err)
s := svc.(*Service) return svc.(*Service), tempHome
}
func TestWorkspace(t *testing.T) {
s, tempHome := newTestService(t)
// Test CreateWorkspace
id, err := s.CreateWorkspace("test-user", "pass123") id, err := s.CreateWorkspace("test-user", "pass123")
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, id) assert.NotEmpty(t, id)
wsPath := filepath.Join(tempHome, ".core", "workspaces", id) wsPath := core.Path(tempHome, ".core", "workspaces", id)
assert.DirExists(t, wsPath) assert.DirExists(t, wsPath)
assert.DirExists(t, filepath.Join(wsPath, "keys")) assert.DirExists(t, core.Path(wsPath, "keys"))
assert.FileExists(t, filepath.Join(wsPath, "keys", "private.key")) assert.FileExists(t, core.Path(wsPath, "keys", "private.key"))
// Test SwitchWorkspace
err = s.SwitchWorkspace(id) err = s.SwitchWorkspace(id)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, id, s.activeWorkspace) assert.Equal(t, id, s.activeWorkspace)
// Test File operations err = s.WorkspaceFileSet("secret.txt", "top secret info")
filename := "secret.txt" require.NoError(t, err)
content := "top secret info"
err = s.WorkspaceFileSet(filename, content)
assert.NoError(t, err)
got, err := s.WorkspaceFileGet(filename) got, err := s.WorkspaceFileGet("secret.txt")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, content, got) assert.Equal(t, "top secret info", got)
}
func TestSwitchWorkspace_TraversalBlocked(t *testing.T) {
s, tempHome := newTestService(t)
outside := core.Path(tempHome, ".core", "escaped")
require.NoError(t, os.MkdirAll(outside, 0755))
err := s.SwitchWorkspace("../escaped")
require.Error(t, err)
assert.Empty(t, s.activeWorkspace)
}
func TestWorkspaceFileSet_TraversalBlocked(t *testing.T) {
s, tempHome := newTestService(t)
id, err := s.CreateWorkspace("test-user", "pass123")
require.NoError(t, err)
require.NoError(t, s.SwitchWorkspace(id))
keyPath := core.Path(tempHome, ".core", "workspaces", id, "keys", "private.key")
before, err := os.ReadFile(keyPath)
require.NoError(t, err)
err = s.WorkspaceFileSet("../keys/private.key", "hijack")
require.Error(t, err)
after, err := os.ReadFile(keyPath)
require.NoError(t, err)
assert.Equal(t, string(before), string(after))
_, err = s.WorkspaceFileGet("../keys/private.key")
require.Error(t, err)
} }