From e1edbc1f9ba4dfe11493a116631e3e2dc41ce4ab Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 11:34:19 +0000 Subject: [PATCH] fix(cli): make tracker iterators snapshot-safe Co-Authored-By: Virgil --- pkg/cli/tracker.go | 14 ++++++++++---- pkg/cli/tracker_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/pkg/cli/tracker.go b/pkg/cli/tracker.go index b9bf9a5..060fbd6 100644 --- a/pkg/cli/tracker.go +++ b/pkg/cli/tracker.go @@ -89,8 +89,11 @@ type TaskTracker struct { func (tr *TaskTracker) Tasks() iter.Seq[*TrackedTask] { return func(yield func(*TrackedTask) bool) { tr.mu.Lock() - defer tr.mu.Unlock() - for _, t := range tr.tasks { + tasks := make([]*TrackedTask, len(tr.tasks)) + copy(tasks, tr.tasks) + tr.mu.Unlock() + + for _, t := range tasks { if !yield(t) { return } @@ -102,8 +105,11 @@ func (tr *TaskTracker) Tasks() iter.Seq[*TrackedTask] { func (tr *TaskTracker) Snapshots() iter.Seq2[string, string] { return func(yield func(string, string) bool) { tr.mu.Lock() - defer tr.mu.Unlock() - for _, t := range tr.tasks { + tasks := make([]*TrackedTask, len(tr.tasks)) + copy(tasks, tr.tasks) + tr.mu.Unlock() + + for _, t := range tasks { name, status, _ := t.snapshot() if !yield(name, status) { return diff --git a/pkg/cli/tracker_test.go b/pkg/cli/tracker_test.go index 8bbe798..138db76 100644 --- a/pkg/cli/tracker_test.go +++ b/pkg/cli/tracker_test.go @@ -189,6 +189,35 @@ func TestTaskTracker_Good(t *testing.T) { assert.NotContains(t, out, "✓") assert.NotContains(t, out, "✗") }) + + t.Run("iterators tolerate mutation during iteration", func(t *testing.T) { + tr := NewTaskTracker() + tr.out = &bytes.Buffer{} + + tr.Add("first") + tr.Add("second") + + done := make(chan struct{}) + go func() { + defer close(done) + for task := range tr.Tasks() { + task.Update("visited") + } + }() + + require.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond) + + for name, status := range tr.Snapshots() { + assert.Equal(t, "visited", status, name) + } + }) } func TestTaskTracker_Bad(t *testing.T) {