diff --git a/mock_ssh_test.go b/mock_ssh_test.go index 5c358d4..a3dde88 100644 --- a/mock_ssh_test.go +++ b/mock_ssh_test.go @@ -8,6 +8,7 @@ import ( "path" "regexp" "strconv" + "strings" "sync" "time" @@ -521,6 +522,7 @@ func executeModuleWithMock(e *Executor, mock *MockSSHClient, host string, task * type sshRunner interface { Run(ctx context.Context, cmd string) (string, string, int, error) RunScript(ctx context.Context, script string) (string, string, int, error) + Upload(ctx context.Context, local io.Reader, remote string, mode fs.FileMode) error } type sshFileTransferRunner interface { @@ -1618,6 +1620,7 @@ func moduleCronWithClient(_ *Executor, client sshRunner, args map[string]any) (* user := getStringArg(args, "user", "root") disabled := getBoolArg(args, "disabled", false) specialTime := getStringArg(args, "special_time", "") + backup := getBoolArg(args, "backup", false) minute := getStringArg(args, "minute", "*") hour := getStringArg(args, "hour", "*") @@ -1625,6 +1628,28 @@ func moduleCronWithClient(_ *Executor, client sshRunner, args map[string]any) (* month := getStringArg(args, "month", "*") weekday := getStringArg(args, "weekday", "*") + var backupPath string + if backup { + stdout, _, rc, err := client.Run(context.Background(), sprintf("crontab -u %s -l 2>/dev/null", user)) + if err != nil { + return nil, err + } + if rc == 0 && strings.TrimSpace(stdout) != "" { + backupName := user + if backupName == "" { + backupName = "root" + } + if name != "" { + backupName += "-" + name + } + backupName = sanitizeBackupToken(backupName) + backupPath = path.Join("/tmp", sprintf("ansible-cron-%s.%s.bak", backupName, time.Now().UTC().Format("20060102T150405Z"))) + if err := client.Upload(context.Background(), bytes.NewReader([]byte(stdout)), backupPath, 0600); err != nil { + return nil, err + } + } + } + if state == "absent" { if name != "" { // Remove by name (comment marker) @@ -1632,7 +1657,11 @@ func moduleCronWithClient(_ *Executor, client sshRunner, args map[string]any) (* user, name, job, user) _, _, _, _ = client.Run(context.Background(), cmd) } - return &TaskResult{Changed: true}, nil + result := &TaskResult{Changed: true} + if backupPath != "" { + result.Data = map[string]any{"backup_file": backupPath} + } + return result, nil } // Build cron entry @@ -1653,7 +1682,11 @@ func moduleCronWithClient(_ *Executor, client sshRunner, args map[string]any) (* return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil } - return &TaskResult{Changed: true}, nil + result := &TaskResult{Changed: true} + if backupPath != "" { + result.Data = map[string]any{"backup_file": backupPath} + } + return result, nil } // --- Authorized key module shim --- diff --git a/modules.go b/modules.go index 677db61..212e1d4 100644 --- a/modules.go +++ b/modules.go @@ -303,6 +303,63 @@ func backupRemoteFile(ctx context.Context, client sshExecutorClient, path string return backupPath, true, nil } +func backupCronTab(ctx context.Context, client sshExecutorClient, user, name string) (string, error) { + stdout, _, rc, err := client.Run(ctx, sprintf("crontab -u %s -l 2>/dev/null", user)) + if err != nil { + return "", coreerr.E("Executor.moduleCron", "backup crontab", err) + } + if rc != 0 || strings.TrimSpace(stdout) == "" { + return "", nil + } + + backupName := user + if backupName == "" { + backupName = "root" + } + if name != "" { + backupName += "-" + name + } + backupName = sanitizeBackupToken(backupName) + + backupPath := path.Join("/tmp", sprintf("ansible-cron-%s.%s.bak", backupName, time.Now().UTC().Format("20060102T150405Z"))) + if err := client.Upload(ctx, bytes.NewReader([]byte(stdout)), backupPath, 0600); err != nil { + return "", coreerr.E("Executor.moduleCron", "backup crontab", err) + } + + return backupPath, nil +} + +func sanitizeBackupToken(value string) string { + if value == "" { + return "default" + } + + var b strings.Builder + b.Grow(len(value)) + lastDash := false + for _, r := range value { + switch { + case r >= 'a' && r <= 'z', + r >= 'A' && r <= 'Z', + r >= '0' && r <= '9', + r == '.', r == '_', r == '-': + b.WriteRune(r) + lastDash = false + default: + if !lastDash { + b.WriteByte('-') + lastDash = true + } + } + } + + token := strings.Trim(b.String(), "-") + if token == "" { + return "default" + } + return token +} + // templateArgs templates all string values in args. func (e *Executor) templateArgs(args map[string]any, host string, task *Task) map[string]any { // Set inventory_hostname for templating @@ -2843,6 +2900,7 @@ func (e *Executor) moduleCron(ctx context.Context, client sshExecutorClient, arg user := getStringArg(args, "user", "root") disabled := getBoolArg(args, "disabled", false) specialTime := getStringArg(args, "special_time", "") + backup := getBoolArg(args, "backup", false) minute := getStringArg(args, "minute", "*") hour := getStringArg(args, "hour", "*") @@ -2850,6 +2908,15 @@ func (e *Executor) moduleCron(ctx context.Context, client sshExecutorClient, arg month := getStringArg(args, "month", "*") weekday := getStringArg(args, "weekday", "*") + var backupPath string + if backup { + var err error + backupPath, err = backupCronTab(ctx, client, user, name) + if err != nil { + return nil, err + } + } + if state == "absent" { if name != "" { // Remove by name (comment marker) @@ -2857,7 +2924,11 @@ func (e *Executor) moduleCron(ctx context.Context, client sshExecutorClient, arg user, name, job, user) _, _, _, _ = client.Run(ctx, cmd) } - return &TaskResult{Changed: true}, nil + result := &TaskResult{Changed: true} + if backupPath != "" { + result.Data = map[string]any{"backup_file": backupPath} + } + return result, nil } // Build cron entry @@ -2878,7 +2949,11 @@ func (e *Executor) moduleCron(ctx context.Context, client sshExecutorClient, arg return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil } - return &TaskResult{Changed: true}, nil + result := &TaskResult{Changed: true} + if backupPath != "" { + result.Data = map[string]any{"backup_file": backupPath} + } + return result, nil } func (e *Executor) moduleBlockinfile(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { diff --git a/modules_adv_test.go b/modules_adv_test.go index fb65847..b70350c 100644 --- a/modules_adv_test.go +++ b/modules_adv_test.go @@ -504,6 +504,35 @@ func TestModulesAdv_ModuleCron_Good_SpecialTime(t *testing.T) { assert.True(t, mock.containsSubstring(`@daily /usr/local/bin/backup.sh # daily-backup`)) } +func TestModulesAdv_ModuleCron_Good_BackupCreatesBackupFile(t *testing.T) { + e := NewExecutor("/tmp") + mock := NewMockSSHClient() + mock.expectCommand(`crontab -u root`, "", "", 0) + mock.expectCommand(`crontab -u root -l`, "0 0 * * * /usr/local/bin/backup.sh # daily-backup\n", "", 0) + + result, err := e.moduleCron(context.Background(), mock, map[string]any{ + "name": "daily-backup", + "job": "/usr/local/bin/backup.sh", + "minute": "0", + "hour": "1", + "backup": true, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + require.NotNil(t, result.Data) + backupPath, ok := result.Data["backup_file"].(string) + require.True(t, ok) + assert.Contains(t, backupPath, "/tmp/ansible-cron-root-daily-backup.") + assert.Equal(t, 1, mock.uploadCount()) + lastUpload := mock.lastUpload() + require.NotNil(t, lastUpload) + assert.Equal(t, backupPath, lastUpload.Remote) + assert.Equal(t, []byte("0 0 * * * /usr/local/bin/backup.sh # daily-backup\n"), lastUpload.Content) + assert.True(t, mock.containsSubstring("crontab -u root -l")) + assert.True(t, mock.containsSubstring("crontab -u root")) +} + func TestModulesAdv_ModuleCron_Good_DisabledJobCommentsEntry(t *testing.T) { e, mock := newTestExecutorWithMock("host1") mock.expectCommand(`crontab -u root`, "", "", 0)