go-rocm/backend.go
Claude 1d8d65f55b
feat: Backend Available() and LoadModel() with GPU detection
Replace stub backend with real implementation: Available() checks
/dev/kfd and llama-server presence, LoadModel() wires up server
lifecycle to return a rocmModel. Add guessModelType() for architecture
detection from GGUF filenames (handles hyphenated variants like
Llama-3). Add TestAvailable and TestGuessModelType.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:12:02 +00:00

67 lines
1.6 KiB
Go

//go:build linux && amd64
package rocm
import (
"fmt"
"os"
"path/filepath"
"strings"
"forge.lthn.ai/core/go-inference"
)
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
type rocmBackend struct{}
func (b *rocmBackend) Name() string { return "rocm" }
// Available reports whether ROCm GPU inference can run on this machine.
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
func (b *rocmBackend) Available() bool {
if _, err := os.Stat("/dev/kfd"); err != nil {
return false
}
if _, err := findLlamaServer(); err != nil {
return false
}
return true
}
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
binary, err := findLlamaServer()
if err != nil {
return nil, err
}
port, err := freePort()
if err != nil {
return nil, fmt.Errorf("rocm: find free port: %w", err)
}
srv, err := startServer(binary, path, port, cfg.GPULayers, cfg.ContextLen)
if err != nil {
return nil, err
}
return &rocmModel{
srv: srv,
modelType: guessModelType(path),
}, nil
}
// guessModelType extracts a model type hint from the filename.
// Hyphens are stripped so that "Llama-3" matches "llama3".
func guessModelType(path string) string {
name := strings.ToLower(filepath.Base(path))
name = strings.ReplaceAll(name, "-", "")
for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} {
if strings.Contains(name, arch) {
return arch
}
}
return "unknown"
}