496 lines
12 KiB
Go
496 lines
12 KiB
Go
package ansible
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"io/fs"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
coreio "dappco.re/go/core/io"
|
|
coreerr "dappco.re/go/core/log"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
)
|
|
|
|
// SSHClient handles SSH connections to remote hosts.
|
|
//
|
|
// Example:
|
|
//
|
|
// client, _ := NewSSHClient(SSHConfig{Host: "web1"})
|
|
type SSHClient struct {
|
|
host string
|
|
port int
|
|
user string
|
|
password string
|
|
keyFile string
|
|
client *ssh.Client
|
|
mu sync.Mutex
|
|
become bool
|
|
becomeUser string
|
|
becomePass string
|
|
timeout time.Duration
|
|
}
|
|
|
|
// SSHConfig holds SSH connection configuration.
|
|
//
|
|
// Example:
|
|
//
|
|
// cfg := SSHConfig{Host: "web1", User: "deploy", Port: 22}
|
|
type SSHConfig struct {
|
|
Host string
|
|
Port int
|
|
User string
|
|
Password string
|
|
KeyFile string
|
|
Become bool
|
|
BecomeUser string
|
|
BecomePass string
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// NewSSHClient creates a new SSH client.
|
|
//
|
|
// Example:
|
|
//
|
|
// client, err := NewSSHClient(SSHConfig{Host: "web1", User: "deploy"})
|
|
func NewSSHClient(cfg SSHConfig) (*SSHClient, error) {
|
|
if cfg.Port == 0 {
|
|
cfg.Port = 22
|
|
}
|
|
if cfg.User == "" {
|
|
cfg.User = "root"
|
|
}
|
|
if cfg.Timeout == 0 {
|
|
cfg.Timeout = 30 * time.Second
|
|
}
|
|
|
|
client := &SSHClient{
|
|
host: cfg.Host,
|
|
port: cfg.Port,
|
|
user: cfg.User,
|
|
password: cfg.Password,
|
|
keyFile: cfg.KeyFile,
|
|
become: cfg.Become,
|
|
becomeUser: cfg.BecomeUser,
|
|
becomePass: cfg.BecomePass,
|
|
timeout: cfg.Timeout,
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
// Connect establishes the SSH connection.
|
|
//
|
|
// Example:
|
|
//
|
|
// _ = client.Connect(context.Background())
|
|
func (c *SSHClient) Connect(ctx context.Context) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.client != nil {
|
|
return nil
|
|
}
|
|
|
|
var authMethods []ssh.AuthMethod
|
|
|
|
// Try key-based auth first
|
|
if c.keyFile != "" {
|
|
keyPath := c.keyFile
|
|
if corexHasPrefix(keyPath, "~") {
|
|
keyPath = joinPath(env("DIR_HOME"), keyPath[1:])
|
|
}
|
|
|
|
if key, err := coreio.Local.Read(keyPath); err == nil {
|
|
if signer, err := ssh.ParsePrivateKey([]byte(key)); err == nil {
|
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Try default SSH keys
|
|
if len(authMethods) == 0 {
|
|
home := env("DIR_HOME")
|
|
defaultKeys := []string{
|
|
joinPath(home, ".ssh", "id_ed25519"),
|
|
joinPath(home, ".ssh", "id_rsa"),
|
|
}
|
|
for _, keyPath := range defaultKeys {
|
|
if key, err := coreio.Local.Read(keyPath); err == nil {
|
|
if signer, err := ssh.ParsePrivateKey([]byte(key)); err == nil {
|
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fall back to password auth
|
|
if c.password != "" {
|
|
authMethods = append(authMethods, ssh.Password(c.password))
|
|
authMethods = append(authMethods, ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
|
answers := make([]string, len(questions))
|
|
for i := range questions {
|
|
answers[i] = c.password
|
|
}
|
|
return answers, nil
|
|
}))
|
|
}
|
|
|
|
if len(authMethods) == 0 {
|
|
return coreerr.E("ssh.Connect", "no authentication method available", nil)
|
|
}
|
|
|
|
// Host key verification
|
|
var hostKeyCallback ssh.HostKeyCallback
|
|
|
|
home := env("DIR_HOME")
|
|
if home == "" {
|
|
return coreerr.E("ssh.Connect", "failed to get user home dir", nil)
|
|
}
|
|
knownHostsPath := joinPath(home, ".ssh", "known_hosts")
|
|
|
|
// Ensure known_hosts file exists
|
|
if !coreio.Local.Exists(knownHostsPath) {
|
|
if err := coreio.Local.EnsureDir(pathDir(knownHostsPath)); err != nil {
|
|
return coreerr.E("ssh.Connect", "failed to create .ssh dir", err)
|
|
}
|
|
if err := coreio.Local.Write(knownHostsPath, ""); err != nil {
|
|
return coreerr.E("ssh.Connect", "failed to create known_hosts file", err)
|
|
}
|
|
}
|
|
|
|
cb, err := knownhosts.New(knownHostsPath)
|
|
if err != nil {
|
|
return coreerr.E("ssh.Connect", "failed to load known_hosts", err)
|
|
}
|
|
hostKeyCallback = cb
|
|
|
|
config := &ssh.ClientConfig{
|
|
User: c.user,
|
|
Auth: authMethods,
|
|
HostKeyCallback: hostKeyCallback,
|
|
Timeout: c.timeout,
|
|
}
|
|
|
|
addr := sprintf("%s:%d", c.host, c.port)
|
|
|
|
// Connect with context timeout
|
|
var d net.Dialer
|
|
conn, err := d.DialContext(ctx, "tcp", addr)
|
|
if err != nil {
|
|
return coreerr.E("ssh.Connect", sprintf("dial %s", addr), err)
|
|
}
|
|
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
|
if err != nil {
|
|
// conn is closed by NewClientConn on error
|
|
return coreerr.E("ssh.Connect", sprintf("ssh connect %s", addr), err)
|
|
}
|
|
|
|
c.client = ssh.NewClient(sshConn, chans, reqs)
|
|
return nil
|
|
}
|
|
|
|
// Close closes the SSH connection.
|
|
//
|
|
// Example:
|
|
//
|
|
// _ = client.Close()
|
|
func (c *SSHClient) Close() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.client != nil {
|
|
err := c.client.Close()
|
|
c.client = nil
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Run executes a command on the remote host.
|
|
//
|
|
// Example:
|
|
//
|
|
// stdout, stderr, rc, err := client.Run(context.Background(), "hostname")
|
|
func (c *SSHClient) Run(ctx context.Context, cmd string) (stdout, stderr string, exitCode int, err error) {
|
|
if err := c.Connect(ctx); err != nil {
|
|
return "", "", -1, err
|
|
}
|
|
|
|
session, err := c.client.NewSession()
|
|
if err != nil {
|
|
return "", "", -1, coreerr.E("ssh.Run", "new session", err)
|
|
}
|
|
defer func() { _ = session.Close() }()
|
|
|
|
var stdoutBuf, stderrBuf bytes.Buffer
|
|
session.Stdout = &stdoutBuf
|
|
session.Stderr = &stderrBuf
|
|
|
|
// Apply become if needed
|
|
if c.become {
|
|
becomeUser := c.becomeUser
|
|
if becomeUser == "" {
|
|
becomeUser = "root"
|
|
}
|
|
// Escape single quotes in the command
|
|
escapedCmd := replaceAll(cmd, "'", "'\\''")
|
|
if c.becomePass != "" {
|
|
// Use sudo with password via stdin (-S flag)
|
|
// We launch a goroutine to write the password to stdin
|
|
cmd = sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd)
|
|
stdin, err := session.StdinPipe()
|
|
if err != nil {
|
|
return "", "", -1, coreerr.E("ssh.Run", "stdin pipe", err)
|
|
}
|
|
go func() {
|
|
defer func() { _ = stdin.Close() }()
|
|
writeString(stdin, c.becomePass+"\n")
|
|
}()
|
|
} else if c.password != "" {
|
|
// Try using connection password for sudo
|
|
cmd = sprintf("sudo -S -u %s bash -c '%s'", becomeUser, escapedCmd)
|
|
stdin, err := session.StdinPipe()
|
|
if err != nil {
|
|
return "", "", -1, coreerr.E("ssh.Run", "stdin pipe", err)
|
|
}
|
|
go func() {
|
|
defer func() { _ = stdin.Close() }()
|
|
writeString(stdin, c.password+"\n")
|
|
}()
|
|
} else {
|
|
// Try passwordless sudo
|
|
cmd = sprintf("sudo -n -u %s bash -c '%s'", becomeUser, escapedCmd)
|
|
}
|
|
}
|
|
|
|
// Run with context
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- session.Run(cmd)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = session.Signal(ssh.SIGKILL)
|
|
return "", "", -1, ctx.Err()
|
|
case err := <-done:
|
|
exitCode = 0
|
|
if err != nil {
|
|
if exitErr, ok := err.(*ssh.ExitError); ok {
|
|
exitCode = exitErr.ExitStatus()
|
|
} else {
|
|
return stdoutBuf.String(), stderrBuf.String(), -1, err
|
|
}
|
|
}
|
|
return stdoutBuf.String(), stderrBuf.String(), exitCode, nil
|
|
}
|
|
}
|
|
|
|
// RunScript runs a script on the remote host.
|
|
//
|
|
// Example:
|
|
//
|
|
// stdout, stderr, rc, err := client.RunScript(context.Background(), "echo hello")
|
|
func (c *SSHClient) RunScript(ctx context.Context, script string) (stdout, stderr string, exitCode int, err error) {
|
|
// Escape the script for heredoc
|
|
cmd := sprintf("bash <<'ANSIBLE_SCRIPT_EOF'\n%s\nANSIBLE_SCRIPT_EOF", script)
|
|
return c.Run(ctx, cmd)
|
|
}
|
|
|
|
// Upload copies a file to the remote host.
|
|
//
|
|
// Example:
|
|
//
|
|
// err := client.Upload(context.Background(), newReader("hello"), "/tmp/hello.txt", 0644)
|
|
func (c *SSHClient) Upload(ctx context.Context, local io.Reader, remote string, mode fs.FileMode) error {
|
|
if err := c.Connect(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Read content
|
|
content, err := readAllString(local)
|
|
if err != nil {
|
|
return coreerr.E("ssh.Upload", "read content", err)
|
|
}
|
|
|
|
// Create parent directory
|
|
dir := pathDir(remote)
|
|
dirCmd := sprintf("mkdir -p %q", dir)
|
|
if c.become {
|
|
dirCmd = sprintf("sudo mkdir -p %q", dir)
|
|
}
|
|
if _, _, _, err := c.Run(ctx, dirCmd); err != nil {
|
|
return coreerr.E("ssh.Upload", "create parent dir", err)
|
|
}
|
|
|
|
// Use cat to write the file (simpler than SCP)
|
|
writeCmd := sprintf("cat > %q && chmod %o %q", remote, mode, remote)
|
|
|
|
// If become is needed, we construct a command that reads password then content from stdin
|
|
// But we need to be careful with handling stdin for sudo + cat.
|
|
// We'll use a session with piped stdin.
|
|
|
|
session2, err := c.client.NewSession()
|
|
if err != nil {
|
|
return coreerr.E("ssh.Upload", "new session for write", err)
|
|
}
|
|
defer func() { _ = session2.Close() }()
|
|
|
|
stdin, err := session2.StdinPipe()
|
|
if err != nil {
|
|
return coreerr.E("ssh.Upload", "stdin pipe", err)
|
|
}
|
|
|
|
var stderrBuf bytes.Buffer
|
|
session2.Stderr = &stderrBuf
|
|
|
|
if c.become {
|
|
becomeUser := c.becomeUser
|
|
if becomeUser == "" {
|
|
becomeUser = "root"
|
|
}
|
|
|
|
pass := c.becomePass
|
|
if pass == "" {
|
|
pass = c.password
|
|
}
|
|
|
|
if pass != "" {
|
|
// Use sudo -S with password from stdin
|
|
writeCmd = sprintf("sudo -S -u %s bash -c 'cat > %q && chmod %o %q'",
|
|
becomeUser, remote, mode, remote)
|
|
} else {
|
|
// Use passwordless sudo (sudo -n) to avoid consuming file content as password
|
|
writeCmd = sprintf("sudo -n -u %s bash -c 'cat > %q && chmod %o %q'",
|
|
becomeUser, remote, mode, remote)
|
|
}
|
|
|
|
if err := session2.Start(writeCmd); err != nil {
|
|
return coreerr.E("ssh.Upload", "start write", err)
|
|
}
|
|
|
|
go func() {
|
|
defer func() { _ = stdin.Close() }()
|
|
if pass != "" {
|
|
writeString(stdin, pass+"\n")
|
|
}
|
|
_, _ = stdin.Write([]byte(content))
|
|
}()
|
|
} else {
|
|
// Normal write
|
|
if err := session2.Start(writeCmd); err != nil {
|
|
return coreerr.E("ssh.Upload", "start write", err)
|
|
}
|
|
|
|
go func() {
|
|
defer func() { _ = stdin.Close() }()
|
|
_, _ = stdin.Write([]byte(content))
|
|
}()
|
|
}
|
|
|
|
if err := session2.Wait(); err != nil {
|
|
return coreerr.E("ssh.Upload", sprintf("write failed (stderr: %s)", stderrBuf.String()), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Download copies a file from the remote host.
|
|
//
|
|
// Example:
|
|
//
|
|
// data, err := client.Download(context.Background(), "/etc/hostname")
|
|
func (c *SSHClient) Download(ctx context.Context, remote string) ([]byte, error) {
|
|
if err := c.Connect(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cmd := sprintf("cat %q", remote)
|
|
|
|
stdout, stderr, exitCode, err := c.Run(ctx, cmd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if exitCode != 0 {
|
|
return nil, coreerr.E("ssh.Download", sprintf("cat failed: %s", stderr), nil)
|
|
}
|
|
|
|
return []byte(stdout), nil
|
|
}
|
|
|
|
// FileExists checks if a file exists on the remote host.
|
|
//
|
|
// Example:
|
|
//
|
|
// ok, err := client.FileExists(context.Background(), "/etc/hosts")
|
|
func (c *SSHClient) FileExists(ctx context.Context, path string) (bool, error) {
|
|
cmd := sprintf("test -e %q && echo yes || echo no", path)
|
|
stdout, _, exitCode, err := c.Run(ctx, cmd)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if exitCode != 0 {
|
|
// test command failed but didn't error - file doesn't exist
|
|
return false, nil
|
|
}
|
|
return corexTrimSpace(stdout) == "yes", nil
|
|
}
|
|
|
|
// Stat returns file info from the remote host.
|
|
//
|
|
// Example:
|
|
//
|
|
// info, err := client.Stat(context.Background(), "/etc/hosts")
|
|
func (c *SSHClient) Stat(ctx context.Context, path string) (map[string]any, error) {
|
|
// Simple approach - get basic file info
|
|
cmd := sprintf(`
|
|
if [ -e %q ]; then
|
|
if [ -d %q ]; then
|
|
echo "exists=true isdir=true"
|
|
else
|
|
echo "exists=true isdir=false"
|
|
fi
|
|
else
|
|
echo "exists=false"
|
|
fi
|
|
`, path, path)
|
|
|
|
stdout, _, _, err := c.Run(ctx, cmd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make(map[string]any)
|
|
parts := fields(corexTrimSpace(stdout))
|
|
for _, part := range parts {
|
|
kv := splitN(part, "=", 2)
|
|
if len(kv) == 2 {
|
|
result[kv[0]] = kv[1] == "true"
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// SetBecome enables privilege escalation.
|
|
//
|
|
// Example:
|
|
//
|
|
// client.SetBecome(true, "root", "")
|
|
func (c *SSHClient) SetBecome(become bool, user, password string) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.become = become
|
|
if user != "" {
|
|
c.becomeUser = user
|
|
}
|
|
if password != "" {
|
|
c.becomePass = password
|
|
}
|
|
}
|