diff --git a/modules.go b/modules.go index 31acfb6..f36b489 100644 --- a/modules.go +++ b/modules.go @@ -1412,7 +1412,11 @@ func (e *Executor) moduleGetURL(ctx context.Context, client sshExecutorClient, a content := []byte(stdout) if checksumSpec != "" { - if err := verifyGetURLChecksum(content, checksumSpec); err != nil { + checksumValue, err := resolveGetURLChecksumValue(ctx, client, checksumSpec, dest) + if err != nil { + return &TaskResult{Failed: true, Msg: err.Error(), Stdout: stdout, Stderr: stderr, RC: rc}, nil + } + if err := verifyGetURLChecksum(content, checksumValue); err != nil { return &TaskResult{Failed: true, Msg: err.Error(), Stdout: stdout, Stderr: stderr, RC: rc}, nil } } @@ -1432,10 +1436,116 @@ func (e *Executor) moduleGetURL(ctx context.Context, client sshExecutorClient, a return &TaskResult{Changed: true}, nil } -func verifyGetURLChecksum(content []byte, checksumSpec string) error { - parts := strings.SplitN(checksumSpec, ":", 2) +func resolveGetURLChecksumValue(ctx context.Context, client sshExecutorClient, checksumSpec, dest string) (string, error) { algorithm := "sha256" expected := checksumSpec + if idx := strings.Index(checksumSpec, ":"); idx > 0 { + candidateAlgorithm := lower(corexTrimSpace(checksumSpec[:idx])) + if isChecksumAlgorithm(candidateAlgorithm) { + algorithm = candidateAlgorithm + expected = corexTrimSpace(checksumSpec[idx+1:]) + } + } + + expected = corexTrimSpace(expected) + if expected == "" { + return "", coreerr.E("Executor.moduleGetURL", "checksum required", nil) + } + + if strings.Contains(expected, "://") { + cmd := sprintf("curl -fsSL %q || wget -q -O - %q", expected, expected) + stdout, stderr, rc, err := client.Run(ctx, cmd) + if err != nil || rc != 0 { + msg := stderr + if msg == "" && err != nil { + msg = err.Error() + } + return "", coreerr.E("Executor.moduleGetURL", "download checksum file: "+msg, err) + } + expected, err = parseGetURLChecksumFile(stdout, dest, algorithm) + if err != nil { + return "", err + } + } + + return sprintf("%s:%s", algorithm, strings.ToLower(expected)), nil +} + +func isChecksumAlgorithm(value string) bool { + switch value { + case "", "sha1", "sha224", "sha256", "sha384", "sha512": + return true + default: + return false + } +} + +func parseGetURLChecksumFile(content, dest, algorithm string) (string, error) { + lines := strings.Split(content, "\n") + base := path.Base(dest) + + for _, line := range lines { + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + + candidate := strings.ToLower(fields[0]) + if !isHexDigest(candidate) { + continue + } + + if len(fields) == 1 { + return candidate, nil + } + + for _, field := range fields[1:] { + cleaned := strings.TrimPrefix(field, "*") + cleaned = path.Base(strings.TrimSpace(cleaned)) + if cleaned == base || cleaned == path.Base(dest) { + return candidate, nil + } + } + } + + for _, line := range lines { + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + candidate := strings.ToLower(fields[0]) + if isHexDigest(candidate) { + return candidate, nil + } + } + + return "", coreerr.E("Executor.moduleGetURL", sprintf("could not parse checksum file for %s", algorithm), nil) +} + +func isHexDigest(value string) bool { + if value == "" { + return false + } + for _, r := range value { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'f': + default: + return false + } + } + return true +} + +func verifyGetURLChecksum(content []byte, checksumValue string) error { + checksumValue = strings.ToLower(corexTrimSpace(checksumValue)) + if checksumValue == "" { + return coreerr.E("Executor.moduleGetURL", "checksum required", nil) + } + + parts := strings.SplitN(checksumValue, ":", 2) + algorithm := "sha256" + expected := checksumValue if len(parts) == 2 { algorithm = lower(corexTrimSpace(parts[0])) expected = corexTrimSpace(parts[1]) diff --git a/modules_file_test.go b/modules_file_test.go index 93bacfc..741e652 100644 --- a/modules_file_test.go +++ b/modules_file_test.go @@ -1650,6 +1650,31 @@ func TestModulesFile_ModuleGetURL_Good_Sha512Checksum(t *testing.T) { assert.Equal(t, []byte(payload), up.Content) } +func TestModulesFile_ModuleGetURL_Good_ChecksumFileURL(t *testing.T) { + e, mock := newTestExecutorWithMock("host1") + payload := "downloaded artifact" + sum := sha256.Sum256([]byte(payload)) + checksumURL := "https://downloads.example.com/app.tgz.sha256" + + mock.expectCommand(`curl.*https://downloads\.example\.com/app\.tgz\.sha256(?:["\s]|$)`, hex.EncodeToString(sum[:])+" app.tgz\n", "", 0) + mock.expectCommand(`curl.*https://downloads\.example\.com/app\.tgz(?:["\s]|$)`, payload, "", 0) + + result, err := e.moduleGetURL(context.Background(), mock, map[string]any{ + "url": "https://downloads.example.com/app.tgz", + "dest": "/tmp/app.tgz", + "checksum": "sha256:" + checksumURL, + }) + + require.NoError(t, err) + assert.True(t, result.Changed) + assert.Equal(t, 1, mock.uploadCount()) + + up := mock.lastUpload() + require.NotNil(t, up) + assert.Equal(t, "/tmp/app.tgz", up.Remote) + assert.Equal(t, []byte(payload), up.Content) +} + func TestModulesFile_ModuleGetURL_Bad_ChecksumMismatch(t *testing.T) { e, mock := newTestExecutorWithMock("host1") mock.expectCommand(`curl.*https://downloads\.example\.com/app\.tgz`, "downloaded artifact", "", 0)