go-rocm/backend.go
Claude 72120bb200
feat: pass --parallel N to llama-server for concurrent inference slots
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 23:13:19 +00:00

61 lines
1.6 KiB
Go

//go:build linux && amd64
package rocm
import (
"fmt"
"os"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/gguf"
)
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
type rocmBackend struct{}
func (b *rocmBackend) Name() string { return "rocm" }
// Available reports whether ROCm GPU inference can run on this machine.
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
func (b *rocmBackend) Available() bool {
if _, err := os.Stat("/dev/kfd"); err != nil {
return false
}
if _, err := findLlamaServer(); err != nil {
return false
}
return true
}
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
// If no context length is specified, defaults to min(model_context_length, 4096)
// to prevent VRAM exhaustion on models with 128K+ native context.
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
binary, err := findLlamaServer()
if err != nil {
return nil, err
}
meta, err := gguf.ReadMetadata(path)
if err != nil {
return nil, fmt.Errorf("rocm: read model metadata: %w", err)
}
ctxLen := cfg.ContextLen
if ctxLen == 0 && meta.ContextLength > 0 {
ctxLen = int(min(meta.ContextLength, 4096))
}
srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots)
if err != nil {
return nil, err
}
return &rocmModel{
srv: srv,
modelType: meta.Architecture,
}, nil
}