diff --git a/executor_become_test.go b/executor_become_test.go index 4d1fe52..f2fc5e6 100644 --- a/executor_become_test.go +++ b/executor_become_test.go @@ -58,6 +58,11 @@ func (c *becomeRecordingClient) SetBecome(become bool, user, password string) { c.mu.Lock() defer c.mu.Unlock() c.become = become + if !become { + c.becomeUser = "" + c.becomePass = "" + return + } if user != "" { c.becomeUser = user } diff --git a/local_client.go b/local_client.go index c54f3d2..5a1934c 100644 --- a/local_client.go +++ b/local_client.go @@ -35,6 +35,11 @@ func (c *localClient) SetBecome(become bool, user, password string) { c.mu.Lock() defer c.mu.Unlock() c.become = become + if !become { + c.becomeUser = "" + c.becomePass = "" + return + } if user != "" { c.becomeUser = user } diff --git a/local_client_test.go b/local_client_test.go index 1a70c05..a507132 100644 --- a/local_client_test.go +++ b/local_client_test.go @@ -78,3 +78,19 @@ func TestExecutor_GatherFacts_Good_LocalConnection(t *testing.T) { assert.NotEmpty(t, facts.Hostname) assert.NotEmpty(t, facts.Kernel) } + +func TestLocalClient_Good_SetBecomeResetsStateWhenDisabled(t *testing.T) { + client := newLocalClient() + + client.SetBecome(true, "admin", "secret") + become, user, password := client.BecomeState() + assert.True(t, become) + assert.Equal(t, "admin", user) + assert.Equal(t, "secret", password) + + client.SetBecome(false, "", "") + become, user, password = client.BecomeState() + assert.False(t, become) + assert.Empty(t, user) + assert.Empty(t, password) +} diff --git a/mock_ssh_test.go b/mock_ssh_test.go index 6b35834..b230add 100644 --- a/mock_ssh_test.go +++ b/mock_ssh_test.go @@ -258,6 +258,11 @@ func (m *MockSSHClient) SetBecome(become bool, user, password string) { m.mu.Lock() defer m.mu.Unlock() m.become = become + if !become { + m.becomeUser = "" + m.becomePass = "" + return + } if user != "" { m.becomeUser = user } diff --git a/modules_infra_test.go b/modules_infra_test.go index fc63151..ecf8b9a 100644 --- a/modules_infra_test.go +++ b/modules_infra_test.go @@ -741,9 +741,8 @@ func TestModulesInfra_Become_Good_DisableAfterEnable(t *testing.T) { client.SetBecome(false, "", "") assert.False(t, client.become) - // becomeUser and becomePass are only updated if non-empty - assert.Equal(t, "root", client.becomeUser) - assert.Equal(t, "secret", client.becomePass) + assert.Empty(t, client.becomeUser) + assert.Empty(t, client.becomePass) } func TestModulesInfra_Become_Good_MockBecomeTracking(t *testing.T) { diff --git a/ssh.go b/ssh.go index bd4e382..acb24e1 100644 --- a/ssh.go +++ b/ssh.go @@ -498,6 +498,11 @@ func (c *SSHClient) SetBecome(become bool, user, password string) { c.mu.Lock() defer c.mu.Unlock() c.become = become + if !become { + c.becomeUser = "" + c.becomePass = "" + return + } if user != "" { c.becomeUser = user } diff --git a/ssh_test.go b/ssh_test.go index 5b12e92..cf5fac2 100644 --- a/ssh_test.go +++ b/ssh_test.go @@ -34,3 +34,19 @@ func TestSSH_NewSSHClient_Good_Defaults(t *testing.T) { assert.Equal(t, "root", client.user) assert.Equal(t, 30*time.Second, client.timeout) } + +func TestSSH_SetBecome_Good_DisablesAndClearsState(t *testing.T) { + client := &SSHClient{} + + client.SetBecome(true, "admin", "secret") + become, user, password := client.BecomeState() + assert.True(t, become) + assert.Equal(t, "admin", user) + assert.Equal(t, "secret", password) + + client.SetBecome(false, "", "") + become, user, password = client.BecomeState() + assert.False(t, become) + assert.Empty(t, user) + assert.Empty(t, password) +}