diff --git a/pkg/cli/daemon.go b/pkg/cli/daemon.go index e43df9f..90b2fd2 100644 --- a/pkg/cli/daemon.go +++ b/pkg/cli/daemon.go @@ -74,13 +74,14 @@ func IsStderrTTY() bool { // PIDFile manages a process ID file for single-instance enforcement. type PIDFile struct { - path string - mu sync.Mutex + medium io.Medium + path string + mu sync.Mutex } // NewPIDFile creates a PID file manager. -func NewPIDFile(path string) *PIDFile { - return &PIDFile{path: path} +func NewPIDFile(m io.Medium, path string) *PIDFile { + return &PIDFile{medium: m, path: path} } // Acquire writes the current PID to the file. @@ -90,7 +91,7 @@ func (p *PIDFile) Acquire() error { defer p.mu.Unlock() // Check if PID file exists - if data, err := io.Local.Read(p.path); err == nil { + if data, err := p.medium.Read(p.path); err == nil { pid, err := strconv.Atoi(data) if err == nil && pid > 0 { // Check if process is still running @@ -101,19 +102,19 @@ func (p *PIDFile) Acquire() error { } } // Stale PID file, remove it - _ = io.Local.Delete(p.path) + _ = p.medium.Delete(p.path) } // Ensure directory exists if dir := filepath.Dir(p.path); dir != "." { - if err := io.Local.EnsureDir(dir); err != nil { + if err := p.medium.EnsureDir(dir); err != nil { return fmt.Errorf("failed to create PID directory: %w", err) } } // Write current PID pid := os.Getpid() - if err := io.Local.Write(p.path, strconv.Itoa(pid)); err != nil { + if err := p.medium.Write(p.path, strconv.Itoa(pid)); err != nil { return fmt.Errorf("failed to write PID file: %w", err) } @@ -124,7 +125,7 @@ func (p *PIDFile) Acquire() error { func (p *PIDFile) Release() error { p.mu.Lock() defer p.mu.Unlock() - return io.Local.Delete(p.path) + return p.medium.Delete(p.path) } // Path returns the PID file path. @@ -246,6 +247,9 @@ func (h *HealthServer) Addr() string { // DaemonOptions configures daemon mode execution. type DaemonOptions struct { + // Medium is the filesystem abstraction. + Medium io.Medium + // PIDFile path for single-instance enforcement. // Leave empty to skip PID file management. PIDFile string @@ -283,13 +287,17 @@ func NewDaemon(opts DaemonOptions) *Daemon { opts.ShutdownTimeout = 30 * time.Second } + if opts.Medium == nil { + opts.Medium = io.Local + } + d := &Daemon{ opts: opts, reload: make(chan struct{}, 1), } if opts.PIDFile != "" { - d.pid = NewPIDFile(opts.PIDFile) + d.pid = NewPIDFile(opts.Medium, opts.PIDFile) } if opts.HealthAddr != "" { diff --git a/pkg/cli/daemon_test.go b/pkg/cli/daemon_test.go index 5eb5132..d128b5e 100644 --- a/pkg/cli/daemon_test.go +++ b/pkg/cli/daemon_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -31,7 +32,7 @@ func TestPIDFile(t *testing.T) { tmpDir := t.TempDir() pidPath := filepath.Join(tmpDir, "test.pid") - pid := NewPIDFile(pidPath) + pid := NewPIDFile(io.Local, pidPath) // Acquire should succeed err := pid.Acquire() @@ -58,7 +59,7 @@ func TestPIDFile(t *testing.T) { err := os.WriteFile(pidPath, []byte("999999999"), 0644) require.NoError(t, err) - pid := NewPIDFile(pidPath) + pid := NewPIDFile(io.Local, pidPath) // Should acquire successfully (stale PID removed) err = pid.Acquire() @@ -72,7 +73,7 @@ func TestPIDFile(t *testing.T) { tmpDir := t.TempDir() pidPath := filepath.Join(tmpDir, "subdir", "nested", "test.pid") - pid := NewPIDFile(pidPath) + pid := NewPIDFile(io.Local, pidPath) err := pid.Acquire() require.NoError(t, err) @@ -85,9 +86,26 @@ func TestPIDFile(t *testing.T) { }) t.Run("path getter", func(t *testing.T) { - pid := NewPIDFile("/tmp/test.pid") + pid := NewPIDFile(io.Local, "/tmp/test.pid") assert.Equal(t, "/tmp/test.pid", pid.Path()) }) + + t.Run("with mock medium", func(t *testing.T) { + mock := io.NewMockMedium() + pidPath := "/tmp/mock.pid" + pid := NewPIDFile(mock, pidPath) + + err := pid.Acquire() + require.NoError(t, err) + + assert.True(t, mock.Exists(pidPath)) + data, _ := mock.Read(pidPath) + assert.NotEmpty(t, data) + + err = pid.Release() + require.NoError(t, err) + assert.False(t, mock.Exists(pidPath)) + }) } func TestHealthServer(t *testing.T) { @@ -244,6 +262,26 @@ func TestDaemon(t *testing.T) { d := NewDaemon(DaemonOptions{}) assert.Equal(t, 30*time.Second, d.opts.ShutdownTimeout) }) + + t.Run("with mock medium", func(t *testing.T) { + mock := io.NewMockMedium() + pidPath := "/tmp/daemon.pid" + + d := NewDaemon(DaemonOptions{ + Medium: mock, + PIDFile: pidPath, + HealthAddr: "127.0.0.1:0", + }) + + err := d.Start() + require.NoError(t, err) + + assert.True(t, mock.Exists(pidPath)) + + err = d.Stop() + require.NoError(t, err) + assert.False(t, mock.Exists(pidPath)) + }) } func TestRunWithTimeout(t *testing.T) { diff --git a/pkg/container/linuxkit.go b/pkg/container/linuxkit.go index 2f2780a..252b864 100644 --- a/pkg/container/linuxkit.go +++ b/pkg/container/linuxkit.go @@ -52,6 +52,10 @@ func NewLinuxKitManagerWithHypervisor(state *State, hypervisor Hypervisor) *Linu // Run starts a new LinuxKit VM from the given image. func (m *LinuxKitManager) Run(ctx context.Context, image string, opts RunOptions) (*Container, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + // Validate image exists if !io.Local.IsFile(image) { return nil, fmt.Errorf("image not found: %s", image) @@ -232,6 +236,10 @@ func (m *LinuxKitManager) waitForExit(id string, cmd *exec.Cmd) { // Stop stops a running container by sending SIGTERM. func (m *LinuxKitManager) Stop(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + container, ok := m.state.Get(id) if !ok { return fmt.Errorf("container not found: %s", id) @@ -290,6 +298,10 @@ func (m *LinuxKitManager) Stop(ctx context.Context, id string) error { // List returns all known containers, verifying process state. func (m *LinuxKitManager) List(ctx context.Context) ([]*Container, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + containers := m.state.All() // Verify each running container's process is still alive @@ -319,6 +331,10 @@ func isProcessRunning(pid int) bool { // Logs returns a reader for the container's log output. func (m *LinuxKitManager) Logs(ctx context.Context, id string, follow bool) (goio.ReadCloser, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + _, ok := m.state.Get(id) if !ok { return nil, fmt.Errorf("container not found: %s", id) @@ -403,6 +419,10 @@ func (f *followReader) Close() error { // Exec executes a command inside the container via SSH. func (m *LinuxKitManager) Exec(ctx context.Context, id string, cmd []string) error { + if err := ctx.Err(); err != nil { + return err + } + container, ok := m.state.Get(id) if !ok { return fmt.Errorf("container not found: %s", id)