From 8321e16969f973c0f2dfe41f6a48e3035fc00c41 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 19:18:52 +0000 Subject: [PATCH] fix(ansible): apply task result conditions Co-authored-by: Virgil --- executor.go | 260 +++++++++++++++++++++++++++++++++++++++++++---- executor_test.go | 46 +++++++++ 2 files changed, 287 insertions(+), 19 deletions(-) diff --git a/executor.go b/executor.go index 85bc9f0..f5d84a5 100644 --- a/executor.go +++ b/executor.go @@ -302,6 +302,7 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p result = &TaskResult{Failed: true, Msg: err.Error()} } result.Duration = time.Since(start) + e.applyTaskResultConditions(host, task, result) // Store result if task.Register != "" { @@ -358,6 +359,7 @@ func (e *Executor) runLoop(ctx context.Context, host string, client *SSHClient, if err != nil { result = &TaskResult{Failed: true, Msg: err.Error()} } + e.applyTaskResultConditions(host, task, result) results = append(results, *result) if result.Failed && !task.IgnoreErrors { @@ -611,11 +613,15 @@ func (e *Executor) gatherFacts(ctx context.Context, host string, play *Play) err // evaluateWhen evaluates a when condition. func (e *Executor) evaluateWhen(when any, host string, task *Task) bool { + return e.evaluateWhenWithLocals(when, host, task, nil) +} + +func (e *Executor) evaluateWhenWithLocals(when any, host string, task *Task, locals map[string]any) bool { conditions := normalizeConditions(when) for _, cond := range conditions { cond = e.templateString(cond, host, task) - if !e.evalCondition(cond, host) { + if !e.evalConditionWithLocals(cond, host, task, locals) { return false } } @@ -643,11 +649,39 @@ func normalizeConditions(when any) []string { // evalCondition evaluates a single condition. func (e *Executor) evalCondition(cond string, host string) bool { + return e.evalConditionWithLocals(cond, host, nil, nil) +} + +func (e *Executor) evalConditionWithLocals(cond string, host string, task *Task, locals map[string]any) bool { cond = corexTrimSpace(cond) // Handle negation if corexHasPrefix(cond, "not ") { - return !e.evalCondition(corexTrimPrefix(cond, "not "), host) + return !e.evalConditionWithLocals(corexTrimPrefix(cond, "not "), host, task, locals) + } + + // Handle equality/inequality + if contains(cond, "==") { + parts := splitN(cond, "==", 2) + if len(parts) == 2 { + left, leftOK := e.resolveConditionOperand(parts[0], host, task, locals) + right, rightOK := e.resolveConditionOperand(parts[1], host, task, locals) + if !leftOK || !rightOK { + return true + } + return left == right + } + } + if contains(cond, "!=") { + parts := splitN(cond, "!=", 2) + if len(parts) == 2 { + left, leftOK := e.resolveConditionOperand(parts[0], host, task, locals) + right, rightOK := e.resolveConditionOperand(parts[1], host, task, locals) + if !leftOK || !rightOK { + return true + } + return left != right + } } // Handle boolean literals @@ -665,25 +699,46 @@ func (e *Executor) evalCondition(cond string, host string) bool { varName := corexTrimSpace(parts[0]) check := corexTrimSpace(parts[1]) - result := e.getRegisteredVar(host, varName) - if result == nil { - return check == "not defined" || check == "undefined" + if result, ok := e.lookupConditionValue(varName, host, task, locals); ok { + switch v := result.(type) { + case *TaskResult: + switch check { + case "defined": + return true + case "not defined", "undefined": + return false + case "success", "succeeded": + return !v.Failed + case "failed": + return v.Failed + case "changed": + return v.Changed + case "skipped": + return v.Skipped + } + case TaskResult: + switch check { + case "defined": + return true + case "not defined", "undefined": + return false + case "success", "succeeded": + return !v.Failed + case "failed": + return v.Failed + case "changed": + return v.Changed + case "skipped": + return v.Skipped + } + } + return true } - switch check { - case "defined": + if check == "not defined" || check == "undefined" { return true - case "not defined", "undefined": - return false - case "success", "succeeded": - return !result.Failed - case "failed": - return result.Failed - case "changed": - return result.Changed - case "skipped": - return result.Skipped } + return false } // Handle simple var checks @@ -697,8 +752,23 @@ func (e *Executor) evalCondition(cond string, host string) bool { } // Check if it's a variable that should be truthy - if result := e.getRegisteredVar(host, cond); result != nil { - return !result.Failed && !result.Skipped + if result, ok := e.lookupConditionValue(cond, host, task, locals); ok { + switch v := result.(type) { + case *TaskResult: + return !v.Failed && !v.Skipped + case TaskResult: + return !v.Failed && !v.Skipped + case bool: + return v + case string: + return v != "" && v != "false" && v != "False" + case int: + return v != 0 + case int64: + return v != 0 + case float64: + return v != 0 + } } // Check vars @@ -717,6 +787,158 @@ func (e *Executor) evalCondition(cond string, host string) bool { return true } +func (e *Executor) lookupConditionValue(name string, host string, task *Task, locals map[string]any) (any, bool) { + name = corexTrimSpace(name) + + if locals != nil { + if val, ok := locals[name]; ok { + return val, true + } + + parts := splitN(name, ".", 2) + if len(parts) == 2 { + if base, ok := locals[parts[0]]; ok { + if value, ok := taskResultField(base, parts[1]); ok { + return value, true + } + } + } + } + + if result := e.getRegisteredVar(host, name); result != nil { + if len(splitN(name, ".", 2)) == 2 { + parts := splitN(name, ".", 2) + if value, ok := taskResultField(result, parts[1]); ok { + return value, true + } + } + return result, true + } + + if val, ok := e.vars[name]; ok { + return val, true + } + + if task != nil { + if val, ok := task.Vars[name]; ok { + return val, true + } + } + + if e.inventory != nil { + hostVars := GetHostVars(e.inventory, host) + if val, ok := hostVars[name]; ok { + return val, true + } + } + + if facts, ok := e.facts[host]; ok { + switch name { + case "ansible_hostname": + return facts.Hostname, true + case "ansible_fqdn": + return facts.FQDN, true + case "ansible_os_family": + return facts.OS, true + case "ansible_memtotal_mb": + return facts.Memory, true + case "ansible_processor_vcpus": + return facts.CPUs, true + case "ansible_default_ipv4_address": + return facts.IPv4, true + case "ansible_distribution": + return facts.Distribution, true + case "ansible_distribution_version": + return facts.Version, true + case "ansible_architecture": + return facts.Architecture, true + case "ansible_kernel": + return facts.Kernel, true + } + } + + return nil, false +} + +func taskResultField(value any, field string) (any, bool) { + switch v := value.(type) { + case *TaskResult: + return taskResultField(*v, field) + case TaskResult: + switch field { + case "stdout": + return v.Stdout, true + case "stderr": + return v.Stderr, true + case "rc": + return v.RC, true + case "changed": + return v.Changed, true + case "failed": + return v.Failed, true + case "skipped": + return v.Skipped, true + case "msg": + return v.Msg, true + } + case map[string]any: + if val, ok := v[field]; ok { + return val, true + } + } + return nil, false +} + +func (e *Executor) resolveConditionOperand(expr string, host string, task *Task, locals map[string]any) (string, bool) { + expr = corexTrimSpace(expr) + + if expr == "true" || expr == "True" || expr == "false" || expr == "False" { + return expr, true + } + if len(expr) > 0 && expr[0] >= '0' && expr[0] <= '9' { + return expr, true + } + if (len(expr) >= 2 && expr[0] == '\'' && expr[len(expr)-1] == '\'') || (len(expr) >= 2 && expr[0] == '"' && expr[len(expr)-1] == '"') { + return expr[1 : len(expr)-1], true + } + + if value, ok := e.lookupConditionValue(expr, host, task, locals); ok { + return sprintf("%v", value), true + } + + return expr, false +} + +func (e *Executor) applyTaskResultConditions(host string, task *Task, result *TaskResult) { + if result == nil || task == nil { + return + } + + locals := map[string]any{ + "result": result, + "stdout": result.Stdout, + "stderr": result.Stderr, + "rc": result.RC, + "changed": result.Changed, + "failed": result.Failed, + "skipped": result.Skipped, + "msg": result.Msg, + "duration": result.Duration, + } + + if task.ChangedWhen != nil { + result.Changed = e.evaluateWhenWithLocals(task.ChangedWhen, host, task, locals) + locals["changed"] = result.Changed + locals["result"] = result + } + + if task.FailedWhen != nil { + result.Failed = e.evaluateWhenWithLocals(task.FailedWhen, host, task, locals) + locals["failed"] = result.Failed + locals["result"] = result + } +} + // getRegisteredVar gets a registered task result. func (e *Executor) getRegisteredVar(host string, name string) *TaskResult { e.mu.RLock() diff --git a/executor_test.go b/executor_test.go index 0e67bd7..f197d39 100644 --- a/executor_test.go +++ b/executor_test.go @@ -261,6 +261,52 @@ func TestExecutor_EvaluateWhen_Good_MultipleConditions(t *testing.T) { assert.False(t, e.evaluateWhen([]any{"true", "false"}, "host1", nil)) } +func TestExecutor_ApplyTaskResultConditions_Good_ChangedWhen(t *testing.T) { + e := NewExecutor("/tmp") + task := &Task{ + ChangedWhen: "stdout == 'expected'", + } + result := &TaskResult{ + Changed: true, + Stdout: "actual", + } + + e.applyTaskResultConditions("host1", task, result) + + assert.False(t, result.Changed) +} + +func TestExecutor_ApplyTaskResultConditions_Good_FailedWhen(t *testing.T) { + e := NewExecutor("/tmp") + task := &Task{ + FailedWhen: []any{"rc != 0", "stdout == 'expected'"}, + } + result := &TaskResult{ + Failed: true, + Stdout: "expected", + RC: 0, + } + + e.applyTaskResultConditions("host1", task, result) + + assert.False(t, result.Failed) +} + +func TestExecutor_ApplyTaskResultConditions_Good_DottedResultAccess(t *testing.T) { + e := NewExecutor("/tmp") + task := &Task{ + ChangedWhen: "result.rc == 0", + } + result := &TaskResult{ + Changed: false, + RC: 0, + } + + e.applyTaskResultConditions("host1", task, result) + + assert.True(t, result.Changed) +} + // --- templateString --- func TestExecutor_TemplateString_Good_SimpleVar(t *testing.T) {