diff --git a/rocm.go b/rocm.go index f63f176..7bdbf66 100644 --- a/rocm.go +++ b/rocm.go @@ -23,3 +23,10 @@ // - ROCm 6.x+ installed // - llama-server binary (from llama.cpp built with -DGGML_HIP=ON) package rocm + +// VRAMInfo reports GPU video memory usage in bytes. +type VRAMInfo struct { + Total uint64 + Used uint64 + Free uint64 +} diff --git a/rocm_stub.go b/rocm_stub.go index d7fd23d..34475e5 100644 --- a/rocm_stub.go +++ b/rocm_stub.go @@ -2,6 +2,13 @@ package rocm +import "fmt" + // ROCmAvailable reports whether ROCm GPU inference is available. // Returns false on non-Linux or non-amd64 platforms. func ROCmAvailable() bool { return false } + +// GetVRAMInfo is not available on non-Linux/non-amd64 platforms. +func GetVRAMInfo() (VRAMInfo, error) { + return VRAMInfo{}, fmt.Errorf("rocm: VRAM monitoring not available on this platform") +} diff --git a/vram.go b/vram.go new file mode 100644 index 0000000..26f1385 --- /dev/null +++ b/vram.go @@ -0,0 +1,61 @@ +//go:build linux && amd64 + +package rocm + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "strings" +) + +// GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs. +// It identifies the dGPU by selecting the card with the largest VRAM total, +// which avoids hardcoding card numbers (e.g. card0=iGPU, card1=dGPU on Ryzen). +func GetVRAMInfo() (VRAMInfo, error) { + cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total") + if err != nil { + return VRAMInfo{}, fmt.Errorf("rocm: glob vram sysfs: %w", err) + } + if len(cards) == 0 { + return VRAMInfo{}, fmt.Errorf("rocm: no GPU VRAM info found in sysfs") + } + + var bestDir string + var bestTotal uint64 + + for _, totalPath := range cards { + total, err := readSysfsUint64(totalPath) + if err != nil { + continue + } + if total > bestTotal { + bestTotal = total + bestDir = filepath.Dir(totalPath) + } + } + + if bestDir == "" { + return VRAMInfo{}, fmt.Errorf("rocm: no readable VRAM sysfs entries") + } + + used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used")) + if err != nil { + return VRAMInfo{}, fmt.Errorf("rocm: read vram used: %w", err) + } + + return VRAMInfo{ + Total: bestTotal, + Used: used, + Free: bestTotal - used, + }, nil +} + +func readSysfsUint64(path string) (uint64, error) { + data, err := os.ReadFile(path) + if err != nil { + return 0, err + } + return strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) +} diff --git a/vram_test.go b/vram_test.go new file mode 100644 index 0000000..012399d --- /dev/null +++ b/vram_test.go @@ -0,0 +1,39 @@ +//go:build linux && amd64 + +package rocm + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadSysfsUint64(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test_value") + require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644)) + + val, err := readSysfsUint64(path) + require.NoError(t, err) + assert.Equal(t, uint64(17163091968), val) +} + +func TestReadSysfsUint64_NotFound(t *testing.T) { + _, err := readSysfsUint64("/nonexistent/path") + assert.Error(t, err) +} + +func TestGetVRAMInfo(t *testing.T) { + info, err := GetVRAMInfo() + if err != nil { + t.Skipf("no VRAM sysfs info available: %v", err) + } + + // On this machine, the dGPU (RX 7800 XT) has ~16GB VRAM. + assert.Greater(t, info.Total, uint64(8*1024*1024*1024), "expected dGPU with >8GB VRAM") + assert.Greater(t, info.Used, uint64(0), "expected some VRAM in use") + assert.Equal(t, info.Total-info.Used, info.Free, "Free should equal Total-Used") +}