2026-02-19 19:39:40 +00:00
|
|
|
//go:build linux && amd64
|
|
|
|
|
|
|
|
|
|
package rocm
|
|
|
|
|
|
2026-02-19 21:12:02 +00:00
|
|
|
import (
|
|
|
|
|
"os"
|
2026-02-24 18:50:37 +00:00
|
|
|
"strings"
|
2026-02-19 21:12:02 +00:00
|
|
|
|
2026-03-16 21:08:52 +00:00
|
|
|
coreerr "forge.lthn.ai/core/go-log"
|
2026-02-19 21:12:02 +00:00
|
|
|
"forge.lthn.ai/core/go-inference"
|
2026-02-19 22:23:07 +00:00
|
|
|
"forge.lthn.ai/core/go-rocm/internal/gguf"
|
2026-02-19 21:12:02 +00:00
|
|
|
)
|
2026-02-19 19:39:40 +00:00
|
|
|
|
|
|
|
|
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
|
|
|
|
|
type rocmBackend struct{}
|
|
|
|
|
|
|
|
|
|
func (b *rocmBackend) Name() string { return "rocm" }
|
|
|
|
|
|
2026-02-19 21:12:02 +00:00
|
|
|
// Available reports whether ROCm GPU inference can run on this machine.
|
|
|
|
|
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
|
2026-02-19 19:39:40 +00:00
|
|
|
func (b *rocmBackend) Available() bool {
|
2026-02-19 21:12:02 +00:00
|
|
|
if _, err := os.Stat("/dev/kfd"); err != nil {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
if _, err := findLlamaServer(); err != nil {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
return true
|
2026-02-19 19:39:40 +00:00
|
|
|
}
|
|
|
|
|
|
2026-02-19 21:12:02 +00:00
|
|
|
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
|
2026-02-19 22:23:07 +00:00
|
|
|
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
|
|
|
|
|
// If no context length is specified, defaults to min(model_context_length, 4096)
|
|
|
|
|
// to prevent VRAM exhaustion on models with 128K+ native context.
|
2026-02-19 19:39:40 +00:00
|
|
|
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
2026-02-19 21:12:02 +00:00
|
|
|
cfg := inference.ApplyLoadOpts(opts)
|
|
|
|
|
|
|
|
|
|
binary, err := findLlamaServer()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-19 22:23:07 +00:00
|
|
|
meta, err := gguf.ReadMetadata(path)
|
|
|
|
|
if err != nil {
|
2026-03-16 21:08:52 +00:00
|
|
|
return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
|
2026-02-19 22:23:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctxLen := cfg.ContextLen
|
|
|
|
|
if ctxLen == 0 && meta.ContextLength > 0 {
|
|
|
|
|
ctxLen = int(min(meta.ContextLength, 4096))
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-19 23:13:19 +00:00
|
|
|
srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots)
|
2026-02-19 21:12:02 +00:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-24 18:50:37 +00:00
|
|
|
// Map quantisation file type to bit width.
|
|
|
|
|
quantBits := 0
|
|
|
|
|
quantGroup := 0
|
|
|
|
|
ftName := gguf.FileTypeName(meta.FileType)
|
|
|
|
|
switch {
|
|
|
|
|
case strings.HasPrefix(ftName, "Q4_"):
|
|
|
|
|
quantBits = 4
|
|
|
|
|
quantGroup = 32
|
|
|
|
|
case strings.HasPrefix(ftName, "Q5_"):
|
|
|
|
|
quantBits = 5
|
|
|
|
|
quantGroup = 32
|
|
|
|
|
case strings.HasPrefix(ftName, "Q8_"):
|
|
|
|
|
quantBits = 8
|
|
|
|
|
quantGroup = 32
|
|
|
|
|
case strings.HasPrefix(ftName, "Q2_"):
|
|
|
|
|
quantBits = 2
|
|
|
|
|
quantGroup = 16
|
|
|
|
|
case strings.HasPrefix(ftName, "Q3_"):
|
|
|
|
|
quantBits = 3
|
|
|
|
|
quantGroup = 32
|
|
|
|
|
case strings.HasPrefix(ftName, "Q6_"):
|
|
|
|
|
quantBits = 6
|
|
|
|
|
quantGroup = 64
|
|
|
|
|
case ftName == "F16":
|
|
|
|
|
quantBits = 16
|
|
|
|
|
case ftName == "F32":
|
|
|
|
|
quantBits = 32
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-19 21:12:02 +00:00
|
|
|
return &rocmModel{
|
|
|
|
|
srv: srv,
|
2026-02-19 22:23:07 +00:00
|
|
|
modelType: meta.Architecture,
|
2026-02-24 18:50:37 +00:00
|
|
|
modelInfo: inference.ModelInfo{
|
|
|
|
|
Architecture: meta.Architecture,
|
|
|
|
|
NumLayers: int(meta.BlockCount),
|
|
|
|
|
QuantBits: quantBits,
|
|
|
|
|
QuantGroup: quantGroup,
|
|
|
|
|
},
|
2026-02-19 21:12:02 +00:00
|
|
|
}, nil
|
|
|
|
|
}
|