From 39659520a8bf7b2d7f3d2eec9776f607267c59c2 Mon Sep 17 00:00:00 2001 From: Snider <631881+Snider@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:23:29 +0000 Subject: [PATCH] Remove StrictHostKeyChecking=no and implement proper host key verification This commit addresses security concerns from the OWASP audit by enforcing strict host key verification for all SSH and SCP commands. Key changes: - Replaced StrictHostKeyChecking=accept-new with yes in pkg/container and pkg/devops. - Removed insecure host key verification from pkg/ansible SSH client. - Implemented a synchronous host key discovery mechanism during VM boot using ssh-keyscan to populate ~/.core/known_hosts. - Updated the devops Boot lifecycle to wait until the host key is verified. - Ensured pkg/ansible correctly handles missing known_hosts files. - Refactored hardcoded SSH port 2222 to a package constant DefaultSSHPort. - Added CORE_SKIP_SSH_SCAN environment variable for test environments. --- pkg/ansible/ssh.go | 37 +++++++++++---------- pkg/container/linuxkit.go | 2 +- pkg/devops/claude.go | 8 ++--- pkg/devops/devops.go | 31 ++++++++++++++++-- pkg/devops/devops_test.go | 3 ++ pkg/devops/serve.go | 4 +-- pkg/devops/shell.go | 4 +-- pkg/devops/ssh_utils.go | 68 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 128 insertions(+), 29 deletions(-) create mode 100644 pkg/devops/ssh_utils.go diff --git a/pkg/ansible/ssh.go b/pkg/ansible/ssh.go index e41be7a..2887d6d 100644 --- a/pkg/ansible/ssh.go +++ b/pkg/ansible/ssh.go @@ -30,7 +30,6 @@ type SSHClient struct { becomeUser string becomePass string timeout time.Duration - insecure bool } // SSHConfig holds SSH connection configuration. @@ -44,7 +43,6 @@ type SSHConfig struct { BecomeUser string BecomePass string Timeout time.Duration - Insecure bool } // NewSSHClient creates a new SSH client. @@ -69,7 +67,6 @@ func NewSSHClient(cfg SSHConfig) (*SSHClient, error) { becomeUser: cfg.BecomeUser, becomePass: cfg.BecomePass, timeout: cfg.Timeout, - insecure: cfg.Insecure, } return client, nil @@ -137,21 +134,27 @@ func (c *SSHClient) Connect(ctx context.Context) error { // Host key verification var hostKeyCallback ssh.HostKeyCallback - if c.insecure { - hostKeyCallback = ssh.InsecureIgnoreHostKey() - } else { - home, err := os.UserHomeDir() - if err != nil { - return log.E("ssh.Connect", "failed to get user home dir", err) - } - knownHostsPath := filepath.Join(home, ".ssh", "known_hosts") - - cb, err := knownhosts.New(knownHostsPath) - if err != nil { - return log.E("ssh.Connect", "failed to load known_hosts (use Insecure=true to bypass)", err) - } - hostKeyCallback = cb + home, err := os.UserHomeDir() + if err != nil { + return log.E("ssh.Connect", "failed to get user home dir", err) } + knownHostsPath := filepath.Join(home, ".ssh", "known_hosts") + + // Ensure known_hosts file exists + if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) { + if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil { + return log.E("ssh.Connect", "failed to create .ssh dir", err) + } + if err := os.WriteFile(knownHostsPath, nil, 0600); err != nil { + return log.E("ssh.Connect", "failed to create known_hosts file", err) + } + } + + cb, err := knownhosts.New(knownHostsPath) + if err != nil { + return log.E("ssh.Connect", "failed to load known_hosts", err) + } + hostKeyCallback = cb config := &ssh.ClientConfig{ User: c.user, diff --git a/pkg/container/linuxkit.go b/pkg/container/linuxkit.go index d3bba48..1906edb 100644 --- a/pkg/container/linuxkit.go +++ b/pkg/container/linuxkit.go @@ -436,7 +436,7 @@ func (m *LinuxKitManager) Exec(ctx context.Context, id string, cmd []string) err // Build SSH command sshArgs := []string{ "-p", fmt.Sprintf("%d", sshPort), - "-o", "StrictHostKeyChecking=accept-new", + "-o", "StrictHostKeyChecking=yes", "-o", "UserKnownHostsFile=~/.core/known_hosts", "-o", "LogLevel=ERROR", "root@localhost", diff --git a/pkg/devops/claude.go b/pkg/devops/claude.go index d62b39d..7bfef0b 100644 --- a/pkg/devops/claude.go +++ b/pkg/devops/claude.go @@ -70,11 +70,11 @@ func (d *DevOps) Claude(ctx context.Context, projectDir string, opts ClaudeOptio // Build SSH command with agent forwarding args := []string{ - "-o", "StrictHostKeyChecking=accept-new", + "-o", "StrictHostKeyChecking=yes", "-o", "UserKnownHostsFile=~/.core/known_hosts", "-o", "LogLevel=ERROR", "-A", // SSH agent forwarding - "-p", "2222", + "-p", fmt.Sprintf("%d", DefaultSSHPort), } args = append(args, "root@localhost") @@ -132,10 +132,10 @@ func (d *DevOps) CopyGHAuth(ctx context.Context) error { // Use scp to copy gh config cmd := exec.CommandContext(ctx, "scp", - "-o", "StrictHostKeyChecking=accept-new", + "-o", "StrictHostKeyChecking=yes", "-o", "UserKnownHostsFile=~/.core/known_hosts", "-o", "LogLevel=ERROR", - "-P", "2222", + "-P", fmt.Sprintf("%d", DefaultSSHPort), "-r", ghConfigDir, "root@localhost:/root/.config/", ) diff --git a/pkg/devops/devops.go b/pkg/devops/devops.go index 2cad57c..d3d6331 100644 --- a/pkg/devops/devops.go +++ b/pkg/devops/devops.go @@ -13,6 +13,11 @@ import ( "github.com/host-uk/core/pkg/io" ) +const ( + // DefaultSSHPort is the default port for SSH connections to the dev environment. + DefaultSSHPort = 2222 +) + // DevOps manages the portable development environment. type DevOps struct { medium io.Medium @@ -137,12 +142,32 @@ func (d *DevOps) Boot(ctx context.Context, opts BootOptions) error { Name: opts.Name, Memory: opts.Memory, CPUs: opts.CPUs, - SSHPort: 2222, + SSHPort: DefaultSSHPort, Detach: true, } _, err = d.container.Run(ctx, imagePath, runOpts) - return err + if err != nil { + return err + } + + // Wait for SSH to be ready and scan host key + // We try for up to 60 seconds as the VM takes a moment to boot + var lastErr error + for i := 0; i < 30; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(2 * time.Second): + if err := ensureHostKey(ctx, runOpts.SSHPort); err == nil { + return nil + } else { + lastErr = err + } + } + } + + return fmt.Errorf("failed to verify host key after boot: %w", lastErr) } // Stop stops the dev environment. @@ -196,7 +221,7 @@ type DevStatus struct { func (d *DevOps) Status(ctx context.Context) (*DevStatus, error) { status := &DevStatus{ Installed: d.images.IsInstalled(), - SSHPort: 2222, + SSHPort: DefaultSSHPort, } if info, ok := d.images.manifest.Images[ImageName()]; ok { diff --git a/pkg/devops/devops_test.go b/pkg/devops/devops_test.go index 2aef52f..fc1789b 100644 --- a/pkg/devops/devops_test.go +++ b/pkg/devops/devops_test.go @@ -616,6 +616,7 @@ func TestDevOps_IsRunning_Bad_DifferentContainerName(t *testing.T) { } func TestDevOps_Boot_Good_FreshFlag(t *testing.T) { + t.Setenv("CORE_SKIP_SSH_SCAN", "true") tempDir, err := os.MkdirTemp("", "devops-test-*") require.NoError(t, err) t.Cleanup(func() { _ = os.RemoveAll(tempDir) }) @@ -700,6 +701,7 @@ func TestDevOps_Stop_Bad_ContainerNotRunning(t *testing.T) { } func TestDevOps_Boot_Good_FreshWithNoExisting(t *testing.T) { + t.Setenv("CORE_SKIP_SSH_SCAN", "true") tempDir, err := os.MkdirTemp("", "devops-boot-fresh-*") require.NoError(t, err) t.Cleanup(func() { _ = os.RemoveAll(tempDir) }) @@ -782,6 +784,7 @@ func TestDevOps_CheckUpdate_Delegates(t *testing.T) { } func TestDevOps_Boot_Good_Success(t *testing.T) { + t.Setenv("CORE_SKIP_SSH_SCAN", "true") tempDir, err := os.MkdirTemp("", "devops-boot-success-*") require.NoError(t, err) t.Cleanup(func() { _ = os.RemoveAll(tempDir) }) diff --git a/pkg/devops/serve.go b/pkg/devops/serve.go index 1e0dc80..aac0e8a 100644 --- a/pkg/devops/serve.go +++ b/pkg/devops/serve.go @@ -59,11 +59,11 @@ func (d *DevOps) mountProject(ctx context.Context, path string) error { // Use reverse SSHFS mount // The VM connects back to host to mount the directory cmd := exec.CommandContext(ctx, "ssh", - "-o", "StrictHostKeyChecking=accept-new", + "-o", "StrictHostKeyChecking=yes", "-o", "UserKnownHostsFile=~/.core/known_hosts", "-o", "LogLevel=ERROR", "-R", "10000:localhost:22", // Reverse tunnel for SSHFS - "-p", "2222", + "-p", fmt.Sprintf("%d", DefaultSSHPort), "root@localhost", fmt.Sprintf("mkdir -p /app && sshfs -p 10000 %s@localhost:%s /app -o allow_other", os.Getenv("USER"), absPath), ) diff --git a/pkg/devops/shell.go b/pkg/devops/shell.go index 8b524fa..fe94d1b 100644 --- a/pkg/devops/shell.go +++ b/pkg/devops/shell.go @@ -33,11 +33,11 @@ func (d *DevOps) Shell(ctx context.Context, opts ShellOptions) error { // sshShell connects via SSH. func (d *DevOps) sshShell(ctx context.Context, command []string) error { args := []string{ - "-o", "StrictHostKeyChecking=accept-new", + "-o", "StrictHostKeyChecking=yes", "-o", "UserKnownHostsFile=~/.core/known_hosts", "-o", "LogLevel=ERROR", "-A", // Agent forwarding - "-p", "2222", + "-p", fmt.Sprintf("%d", DefaultSSHPort), "root@localhost", } diff --git a/pkg/devops/ssh_utils.go b/pkg/devops/ssh_utils.go new file mode 100644 index 0000000..d05902b --- /dev/null +++ b/pkg/devops/ssh_utils.go @@ -0,0 +1,68 @@ +package devops + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// ensureHostKey ensures that the host key for the dev environment is in the known hosts file. +// This is used after boot to allow StrictHostKeyChecking=yes to work. +func ensureHostKey(ctx context.Context, port int) error { + // Skip if requested (used in tests) + if os.Getenv("CORE_SKIP_SSH_SCAN") == "true" { + return nil + } + + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("get home dir: %w", err) + } + + knownHostsPath := filepath.Join(home, ".core", "known_hosts") + + // Ensure directory exists + if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0755); err != nil { + return fmt.Errorf("create known_hosts dir: %w", err) + } + + // Get host key using ssh-keyscan + cmd := exec.CommandContext(ctx, "ssh-keyscan", "-p", fmt.Sprintf("%d", port), "localhost") + out, err := cmd.Output() + if err != nil { + return fmt.Errorf("ssh-keyscan failed: %w", err) + } + + if len(out) == 0 { + return fmt.Errorf("ssh-keyscan returned no keys") + } + + // Read existing known_hosts to avoid duplicates + existing, _ := os.ReadFile(knownHostsPath) + existingStr := string(existing) + + // Append new keys that aren't already there + f, err := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("open known_hosts: %w", err) + } + defer f.Close() + + lines := strings.Split(string(out), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + if !strings.Contains(existingStr, line) { + if _, err := f.WriteString(line + "\n"); err != nil { + return fmt.Errorf("write known_hosts: %w", err) + } + } + } + + return nil +}