feat: Backend Available() and LoadModel() with GPU detection

Replace stub backend with real implementation: Available() checks
/dev/kfd and llama-server presence, LoadModel() wires up server
lifecycle to return a rocmModel. Add guessModelType() for architecture
detection from GGUF filenames (handles hyphenated variants like
Llama-3). Add TestAvailable and TestGuessModelType.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-19 21:12:02 +00:00
parent a8c494771d
commit 1d8d65f55b
No known key found for this signature in database
GPG key ID: AF404715446AEB41
2 changed files with 79 additions and 13 deletions

View file

@ -2,27 +2,66 @@
package rocm
import "forge.lthn.ai/core/go-inference"
import (
"fmt"
"os"
"path/filepath"
"strings"
"forge.lthn.ai/core/go-inference"
)
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
// Uses llama-server (llama.cpp built with HIP) as the inference engine.
type rocmBackend struct{}
func (b *rocmBackend) Name() string { return "rocm" }
// Available reports whether ROCm GPU inference can run on this machine.
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
func (b *rocmBackend) Available() bool {
// TODO: Check for ROCm runtime + GPU presence
// - /dev/kfd exists (ROCm kernel driver)
// - rocm-smi detects a GPU
// - llama-server binary is findable
return false // Stub until Phase 1 implementation
if _, err := os.Stat("/dev/kfd"); err != nil {
return false
}
if _, err := findLlamaServer(); err != nil {
return false
}
return true
}
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
// TODO: Phase 1 implementation
// 1. Find llama-server binary (PATH or configured location)
// 2. Spawn llama-server with --model path --port <free> --n-gpu-layers cfg.GPULayers
// 3. Wait for health endpoint to respond
// 4. Return rocmModel wrapping the HTTP client
return nil, nil
cfg := inference.ApplyLoadOpts(opts)
binary, err := findLlamaServer()
if err != nil {
return nil, err
}
port, err := freePort()
if err != nil {
return nil, fmt.Errorf("rocm: find free port: %w", err)
}
srv, err := startServer(binary, path, port, cfg.GPULayers, cfg.ContextLen)
if err != nil {
return nil, err
}
return &rocmModel{
srv: srv,
modelType: guessModelType(path),
}, nil
}
// guessModelType extracts a model type hint from the filename.
// Hyphens are stripped so that "Llama-3" matches "llama3".
func guessModelType(path string) string {
name := strings.ToLower(filepath.Base(path))
name = strings.ReplaceAll(name, "-", "")
for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} {
if strings.Contains(name, arch) {
return arch
}
}
return "unknown"
}

View file

@ -3,6 +3,7 @@
package rocm
import (
"os"
"strings"
"testing"
@ -68,3 +69,29 @@ func TestServerEnv_FiltersExistingHIP(t *testing.T) {
}
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
func TestAvailable(t *testing.T) {
b := &rocmBackend{}
if _, err := os.Stat("/dev/kfd"); err != nil {
t.Skip("no ROCm hardware")
}
assert.True(t, b.Available())
}
func TestGuessModelType(t *testing.T) {
tests := []struct {
path string
expected string
}{
{"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"},
{"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"},
{"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"},
{"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"},
{"/data/models/random-model.gguf", "unknown"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
assert.Equal(t, tt.expected, guessModelType(tt.path))
})
}
}