diff --git a/executor.go b/executor.go index 29ae00f..c97e5f7 100644 --- a/executor.go +++ b/executor.go @@ -233,20 +233,47 @@ func (e *Executor) runTaskOnHosts(ctx context.Context, hosts []string, task *Tas return nil } + targetHosts := selectTaskHosts(hosts, task) + // Handle block tasks if len(task.Block) > 0 { - return e.runBlock(ctx, hosts, task, play) + return e.runBlock(ctx, targetHosts, task, play) } // Handle include/import if task.IncludeTasks != "" || task.ImportTasks != "" { - return e.runIncludeTasks(ctx, hosts, task, play) + return e.runIncludeTasks(ctx, targetHosts, task, play) } if task.IncludeRole != nil || task.ImportRole != nil { - return e.runIncludeRole(ctx, hosts, task, play) + return e.runIncludeRole(ctx, targetHosts, task, play) } - for _, host := range hosts { + if task.RunOnce { + if len(targetHosts) == 0 { + return nil + } + host := targetHosts[0] + if err := e.runTaskOnHost(ctx, host, task, play); err != nil { + return err + } + + if task.Register != "" && len(hosts) > 1 { + if hostResults, ok := e.results[host]; ok { + if result, ok := hostResults[task.Register]; ok { + for _, otherHost := range hosts[1:] { + if e.results[otherHost] == nil { + e.results[otherHost] = make(map[string]*TaskResult) + } + e.results[otherHost][task.Register] = result + } + } + } + } + + return nil + } + + for _, host := range targetHosts { if err := e.runTaskOnHost(ctx, host, task, play); err != nil { if !task.IgnoreErrors { return err @@ -257,6 +284,13 @@ func (e *Executor) runTaskOnHosts(ctx context.Context, hosts []string, task *Tas return nil } +func selectTaskHosts(hosts []string, task *Task) []string { + if task != nil && task.RunOnce && len(hosts) > 0 { + return hosts[:1] + } + return hosts +} + // runTaskOnHost runs a task on a single host. func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, play *Play) error { start := time.Now() diff --git a/executor_test.go b/executor_test.go index 0e67bd7..a027838 100644 --- a/executor_test.go +++ b/executor_test.go @@ -1,6 +1,7 @@ package ansible import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -98,6 +99,41 @@ func TestExecutor_GetHosts_Good_WithLimit(t *testing.T) { assert.Contains(t, hosts, "host2") } +// --- selectTaskHosts --- + +func TestExecutor_SelectTaskHosts_Good_RunOnce(t *testing.T) { + hosts := selectTaskHosts([]string{"host1", "host2", "host3"}, &Task{RunOnce: true}) + assert.Equal(t, []string{"host1"}, hosts) +} + +func TestExecutor_SelectTaskHosts_Good_NormalTask(t *testing.T) { + hosts := selectTaskHosts([]string{"host1", "host2"}, &Task{}) + assert.Equal(t, []string{"host1", "host2"}, hosts) +} + +// --- runTaskOnHosts --- + +func TestExecutor_RunTaskOnHosts_Good_RunOncePropagatesRegisteredResult(t *testing.T) { + e := NewExecutor("/tmp") + task := &Task{ + Name: "Skip once", + RunOnce: true, + Register: "result", + When: "false", + } + play := &Play{} + + err := e.runTaskOnHosts(context.Background(), []string{"host1", "host2"}, task, play) + + assert.NoError(t, err) + assert.Contains(t, e.results, "host1") + assert.Contains(t, e.results, "host2") + assert.Contains(t, e.results["host1"], "result") + assert.Contains(t, e.results["host2"], "result") + assert.True(t, e.results["host1"]["result"].Skipped) + assert.True(t, e.results["host2"]["result"].Skipped) +} + // --- matchesTags --- func TestExecutor_MatchesTags_Good_NoTagsFilter(t *testing.T) {