go-rocm/backend.go
Claude 41b34b6779
feat(ax): apply RFC-025 AX compliance review
Principle 1 — Predictable Names:
- rocmModel.srv → rocmModel.server (struct field)
- recordMetrics: met → metrics (local var)
- backend.go/model.go: cfg → config (local vars)
- gguf.go: tc/kc → tensorCount32/kvCount32 (v2 count reads)

Principle 2 — Comments as Usage Examples:
- Added concrete usage examples to all exported functions:
  VRAMInfo, ModelInfo, DiscoverModels, GetVRAMInfo,
  ROCmAvailable, LoadModel, Available, NewClient, Health,
  ChatComplete, Complete, ReadMetadata, FileTypeName

Principle 5 — Test naming (_Good/_Bad/_Ugly):
- All test functions renamed to AX-7 convention across:
  discover_test.go, vram_test.go, server_test.go,
  internal/gguf/gguf_test.go, internal/llamacpp/client_test.go,
  internal/llamacpp/health_test.go

Also: fix go.sum missing entry for dappco.re/go/core transitive dep
(pulled in by go-inference replace directive).

All tests pass: go test ./... -short -count=1

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 07:33:47 +01:00

106 lines
2.7 KiB
Go

//go:build linux && amd64
package rocm
import (
"os"
"strings"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/gguf"
)
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
type rocmBackend struct{}
func (b *rocmBackend) Name() string { return "rocm" }
// Available reports whether ROCm GPU inference can run on this machine.
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
//
// b := inference.FindBackend("rocm")
// if b.Available() { /* safe to LoadModel */ }
func (b *rocmBackend) Available() bool {
if _, err := os.Stat("/dev/kfd"); err != nil {
return false
}
if _, err := findLlamaServer(); err != nil {
return false
}
return true
}
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
// If no context length is specified, defaults to min(model_context_length, 4096)
// to prevent VRAM exhaustion on models with 128K+ native context.
//
// m, err := backend.LoadModel("/data/lem/gguf/model.gguf",
// inference.WithContextLen(4096),
// inference.WithGPULayers(-1),
// )
// defer m.Close()
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
config := inference.ApplyLoadOpts(opts)
binary, err := findLlamaServer()
if err != nil {
return nil, err
}
meta, err := gguf.ReadMetadata(path)
if err != nil {
return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
}
ctxLen := config.ContextLen
if ctxLen == 0 && meta.ContextLength > 0 {
ctxLen = int(min(meta.ContextLength, 4096))
}
srv, err := startServer(binary, path, config.GPULayers, ctxLen, config.ParallelSlots)
if err != nil {
return nil, err
}
// Map quantisation file type to bit width.
quantBits := 0
quantGroup := 0
ftName := gguf.FileTypeName(meta.FileType)
switch {
case strings.HasPrefix(ftName, "Q4_"):
quantBits = 4
quantGroup = 32
case strings.HasPrefix(ftName, "Q5_"):
quantBits = 5
quantGroup = 32
case strings.HasPrefix(ftName, "Q8_"):
quantBits = 8
quantGroup = 32
case strings.HasPrefix(ftName, "Q2_"):
quantBits = 2
quantGroup = 16
case strings.HasPrefix(ftName, "Q3_"):
quantBits = 3
quantGroup = 32
case strings.HasPrefix(ftName, "Q6_"):
quantBits = 6
quantGroup = 64
case ftName == "F16":
quantBits = 16
case ftName == "F32":
quantBits = 32
}
return &rocmModel{
server: srv,
modelType: meta.Architecture,
modelInfo: inference.ModelInfo{
Architecture: meta.Architecture,
NumLayers: int(meta.BlockCount),
QuantBits: quantBits,
QuantGroup: quantGroup,
},
}, nil
}