feat: Backend Available() and LoadModel() with GPU detection
Replace stub backend with real implementation: Available() checks /dev/kfd and llama-server presence, LoadModel() wires up server lifecycle to return a rocmModel. Add guessModelType() for architecture detection from GGUF filenames (handles hyphenated variants like Llama-3). Add TestAvailable and TestGuessModelType. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a8c494771d
commit
1d8d65f55b
2 changed files with 79 additions and 13 deletions
65
backend.go
65
backend.go
|
|
@ -2,27 +2,66 @@
|
|||
|
||||
package rocm
|
||||
|
||||
import "forge.lthn.ai/core/go-inference"
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
)
|
||||
|
||||
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
|
||||
// Uses llama-server (llama.cpp built with HIP) as the inference engine.
|
||||
type rocmBackend struct{}
|
||||
|
||||
func (b *rocmBackend) Name() string { return "rocm" }
|
||||
|
||||
// Available reports whether ROCm GPU inference can run on this machine.
|
||||
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
|
||||
func (b *rocmBackend) Available() bool {
|
||||
// TODO: Check for ROCm runtime + GPU presence
|
||||
// - /dev/kfd exists (ROCm kernel driver)
|
||||
// - rocm-smi detects a GPU
|
||||
// - llama-server binary is findable
|
||||
return false // Stub until Phase 1 implementation
|
||||
if _, err := os.Stat("/dev/kfd"); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := findLlamaServer(); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
|
||||
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
// TODO: Phase 1 implementation
|
||||
// 1. Find llama-server binary (PATH or configured location)
|
||||
// 2. Spawn llama-server with --model path --port <free> --n-gpu-layers cfg.GPULayers
|
||||
// 3. Wait for health endpoint to respond
|
||||
// 4. Return rocmModel wrapping the HTTP client
|
||||
return nil, nil
|
||||
cfg := inference.ApplyLoadOpts(opts)
|
||||
|
||||
binary, err := findLlamaServer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
port, err := freePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rocm: find free port: %w", err)
|
||||
}
|
||||
|
||||
srv, err := startServer(binary, path, port, cfg.GPULayers, cfg.ContextLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rocmModel{
|
||||
srv: srv,
|
||||
modelType: guessModelType(path),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// guessModelType extracts a model type hint from the filename.
|
||||
// Hyphens are stripped so that "Llama-3" matches "llama3".
|
||||
func guessModelType(path string) string {
|
||||
name := strings.ToLower(filepath.Base(path))
|
||||
name = strings.ReplaceAll(name, "-", "")
|
||||
for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} {
|
||||
if strings.Contains(name, arch) {
|
||||
return arch
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
package rocm
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
|
@ -68,3 +69,29 @@ func TestServerEnv_FiltersExistingHIP(t *testing.T) {
|
|||
}
|
||||
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
|
||||
}
|
||||
|
||||
func TestAvailable(t *testing.T) {
|
||||
b := &rocmBackend{}
|
||||
if _, err := os.Stat("/dev/kfd"); err != nil {
|
||||
t.Skip("no ROCm hardware")
|
||||
}
|
||||
assert.True(t, b.Available())
|
||||
}
|
||||
|
||||
func TestGuessModelType(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"},
|
||||
{"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"},
|
||||
{"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"},
|
||||
{"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"},
|
||||
{"/data/models/random-model.gguf", "unknown"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, guessModelType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue