diff --git a/backend.go b/backend.go index 088b30e..bfeb35d 100644 --- a/backend.go +++ b/backend.go @@ -2,27 +2,66 @@ package rocm -import "forge.lthn.ai/core/go-inference" +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "forge.lthn.ai/core/go-inference" +) // rocmBackend implements inference.Backend for AMD ROCm GPUs. -// Uses llama-server (llama.cpp built with HIP) as the inference engine. type rocmBackend struct{} func (b *rocmBackend) Name() string { return "rocm" } +// Available reports whether ROCm GPU inference can run on this machine. +// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary. func (b *rocmBackend) Available() bool { - // TODO: Check for ROCm runtime + GPU presence - // - /dev/kfd exists (ROCm kernel driver) - // - rocm-smi detects a GPU - // - llama-server binary is findable - return false // Stub until Phase 1 implementation + if _, err := os.Stat("/dev/kfd"); err != nil { + return false + } + if _, err := findLlamaServer(); err != nil { + return false + } + return true } +// LoadModel loads a GGUF model onto the AMD GPU via llama-server. func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) { - // TODO: Phase 1 implementation - // 1. Find llama-server binary (PATH or configured location) - // 2. Spawn llama-server with --model path --port --n-gpu-layers cfg.GPULayers - // 3. Wait for health endpoint to respond - // 4. Return rocmModel wrapping the HTTP client - return nil, nil + cfg := inference.ApplyLoadOpts(opts) + + binary, err := findLlamaServer() + if err != nil { + return nil, err + } + + port, err := freePort() + if err != nil { + return nil, fmt.Errorf("rocm: find free port: %w", err) + } + + srv, err := startServer(binary, path, port, cfg.GPULayers, cfg.ContextLen) + if err != nil { + return nil, err + } + + return &rocmModel{ + srv: srv, + modelType: guessModelType(path), + }, nil +} + +// guessModelType extracts a model type hint from the filename. +// Hyphens are stripped so that "Llama-3" matches "llama3". +func guessModelType(path string) string { + name := strings.ToLower(filepath.Base(path)) + name = strings.ReplaceAll(name, "-", "") + for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} { + if strings.Contains(name, arch) { + return arch + } + } + return "unknown" } diff --git a/server_test.go b/server_test.go index 84dcc83..0976322 100644 --- a/server_test.go +++ b/server_test.go @@ -3,6 +3,7 @@ package rocm import ( + "os" "strings" "testing" @@ -68,3 +69,29 @@ func TestServerEnv_FiltersExistingHIP(t *testing.T) { } assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals) } + +func TestAvailable(t *testing.T) { + b := &rocmBackend{} + if _, err := os.Stat("/dev/kfd"); err != nil { + t.Skip("no ROCm hardware") + } + assert.True(t, b.Available()) +} + +func TestGuessModelType(t *testing.T) { + tests := []struct { + path string + expected string + }{ + {"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"}, + {"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"}, + {"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"}, + {"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"}, + {"/data/models/random-model.gguf", "unknown"}, + } + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + assert.Equal(t, tt.expected, guessModelType(tt.path)) + }) + } +}