diff --git a/modules.go b/modules.go index b2f7b03..fc56387 100644 --- a/modules.go +++ b/modules.go @@ -3663,6 +3663,8 @@ func (e *Executor) moduleAuthorizedKey(ctx context.Context, client sshExecutorCl exclusive := getBoolArg(args, "exclusive", false) manageDir := getBoolArg(args, "manage_dir", true) pathArg := getStringArg(args, "path", "") + keyOptions := getStringArg(args, "key_options", "") + comment := getStringArg(args, "comment", "") if user == "" || key == "" { return nil, coreerr.E("Executor.moduleAuthorizedKey", "user and key required", nil) @@ -3693,19 +3695,36 @@ func (e *Executor) moduleAuthorizedKey(ctx context.Context, client sshExecutorCl authKeysPath = joinPath(home, ".ssh", "authorized_keys") } + line := authorizedKeyLine(key, keyOptions, comment) + base := authorizedKeyBase(line) + if state == "absent" { - if content, ok := remoteFileText(ctx, client, authKeysPath); !ok || !fileContainsExactLine(content, key) { + content, ok := remoteFileText(ctx, client, authKeysPath) + if !ok || !authorizedKeyContainsBase(content, base) { return &TaskResult{Changed: false}, nil } - // Remove the exact key line when present. - cmd := sprintf("if [ -f %q ]; then sed -i '\\|^%s$|d' %q; fi", - authKeysPath, sedExactLinePattern(key), authKeysPath) - _, _, _, _ = client.Run(ctx, cmd) + + updated, changed := rewriteAuthorizedKeyContent(content, base, "") + if !changed { + return &TaskResult{Changed: false}, nil + } + if err := client.Upload(ctx, newReader(updated), authKeysPath, 0600); err != nil { + return nil, coreerr.E("Executor.moduleAuthorizedKey", "upload authorised keys", err) + } return &TaskResult{Changed: true}, nil } - if content, ok := remoteFileText(ctx, client, authKeysPath); ok && fileContainsExactLine(content, key) { - return &TaskResult{Changed: false, Msg: sprintf("already up to date: %s", authKeysPath)}, nil + if content, ok := remoteFileText(ctx, client, authKeysPath); ok { + updated, changed := rewriteAuthorizedKeyContent(content, base, line) + if !changed { + return &TaskResult{Changed: false, Msg: sprintf("already up to date: %s", authKeysPath)}, nil + } + if err := client.Upload(ctx, newReader(updated), authKeysPath, 0600); err != nil { + return nil, coreerr.E("Executor.moduleAuthorizedKey", "upload authorised keys", err) + } + _, _, _, _ = client.Run(ctx, sprintf("chmod 600 %q && chown %s:%s %q", + authKeysPath, user, user, authKeysPath)) + return &TaskResult{Changed: true}, nil } if manageDir { @@ -3715,37 +3734,149 @@ func (e *Executor) moduleAuthorizedKey(ctx context.Context, client sshExecutorCl } if exclusive { - cmd := sprintf("printf '%%s\\n' %q > %q", key, authKeysPath) - stdout, stderr, rc, err := client.Run(ctx, cmd) - if err != nil || rc != 0 { - return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + if err := client.Upload(ctx, newReader(line+"\n"), authKeysPath, 0600); err != nil { + return nil, coreerr.E("Executor.moduleAuthorizedKey", "upload authorised keys", err) } - _, _, _, _ = client.Run(ctx, sprintf("chmod 600 %q && chown %s:%s %q", authKeysPath, user, user, authKeysPath)) return &TaskResult{Changed: true}, nil } - // Add the key if it is not already present. - cmd := sprintf("grep -qF %q %q 2>/dev/null || echo %q >> %q", - key, authKeysPath, key, authKeysPath) - stdout, stderr, rc, err := client.Run(ctx, cmd) - if err != nil || rc != 0 { - return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil + var updated string + if content, ok := remoteFileText(ctx, client, authKeysPath); ok { + updated, _ = rewriteAuthorizedKeyContent(content, base, line) + } else { + updated = line + "\n" + } + if err := client.Upload(ctx, newReader(updated), authKeysPath, 0600); err != nil { + return nil, coreerr.E("Executor.moduleAuthorizedKey", "upload authorised keys", err) } - - // Fix permissions (best-effort) _, _, _, _ = client.Run(ctx, sprintf("chmod 600 %q && chown %s:%s %q", authKeysPath, user, user, authKeysPath)) return &TaskResult{Changed: true}, nil } +func authorizedKeyLine(key, keyOptions, comment string) string { + key = corexTrimSpace(key) + keyOptions = corexTrimSpace(keyOptions) + comment = corexTrimSpace(comment) + + if keyOptions == "" && comment == "" { + return key + } + + base := authorizedKeyBase(key) + if base == "" { + base = key + } + + parts := make([]string, 0, 3) + if keyOptions != "" { + parts = append(parts, keyOptions) + } + if base != "" { + parts = append(parts, base) + } + if comment != "" { + parts = append(parts, comment) + } + return join(" ", parts) +} + +func authorizedKeyBase(line string) string { + line = corexTrimSpace(line) + if line == "" { + return "" + } + + fields := strings.Fields(line) + for i, field := range fields { + if isAuthorizedKeyType(field) { + if i+1 >= len(fields) { + return field + } + return field + " " + fields[i+1] + } + } + + return line +} + +func isAuthorizedKeyType(value string) bool { + return strings.HasPrefix(value, "ssh-") || + strings.HasPrefix(value, "ecdsa-") || + strings.HasPrefix(value, "sk-") +} + +func authorizedKeyContainsBase(content, base string) bool { + if content == "" || base == "" { + return false + } + + for _, line := range strings.Split(content, "\n") { + if authorizedKeyBase(line) == base { + return true + } + } + + return false +} + func sedExactLinePattern(value string) string { pattern := regexp.QuoteMeta(value) return replaceAll(pattern, "|", "\\|") } +func rewriteAuthorizedKeyContent(content, base, line string) (string, bool) { + if base == "" { + base = authorizedKeyBase(line) + } + + lines := strings.Split(content, "\n") + matches := 0 + exactMatches := 0 + for _, current := range lines { + if current == "" { + continue + } + if authorizedKeyBase(current) != base { + continue + } + matches++ + if current == line { + exactMatches++ + } + } + + if line != "" && matches == 1 && exactMatches == 1 { + return content, false + } + if line == "" && matches == 0 { + return content, false + } + + kept := make([]string, 0, len(lines)+1) + for _, current := range lines { + if current == "" { + continue + } + if authorizedKeyBase(current) == base { + continue + } + kept = append(kept, current) + } + + if line != "" { + kept = append(kept, line) + } + if len(kept) == 0 { + return "", true + } + + return join("\n", kept) + "\n", true +} + func (e *Executor) moduleDockerCompose(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { projectSrc := getStringArg(args, "project_src", "") state := getStringArg(args, "state", "present") diff --git a/modules_adv_test.go b/modules_adv_test.go index bb69ebb..fb65847 100644 --- a/modules_adv_test.go +++ b/modules_adv_test.go @@ -614,6 +614,31 @@ func TestModulesAdv_ModuleAuthorizedKey_Good_KeyAlreadyExists(t *testing.T) { assert.Contains(t, result.Msg, "already up to date") } +func TestModulesAdv_ModuleAuthorizedKey_Good_RewritesKeyOptionsAndComment(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + testKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA user@host" + authPath := "/home/deploy/.ssh/authorized_keys" + mock.addFile(authPath, []byte(testKey+"\n")) + mock.expectCommand(`getent passwd deploy`, "/home/deploy", "", 0) + + result, err := e.moduleAuthorizedKey(context.Background(), mock, map[string]any{ + "user": "deploy", + "key": testKey, + "key_options": "command=\"/usr/local/bin/backup-only\"", + "comment": "backup access", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.False(t, result.Failed) + assert.True(t, mock.hasExecuted(`chmod 600`)) + + content, err := mock.Download(context.Background(), authPath) + require.NoError(t, err) + assert.Contains(t, string(content), `command="/usr/local/bin/backup-only" ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA backup access`) + assert.NotContains(t, string(content), testKey) +} + func TestModulesAdv_ModuleAuthorizedKey_Good_ExclusiveRewritesFile(t *testing.T) { e := NewExecutor("/tmp") mock := NewMockSSHClient() @@ -621,7 +646,6 @@ func TestModulesAdv_ModuleAuthorizedKey_Good_ExclusiveRewritesFile(t *testing.T) mock.expectCommand(`getent passwd deploy`, "/home/deploy", "", 0) mock.expectCommand(`mkdir -p`, "", "", 0) - mock.expectCommand(`printf '%s\\n'`, "", "", 0) mock.expectCommand(`chmod 600`, "", "", 0) result, err := e.moduleAuthorizedKey(context.Background(), mock, map[string]any{ @@ -633,7 +657,9 @@ func TestModulesAdv_ModuleAuthorizedKey_Good_ExclusiveRewritesFile(t *testing.T) require.NoError(t, err) assert.True(t, result.Changed) assert.False(t, result.Failed) - assert.True(t, mock.hasExecuted(`printf '%s\\n'`)) + content, err := mock.Download(context.Background(), "/home/deploy/.ssh/authorized_keys") + require.NoError(t, err) + assert.Equal(t, testKey+"\n", string(content)) assert.False(t, mock.hasExecuted(`grep -qF`)) }