fix(ansible): apply task result conditions
Co-authored-by: Virgil <virgil@lethean.io>
This commit is contained in:
parent
f27fb19bed
commit
8321e16969
2 changed files with 287 additions and 19 deletions
260
executor.go
260
executor.go
|
|
@ -302,6 +302,7 @@ func (e *Executor) runTaskOnHost(ctx context.Context, host string, task *Task, p
|
|||
result = &TaskResult{Failed: true, Msg: err.Error()}
|
||||
}
|
||||
result.Duration = time.Since(start)
|
||||
e.applyTaskResultConditions(host, task, result)
|
||||
|
||||
// Store result
|
||||
if task.Register != "" {
|
||||
|
|
@ -358,6 +359,7 @@ func (e *Executor) runLoop(ctx context.Context, host string, client *SSHClient,
|
|||
if err != nil {
|
||||
result = &TaskResult{Failed: true, Msg: err.Error()}
|
||||
}
|
||||
e.applyTaskResultConditions(host, task, result)
|
||||
results = append(results, *result)
|
||||
|
||||
if result.Failed && !task.IgnoreErrors {
|
||||
|
|
@ -611,11 +613,15 @@ func (e *Executor) gatherFacts(ctx context.Context, host string, play *Play) err
|
|||
|
||||
// evaluateWhen evaluates a when condition.
|
||||
func (e *Executor) evaluateWhen(when any, host string, task *Task) bool {
|
||||
return e.evaluateWhenWithLocals(when, host, task, nil)
|
||||
}
|
||||
|
||||
func (e *Executor) evaluateWhenWithLocals(when any, host string, task *Task, locals map[string]any) bool {
|
||||
conditions := normalizeConditions(when)
|
||||
|
||||
for _, cond := range conditions {
|
||||
cond = e.templateString(cond, host, task)
|
||||
if !e.evalCondition(cond, host) {
|
||||
if !e.evalConditionWithLocals(cond, host, task, locals) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -643,11 +649,39 @@ func normalizeConditions(when any) []string {
|
|||
|
||||
// evalCondition evaluates a single condition.
|
||||
func (e *Executor) evalCondition(cond string, host string) bool {
|
||||
return e.evalConditionWithLocals(cond, host, nil, nil)
|
||||
}
|
||||
|
||||
func (e *Executor) evalConditionWithLocals(cond string, host string, task *Task, locals map[string]any) bool {
|
||||
cond = corexTrimSpace(cond)
|
||||
|
||||
// Handle negation
|
||||
if corexHasPrefix(cond, "not ") {
|
||||
return !e.evalCondition(corexTrimPrefix(cond, "not "), host)
|
||||
return !e.evalConditionWithLocals(corexTrimPrefix(cond, "not "), host, task, locals)
|
||||
}
|
||||
|
||||
// Handle equality/inequality
|
||||
if contains(cond, "==") {
|
||||
parts := splitN(cond, "==", 2)
|
||||
if len(parts) == 2 {
|
||||
left, leftOK := e.resolveConditionOperand(parts[0], host, task, locals)
|
||||
right, rightOK := e.resolveConditionOperand(parts[1], host, task, locals)
|
||||
if !leftOK || !rightOK {
|
||||
return true
|
||||
}
|
||||
return left == right
|
||||
}
|
||||
}
|
||||
if contains(cond, "!=") {
|
||||
parts := splitN(cond, "!=", 2)
|
||||
if len(parts) == 2 {
|
||||
left, leftOK := e.resolveConditionOperand(parts[0], host, task, locals)
|
||||
right, rightOK := e.resolveConditionOperand(parts[1], host, task, locals)
|
||||
if !leftOK || !rightOK {
|
||||
return true
|
||||
}
|
||||
return left != right
|
||||
}
|
||||
}
|
||||
|
||||
// Handle boolean literals
|
||||
|
|
@ -665,25 +699,46 @@ func (e *Executor) evalCondition(cond string, host string) bool {
|
|||
varName := corexTrimSpace(parts[0])
|
||||
check := corexTrimSpace(parts[1])
|
||||
|
||||
result := e.getRegisteredVar(host, varName)
|
||||
if result == nil {
|
||||
return check == "not defined" || check == "undefined"
|
||||
if result, ok := e.lookupConditionValue(varName, host, task, locals); ok {
|
||||
switch v := result.(type) {
|
||||
case *TaskResult:
|
||||
switch check {
|
||||
case "defined":
|
||||
return true
|
||||
case "not defined", "undefined":
|
||||
return false
|
||||
case "success", "succeeded":
|
||||
return !v.Failed
|
||||
case "failed":
|
||||
return v.Failed
|
||||
case "changed":
|
||||
return v.Changed
|
||||
case "skipped":
|
||||
return v.Skipped
|
||||
}
|
||||
case TaskResult:
|
||||
switch check {
|
||||
case "defined":
|
||||
return true
|
||||
case "not defined", "undefined":
|
||||
return false
|
||||
case "success", "succeeded":
|
||||
return !v.Failed
|
||||
case "failed":
|
||||
return v.Failed
|
||||
case "changed":
|
||||
return v.Changed
|
||||
case "skipped":
|
||||
return v.Skipped
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch check {
|
||||
case "defined":
|
||||
if check == "not defined" || check == "undefined" {
|
||||
return true
|
||||
case "not defined", "undefined":
|
||||
return false
|
||||
case "success", "succeeded":
|
||||
return !result.Failed
|
||||
case "failed":
|
||||
return result.Failed
|
||||
case "changed":
|
||||
return result.Changed
|
||||
case "skipped":
|
||||
return result.Skipped
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle simple var checks
|
||||
|
|
@ -697,8 +752,23 @@ func (e *Executor) evalCondition(cond string, host string) bool {
|
|||
}
|
||||
|
||||
// Check if it's a variable that should be truthy
|
||||
if result := e.getRegisteredVar(host, cond); result != nil {
|
||||
return !result.Failed && !result.Skipped
|
||||
if result, ok := e.lookupConditionValue(cond, host, task, locals); ok {
|
||||
switch v := result.(type) {
|
||||
case *TaskResult:
|
||||
return !v.Failed && !v.Skipped
|
||||
case TaskResult:
|
||||
return !v.Failed && !v.Skipped
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
return v != "" && v != "false" && v != "False"
|
||||
case int:
|
||||
return v != 0
|
||||
case int64:
|
||||
return v != 0
|
||||
case float64:
|
||||
return v != 0
|
||||
}
|
||||
}
|
||||
|
||||
// Check vars
|
||||
|
|
@ -717,6 +787,158 @@ func (e *Executor) evalCondition(cond string, host string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (e *Executor) lookupConditionValue(name string, host string, task *Task, locals map[string]any) (any, bool) {
|
||||
name = corexTrimSpace(name)
|
||||
|
||||
if locals != nil {
|
||||
if val, ok := locals[name]; ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
parts := splitN(name, ".", 2)
|
||||
if len(parts) == 2 {
|
||||
if base, ok := locals[parts[0]]; ok {
|
||||
if value, ok := taskResultField(base, parts[1]); ok {
|
||||
return value, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result := e.getRegisteredVar(host, name); result != nil {
|
||||
if len(splitN(name, ".", 2)) == 2 {
|
||||
parts := splitN(name, ".", 2)
|
||||
if value, ok := taskResultField(result, parts[1]); ok {
|
||||
return value, true
|
||||
}
|
||||
}
|
||||
return result, true
|
||||
}
|
||||
|
||||
if val, ok := e.vars[name]; ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
if task != nil {
|
||||
if val, ok := task.Vars[name]; ok {
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
|
||||
if e.inventory != nil {
|
||||
hostVars := GetHostVars(e.inventory, host)
|
||||
if val, ok := hostVars[name]; ok {
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
|
||||
if facts, ok := e.facts[host]; ok {
|
||||
switch name {
|
||||
case "ansible_hostname":
|
||||
return facts.Hostname, true
|
||||
case "ansible_fqdn":
|
||||
return facts.FQDN, true
|
||||
case "ansible_os_family":
|
||||
return facts.OS, true
|
||||
case "ansible_memtotal_mb":
|
||||
return facts.Memory, true
|
||||
case "ansible_processor_vcpus":
|
||||
return facts.CPUs, true
|
||||
case "ansible_default_ipv4_address":
|
||||
return facts.IPv4, true
|
||||
case "ansible_distribution":
|
||||
return facts.Distribution, true
|
||||
case "ansible_distribution_version":
|
||||
return facts.Version, true
|
||||
case "ansible_architecture":
|
||||
return facts.Architecture, true
|
||||
case "ansible_kernel":
|
||||
return facts.Kernel, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func taskResultField(value any, field string) (any, bool) {
|
||||
switch v := value.(type) {
|
||||
case *TaskResult:
|
||||
return taskResultField(*v, field)
|
||||
case TaskResult:
|
||||
switch field {
|
||||
case "stdout":
|
||||
return v.Stdout, true
|
||||
case "stderr":
|
||||
return v.Stderr, true
|
||||
case "rc":
|
||||
return v.RC, true
|
||||
case "changed":
|
||||
return v.Changed, true
|
||||
case "failed":
|
||||
return v.Failed, true
|
||||
case "skipped":
|
||||
return v.Skipped, true
|
||||
case "msg":
|
||||
return v.Msg, true
|
||||
}
|
||||
case map[string]any:
|
||||
if val, ok := v[field]; ok {
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (e *Executor) resolveConditionOperand(expr string, host string, task *Task, locals map[string]any) (string, bool) {
|
||||
expr = corexTrimSpace(expr)
|
||||
|
||||
if expr == "true" || expr == "True" || expr == "false" || expr == "False" {
|
||||
return expr, true
|
||||
}
|
||||
if len(expr) > 0 && expr[0] >= '0' && expr[0] <= '9' {
|
||||
return expr, true
|
||||
}
|
||||
if (len(expr) >= 2 && expr[0] == '\'' && expr[len(expr)-1] == '\'') || (len(expr) >= 2 && expr[0] == '"' && expr[len(expr)-1] == '"') {
|
||||
return expr[1 : len(expr)-1], true
|
||||
}
|
||||
|
||||
if value, ok := e.lookupConditionValue(expr, host, task, locals); ok {
|
||||
return sprintf("%v", value), true
|
||||
}
|
||||
|
||||
return expr, false
|
||||
}
|
||||
|
||||
func (e *Executor) applyTaskResultConditions(host string, task *Task, result *TaskResult) {
|
||||
if result == nil || task == nil {
|
||||
return
|
||||
}
|
||||
|
||||
locals := map[string]any{
|
||||
"result": result,
|
||||
"stdout": result.Stdout,
|
||||
"stderr": result.Stderr,
|
||||
"rc": result.RC,
|
||||
"changed": result.Changed,
|
||||
"failed": result.Failed,
|
||||
"skipped": result.Skipped,
|
||||
"msg": result.Msg,
|
||||
"duration": result.Duration,
|
||||
}
|
||||
|
||||
if task.ChangedWhen != nil {
|
||||
result.Changed = e.evaluateWhenWithLocals(task.ChangedWhen, host, task, locals)
|
||||
locals["changed"] = result.Changed
|
||||
locals["result"] = result
|
||||
}
|
||||
|
||||
if task.FailedWhen != nil {
|
||||
result.Failed = e.evaluateWhenWithLocals(task.FailedWhen, host, task, locals)
|
||||
locals["failed"] = result.Failed
|
||||
locals["result"] = result
|
||||
}
|
||||
}
|
||||
|
||||
// getRegisteredVar gets a registered task result.
|
||||
func (e *Executor) getRegisteredVar(host string, name string) *TaskResult {
|
||||
e.mu.RLock()
|
||||
|
|
|
|||
|
|
@ -261,6 +261,52 @@ func TestExecutor_EvaluateWhen_Good_MultipleConditions(t *testing.T) {
|
|||
assert.False(t, e.evaluateWhen([]any{"true", "false"}, "host1", nil))
|
||||
}
|
||||
|
||||
func TestExecutor_ApplyTaskResultConditions_Good_ChangedWhen(t *testing.T) {
|
||||
e := NewExecutor("/tmp")
|
||||
task := &Task{
|
||||
ChangedWhen: "stdout == 'expected'",
|
||||
}
|
||||
result := &TaskResult{
|
||||
Changed: true,
|
||||
Stdout: "actual",
|
||||
}
|
||||
|
||||
e.applyTaskResultConditions("host1", task, result)
|
||||
|
||||
assert.False(t, result.Changed)
|
||||
}
|
||||
|
||||
func TestExecutor_ApplyTaskResultConditions_Good_FailedWhen(t *testing.T) {
|
||||
e := NewExecutor("/tmp")
|
||||
task := &Task{
|
||||
FailedWhen: []any{"rc != 0", "stdout == 'expected'"},
|
||||
}
|
||||
result := &TaskResult{
|
||||
Failed: true,
|
||||
Stdout: "expected",
|
||||
RC: 0,
|
||||
}
|
||||
|
||||
e.applyTaskResultConditions("host1", task, result)
|
||||
|
||||
assert.False(t, result.Failed)
|
||||
}
|
||||
|
||||
func TestExecutor_ApplyTaskResultConditions_Good_DottedResultAccess(t *testing.T) {
|
||||
e := NewExecutor("/tmp")
|
||||
task := &Task{
|
||||
ChangedWhen: "result.rc == 0",
|
||||
}
|
||||
result := &TaskResult{
|
||||
Changed: false,
|
||||
RC: 0,
|
||||
}
|
||||
|
||||
e.applyTaskResultConditions("host1", task, result)
|
||||
|
||||
assert.True(t, result.Changed)
|
||||
}
|
||||
|
||||
// --- templateString ---
|
||||
|
||||
func TestExecutor_TemplateString_Good_SimpleVar(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue