From 1c2a6a6902c689155d9eabd52521ee7e1b1c598e Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 20 Feb 2026 03:09:42 +0000 Subject: [PATCH] feat: add RemoteTransport interface for SSH abstraction Introduce RemoteTransport interface (Run, CopyFrom, CopyTo) with SSHTransport implementation using ssh/scp binaries. AgentConfig gains a Transport field with lazy initialization from M3 credentials. All internal callers (DiscoverCheckpoints, processMLXNative, processWithConversion) now use cfg.transport() instead of global SSHCommand/SCPFrom. Old functions preserved as deprecated wrappers. Enables mock injection for testing agent loop without real SSH. Co-Authored-By: Virgil --- agent_config.go | 14 ++++ agent_eval.go | 15 +++-- agent_execute.go | 9 ++- agent_ssh.go | 164 +++++++++++++++++++++++++++++++++++++---------- 4 files changed, 159 insertions(+), 43 deletions(-) diff --git a/agent_config.go b/agent_config.go index 342e888..a34c278 100644 --- a/agent_config.go +++ b/agent_config.go @@ -22,6 +22,20 @@ type AgentConfig struct { Force bool OneShot bool DryRun bool + + // Transport is the remote transport used for SSH commands and file transfers. + // If nil, an SSHTransport is created from M3Host/M3User/M3SSHKey. + Transport RemoteTransport +} + +// transport returns the configured RemoteTransport, lazily creating an +// SSHTransport from the M3 fields if none was set. +func (c *AgentConfig) transport() RemoteTransport { + if c.Transport != nil { + return c.Transport + } + c.Transport = NewSSHTransport(c.M3Host, c.M3User, c.M3SSHKey) + return c.Transport } // Checkpoint represents a discovered adapter checkpoint on M3. diff --git a/agent_eval.go b/agent_eval.go index 8267978..946dc8d 100644 --- a/agent_eval.go +++ b/agent_eval.go @@ -88,10 +88,12 @@ func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) err localSF := filepath.Join(localAdapterDir, cp.Filename) localCfg := filepath.Join(localAdapterDir, "adapter_config.json") - if err := SCPFrom(cfg, remoteSF, localSF); err != nil { + ctx := context.Background() + t := cfg.transport() + if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil { return fmt.Errorf("scp safetensors: %w", err) } - if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil { + if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil { return fmt.Errorf("scp config: %w", err) } @@ -105,8 +107,6 @@ func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) err return fmt.Errorf("ollama create: %w", err) } log.Printf("Ollama model %s ready", tempModel) - - ctx := context.Background() probeBackend := NewHTTPBackend(cfg.JudgeURL, tempModel) const baseTS int64 = 1739577600 @@ -170,10 +170,12 @@ func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename) remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir) - if err := SCPFrom(cfg, remoteSF, localSF); err != nil { + ctx := context.Background() + t := cfg.transport() + if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil { return fmt.Errorf("scp safetensors: %w", err) } - if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil { + if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil { return fmt.Errorf("scp config: %w", err) } @@ -184,7 +186,6 @@ func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint } log.Println("Running 23 capability probes...") - ctx := context.Background() modelName := cfg.Model if modelName == "" { modelName = cp.ModelTag diff --git a/agent_execute.go b/agent_execute.go index 6157f8e..42b2f02 100644 --- a/agent_execute.go +++ b/agent_execute.go @@ -1,6 +1,7 @@ package ml import ( + "context" "fmt" "log" "os" @@ -96,7 +97,9 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) { if cfg.Filter != "" { pattern = "adapters-" + cfg.Filter + "*" } - out, err := SSHCommand(cfg, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern)) + t := cfg.transport() + ctx := context.Background() + out, err := t.Run(ctx, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern)) if err != nil { return nil, fmt.Errorf("list adapter dirs: %w", err) } @@ -109,7 +112,7 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) { if dirpath == "" { continue } - subOut, subErr := SSHCommand(cfg, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath)) + subOut, subErr := t.Run(ctx, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath)) if subErr == nil && strings.TrimSpace(subOut) != "" { for _, sub := range strings.Split(strings.TrimSpace(subOut), "\n") { if sub != "" { @@ -124,7 +127,7 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) { for _, dirpath := range adapterDirs { dirname := strings.TrimPrefix(dirpath, cfg.M3AdapterBase+"/") - filesOut, err := SSHCommand(cfg, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath)) + filesOut, err := t.Run(ctx, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath)) if err != nil { continue } diff --git a/agent_ssh.go b/agent_ssh.go index 770b769..634f594 100644 --- a/agent_ssh.go +++ b/agent_ssh.go @@ -1,65 +1,163 @@ package ml import ( + "context" "fmt" "os" "os/exec" "path/filepath" "strings" + "time" ) -// SSHCommand executes a command on M3 via SSH. -func SSHCommand(cfg *AgentConfig, cmd string) (string, error) { - sshArgs := []string{ - "-o", "ConnectTimeout=10", +// RemoteTransport abstracts remote command execution and file transfer. +// Implementations may use SSH/SCP, Docker exec, or in-memory fakes for testing. +type RemoteTransport interface { + // Run executes a command on the remote host and returns combined output. + Run(ctx context.Context, cmd string) (string, error) + + // CopyFrom copies a file from the remote host to a local path. + CopyFrom(ctx context.Context, remote, local string) error + + // CopyTo copies a local file to the remote host. + CopyTo(ctx context.Context, local, remote string) error +} + +// SSHTransport implements RemoteTransport using the ssh and scp binaries. +type SSHTransport struct { + Host string + User string + KeyPath string + Port string + Timeout time.Duration +} + +// SSHOption configures an SSHTransport. +type SSHOption func(*SSHTransport) + +// WithPort sets a non-default SSH port. +func WithPort(port string) SSHOption { + return func(t *SSHTransport) { + t.Port = port + } +} + +// WithTimeout sets the SSH connection timeout. +func WithTimeout(d time.Duration) SSHOption { + return func(t *SSHTransport) { + t.Timeout = d + } +} + +// NewSSHTransport creates an SSHTransport with the given credentials and options. +func NewSSHTransport(host, user, keyPath string, opts ...SSHOption) *SSHTransport { + t := &SSHTransport{ + Host: host, + User: user, + KeyPath: keyPath, + Port: "22", + Timeout: 10 * time.Second, + } + for _, o := range opts { + o(t) + } + return t +} + +// commonArgs returns the shared SSH options for both ssh and scp. +func (t *SSHTransport) commonArgs() []string { + timeout := int(t.Timeout.Seconds()) + if timeout < 1 { + timeout = 10 + } + args := []string{ + "-o", fmt.Sprintf("ConnectTimeout=%d", timeout), "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no", - "-i", cfg.M3SSHKey, - fmt.Sprintf("%s@%s", cfg.M3User, cfg.M3Host), - cmd, + "-i", t.KeyPath, } - result, err := exec.Command("ssh", sshArgs...).CombinedOutput() + if t.Port != "" && t.Port != "22" { + args = append(args, "-P", t.Port) + } + return args +} + +// sshPortArgs returns the port flag for ssh (uses -p, not -P). +func (t *SSHTransport) sshPortArgs() []string { + timeout := int(t.Timeout.Seconds()) + if timeout < 1 { + timeout = 10 + } + args := []string{ + "-o", fmt.Sprintf("ConnectTimeout=%d", timeout), + "-o", "BatchMode=yes", + "-o", "StrictHostKeyChecking=no", + "-i", t.KeyPath, + } + if t.Port != "" && t.Port != "22" { + args = append(args, "-p", t.Port) + } + return args +} + +// Run executes a command on the remote host via ssh. +func (t *SSHTransport) Run(ctx context.Context, cmd string) (string, error) { + args := t.sshPortArgs() + args = append(args, fmt.Sprintf("%s@%s", t.User, t.Host), cmd) + + c := exec.CommandContext(ctx, "ssh", args...) + result, err := c.CombinedOutput() if err != nil { return "", fmt.Errorf("ssh %q: %w: %s", cmd, err, strings.TrimSpace(string(result))) } return string(result), nil } -// SCPFrom copies a file from M3 to a local path. -func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error { - os.MkdirAll(filepath.Dir(localPath), 0755) - scpArgs := []string{ - "-o", "ConnectTimeout=10", - "-o", "BatchMode=yes", - "-o", "StrictHostKeyChecking=no", - "-i", cfg.M3SSHKey, - fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath), - localPath, - } - result, err := exec.Command("scp", scpArgs...).CombinedOutput() +// CopyFrom copies a file from the remote host to a local path via scp. +func (t *SSHTransport) CopyFrom(ctx context.Context, remote, local string) error { + os.MkdirAll(filepath.Dir(local), 0755) + args := t.commonArgs() + args = append(args, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote), local) + + c := exec.CommandContext(ctx, "scp", args...) + result, err := c.CombinedOutput() if err != nil { - return fmt.Errorf("scp %s: %w: %s", remotePath, err, strings.TrimSpace(string(result))) + return fmt.Errorf("scp %s: %w: %s", remote, err, strings.TrimSpace(string(result))) } return nil } -// SCPTo copies a local file to M3. -func SCPTo(cfg *AgentConfig, localPath, remotePath string) error { - scpArgs := []string{ - "-o", "ConnectTimeout=10", - "-o", "BatchMode=yes", - "-o", "StrictHostKeyChecking=no", - "-i", cfg.M3SSHKey, - localPath, - fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath), - } - result, err := exec.Command("scp", scpArgs...).CombinedOutput() +// CopyTo copies a local file to the remote host via scp. +func (t *SSHTransport) CopyTo(ctx context.Context, local, remote string) error { + args := t.commonArgs() + args = append(args, local, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote)) + + c := exec.CommandContext(ctx, "scp", args...) + result, err := c.CombinedOutput() if err != nil { - return fmt.Errorf("scp to %s: %w: %s", remotePath, err, strings.TrimSpace(string(result))) + return fmt.Errorf("scp to %s: %w: %s", remote, err, strings.TrimSpace(string(result))) } return nil } +// SSHCommand executes a command on M3 via SSH. +// Deprecated: Use AgentConfig.Transport.Run() instead. +func SSHCommand(cfg *AgentConfig, cmd string) (string, error) { + return cfg.transport().Run(context.Background(), cmd) +} + +// SCPFrom copies a file from M3 to a local path. +// Deprecated: Use AgentConfig.Transport.CopyFrom() instead. +func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error { + return cfg.transport().CopyFrom(context.Background(), remotePath, localPath) +} + +// SCPTo copies a local file to M3. +// Deprecated: Use AgentConfig.Transport.CopyTo() instead. +func SCPTo(cfg *AgentConfig, localPath, remotePath string) error { + return cfg.transport().CopyTo(context.Background(), localPath, remotePath) +} + // fileBase returns the last component of a path. func fileBase(path string) string { if i := strings.LastIndexAny(path, "/\\"); i >= 0 {