go-ml/agent_ssh.go
Snider 1c2a6a6902 feat: add RemoteTransport interface for SSH abstraction
Introduce RemoteTransport interface (Run, CopyFrom, CopyTo) with
SSHTransport implementation using ssh/scp binaries. AgentConfig gains
a Transport field with lazy initialization from M3 credentials.

All internal callers (DiscoverCheckpoints, processMLXNative,
processWithConversion) now use cfg.transport() instead of global
SSHCommand/SCPFrom. Old functions preserved as deprecated wrappers.

Enables mock injection for testing agent loop without real SSH.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 03:09:42 +00:00

200 lines
5.2 KiB
Go

package ml
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
)
// RemoteTransport abstracts remote command execution and file transfer.
// Implementations may use SSH/SCP, Docker exec, or in-memory fakes for testing.
type RemoteTransport interface {
// Run executes a command on the remote host and returns combined output.
Run(ctx context.Context, cmd string) (string, error)
// CopyFrom copies a file from the remote host to a local path.
CopyFrom(ctx context.Context, remote, local string) error
// CopyTo copies a local file to the remote host.
CopyTo(ctx context.Context, local, remote string) error
}
// SSHTransport implements RemoteTransport using the ssh and scp binaries.
type SSHTransport struct {
Host string
User string
KeyPath string
Port string
Timeout time.Duration
}
// SSHOption configures an SSHTransport.
type SSHOption func(*SSHTransport)
// WithPort sets a non-default SSH port.
func WithPort(port string) SSHOption {
return func(t *SSHTransport) {
t.Port = port
}
}
// WithTimeout sets the SSH connection timeout.
func WithTimeout(d time.Duration) SSHOption {
return func(t *SSHTransport) {
t.Timeout = d
}
}
// NewSSHTransport creates an SSHTransport with the given credentials and options.
func NewSSHTransport(host, user, keyPath string, opts ...SSHOption) *SSHTransport {
t := &SSHTransport{
Host: host,
User: user,
KeyPath: keyPath,
Port: "22",
Timeout: 10 * time.Second,
}
for _, o := range opts {
o(t)
}
return t
}
// commonArgs returns the shared SSH options for both ssh and scp.
func (t *SSHTransport) commonArgs() []string {
timeout := int(t.Timeout.Seconds())
if timeout < 1 {
timeout = 10
}
args := []string{
"-o", fmt.Sprintf("ConnectTimeout=%d", timeout),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", t.KeyPath,
}
if t.Port != "" && t.Port != "22" {
args = append(args, "-P", t.Port)
}
return args
}
// sshPortArgs returns the port flag for ssh (uses -p, not -P).
func (t *SSHTransport) sshPortArgs() []string {
timeout := int(t.Timeout.Seconds())
if timeout < 1 {
timeout = 10
}
args := []string{
"-o", fmt.Sprintf("ConnectTimeout=%d", timeout),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", t.KeyPath,
}
if t.Port != "" && t.Port != "22" {
args = append(args, "-p", t.Port)
}
return args
}
// Run executes a command on the remote host via ssh.
func (t *SSHTransport) Run(ctx context.Context, cmd string) (string, error) {
args := t.sshPortArgs()
args = append(args, fmt.Sprintf("%s@%s", t.User, t.Host), cmd)
c := exec.CommandContext(ctx, "ssh", args...)
result, err := c.CombinedOutput()
if err != nil {
return "", fmt.Errorf("ssh %q: %w: %s", cmd, err, strings.TrimSpace(string(result)))
}
return string(result), nil
}
// CopyFrom copies a file from the remote host to a local path via scp.
func (t *SSHTransport) CopyFrom(ctx context.Context, remote, local string) error {
os.MkdirAll(filepath.Dir(local), 0755)
args := t.commonArgs()
args = append(args, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote), local)
c := exec.CommandContext(ctx, "scp", args...)
result, err := c.CombinedOutput()
if err != nil {
return fmt.Errorf("scp %s: %w: %s", remote, err, strings.TrimSpace(string(result)))
}
return nil
}
// CopyTo copies a local file to the remote host via scp.
func (t *SSHTransport) CopyTo(ctx context.Context, local, remote string) error {
args := t.commonArgs()
args = append(args, local, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote))
c := exec.CommandContext(ctx, "scp", args...)
result, err := c.CombinedOutput()
if err != nil {
return fmt.Errorf("scp to %s: %w: %s", remote, err, strings.TrimSpace(string(result)))
}
return nil
}
// SSHCommand executes a command on M3 via SSH.
// Deprecated: Use AgentConfig.Transport.Run() instead.
func SSHCommand(cfg *AgentConfig, cmd string) (string, error) {
return cfg.transport().Run(context.Background(), cmd)
}
// SCPFrom copies a file from M3 to a local path.
// Deprecated: Use AgentConfig.Transport.CopyFrom() instead.
func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error {
return cfg.transport().CopyFrom(context.Background(), remotePath, localPath)
}
// SCPTo copies a local file to M3.
// Deprecated: Use AgentConfig.Transport.CopyTo() instead.
func SCPTo(cfg *AgentConfig, localPath, remotePath string) error {
return cfg.transport().CopyTo(context.Background(), localPath, remotePath)
}
// fileBase returns the last component of a path.
func fileBase(path string) string {
if i := strings.LastIndexAny(path, "/\\"); i >= 0 {
return path[i+1:]
}
return path
}
// EnvOr returns the environment variable value or a fallback.
func EnvOr(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
// IntEnvOr returns the integer environment variable value or a fallback.
func IntEnvOr(key string, fallback int) int {
v := os.Getenv(key)
if v == "" {
return fallback
}
var n int
fmt.Sscanf(v, "%d", &n)
if n == 0 {
return fallback
}
return n
}
// ExpandHome expands ~ to the user's home directory.
func ExpandHome(path string) string {
if strings.HasPrefix(path, "~/") {
home, err := os.UserHomeDir()
if err == nil {
return filepath.Join(home, path[2:])
}
}
return path
}