diff --git a/cmd/ansible/cmd_test.go b/cmd/ansible/cmd_test.go index 71fac28..0d55d05 100644 --- a/cmd/ansible/cmd_test.go +++ b/cmd/ansible/cmd_test.go @@ -20,4 +20,3 @@ func TestRegister_AnsibleCommandExposesDiffFlag(t *testing.T) { require.True(t, cmd.Flags.Has("diff")) require.False(t, cmd.Flags.Bool("diff")) } - diff --git a/executor.go b/executor.go index f1ebe83..734adf5 100644 --- a/executor.go +++ b/executor.go @@ -347,8 +347,8 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p return e.runLoop(ctx, host, client, task, play) } - // Execute the task - result, err := e.executeModule(ctx, host, client, task, play) + // Execute the task, retrying when an until condition is present. + result, err := e.runTaskWithUntil(ctx, host, client, task, play) if err != nil { result = &TaskResult{Failed: true, Msg: err.Error()} } @@ -375,6 +375,139 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p return nil } +// runTaskWithUntil executes a task once or retries it until the until +// condition evaluates to true. +func (e *Executor) runTaskWithUntil(ctx context.Context, host string, client *SSHClient, task *Task, play *Play) (*TaskResult, error) { + if task.Until == "" { + return e.executeModule(ctx, host, client, task, play) + } + + retries := task.Retries + if retries <= 0 { + retries = 3 + } + + delay := task.Delay + if delay <= 0 { + delay = 1 + } + + restoreAlias := task.Register != "result" + var ( + previousAlias *TaskResult + hadAlias bool + ) + if restoreAlias { + e.mu.RLock() + if hostResults, ok := e.results[host]; ok { + previousAlias, hadAlias = hostResults["result"] + } + e.mu.RUnlock() + defer func() { + e.mu.Lock() + defer e.mu.Unlock() + if e.results[host] == nil { + e.results[host] = make(map[string]*TaskResult) + } + if hadAlias { + e.results[host]["result"] = previousAlias + } else { + delete(e.results[host], "result") + } + }() + } + + lastResult, lastErr := retryTask(ctx, retries, delay, func() (*TaskResult, error) { + result, err := e.executeModule(ctx, host, client, task, play) + if err != nil { + result = &TaskResult{Failed: true, Msg: err.Error()} + } + if result == nil { + result = &TaskResult{} + } + e.setTempResult(host, task.Register, result) + return result, nil + }, func(result *TaskResult) bool { + return e.evaluateWhen(task.Until, host, task) + }) + + if lastErr != nil { + return lastResult, lastErr + } + + if lastResult == nil { + lastResult = &TaskResult{} + } + lastResult.Failed = true + if lastResult.Msg == "" { + lastResult.Msg = sprintf("until condition not met: %s", task.Until) + } + return lastResult, nil +} + +// setTempResult exposes the latest task result for until evaluation without +// leaving stale state behind when a different register name is used. +func (e *Executor) setTempResult(host string, register string, result *TaskResult) { + e.mu.Lock() + defer e.mu.Unlock() + + if e.results[host] == nil { + e.results[host] = make(map[string]*TaskResult) + } + e.results[host]["result"] = result + + if register != "" { + e.results[host][register] = result + } +} + +// retryTask runs fn until done returns true or the retry budget is exhausted. +func retryTask(ctx context.Context, retries, delay int, fn func() (*TaskResult, error), done func(*TaskResult) bool) (*TaskResult, error) { + if retries < 0 { + retries = 0 + } + if delay < 0 { + delay = 0 + } + + var lastResult *TaskResult + var lastErr error + + for attempt := 0; attempt <= retries; attempt++ { + lastResult, lastErr = fn() + if lastErr != nil { + return lastResult, lastErr + } + if lastResult == nil { + lastResult = &TaskResult{} + } + if done == nil || done(lastResult) { + return lastResult, nil + } + if attempt < retries && delay > 0 { + if err := sleepWithContext(ctx, time.Duration(delay)*time.Second); err != nil { + return lastResult, err + } + } + } + + return lastResult, nil +} + +// sleepWithContext pauses for the requested duration or stops early when the +// context is cancelled. +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + // runLoop handles task loops. func (e *Executor) runLoop(ctx context.Context, host string, client *SSHClient, task *Task, play *Play) error { items := e.resolveLoop(task.Loop, host) diff --git a/executor_test.go b/executor_test.go index 002af3f..2d79235 100644 --- a/executor_test.go +++ b/executor_test.go @@ -157,6 +157,37 @@ func TestExecutor_RunTaskOnHost_Good_LoopControlPause(t *testing.T) { assert.GreaterOrEqual(t, time.Since(start), 900*time.Millisecond) } +func TestExecutor_SetTempResult_Good_ResultAliasSupportsUntil(t *testing.T) { + e := NewExecutor("/tmp") + + e.setTempResult("host1", "", &TaskResult{Failed: true, RC: 1}) + assert.False(t, e.evaluateWhen("result is success", "host1", &Task{})) + + e.setTempResult("host1", "", &TaskResult{Failed: false, RC: 0}) + assert.True(t, e.evaluateWhen("result is success", "host1", &Task{})) +} + +func TestRetryTask_Good_RetriesAndWaits(t *testing.T) { + attempts := 0 + start := time.Now() + + result, err := retryTask(context.Background(), 1, 1, func() (*TaskResult, error) { + attempts++ + if attempts == 1 { + return &TaskResult{Failed: true, RC: 1}, nil + } + return &TaskResult{Failed: false, RC: 0}, nil + }, func(result *TaskResult) bool { + return result.RC == 0 + }) + + require.NoError(t, err) + assert.Equal(t, 2, attempts) + assert.NotNil(t, result) + assert.False(t, result.Failed) + assert.GreaterOrEqual(t, time.Since(start), 900*time.Millisecond) +} + // --- matchesTags --- func TestExecutor_MatchesTags_Good_NoTagsFilter(t *testing.T) {