feat: add RemoteTransport interface for SSH abstraction

Introduce RemoteTransport interface (Run, CopyFrom, CopyTo) with
SSHTransport implementation using ssh/scp binaries. AgentConfig gains
a Transport field with lazy initialization from M3 credentials.

All internal callers (DiscoverCheckpoints, processMLXNative,
processWithConversion) now use cfg.transport() instead of global
SSHCommand/SCPFrom. Old functions preserved as deprecated wrappers.

Enables mock injection for testing agent loop without real SSH.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-20 03:09:42 +00:00
parent 33939fe038
commit 1c2a6a6902
4 changed files with 159 additions and 43 deletions

View file

@ -22,6 +22,20 @@ type AgentConfig struct {
Force bool
OneShot bool
DryRun bool
// Transport is the remote transport used for SSH commands and file transfers.
// If nil, an SSHTransport is created from M3Host/M3User/M3SSHKey.
Transport RemoteTransport
}
// transport returns the configured RemoteTransport, lazily creating an
// SSHTransport from the M3 fields if none was set.
func (c *AgentConfig) transport() RemoteTransport {
if c.Transport != nil {
return c.Transport
}
c.Transport = NewSSHTransport(c.M3Host, c.M3User, c.M3SSHKey)
return c.Transport
}
// Checkpoint represents a discovered adapter checkpoint on M3.

View file

@ -88,10 +88,12 @@ func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) err
localSF := filepath.Join(localAdapterDir, cp.Filename)
localCfg := filepath.Join(localAdapterDir, "adapter_config.json")
if err := SCPFrom(cfg, remoteSF, localSF); err != nil {
ctx := context.Background()
t := cfg.transport()
if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil {
return fmt.Errorf("scp safetensors: %w", err)
}
if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil {
if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil {
return fmt.Errorf("scp config: %w", err)
}
@ -105,8 +107,6 @@ func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) err
return fmt.Errorf("ollama create: %w", err)
}
log.Printf("Ollama model %s ready", tempModel)
ctx := context.Background()
probeBackend := NewHTTPBackend(cfg.JudgeURL, tempModel)
const baseTS int64 = 1739577600
@ -170,10 +170,12 @@ func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint
remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename)
remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir)
if err := SCPFrom(cfg, remoteSF, localSF); err != nil {
ctx := context.Background()
t := cfg.transport()
if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil {
return fmt.Errorf("scp safetensors: %w", err)
}
if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil {
if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil {
return fmt.Errorf("scp config: %w", err)
}
@ -184,7 +186,6 @@ func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint
}
log.Println("Running 23 capability probes...")
ctx := context.Background()
modelName := cfg.Model
if modelName == "" {
modelName = cp.ModelTag

View file

@ -1,6 +1,7 @@
package ml
import (
"context"
"fmt"
"log"
"os"
@ -96,7 +97,9 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
if cfg.Filter != "" {
pattern = "adapters-" + cfg.Filter + "*"
}
out, err := SSHCommand(cfg, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern))
t := cfg.transport()
ctx := context.Background()
out, err := t.Run(ctx, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern))
if err != nil {
return nil, fmt.Errorf("list adapter dirs: %w", err)
}
@ -109,7 +112,7 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
if dirpath == "" {
continue
}
subOut, subErr := SSHCommand(cfg, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath))
subOut, subErr := t.Run(ctx, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath))
if subErr == nil && strings.TrimSpace(subOut) != "" {
for _, sub := range strings.Split(strings.TrimSpace(subOut), "\n") {
if sub != "" {
@ -124,7 +127,7 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
for _, dirpath := range adapterDirs {
dirname := strings.TrimPrefix(dirpath, cfg.M3AdapterBase+"/")
filesOut, err := SSHCommand(cfg, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath))
filesOut, err := t.Run(ctx, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath))
if err != nil {
continue
}

View file

@ -1,65 +1,163 @@
package ml
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
)
// SSHCommand executes a command on M3 via SSH.
func SSHCommand(cfg *AgentConfig, cmd string) (string, error) {
sshArgs := []string{
"-o", "ConnectTimeout=10",
// RemoteTransport abstracts remote command execution and file transfer.
// Implementations may use SSH/SCP, Docker exec, or in-memory fakes for testing.
type RemoteTransport interface {
// Run executes a command on the remote host and returns combined output.
Run(ctx context.Context, cmd string) (string, error)
// CopyFrom copies a file from the remote host to a local path.
CopyFrom(ctx context.Context, remote, local string) error
// CopyTo copies a local file to the remote host.
CopyTo(ctx context.Context, local, remote string) error
}
// SSHTransport implements RemoteTransport using the ssh and scp binaries.
type SSHTransport struct {
Host string
User string
KeyPath string
Port string
Timeout time.Duration
}
// SSHOption configures an SSHTransport.
type SSHOption func(*SSHTransport)
// WithPort sets a non-default SSH port.
func WithPort(port string) SSHOption {
return func(t *SSHTransport) {
t.Port = port
}
}
// WithTimeout sets the SSH connection timeout.
func WithTimeout(d time.Duration) SSHOption {
return func(t *SSHTransport) {
t.Timeout = d
}
}
// NewSSHTransport creates an SSHTransport with the given credentials and options.
func NewSSHTransport(host, user, keyPath string, opts ...SSHOption) *SSHTransport {
t := &SSHTransport{
Host: host,
User: user,
KeyPath: keyPath,
Port: "22",
Timeout: 10 * time.Second,
}
for _, o := range opts {
o(t)
}
return t
}
// commonArgs returns the shared SSH options for both ssh and scp.
func (t *SSHTransport) commonArgs() []string {
timeout := int(t.Timeout.Seconds())
if timeout < 1 {
timeout = 10
}
args := []string{
"-o", fmt.Sprintf("ConnectTimeout=%d", timeout),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", cfg.M3SSHKey,
fmt.Sprintf("%s@%s", cfg.M3User, cfg.M3Host),
cmd,
"-i", t.KeyPath,
}
result, err := exec.Command("ssh", sshArgs...).CombinedOutput()
if t.Port != "" && t.Port != "22" {
args = append(args, "-P", t.Port)
}
return args
}
// sshPortArgs returns the port flag for ssh (uses -p, not -P).
func (t *SSHTransport) sshPortArgs() []string {
timeout := int(t.Timeout.Seconds())
if timeout < 1 {
timeout = 10
}
args := []string{
"-o", fmt.Sprintf("ConnectTimeout=%d", timeout),
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", t.KeyPath,
}
if t.Port != "" && t.Port != "22" {
args = append(args, "-p", t.Port)
}
return args
}
// Run executes a command on the remote host via ssh.
func (t *SSHTransport) Run(ctx context.Context, cmd string) (string, error) {
args := t.sshPortArgs()
args = append(args, fmt.Sprintf("%s@%s", t.User, t.Host), cmd)
c := exec.CommandContext(ctx, "ssh", args...)
result, err := c.CombinedOutput()
if err != nil {
return "", fmt.Errorf("ssh %q: %w: %s", cmd, err, strings.TrimSpace(string(result)))
}
return string(result), nil
}
// SCPFrom copies a file from M3 to a local path.
func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error {
os.MkdirAll(filepath.Dir(localPath), 0755)
scpArgs := []string{
"-o", "ConnectTimeout=10",
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", cfg.M3SSHKey,
fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath),
localPath,
}
result, err := exec.Command("scp", scpArgs...).CombinedOutput()
// CopyFrom copies a file from the remote host to a local path via scp.
func (t *SSHTransport) CopyFrom(ctx context.Context, remote, local string) error {
os.MkdirAll(filepath.Dir(local), 0755)
args := t.commonArgs()
args = append(args, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote), local)
c := exec.CommandContext(ctx, "scp", args...)
result, err := c.CombinedOutput()
if err != nil {
return fmt.Errorf("scp %s: %w: %s", remotePath, err, strings.TrimSpace(string(result)))
return fmt.Errorf("scp %s: %w: %s", remote, err, strings.TrimSpace(string(result)))
}
return nil
}
// SCPTo copies a local file to M3.
func SCPTo(cfg *AgentConfig, localPath, remotePath string) error {
scpArgs := []string{
"-o", "ConnectTimeout=10",
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=no",
"-i", cfg.M3SSHKey,
localPath,
fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath),
}
result, err := exec.Command("scp", scpArgs...).CombinedOutput()
// CopyTo copies a local file to the remote host via scp.
func (t *SSHTransport) CopyTo(ctx context.Context, local, remote string) error {
args := t.commonArgs()
args = append(args, local, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote))
c := exec.CommandContext(ctx, "scp", args...)
result, err := c.CombinedOutput()
if err != nil {
return fmt.Errorf("scp to %s: %w: %s", remotePath, err, strings.TrimSpace(string(result)))
return fmt.Errorf("scp to %s: %w: %s", remote, err, strings.TrimSpace(string(result)))
}
return nil
}
// SSHCommand executes a command on M3 via SSH.
// Deprecated: Use AgentConfig.Transport.Run() instead.
func SSHCommand(cfg *AgentConfig, cmd string) (string, error) {
return cfg.transport().Run(context.Background(), cmd)
}
// SCPFrom copies a file from M3 to a local path.
// Deprecated: Use AgentConfig.Transport.CopyFrom() instead.
func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error {
return cfg.transport().CopyFrom(context.Background(), remotePath, localPath)
}
// SCPTo copies a local file to M3.
// Deprecated: Use AgentConfig.Transport.CopyTo() instead.
func SCPTo(cfg *AgentConfig, localPath, remotePath string) error {
return cfg.transport().CopyTo(context.Background(), localPath, remotePath)
}
// fileBase returns the last component of a path.
func fileBase(path string) string {
if i := strings.LastIndexAny(path, "/\\"); i >= 0 {