feat: add RemoteTransport interface for SSH abstraction
Introduce RemoteTransport interface (Run, CopyFrom, CopyTo) with SSHTransport implementation using ssh/scp binaries. AgentConfig gains a Transport field with lazy initialization from M3 credentials. All internal callers (DiscoverCheckpoints, processMLXNative, processWithConversion) now use cfg.transport() instead of global SSHCommand/SCPFrom. Old functions preserved as deprecated wrappers. Enables mock injection for testing agent loop without real SSH. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
33939fe038
commit
1c2a6a6902
4 changed files with 159 additions and 43 deletions
|
|
@ -22,6 +22,20 @@ type AgentConfig struct {
|
|||
Force bool
|
||||
OneShot bool
|
||||
DryRun bool
|
||||
|
||||
// Transport is the remote transport used for SSH commands and file transfers.
|
||||
// If nil, an SSHTransport is created from M3Host/M3User/M3SSHKey.
|
||||
Transport RemoteTransport
|
||||
}
|
||||
|
||||
// transport returns the configured RemoteTransport, lazily creating an
|
||||
// SSHTransport from the M3 fields if none was set.
|
||||
func (c *AgentConfig) transport() RemoteTransport {
|
||||
if c.Transport != nil {
|
||||
return c.Transport
|
||||
}
|
||||
c.Transport = NewSSHTransport(c.M3Host, c.M3User, c.M3SSHKey)
|
||||
return c.Transport
|
||||
}
|
||||
|
||||
// Checkpoint represents a discovered adapter checkpoint on M3.
|
||||
|
|
|
|||
|
|
@ -88,10 +88,12 @@ func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) err
|
|||
localSF := filepath.Join(localAdapterDir, cp.Filename)
|
||||
localCfg := filepath.Join(localAdapterDir, "adapter_config.json")
|
||||
|
||||
if err := SCPFrom(cfg, remoteSF, localSF); err != nil {
|
||||
ctx := context.Background()
|
||||
t := cfg.transport()
|
||||
if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil {
|
||||
return fmt.Errorf("scp safetensors: %w", err)
|
||||
}
|
||||
if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil {
|
||||
if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil {
|
||||
return fmt.Errorf("scp config: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -105,8 +107,6 @@ func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) err
|
|||
return fmt.Errorf("ollama create: %w", err)
|
||||
}
|
||||
log.Printf("Ollama model %s ready", tempModel)
|
||||
|
||||
ctx := context.Background()
|
||||
probeBackend := NewHTTPBackend(cfg.JudgeURL, tempModel)
|
||||
|
||||
const baseTS int64 = 1739577600
|
||||
|
|
@ -170,10 +170,12 @@ func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint
|
|||
remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename)
|
||||
remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir)
|
||||
|
||||
if err := SCPFrom(cfg, remoteSF, localSF); err != nil {
|
||||
ctx := context.Background()
|
||||
t := cfg.transport()
|
||||
if err := t.CopyFrom(ctx, remoteSF, localSF); err != nil {
|
||||
return fmt.Errorf("scp safetensors: %w", err)
|
||||
}
|
||||
if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil {
|
||||
if err := t.CopyFrom(ctx, remoteCfg, localCfg); err != nil {
|
||||
return fmt.Errorf("scp config: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -184,7 +186,6 @@ func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint
|
|||
}
|
||||
|
||||
log.Println("Running 23 capability probes...")
|
||||
ctx := context.Background()
|
||||
modelName := cfg.Model
|
||||
if modelName == "" {
|
||||
modelName = cp.ModelTag
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
|
@ -96,7 +97,9 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
|
|||
if cfg.Filter != "" {
|
||||
pattern = "adapters-" + cfg.Filter + "*"
|
||||
}
|
||||
out, err := SSHCommand(cfg, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern))
|
||||
t := cfg.transport()
|
||||
ctx := context.Background()
|
||||
out, err := t.Run(ctx, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list adapter dirs: %w", err)
|
||||
}
|
||||
|
|
@ -109,7 +112,7 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
|
|||
if dirpath == "" {
|
||||
continue
|
||||
}
|
||||
subOut, subErr := SSHCommand(cfg, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath))
|
||||
subOut, subErr := t.Run(ctx, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath))
|
||||
if subErr == nil && strings.TrimSpace(subOut) != "" {
|
||||
for _, sub := range strings.Split(strings.TrimSpace(subOut), "\n") {
|
||||
if sub != "" {
|
||||
|
|
@ -124,7 +127,7 @@ func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) {
|
|||
for _, dirpath := range adapterDirs {
|
||||
dirname := strings.TrimPrefix(dirpath, cfg.M3AdapterBase+"/")
|
||||
|
||||
filesOut, err := SSHCommand(cfg, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath))
|
||||
filesOut, err := t.Run(ctx, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
|
|||
164
agent_ssh.go
164
agent_ssh.go
|
|
@ -1,65 +1,163 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SSHCommand executes a command on M3 via SSH.
|
||||
func SSHCommand(cfg *AgentConfig, cmd string) (string, error) {
|
||||
sshArgs := []string{
|
||||
"-o", "ConnectTimeout=10",
|
||||
// RemoteTransport abstracts remote command execution and file transfer.
|
||||
// Implementations may use SSH/SCP, Docker exec, or in-memory fakes for testing.
|
||||
type RemoteTransport interface {
|
||||
// Run executes a command on the remote host and returns combined output.
|
||||
Run(ctx context.Context, cmd string) (string, error)
|
||||
|
||||
// CopyFrom copies a file from the remote host to a local path.
|
||||
CopyFrom(ctx context.Context, remote, local string) error
|
||||
|
||||
// CopyTo copies a local file to the remote host.
|
||||
CopyTo(ctx context.Context, local, remote string) error
|
||||
}
|
||||
|
||||
// SSHTransport implements RemoteTransport using the ssh and scp binaries.
|
||||
type SSHTransport struct {
|
||||
Host string
|
||||
User string
|
||||
KeyPath string
|
||||
Port string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// SSHOption configures an SSHTransport.
|
||||
type SSHOption func(*SSHTransport)
|
||||
|
||||
// WithPort sets a non-default SSH port.
|
||||
func WithPort(port string) SSHOption {
|
||||
return func(t *SSHTransport) {
|
||||
t.Port = port
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the SSH connection timeout.
|
||||
func WithTimeout(d time.Duration) SSHOption {
|
||||
return func(t *SSHTransport) {
|
||||
t.Timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
// NewSSHTransport creates an SSHTransport with the given credentials and options.
|
||||
func NewSSHTransport(host, user, keyPath string, opts ...SSHOption) *SSHTransport {
|
||||
t := &SSHTransport{
|
||||
Host: host,
|
||||
User: user,
|
||||
KeyPath: keyPath,
|
||||
Port: "22",
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// commonArgs returns the shared SSH options for both ssh and scp.
|
||||
func (t *SSHTransport) commonArgs() []string {
|
||||
timeout := int(t.Timeout.Seconds())
|
||||
if timeout < 1 {
|
||||
timeout = 10
|
||||
}
|
||||
args := []string{
|
||||
"-o", fmt.Sprintf("ConnectTimeout=%d", timeout),
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", cfg.M3SSHKey,
|
||||
fmt.Sprintf("%s@%s", cfg.M3User, cfg.M3Host),
|
||||
cmd,
|
||||
"-i", t.KeyPath,
|
||||
}
|
||||
result, err := exec.Command("ssh", sshArgs...).CombinedOutput()
|
||||
if t.Port != "" && t.Port != "22" {
|
||||
args = append(args, "-P", t.Port)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// sshPortArgs returns the port flag for ssh (uses -p, not -P).
|
||||
func (t *SSHTransport) sshPortArgs() []string {
|
||||
timeout := int(t.Timeout.Seconds())
|
||||
if timeout < 1 {
|
||||
timeout = 10
|
||||
}
|
||||
args := []string{
|
||||
"-o", fmt.Sprintf("ConnectTimeout=%d", timeout),
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", t.KeyPath,
|
||||
}
|
||||
if t.Port != "" && t.Port != "22" {
|
||||
args = append(args, "-p", t.Port)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// Run executes a command on the remote host via ssh.
|
||||
func (t *SSHTransport) Run(ctx context.Context, cmd string) (string, error) {
|
||||
args := t.sshPortArgs()
|
||||
args = append(args, fmt.Sprintf("%s@%s", t.User, t.Host), cmd)
|
||||
|
||||
c := exec.CommandContext(ctx, "ssh", args...)
|
||||
result, err := c.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ssh %q: %w: %s", cmd, err, strings.TrimSpace(string(result)))
|
||||
}
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
// SCPFrom copies a file from M3 to a local path.
|
||||
func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error {
|
||||
os.MkdirAll(filepath.Dir(localPath), 0755)
|
||||
scpArgs := []string{
|
||||
"-o", "ConnectTimeout=10",
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", cfg.M3SSHKey,
|
||||
fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath),
|
||||
localPath,
|
||||
}
|
||||
result, err := exec.Command("scp", scpArgs...).CombinedOutput()
|
||||
// CopyFrom copies a file from the remote host to a local path via scp.
|
||||
func (t *SSHTransport) CopyFrom(ctx context.Context, remote, local string) error {
|
||||
os.MkdirAll(filepath.Dir(local), 0755)
|
||||
args := t.commonArgs()
|
||||
args = append(args, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote), local)
|
||||
|
||||
c := exec.CommandContext(ctx, "scp", args...)
|
||||
result, err := c.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scp %s: %w: %s", remotePath, err, strings.TrimSpace(string(result)))
|
||||
return fmt.Errorf("scp %s: %w: %s", remote, err, strings.TrimSpace(string(result)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SCPTo copies a local file to M3.
|
||||
func SCPTo(cfg *AgentConfig, localPath, remotePath string) error {
|
||||
scpArgs := []string{
|
||||
"-o", "ConnectTimeout=10",
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", cfg.M3SSHKey,
|
||||
localPath,
|
||||
fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath),
|
||||
}
|
||||
result, err := exec.Command("scp", scpArgs...).CombinedOutput()
|
||||
// CopyTo copies a local file to the remote host via scp.
|
||||
func (t *SSHTransport) CopyTo(ctx context.Context, local, remote string) error {
|
||||
args := t.commonArgs()
|
||||
args = append(args, local, fmt.Sprintf("%s@%s:%s", t.User, t.Host, remote))
|
||||
|
||||
c := exec.CommandContext(ctx, "scp", args...)
|
||||
result, err := c.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scp to %s: %w: %s", remotePath, err, strings.TrimSpace(string(result)))
|
||||
return fmt.Errorf("scp to %s: %w: %s", remote, err, strings.TrimSpace(string(result)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSHCommand executes a command on M3 via SSH.
|
||||
// Deprecated: Use AgentConfig.Transport.Run() instead.
|
||||
func SSHCommand(cfg *AgentConfig, cmd string) (string, error) {
|
||||
return cfg.transport().Run(context.Background(), cmd)
|
||||
}
|
||||
|
||||
// SCPFrom copies a file from M3 to a local path.
|
||||
// Deprecated: Use AgentConfig.Transport.CopyFrom() instead.
|
||||
func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error {
|
||||
return cfg.transport().CopyFrom(context.Background(), remotePath, localPath)
|
||||
}
|
||||
|
||||
// SCPTo copies a local file to M3.
|
||||
// Deprecated: Use AgentConfig.Transport.CopyTo() instead.
|
||||
func SCPTo(cfg *AgentConfig, localPath, remotePath string) error {
|
||||
return cfg.transport().CopyTo(context.Background(), localPath, remotePath)
|
||||
}
|
||||
|
||||
// fileBase returns the last component of a path.
|
||||
func fileBase(path string) string {
|
||||
if i := strings.LastIndexAny(path, "/\\"); i >= 0 {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue