cli/pkg/ansible/ssh.go
Snider f55ca297a0 fix(deploy): address linter warnings and build errors
- Fix fmt.Sprintf format verb error in ssh.go (remove unused stat command)
- Fix errcheck warnings by explicitly ignoring best-effort operations
- Fix ineffassign warning in cmd_ansible.go

All golangci-lint checks now pass for deploy packages.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:10:13 +00:00

378 lines
8.8 KiB
Go

package ansible
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
// SSHClient handles SSH connections to remote hosts.
type SSHClient struct {
host string
port int
user string
password string
keyFile string
client *ssh.Client
mu sync.Mutex
become bool
becomeUser string
becomePass string
}
// SSHConfig holds SSH connection configuration.
type SSHConfig struct {
Host string
Port int
User string
Password string
KeyFile string
Become bool
BecomeUser string
BecomePass string
Timeout time.Duration
}
// NewSSHClient creates a new SSH client.
func NewSSHClient(cfg SSHConfig) (*SSHClient, error) {
if cfg.Port == 0 {
cfg.Port = 22
}
if cfg.User == "" {
cfg.User = "root"
}
if cfg.Timeout == 0 {
cfg.Timeout = 30 * time.Second
}
client := &SSHClient{
host: cfg.Host,
port: cfg.Port,
user: cfg.User,
password: cfg.Password,
keyFile: cfg.KeyFile,
become: cfg.Become,
becomeUser: cfg.BecomeUser,
becomePass: cfg.BecomePass,
}
return client, nil
}
// Connect establishes the SSH connection.
func (c *SSHClient) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client != nil {
return nil
}
var authMethods []ssh.AuthMethod
// Try key-based auth first
if c.keyFile != "" {
keyPath := c.keyFile
if strings.HasPrefix(keyPath, "~") {
home, _ := os.UserHomeDir()
keyPath = filepath.Join(home, keyPath[1:])
}
if key, err := os.ReadFile(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey(key); err == nil {
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
}
}
// Try default SSH keys
if len(authMethods) == 0 {
home, _ := os.UserHomeDir()
defaultKeys := []string{
filepath.Join(home, ".ssh", "id_ed25519"),
filepath.Join(home, ".ssh", "id_rsa"),
}
for _, keyPath := range defaultKeys {
if key, err := os.ReadFile(keyPath); err == nil {
if signer, err := ssh.ParsePrivateKey(key); err == nil {
authMethods = append(authMethods, ssh.PublicKeys(signer))
break
}
}
}
}
// Fall back to password auth
if c.password != "" {
authMethods = append(authMethods, ssh.Password(c.password))
authMethods = append(authMethods, ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range questions {
answers[i] = c.password
}
return answers, nil
}))
}
if len(authMethods) == 0 {
return fmt.Errorf("no authentication method available")
}
config := &ssh.ClientConfig{
User: c.user,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: proper host key checking
Timeout: 30 * time.Second,
}
addr := fmt.Sprintf("%s:%d", c.host, c.port)
// Connect with context timeout
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("dial %s: %w", addr, err)
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
_ = conn.Close()
return fmt.Errorf("ssh connect %s: %w", addr, err)
}
c.client = ssh.NewClient(sshConn, chans, reqs)
return nil
}
// Close closes the SSH connection.
func (c *SSHClient) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client != nil {
err := c.client.Close()
c.client = nil
return err
}
return nil
}
// Run executes a command on the remote host.
func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string, exitCode int, err error) {
if err := c.Connect(ctx); err != nil {
return "", "", -1, err
}
session, err := c.client.NewSession()
if err != nil {
return "", "", -1, fmt.Errorf("new session: %w", err)
}
defer func() { _ = session.Close() }()
var stdoutBuf, stderrBuf bytes.Buffer
session.Stdout = &stdoutBuf
session.Stderr = &stderrBuf
// Apply become if needed
if c.become {
becomeUser := c.becomeUser
if becomeUser == "" {
becomeUser = "root"
}
// Escape single quotes in the command
escapedCmd := strings.ReplaceAll(cmd, "'", "'\\''")
if c.becomePass != "" {
// Use sudo with password via stdin (-S flag)
cmd = fmt.Sprintf("echo '%s' | sudo -S -u %s bash -c '%s'", c.becomePass, becomeUser, escapedCmd)
} else if c.password != "" {
// Try using connection password for sudo
cmd = fmt.Sprintf("echo '%s' | sudo -S -u %s bash -c '%s'", c.password, becomeUser, escapedCmd)
} else {
// Try passwordless sudo
cmd = fmt.Sprintf("sudo -n -u %s bash -c '%s'", becomeUser, escapedCmd)
}
}
// Run with context
done := make(chan error, 1)
go func() {
done <- session.Run(cmd)
}()
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGKILL)
return "", "", -1, ctx.Err()
case err := <-done:
exitCode = 0
if err != nil {
if exitErr, ok := err.(*ssh.ExitError); ok {
exitCode = exitErr.ExitStatus()
} else {
return stdoutBuf.String(), stderrBuf.String(), -1, err
}
}
return stdoutBuf.String(), stderrBuf.String(), exitCode, nil
}
}
// RunScript runs a script on the remote host.
func (c *SSHClient) RunScript(ctx context.Context, script string) (stdout, stderr string, exitCode int, err error) {
// Escape the script for heredoc
cmd := fmt.Sprintf("bash <<'ANSIBLE_SCRIPT_EOF'\n%s\nANSIBLE_SCRIPT_EOF", script)
return c.Run(ctx, cmd)
}
// Upload copies a file to the remote host.
func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, mode os.FileMode) error {
if err := c.Connect(ctx); err != nil {
return err
}
session, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("new session: %w", err)
}
defer func() { _ = session.Close() }()
// Read content
content, err := io.ReadAll(local)
if err != nil {
return fmt.Errorf("read content: %w", err)
}
// Create parent directory
dir := filepath.Dir(remote)
dirCmd := fmt.Sprintf("mkdir -p %q", dir)
if c.become {
dirCmd = fmt.Sprintf("sudo mkdir -p %q", dir)
}
if _, _, _, err := c.Run(ctx, dirCmd); err != nil {
return fmt.Errorf("create parent dir: %w", err)
}
// Use cat to write the file (simpler than SCP)
writeCmd := fmt.Sprintf("cat > %q && chmod %o %q", remote, mode, remote)
if c.become {
writeCmd = fmt.Sprintf("sudo bash -c 'cat > %q && chmod %o %q'", remote, mode, remote)
}
session2, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("new session for write: %w", err)
}
defer func() { _ = session2.Close() }()
stdin, err := session2.StdinPipe()
if err != nil {
return fmt.Errorf("stdin pipe: %w", err)
}
var stderrBuf bytes.Buffer
session2.Stderr = &stderrBuf
if err := session2.Start(writeCmd); err != nil {
return fmt.Errorf("start write: %w", err)
}
if _, err := stdin.Write(content); err != nil {
return fmt.Errorf("write content: %w", err)
}
_ = stdin.Close()
if err := session2.Wait(); err != nil {
return fmt.Errorf("write failed: %w (stderr: %s)", err, stderrBuf.String())
}
return nil
}
// Download copies a file from the remote host.
func (c *SSHClient) Download(ctx context.Context, remote string) ([]byte, error) {
if err := c.Connect(ctx); err != nil {
return nil, err
}
cmd := fmt.Sprintf("cat %q", remote)
if c.become {
cmd = fmt.Sprintf("sudo cat %q", remote)
}
stdout, stderr, exitCode, err := c.Run(ctx, cmd)
if err != nil {
return nil, err
}
if exitCode != 0 {
return nil, fmt.Errorf("cat failed: %s", stderr)
}
return []byte(stdout), nil
}
// FileExists checks if a file exists on the remote host.
func (c *SSHClient) FileExists(ctx context.Context, path string) (bool, error) {
cmd := fmt.Sprintf("test -e %q && echo yes || echo no", path)
stdout, _, exitCode, err := c.Run(ctx, cmd)
if err != nil {
return false, err
}
if exitCode != 0 {
// test command failed but didn't error - file doesn't exist
return false, nil
}
return strings.TrimSpace(stdout) == "yes", nil
}
// Stat returns file info from the remote host.
func (c *SSHClient) Stat(ctx context.Context, path string) (map[string]any, error) {
// Simple approach - get basic file info
cmd := fmt.Sprintf(`
if [ -e %q ]; then
if [ -d %q ]; then
echo "exists=true isdir=true"
else
echo "exists=true isdir=false"
fi
else
echo "exists=false"
fi
`, path, path)
stdout, _, _, err := c.Run(ctx, cmd)
if err != nil {
return nil, err
}
result := make(map[string]any)
parts := strings.Fields(strings.TrimSpace(stdout))
for _, part := range parts {
kv := strings.SplitN(part, "=", 2)
if len(kv) == 2 {
result[kv[0]] = kv[1] == "true"
}
}
return result, nil
}
// SetBecome enables privilege escalation.
func (c *SSHClient) SetBecome(become bool, user, password string) {
c.mu.Lock()
defer c.mu.Unlock()
c.become = become
if user != "" {
c.becomeUser = user
}
if password != "" {
c.becomePass = password
}
}