feat(ax): apply RFC-025 AX compliance review

Principle 1 — Predictable Names:
- rocmModel.srv → rocmModel.server (struct field)
- recordMetrics: met → metrics (local var)
- backend.go/model.go: cfg → config (local vars)
- gguf.go: tc/kc → tensorCount32/kvCount32 (v2 count reads)

Principle 2 — Comments as Usage Examples:
- Added concrete usage examples to all exported functions:
  VRAMInfo, ModelInfo, DiscoverModels, GetVRAMInfo,
  ROCmAvailable, LoadModel, Available, NewClient, Health,
  ChatComplete, Complete, ReadMetadata, FileTypeName

Principle 5 — Test naming (_Good/_Bad/_Ugly):
- All test functions renamed to AX-7 convention across:
  discover_test.go, vram_test.go, server_test.go,
  internal/gguf/gguf_test.go, internal/llamacpp/client_test.go,
  internal/llamacpp/health_test.go

Also: fix go.sum missing entry for dappco.re/go/core transitive dep
(pulled in by go-inference replace directive).

All tests pass: go test ./... -short -count=1

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Claude 2026-03-31 07:33:47 +01:00
parent 2fa87bfeb6
commit 41b34b6779
No known key found for this signature in database
GPG key ID: AF404715446AEB41
18 changed files with 169 additions and 102 deletions

View file

@ -18,6 +18,9 @@ func (b *rocmBackend) Name() string { return "rocm" }
// Available reports whether ROCm GPU inference can run on this machine.
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
//
// b := inference.FindBackend("rocm")
// if b.Available() { /* safe to LoadModel */ }
func (b *rocmBackend) Available() bool {
if _, err := os.Stat("/dev/kfd"); err != nil {
return false
@ -32,8 +35,14 @@ func (b *rocmBackend) Available() bool {
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
// If no context length is specified, defaults to min(model_context_length, 4096)
// to prevent VRAM exhaustion on models with 128K+ native context.
//
// m, err := backend.LoadModel("/data/lem/gguf/model.gguf",
// inference.WithContextLen(4096),
// inference.WithGPULayers(-1),
// )
// defer m.Close()
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
config := inference.ApplyLoadOpts(opts)
binary, err := findLlamaServer()
if err != nil {
@ -45,12 +54,12 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
}
ctxLen := cfg.ContextLen
ctxLen := config.ContextLen
if ctxLen == 0 && meta.ContextLength > 0 {
ctxLen = int(min(meta.ContextLength, 4096))
}
srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots)
srv, err := startServer(binary, path, config.GPULayers, ctxLen, config.ParallelSlots)
if err != nil {
return nil, err
}
@ -85,7 +94,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
}
return &rocmModel{
srv: srv,
server: srv,
modelType: meta.Architecture,
modelInfo: inference.ModelInfo{
Architecture: meta.Architecture,

View file

@ -6,8 +6,13 @@ import (
"forge.lthn.ai/core/go-rocm/internal/gguf"
)
// DiscoverModels scans a directory for GGUF model files and returns
// structured information about each. Files that cannot be parsed are skipped.
// DiscoverModels scans a directory for GGUF model files.
// Files that cannot be parsed are silently skipped.
//
// models, err := rocm.DiscoverModels("/data/lem/gguf")
// for _, m := range models {
// fmt.Printf("%s: %s %s ctx=%d\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
// }
func DiscoverModels(dir string) ([]ModelInfo, error) {
matches, err := filepath.Glob(filepath.Join(dir, "*.gguf"))
if err != nil {

View file

@ -63,7 +63,7 @@ func writeDiscoverKV(t *testing.T, f *os.File, key string, val any) {
}
}
func TestDiscoverModels(t *testing.T) {
func TestDiscoverModels_Good(t *testing.T) {
dir := t.TempDir()
// Create two valid GGUF model files.
@ -112,7 +112,7 @@ func TestDiscoverModels(t *testing.T) {
assert.Greater(t, llama.FileSize, int64(0))
}
func TestDiscoverModels_EmptyDir(t *testing.T) {
func TestDiscoverModels_Good_EmptyDir(t *testing.T) {
dir := t.TempDir()
models, err := DiscoverModels(dir)
@ -120,7 +120,7 @@ func TestDiscoverModels_EmptyDir(t *testing.T) {
assert.Empty(t, models)
}
func TestDiscoverModels_NotFound(t *testing.T) {
func TestDiscoverModels_Good_NonExistentDir(t *testing.T) {
// filepath.Glob returns nil, nil for a pattern matching no files,
// even when the directory does not exist.
models, err := DiscoverModels("/nonexistent/dir")
@ -128,7 +128,7 @@ func TestDiscoverModels_NotFound(t *testing.T) {
assert.Empty(t, models)
}
func TestDiscoverModels_SkipsCorruptFile(t *testing.T) {
func TestDiscoverModels_Ugly_SkipsCorruptFile(t *testing.T) {
dir := t.TempDir()
// Create a valid GGUF file.

5
go.mod
View file

@ -7,7 +7,10 @@ require (
forge.lthn.ai/core/go-log v0.0.4
)
require github.com/kr/text v0.2.0 // indirect
require (
dappco.re/go/core v0.8.0-alpha.1 // indirect
github.com/kr/text v0.2.0 // indirect
)
require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect

2
go.sum
View file

@ -1,3 +1,5 @@
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=

View file

@ -72,6 +72,10 @@ var fileTypeNames = map[uint32]string{
// FileTypeName returns a human-readable name for a GGML quantisation file type.
// Unknown types return "type_N" where N is the numeric value.
//
// gguf.FileTypeName(15) // "Q4_K_M"
// gguf.FileTypeName(7) // "Q8_0"
// gguf.FileTypeName(1) // "F16"
func FileTypeName(ft uint32) string {
if name, ok := fileTypeNames[ft]; ok {
return name
@ -81,6 +85,9 @@ func FileTypeName(ft uint32) string {
// ReadMetadata reads the GGUF header from the file at path and returns the
// extracted metadata. Only metadata KV pairs are read; tensor data is not loaded.
//
// meta, err := gguf.ReadMetadata("/data/model.gguf")
// fmt.Printf("arch=%s ctx=%d quant=%s", meta.Architecture, meta.ContextLength, gguf.FileTypeName(meta.FileType))
func ReadMetadata(path string) (Metadata, error) {
f, err := os.Open(path)
if err != nil {
@ -123,15 +130,15 @@ func ReadMetadata(path string) (Metadata, error) {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
}
} else {
var tc, kc uint32
if err := binary.Read(r, binary.LittleEndian, &tc); err != nil {
var tensorCount32, kvCount32 uint32
if err := binary.Read(r, binary.LittleEndian, &tensorCount32); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
}
if err := binary.Read(r, binary.LittleEndian, &kc); err != nil {
if err := binary.Read(r, binary.LittleEndian, &kvCount32); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
}
tensorCount = uint64(tc)
kvCount = uint64(kc)
tensorCount = uint64(tensorCount32)
kvCount = uint64(kvCount32)
}
_ = tensorCount // we only read metadata KVs

View file

@ -109,7 +109,7 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string {
return path
}
func TestReadMetadata_Gemma3(t *testing.T) {
func TestReadMetadata_Good_Gemma3(t *testing.T) {
path := writeTestGGUFOrdered(t, [][2]any{
{"general.architecture", "gemma3"},
{"general.name", "Test Gemma3 1B"},
@ -131,7 +131,7 @@ func TestReadMetadata_Gemma3(t *testing.T) {
assert.Greater(t, m.FileSize, int64(0))
}
func TestReadMetadata_Llama(t *testing.T) {
func TestReadMetadata_Good_Llama(t *testing.T) {
path := writeTestGGUFOrdered(t, [][2]any{
{"general.architecture", "llama"},
{"general.name", "Test Llama 8B"},
@ -153,7 +153,7 @@ func TestReadMetadata_Llama(t *testing.T) {
assert.Greater(t, m.FileSize, int64(0))
}
func TestReadMetadata_ArchAfterContextLength(t *testing.T) {
func TestReadMetadata_Ugly_ArchAfterContextLength(t *testing.T) {
// Architecture key comes AFTER the arch-specific keys.
// The parser must handle deferred resolution of arch-prefixed keys.
path := writeTestGGUFOrdered(t, [][2]any{
@ -174,7 +174,7 @@ func TestReadMetadata_ArchAfterContextLength(t *testing.T) {
assert.Equal(t, uint32(32), m.BlockCount)
}
func TestReadMetadata_InvalidMagic(t *testing.T) {
func TestReadMetadata_Bad_InvalidMagic(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "notgguf.bin")
@ -186,12 +186,12 @@ func TestReadMetadata_InvalidMagic(t *testing.T) {
assert.Contains(t, err.Error(), "invalid magic")
}
func TestReadMetadata_FileNotFound(t *testing.T) {
func TestReadMetadata_Bad_FileNotFound(t *testing.T) {
_, err := ReadMetadata("/nonexistent/path/model.gguf")
require.Error(t, err)
}
func TestFileTypeName(t *testing.T) {
func TestFileTypeName_Good(t *testing.T) {
assert.Equal(t, "Q4_K_M", FileTypeName(15))
assert.Equal(t, "Q5_K_M", FileTypeName(17))
assert.Equal(t, "Q8_0", FileTypeName(7))
@ -199,7 +199,7 @@ func TestFileTypeName(t *testing.T) {
assert.Equal(t, "type_999", FileTypeName(999))
}
func TestReadMetadata_V2(t *testing.T) {
func TestReadMetadata_Good_V2(t *testing.T) {
// GGUF v2 uses uint32 for tensor and KV counts (instead of uint64 in v3).
path := writeTestGGUFV2(t, [][2]any{
{"general.architecture", "llama"},
@ -219,7 +219,7 @@ func TestReadMetadata_V2(t *testing.T) {
assert.Equal(t, uint32(16), m.BlockCount)
}
func TestReadMetadata_UnsupportedVersion(t *testing.T) {
func TestReadMetadata_Bad_UnsupportedVersion(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad_version.gguf")
@ -235,7 +235,7 @@ func TestReadMetadata_UnsupportedVersion(t *testing.T) {
assert.Contains(t, err.Error(), "unsupported GGUF version")
}
func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
func TestReadMetadata_Ugly_SkipsUnknownValueTypes(t *testing.T) {
// Tests skipValue for uint8, int16, float32, uint64, bool, and array types.
// These are stored under uninteresting keys so ReadMetadata skips them.
dir := t.TempDir()
@ -299,7 +299,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
assert.Equal(t, "Skip Test Model", m.Name)
}
func TestReadMetadata_Uint64ContextLength(t *testing.T) {
func TestReadMetadata_Ugly_Uint64ContextLength(t *testing.T) {
// context_length stored as uint64 that fits in uint32 — readTypedValue
// should downcast it to uint32.
path := writeTestGGUFOrdered(t, [][2]any{
@ -315,7 +315,7 @@ func TestReadMetadata_Uint64ContextLength(t *testing.T) {
assert.Equal(t, uint32(32), m.BlockCount)
}
func TestReadMetadata_TruncatedFile(t *testing.T) {
func TestReadMetadata_Bad_TruncatedFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "truncated.gguf")
@ -330,7 +330,7 @@ func TestReadMetadata_TruncatedFile(t *testing.T) {
assert.Contains(t, err.Error(), "reading version")
}
func TestReadMetadata_SkipsStringValue(t *testing.T) {
func TestReadMetadata_Ugly_SkipsStringValue(t *testing.T) {
// Tests skipValue for string type (type 8) on an uninteresting key.
dir := t.TempDir()
path := filepath.Join(dir, "skip_string.gguf")

View file

@ -60,8 +60,14 @@ type completionChunkResponse struct {
}
// ChatComplete sends a streaming chat completion request to /v1/chat/completions.
// It returns an iterator over text chunks and a function that returns any error
// that occurred during the request or while reading the stream.
// Returns an iterator over text chunks and an error accessor called after ranging.
//
// chunks, errFn := client.ChatComplete(ctx, llamacpp.ChatRequest{
// Messages: []llamacpp.ChatMessage{{Role: "user", Content: "Hello"}},
// MaxTokens: 128, Temperature: 0.7,
// })
// for text := range chunks { fmt.Print(text) }
// if err := errFn(); err != nil { /* handle */ }
func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) {
req.Stream = true
@ -125,8 +131,13 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
}
// Complete sends a streaming completion request to /v1/completions.
// It returns an iterator over text chunks and a function that returns any error
// that occurred during the request or while reading the stream.
// Returns an iterator over text chunks and an error accessor called after ranging.
//
// chunks, errFn := client.Complete(ctx, llamacpp.CompletionRequest{
// Prompt: "The capital of France is", MaxTokens: 16, Temperature: 0.0,
// })
// for text := range chunks { fmt.Print(text) }
// if err := errFn(); err != nil { /* handle */ }
func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) {
req.Stream = true

View file

@ -27,7 +27,7 @@ func sseLines(w http.ResponseWriter, lines []string) {
}
}
func TestChatComplete_Streaming(t *testing.T) {
func TestChatComplete_Good_Streaming(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
assert.Equal(t, "POST", r.Method)
@ -56,7 +56,7 @@ func TestChatComplete_Streaming(t *testing.T) {
assert.Equal(t, []string{"Hello", " world"}, got)
}
func TestChatComplete_EmptyResponse(t *testing.T) {
func TestChatComplete_Good_EmptyResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sseLines(w, []string{"[DONE]"})
}))
@ -77,7 +77,7 @@ func TestChatComplete_EmptyResponse(t *testing.T) {
assert.Empty(t, got)
}
func TestChatComplete_HTTPError(t *testing.T) {
func TestChatComplete_Bad_HTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal server error", http.StatusInternalServerError)
}))
@ -100,7 +100,7 @@ func TestChatComplete_HTTPError(t *testing.T) {
assert.Contains(t, err.Error(), "500")
}
func TestChatComplete_ContextCancelled(t *testing.T) {
func TestChatComplete_Ugly_ContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -140,7 +140,7 @@ func TestChatComplete_ContextCancelled(t *testing.T) {
assert.Equal(t, []string{"Hello"}, got)
}
func TestComplete_Streaming(t *testing.T) {
func TestComplete_Good_Streaming(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/completions", r.URL.Path)
assert.Equal(t, "POST", r.Method)
@ -170,7 +170,7 @@ func TestComplete_Streaming(t *testing.T) {
assert.Equal(t, []string{"Once", " upon", " a time"}, got)
}
func TestComplete_HTTPError(t *testing.T) {
func TestComplete_Bad_HTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
}))

View file

@ -18,6 +18,11 @@ type Client struct {
}
// NewClient creates a client for the llama-server at the given base URL.
//
// client := llamacpp.NewClient("http://127.0.0.1:8080")
// if err := client.Health(ctx); err != nil {
// // server not ready
// }
func NewClient(baseURL string) *Client {
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
@ -30,6 +35,11 @@ type healthResponse struct {
}
// Health checks whether the llama-server is ready to accept requests.
// Returns nil when the server responds with {"status":"ok"}.
//
// if err := client.Health(ctx); err != nil {
// log.Printf("server not ready: %v", err)
// }
func (c *Client) Health(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
if err != nil {

View file

@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestHealth_OK(t *testing.T) {
func TestHealth_Good(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/health", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
@ -23,7 +23,7 @@ func TestHealth_OK(t *testing.T) {
require.NoError(t, err)
}
func TestHealth_NotReady(t *testing.T) {
func TestHealth_Bad_NotReady(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"loading model"}`))
@ -35,7 +35,7 @@ func TestHealth_NotReady(t *testing.T) {
assert.ErrorContains(t, err, "not ready")
}
func TestHealth_Loading(t *testing.T) {
func TestHealth_Bad_Loading(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
@ -48,7 +48,7 @@ func TestHealth_Loading(t *testing.T) {
assert.ErrorContains(t, err, "503")
}
func TestHealth_ServerDown(t *testing.T) {
func TestHealth_Bad_ServerDown(t *testing.T) {
c := NewClient("http://127.0.0.1:1") // nothing listening
err := c.Health(context.Background())
assert.Error(t, err)

View file

@ -17,7 +17,7 @@ import (
// rocmModel implements inference.TextModel using a llama-server subprocess.
type rocmModel struct {
srv *server
server *server
modelType string
modelInfo inference.ModelInfo
@ -32,24 +32,24 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
m.lastErr = nil
m.mu.Unlock()
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
cfg := inference.ApplyGenerateOpts(opts)
config := inference.ApplyGenerateOpts(opts)
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopK: config.TopK,
TopP: config.TopP,
RepeatPenalty: config.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.srv.client.Complete(ctx, req)
chunks, errFn := m.server.client.Complete(ctx, req)
return func(yield func(inference.Token) bool) {
var count int
@ -75,12 +75,12 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
m.lastErr = nil
m.mu.Unlock()
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
cfg := inference.ApplyGenerateOpts(opts)
config := inference.ApplyGenerateOpts(opts)
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
for i, msg := range messages {
@ -92,15 +92,15 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
req := llamacpp.ChatRequest{
Messages: chatMsgs,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopK: config.TopK,
TopP: config.TopP,
RepeatPenalty: config.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.srv.client.ChatComplete(ctx, req)
chunks, errFn := m.server.client.ChatComplete(ctx, req)
return func(yield func(inference.Token) bool) {
var count int
@ -124,7 +124,7 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
// llama-server has no native classify endpoint, so this simulates it.
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return nil, m.Err()
}
@ -143,7 +143,7 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
Temperature: 0,
}
chunks, errFn := m.srv.client.Complete(ctx, req)
chunks, errFn := m.server.client.Complete(ctx, req)
var text strings.Builder
for chunk := range chunks {
text.WriteString(chunk)
@ -164,12 +164,12 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
// BatchGenerate runs batched autoregressive generation via llama-server.
// Each prompt is decoded sequentially up to MaxTokens.
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return nil, m.Err()
}
cfg := inference.ApplyGenerateOpts(opts)
config := inference.ApplyGenerateOpts(opts)
start := time.Now()
results := make([]inference.BatchResult, len(prompts))
var totalGenerated int
@ -182,14 +182,14 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopK: config.TopK,
TopP: config.TopP,
RepeatPenalty: config.RepeatPenalty,
}
chunks, errFn := m.srv.client.Complete(ctx, req)
chunks, errFn := m.server.client.Complete(ctx, req)
var tokens []inference.Token
for text := range chunks {
tokens = append(tokens, inference.Token{Text: text})
@ -227,15 +227,15 @@ func (m *rocmModel) Err() error {
// Close releases the llama-server subprocess and all associated resources.
func (m *rocmModel) Close() error {
return m.srv.stop()
return m.server.stop()
}
// setServerExitErr stores an appropriate error when the server is dead.
func (m *rocmModel) setServerExitErr() {
m.mu.Lock()
defer m.mu.Unlock()
if m.srv.exitErr != nil {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr)
if m.server.exitErr != nil {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.server.exitErr)
} else {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil)
}
@ -248,7 +248,7 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
decode := now.Sub(decodeStart)
prefill := total - decode
met := inference.GenerateMetrics{
metrics := inference.GenerateMetrics{
PromptTokens: promptTokens,
GeneratedTokens: generatedTokens,
PrefillDuration: prefill,
@ -256,19 +256,19 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
TotalDuration: total,
}
if prefill > 0 && promptTokens > 0 {
met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
metrics.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
}
if decode > 0 && generatedTokens > 0 {
met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
metrics.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
}
// Try to get VRAM stats — best effort.
if vram, err := GetVRAMInfo(); err == nil {
met.PeakMemoryBytes = vram.Used
met.ActiveMemoryBytes = vram.Used
metrics.PeakMemoryBytes = vram.Used
metrics.ActiveMemoryBytes = vram.Used
}
m.mu.Lock()
m.metrics = met
m.metrics = metrics
m.mu.Unlock()
}

View file

@ -8,5 +8,10 @@ func init() {
inference.Register(&rocmBackend{})
}
// ROCmAvailable reports whether ROCm GPU inference is available.
// ROCmAvailable reports whether ROCm GPU inference is available on this machine.
// Returns true only on linux/amd64 with /dev/kfd present and llama-server findable.
//
// if rocm.ROCmAvailable() {
// m, err := inference.LoadModel("/data/model.gguf")
// }
func ROCmAvailable() bool { return true }

View file

@ -25,6 +25,9 @@
package rocm
// VRAMInfo reports GPU video memory usage in bytes.
//
// info, err := rocm.GetVRAMInfo()
// fmt.Printf("VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024))
type VRAMInfo struct {
Total uint64
Used uint64
@ -32,6 +35,11 @@ type VRAMInfo struct {
}
// ModelInfo describes a GGUF model file discovered on disk.
//
// models, _ := rocm.DiscoverModels("/data/lem/gguf")
// for _, m := range models {
// fmt.Printf("%s (%s %s, ctx=%d)\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
// }
type ModelInfo struct {
Path string // full path to .gguf file
Architecture string // GGUF architecture (e.g. "gemma3", "llama", "qwen2")

View file

@ -4,8 +4,12 @@ package rocm
import coreerr "forge.lthn.ai/core/go-log"
// ROCmAvailable reports whether ROCm GPU inference is available.
// ROCmAvailable reports whether ROCm GPU inference is available on this machine.
// Returns false on non-Linux or non-amd64 platforms.
//
// if rocm.ROCmAvailable() {
// m, err := inference.LoadModel("/data/model.gguf")
// }
func ROCmAvailable() bool { return false }
// GetVRAMInfo is not available on non-Linux/non-amd64 platforms.

View file

@ -14,34 +14,34 @@ import (
"github.com/stretchr/testify/require"
)
func TestFindLlamaServer_InPATH(t *testing.T) {
func TestFindLlamaServer_Good_InPATH(t *testing.T) {
// llama-server is at /usr/local/bin/llama-server on this machine.
path, err := findLlamaServer()
require.NoError(t, err)
assert.Contains(t, path, "llama-server")
}
func TestFindLlamaServer_EnvOverride(t *testing.T) {
func TestFindLlamaServer_Good_EnvOverride(t *testing.T) {
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/usr/local/bin/llama-server")
path, err := findLlamaServer()
require.NoError(t, err)
assert.Equal(t, "/usr/local/bin/llama-server", path)
}
func TestFindLlamaServer_EnvNotFound(t *testing.T) {
func TestFindLlamaServer_Bad_EnvPathMissing(t *testing.T) {
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/nonexistent/llama-server")
_, err := findLlamaServer()
assert.ErrorContains(t, err, "not found")
}
func TestFreePort(t *testing.T) {
func TestFreePort_Good(t *testing.T) {
port, err := freePort()
require.NoError(t, err)
assert.Greater(t, port, 0)
assert.Less(t, port, 65536)
}
func TestFreePort_UniquePerCall(t *testing.T) {
func TestFreePort_Good_UniquePerCall(t *testing.T) {
p1, err := freePort()
require.NoError(t, err)
p2, err := freePort()
@ -50,7 +50,7 @@ func TestFreePort_UniquePerCall(t *testing.T) {
_ = p2
}
func TestServerEnv_HIPVisibleDevices(t *testing.T) {
func TestServerEnv_Good_SetsHIPVisibleDevices(t *testing.T) {
env := serverEnv()
var hipVals []string
for _, e := range env {
@ -61,7 +61,7 @@ func TestServerEnv_HIPVisibleDevices(t *testing.T) {
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
func TestServerEnv_FiltersExistingHIP(t *testing.T) {
func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) {
t.Setenv("HIP_VISIBLE_DEVICES", "1")
env := serverEnv()
var hipVals []string
@ -73,7 +73,7 @@ func TestServerEnv_FiltersExistingHIP(t *testing.T) {
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
func TestAvailable(t *testing.T) {
func TestAvailable_Good(t *testing.T) {
b := &rocmBackend{}
if _, err := os.Stat("/dev/kfd"); err != nil {
t.Skip("no ROCm hardware")
@ -81,27 +81,26 @@ func TestAvailable(t *testing.T) {
assert.True(t, b.Available())
}
func TestServerAlive_Running(t *testing.T) {
func TestServerAlive_Good_Running(t *testing.T) {
s := &server{exited: make(chan struct{})}
assert.True(t, s.alive())
}
func TestServerAlive_Exited(t *testing.T) {
func TestServerAlive_Good_Exited(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: coreerr.E("test", "process killed", nil)}
assert.False(t, s.alive())
}
func TestGenerate_ServerDead(t *testing.T) {
func TestGenerate_Bad_ServerDead(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{
exited: exited,
exitErr: coreerr.E("test", "process killed", nil),
}
m := &rocmModel{srv: s}
m := &rocmModel{server: s}
var count int
for range m.Generate(context.Background(), "hello") {
@ -111,7 +110,7 @@ func TestGenerate_ServerDead(t *testing.T) {
assert.ErrorContains(t, m.Err(), "server has exited")
}
func TestStartServer_RetriesOnProcessExit(t *testing.T) {
func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) {
// /bin/false starts successfully but exits immediately with code 1.
// startServer should retry up to 3 times, then fail.
_, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0, 0)
@ -119,14 +118,14 @@ func TestStartServer_RetriesOnProcessExit(t *testing.T) {
assert.Contains(t, err.Error(), "failed after 3 attempts")
}
func TestChat_ServerDead(t *testing.T) {
func TestChat_Bad_ServerDead(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{
exited: exited,
exitErr: coreerr.E("test", "process killed", nil),
}
m := &rocmModel{srv: s}
m := &rocmModel{server: s}
msgs := []inference.Message{{Role: "user", Content: "hello"}}
var count int

View file

@ -17,6 +17,10 @@ import (
//
// Note: total and used are read non-atomically from sysfs; transient
// inconsistencies are possible under heavy allocation churn.
//
// info, err := rocm.GetVRAMInfo()
// fmt.Printf("VRAM: %d MiB used / %d MiB total (free: %d MiB)",
// info.Used/(1024*1024), info.Total/(1024*1024), info.Free/(1024*1024))
func GetVRAMInfo() (VRAMInfo, error) {
cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
if err != nil {

View file

@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestReadSysfsUint64(t *testing.T) {
func TestReadSysfsUint64_Good(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test_value")
require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644))
@ -21,12 +21,12 @@ func TestReadSysfsUint64(t *testing.T) {
assert.Equal(t, uint64(17163091968), val)
}
func TestReadSysfsUint64_NotFound(t *testing.T) {
func TestReadSysfsUint64_Bad_NotFound(t *testing.T) {
_, err := readSysfsUint64("/nonexistent/path")
assert.Error(t, err)
}
func TestReadSysfsUint64_InvalidContent(t *testing.T) {
func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad_value")
require.NoError(t, os.WriteFile(path, []byte("not-a-number\n"), 0644))
@ -35,7 +35,7 @@ func TestReadSysfsUint64_InvalidContent(t *testing.T) {
assert.Error(t, err)
}
func TestReadSysfsUint64_EmptyFile(t *testing.T) {
func TestReadSysfsUint64_Bad_EmptyFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "empty_value")
require.NoError(t, os.WriteFile(path, []byte(""), 0644))
@ -44,7 +44,7 @@ func TestReadSysfsUint64_EmptyFile(t *testing.T) {
assert.Error(t, err)
}
func TestGetVRAMInfo(t *testing.T) {
func TestGetVRAMInfo_Good(t *testing.T) {
info, err := GetVRAMInfo()
if err != nil {
t.Skipf("no VRAM sysfs info available: %v", err)