diff --git a/executor.go b/executor.go index 295f8a0..93e15c3 100644 --- a/executor.go +++ b/executor.go @@ -1948,40 +1948,61 @@ func (e *Executor) getClient(host string, play *Play) (sshExecutorClient, error) cfg.KeyFile = k } + desiredLocal := isLocalConnection(host, play, vars) + desiredBecome := play != nil && play.Become + desiredBecomeUser := "" + desiredBecomePass := "" + if desiredBecome { + desiredBecomeUser = play.BecomeUser + if bp, ok := vars["ansible_become_password"].(string); ok { + desiredBecomePass = bp + } else if cfg.Password != "" { + desiredBecomePass = cfg.Password + } + } + e.mu.Lock() defer e.mu.Unlock() if client, ok := e.clients[host]; ok { - if !isLocalConnection(host, play, vars) { + switch cached := client.(type) { + case *localClient: + if desiredLocal { + cached.SetBecome(desiredBecome, desiredBecomeUser, desiredBecomePass) + return cached, nil + } + _ = cached.Close() + delete(e.clients, host) + case *SSHClient: + if !desiredLocal && + cached.host == cfg.Host && + cached.port == cfg.Port && + cached.user == cfg.User && + cached.password == cfg.Password && + cached.keyFile == cfg.KeyFile { + cached.SetBecome(desiredBecome, desiredBecomeUser, desiredBecomePass) + return cached, nil + } + _ = cached.Close() + delete(e.clients, host) + default: + client.SetBecome(desiredBecome, desiredBecomeUser, desiredBecomePass) return client, nil } } - if isLocalConnection(host, play, vars) { + if desiredLocal { client := newLocalClient() - if play.Become { - becomePass := "" - if bp, ok := vars["ansible_become_password"].(string); ok { - becomePass = bp - } else if p, ok := vars["ansible_password"].(string); ok { - becomePass = p - } - client.SetBecome(true, play.BecomeUser, becomePass) - } + client.SetBecome(desiredBecome, desiredBecomeUser, desiredBecomePass) e.clients[host] = client return client, nil } // Apply play become settings - if play.Become { + if desiredBecome { cfg.Become = true - cfg.BecomeUser = play.BecomeUser - if bp, ok := vars["ansible_become_password"].(string); ok { - cfg.BecomePass = bp - } else if cfg.Password != "" { - // Use SSH password for sudo if no become password specified - cfg.BecomePass = cfg.Password - } + cfg.BecomeUser = desiredBecomeUser + cfg.BecomePass = desiredBecomePass } client, err := NewSSHClient(cfg) diff --git a/executor_become_test.go b/executor_become_test.go index c52a901..b60caa9 100644 --- a/executor_become_test.go +++ b/executor_become_test.go @@ -94,6 +94,6 @@ func TestExecutor_RunTaskOnHost_Good_TaskBecomeFalseOverridesPlayBecome(t *testi require.Len(t, client.runBecomeSeen, 1) assert.False(t, client.runBecomeSeen[0]) assert.True(t, client.become) - assert.Equal(t, "root", client.becomeUser) + assert.Equal(t, "admin", client.becomeUser) assert.Equal(t, "secret", client.becomePass) } diff --git a/executor_test.go b/executor_test.go index e3eea09..daf5242 100644 --- a/executor_test.go +++ b/executor_test.go @@ -128,6 +128,30 @@ func TestExecutor_GetClient_Good_PlayVarsOverrideInventoryVars(t *testing.T) { assert.Equal(t, "play-user", sshClient.user) } +func TestExecutor_GetClient_Good_UpdatesCachedBecomeState(t *testing.T) { + e := NewExecutor("/tmp") + e.SetInventoryDirect(&Inventory{ + All: &InventoryGroup{ + Hosts: map[string]*Host{ + "host1": {AnsibleHost: "127.0.0.1"}, + }, + }, + }) + + cached := &becomeRecordingClient{} + e.clients["host1"] = cached + + play := &Play{Become: true, BecomeUser: "admin"} + client, err := e.getClient("host1", play) + require.NoError(t, err) + require.Same(t, cached, client) + + become, user, pass := cached.BecomeState() + assert.True(t, become) + assert.Equal(t, "admin", user) + assert.Empty(t, pass) +} + // --- matchesTags --- func TestExecutor_MatchesTags_Good_NoTagsFilter(t *testing.T) {