diff --git a/modules.go b/modules.go index e4cc26c..59b669d 100644 --- a/modules.go +++ b/modules.go @@ -4,7 +4,10 @@ import ( "bufio" "bytes" "context" + "crypto/sha1" + "crypto/sha256" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "io/fs" @@ -1397,21 +1400,70 @@ func (e *Executor) moduleGetURL(ctx context.Context, client sshExecutorClient, a return nil, coreerr.E("Executor.moduleGetURL", "url and dest required", nil) } - // Use curl or wget - cmd := sprintf("curl -fsSL -o %q %q || wget -q -O %q %q", dest, url, dest, url) + checksumSpec := corexTrimSpace(getStringArg(args, "checksum", "")) + + // Stream to stdout so we can validate checksums before writing the file. + cmd := sprintf("curl -fsSL %q || wget -q -O - %q", url, url) stdout, stderr, rc, err := client.Run(ctx, cmd) if err != nil || rc != 0 { return &TaskResult{Failed: true, Msg: stderr, Stdout: stdout, RC: rc}, nil } - // Set mode if specified (best-effort) - if mode := getStringArg(args, "mode", ""); mode != "" { - _, _, _, _ = client.Run(ctx, sprintf("chmod %s %q", mode, dest)) + content := []byte(stdout) + if checksumSpec != "" { + if err := verifyGetURLChecksum(content, checksumSpec); err != nil { + return &TaskResult{Failed: true, Msg: err.Error(), Stdout: stdout, Stderr: stderr, RC: rc}, nil + } + } + + mode := fs.FileMode(0644) + // Set mode if specified (best-effort). + if modeArg := getStringArg(args, "mode", ""); modeArg != "" { + if parsed, err := strconv.ParseInt(modeArg, 8, 32); err == nil { + mode = fs.FileMode(parsed) + } + } + + if err := client.Upload(ctx, bytes.NewReader(content), dest, mode); err != nil { + return nil, err } return &TaskResult{Changed: true}, nil } +func verifyGetURLChecksum(content []byte, checksumSpec string) error { + parts := strings.SplitN(checksumSpec, ":", 2) + algorithm := "sha256" + expected := checksumSpec + if len(parts) == 2 { + algorithm = lower(corexTrimSpace(parts[0])) + expected = corexTrimSpace(parts[1]) + } + + expected = strings.ToLower(corexTrimSpace(expected)) + if expected == "" { + return coreerr.E("Executor.moduleGetURL", "checksum required", nil) + } + + var actual string + switch algorithm { + case "", "sha256": + sum := sha256.Sum256(content) + actual = hex.EncodeToString(sum[:]) + case "sha1": + sum := sha1.Sum(content) + actual = hex.EncodeToString(sum[:]) + default: + return coreerr.E("Executor.moduleGetURL", "unsupported checksum algorithm: "+algorithm, nil) + } + + if actual != expected { + return coreerr.E("Executor.moduleGetURL", sprintf("checksum mismatch: expected %s but got %s", expected, actual), nil) + } + + return nil +} + // --- Package Modules --- func (e *Executor) moduleApt(ctx context.Context, client sshExecutorClient, args map[string]any) (*TaskResult, error) { diff --git a/modules_file_test.go b/modules_file_test.go index 6977af4..48dcfce 100644 --- a/modules_file_test.go +++ b/modules_file_test.go @@ -2,6 +2,8 @@ package ansible import ( "context" + "crypto/sha256" + "encoding/hex" "io" "io/fs" "regexp" @@ -1600,3 +1602,43 @@ func TestModulesFile_ModuleFile_Good_TemplatedPath(t *testing.T) { assert.True(t, mock.hasExecuted(`mkdir -p "/var/www/html/uploads"`)) assert.True(t, mock.hasExecuted(`chown www-data "/var/www/html/uploads"`)) } + +func TestModulesFile_ModuleGetURL_Good_Checksum(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + payload := "downloaded artifact" + mock.expectCommand(`curl.*https://downloads\.example\.com/app\.tgz`, payload, "", 0) + + sum := sha256.Sum256([]byte(payload)) + result, err := e.moduleGetURL(context.Background(), mock, map[string]any{ + "url": "https://downloads.example.com/app.tgz", + "dest": "/tmp/app.tgz", + "checksum": "sha256:" + hex.EncodeToString(sum[:]), + "mode": "0600", + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, 1, mock.uploadCount()) + + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, "/tmp/app.tgz", up.Remote) + assert.Equal(t, []byte(payload), up.Content) + assert.Equal(t, fs.FileMode(0600), up.Mode) +} + +func TestModulesFile_ModuleGetURL_Bad_ChecksumMismatch(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + mock.expectCommand(`curl.*https://downloads\.example\.com/app\.tgz`, "downloaded artifact", "", 0) + + result, err := e.moduleGetURL(context.Background(), mock, map[string]any{ + "url": "https://downloads.example.com/app.tgz", + "dest": "/tmp/app.tgz", + "checksum": "sha256:deadbeef", + }) + + require.NoError(t, err) + assert.True(t, result.Failed) + assert.Contains(t, result.Msg, "checksum mismatch") + assert.Equal(t, 0, mock.uploadCount()) +}