diff --git a/executor.go b/executor.go index 7a20b7c..e86dd34 100644 --- a/executor.go +++ b/executor.go @@ -4,6 +4,8 @@ import ( "context" "regexp" "slices" + "strconv" + "strings" "sync" "text/template" "time" @@ -136,55 +138,160 @@ func (e *Executor) runPlay(ctx context.Context, play *Play) error { e.vars[k] = v } - // Gather facts if needed - gatherFacts := play.GatherFacts == nil || *play.GatherFacts - if gatherFacts { - for _, host := range hosts { - if err := e.gatherFacts(ctx, host, play); err != nil { - // Non-fatal - if e.Verbose > 0 { - coreerr.Warn("gather facts failed", "host", host, "err", err) + for _, batch := range splitSerialHosts(hosts, play.Serial) { + if len(batch) == 0 { + continue + } + + // Gather facts if needed + gatherFacts := play.GatherFacts == nil || *play.GatherFacts + if gatherFacts { + for _, host := range batch { + if err := e.gatherFacts(ctx, host, play); err != nil { + // Non-fatal + if e.Verbose > 0 { + coreerr.Warn("gather facts failed", "host", host, "err", err) + } } } } - } - // Execute pre_tasks - for _, task := range play.PreTasks { - if err := e.runTaskOnHosts(ctx, hosts, &task, play); err != nil { + // Execute pre_tasks + for _, task := range play.PreTasks { + if err := e.runTaskOnHosts(ctx, batch, &task, play); err != nil { + return err + } + } + + // Execute roles + for _, roleRef := range play.Roles { + if err := e.runRole(ctx, batch, &roleRef, play); err != nil { + return err + } + } + + // Execute tasks + for _, task := range play.Tasks { + if err := e.runTaskOnHosts(ctx, batch, &task, play); err != nil { + return err + } + } + + // Execute post_tasks + for _, task := range play.PostTasks { + if err := e.runTaskOnHosts(ctx, batch, &task, play); err != nil { + return err + } + } + + // Run notified handlers for this batch. + if err := e.runNotifiedHandlers(ctx, batch, play); err != nil { return err } } - // Execute roles - for _, roleRef := range play.Roles { - if err := e.runRole(ctx, hosts, &roleRef, play); err != nil { - return err - } - } - - // Execute tasks - for _, task := range play.Tasks { - if err := e.runTaskOnHosts(ctx, hosts, &task, play); err != nil { - return err - } - } - - // Execute post_tasks - for _, task := range play.PostTasks { - if err := e.runTaskOnHosts(ctx, hosts, &task, play); err != nil { - return err - } - } - - // Run notified handlers - if err := e.runNotifiedHandlers(ctx, hosts, play); err != nil { - return err - } - return nil } +// splitSerialHosts splits a host list into serial batches. +func splitSerialHosts(hosts []string, serial any) [][]string { + batchSize := resolveSerialBatchSize(serial, len(hosts)) + if batchSize <= 0 || batchSize >= len(hosts) { + if len(hosts) == 0 { + return nil + } + return [][]string{hosts} + } + + batches := make([][]string, 0, (len(hosts)+batchSize-1)/batchSize) + for len(hosts) > 0 { + size := batchSize + if size > len(hosts) { + size = len(hosts) + } + batch := append([]string(nil), hosts[:size]...) + batches = append(batches, batch) + hosts = hosts[size:] + } + return batches +} + +// resolveSerialBatchSize converts a play serial value into a concrete batch size. +func resolveSerialBatchSize(serial any, total int) int { + if total <= 0 { + return 0 + } + + switch v := serial.(type) { + case nil: + return total + case int: + if v > 0 { + return v + } + case int8: + if v > 0 { + return int(v) + } + case int16: + if v > 0 { + return int(v) + } + case int32: + if v > 0 { + return int(v) + } + case int64: + if v > 0 { + return int(v) + } + case uint: + if v > 0 { + return int(v) + } + case uint8: + if v > 0 { + return int(v) + } + case uint16: + if v > 0 { + return int(v) + } + case uint32: + if v > 0 { + return int(v) + } + case uint64: + if v > 0 { + return int(v) + } + case string: + s := corexTrimSpace(v) + if s == "" { + return total + } + if corexHasSuffix(s, "%") { + percent, err := strconv.Atoi(strings.TrimSuffix(s, "%")) + if err == nil && percent > 0 { + size := (total*percent + 99) / 100 + if size < 1 { + size = 1 + } + if size > total { + size = total + } + return size + } + return total + } + if n, err := strconv.Atoi(s); err == nil && n > 0 { + return n + } + } + + return total +} + // runRole executes a role on hosts. func (e *Executor) runRole(ctx context.Context, hosts []string, roleRef *RoleRef, play *Play) error { // Check when condition diff --git a/executor_test.go b/executor_test.go index b575346..6e3b33b 100644 --- a/executor_test.go +++ b/executor_test.go @@ -207,6 +207,46 @@ func TestExecutor_RunTaskOnHosts_Good_RunOnceSharesRegisteredResult(t *testing.T assert.Equal(t, "hello", e.results["host2"]["debug_result"].Msg) } +func TestExecutor_RunPlay_Good_SerialBatchesHosts(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {}, + "host2": {}, + "host3": {}, + }, + }, + }) + + gatherFacts := false + play := &Play{ + Hosts: "all", + GatherFacts: &gatherFacts, + Serial: 1, + Tasks: []Task{ + {Name: "first", Module: "debug", Args: map[string]any{"msg": "one"}}, + {Name: "second", Module: "debug", Args: map[string]any{"msg": "two"}}, + }, + } + + var got []string + e.OnTaskStart = func(host string, task *Task) { + got = append(got, host+":"+task.Name) + } + + require.NoError(t, e.runPlay(context.Background(), play)) + + assert.Equal(t, []string{ + "host1:first", + "host1:second", + "host2:first", + "host2:second", + "host3:first", + "host3:second", + }, got) +} + func TestExecutor_RunTaskOnHost_Good_LoopControlPause(t *testing.T) { e := NewExecutor("/tmp") task := &Task{