From 23047aaa6f48a6dd7c834ab974767ea3682fce86 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 21:25:32 +0000 Subject: [PATCH] Handle meta end_batch actions --- executor.go | 20 ++++++++++++ executor_extra_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/executor.go b/executor.go index 5708f39..6d85a12 100644 --- a/executor.go +++ b/executor.go @@ -20,6 +20,7 @@ import ( var errEndPlay = errors.New("end play") var errEndHost = errors.New("end host") +var errEndBatch = errors.New("end batch") // sshExecutorClient is the client contract used by the executor. type sshExecutorClient interface { @@ -213,6 +214,9 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { if errors.Is(err, errEndPlay) { return nil } + if errors.Is(err, errEndBatch) { + goto nextBatch + } return err } } @@ -223,6 +227,9 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { if errors.Is(err, errEndPlay) { return nil } + if errors.Is(err, errEndBatch) { + goto nextBatch + } return err } } @@ -233,6 +240,9 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { if errors.Is(err, errEndPlay) { return nil } + if errors.Is(err, errEndBatch) { + goto nextBatch + } return err } } @@ -243,6 +253,9 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { if errors.Is(err, errEndPlay) { return nil } + if errors.Is(err, errEndBatch) { + goto nextBatch + } return err } } @@ -252,8 +265,13 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { if errors.Is(err, errEndPlay) { return nil } + if errors.Is(err, errEndBatch) { + goto nextBatch + } return err } + + nextBatch: } return nil @@ -1951,6 +1969,8 @@ func (e *Executor) handleMetaAction(ctx context.Context, host string, hosts []st return e.refreshInventory() case "end_play": return errEndPlay + case "end_batch": + return errEndBatch case "end_host": e.markHostEnded(host) return errEndHost diff --git a/executor_extra_test.go b/executor_extra_test.go index 8a9b791..d87c69f 100644 --- a/executor_extra_test.go +++ b/executor_extra_test.go @@ -344,6 +344,16 @@ func TestExecutorExtra_ModuleMeta_Good_EndHost(t *testing.T) { assert.Equal(t, "end_host", result.Data["action"]) } +func TestExecutorExtra_ModuleMeta_Good_EndBatch(t *testing.T) { + e := NewExecutor("/tmp") + result, err := e.moduleMeta(map[string]any{"_raw_params": "end_batch"}) + + require.NoError(t, err) + assert.False(t, result.Changed) + require.NotNil(t, result.Data) + assert.Equal(t, "end_batch", result.Data["action"]) +} + func TestExecutorExtra_HandleMetaAction_Good_ClearFacts(t *testing.T) { e := NewExecutor("/tmp") e.facts["host1"] = &Facts{Hostname: "web01"} @@ -369,6 +379,17 @@ func TestExecutorExtra_HandleMetaAction_Good_EndHost(t *testing.T) { assert.False(t, e.isHostEnded("host2")) } +func TestExecutorExtra_HandleMetaAction_Good_EndBatch(t *testing.T) { + e := NewExecutor("/tmp") + + result := &TaskResult{Data: map[string]any{"action": "end_batch"}} + err := e.handleMetaAction(context.Background(), "host1", []string{"host1", "host2"}, nil, result) + + require.ErrorIs(t, err, errEndBatch) + assert.False(t, e.isHostEnded("host1")) + assert.False(t, e.isHostEnded("host2")) +} + func TestExecutorExtra_HandleMetaAction_Good_ResetConnection(t *testing.T) { e := NewExecutor("/tmp") mock := NewMockSSHClient() @@ -426,6 +447,55 @@ func TestExecutorExtra_RunTaskOnHosts_Good_EndHostSkipsFutureTasks(t *testing.T) assert.NotContains(t, started, "host1:Follow-up") } +func TestExecutorExtra_RunPlay_Good_MetaEndBatchAdvancesToNextSerialBatch(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {Vars: map[string]any{"end_batch": true}}, + "host2": {Vars: map[string]any{"end_batch": false}}, + "host3": {Vars: map[string]any{"end_batch": false}}, + }, + }, + }) + e.clients["host1"] = NewMockSSHClient() + e.clients["host2"] = NewMockSSHClient() + e.clients["host3"] = NewMockSSHClient() + + serial := 1 + gatherFacts := false + play := &Play{ + Hosts: "all", + Serial: serial, + GatherFacts: &gatherFacts, + Tasks: []Task{ + { + Name: "end current batch", + Module: "meta", + Args: map[string]any{"_raw_params": "end_batch"}, + When: "end_batch", + }, + { + Name: "follow-up", + Module: "debug", + Args: map[string]any{"msg": "next batch"}, + }, + }, + } + + var started []string + e.OnTaskStart = func(host string, task *Task) { + started = append(started, host+":"+task.Name) + } + + require.NoError(t, e.runPlay(context.Background(), play)) + + assert.Contains(t, started, "host1:end current batch") + assert.NotContains(t, started, "host1:follow-up") + assert.Contains(t, started, "host2:follow-up") + assert.Contains(t, started, "host3:follow-up") +} + // ============================================================ // Tests for handleLookup (0% coverage) // ============================================================