From 41b34b6779f6bf921e0a2b5a852a0bc3de616951 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 31 Mar 2026 07:33:47 +0100 Subject: [PATCH] feat(ax): apply RFC-025 AX compliance review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Principle 1 — Predictable Names: - rocmModel.srv → rocmModel.server (struct field) - recordMetrics: met → metrics (local var) - backend.go/model.go: cfg → config (local vars) - gguf.go: tc/kc → tensorCount32/kvCount32 (v2 count reads) Principle 2 — Comments as Usage Examples: - Added concrete usage examples to all exported functions: VRAMInfo, ModelInfo, DiscoverModels, GetVRAMInfo, ROCmAvailable, LoadModel, Available, NewClient, Health, ChatComplete, Complete, ReadMetadata, FileTypeName Principle 5 — Test naming (_Good/_Bad/_Ugly): - All test functions renamed to AX-7 convention across: discover_test.go, vram_test.go, server_test.go, internal/gguf/gguf_test.go, internal/llamacpp/client_test.go, internal/llamacpp/health_test.go Also: fix go.sum missing entry for dappco.re/go/core transitive dep (pulled in by go-inference replace directive). All tests pass: go test ./... -short -count=1 Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 (1M context) --- backend.go | 17 ++++++-- discover.go | 9 +++- discover_test.go | 8 ++-- go.mod | 5 ++- go.sum | 2 + internal/gguf/gguf.go | 17 +++++--- internal/gguf/gguf_test.go | 26 ++++++------ internal/llamacpp/client.go | 19 +++++++-- internal/llamacpp/client_test.go | 12 +++--- internal/llamacpp/health.go | 10 +++++ internal/llamacpp/health_test.go | 8 ++-- model.go | 72 ++++++++++++++++---------------- register_rocm.go | 7 +++- rocm.go | 8 ++++ rocm_stub.go | 6 ++- server_test.go | 31 +++++++------- vram.go | 4 ++ vram_test.go | 10 ++--- 18 files changed, 169 insertions(+), 102 deletions(-) diff --git a/backend.go b/backend.go index 8f1bdcf..dde011a 100644 --- a/backend.go +++ b/backend.go @@ -18,6 +18,9 @@ func (b *rocmBackend) Name() string { return "rocm" } // Available reports whether ROCm GPU inference can run on this machine. // Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary. +// +// b := inference.FindBackend("rocm") +// if b.Available() { /* safe to LoadModel */ } func (b *rocmBackend) Available() bool { if _, err := os.Stat("/dev/kfd"); err != nil { return false @@ -32,8 +35,14 @@ func (b *rocmBackend) Available() bool { // Model architecture is read from GGUF metadata (replacing filename-based guessing). // If no context length is specified, defaults to min(model_context_length, 4096) // to prevent VRAM exhaustion on models with 128K+ native context. +// +// m, err := backend.LoadModel("/data/lem/gguf/model.gguf", +// inference.WithContextLen(4096), +// inference.WithGPULayers(-1), +// ) +// defer m.Close() func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) { - cfg := inference.ApplyLoadOpts(opts) + config := inference.ApplyLoadOpts(opts) binary, err := findLlamaServer() if err != nil { @@ -45,12 +54,12 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe return nil, coreerr.E("rocm.LoadModel", "read model metadata", err) } - ctxLen := cfg.ContextLen + ctxLen := config.ContextLen if ctxLen == 0 && meta.ContextLength > 0 { ctxLen = int(min(meta.ContextLength, 4096)) } - srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots) + srv, err := startServer(binary, path, config.GPULayers, ctxLen, config.ParallelSlots) if err != nil { return nil, err } @@ -85,7 +94,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe } return &rocmModel{ - srv: srv, + server: srv, modelType: meta.Architecture, modelInfo: inference.ModelInfo{ Architecture: meta.Architecture, diff --git a/discover.go b/discover.go index 0d298ca..5b23482 100644 --- a/discover.go +++ b/discover.go @@ -6,8 +6,13 @@ import ( "forge.lthn.ai/core/go-rocm/internal/gguf" ) -// DiscoverModels scans a directory for GGUF model files and returns -// structured information about each. Files that cannot be parsed are skipped. +// DiscoverModels scans a directory for GGUF model files. +// Files that cannot be parsed are silently skipped. +// +// models, err := rocm.DiscoverModels("/data/lem/gguf") +// for _, m := range models { +// fmt.Printf("%s: %s %s ctx=%d\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen) +// } func DiscoverModels(dir string) ([]ModelInfo, error) { matches, err := filepath.Glob(filepath.Join(dir, "*.gguf")) if err != nil { diff --git a/discover_test.go b/discover_test.go index 9a6ce1a..1f74a3c 100644 --- a/discover_test.go +++ b/discover_test.go @@ -63,7 +63,7 @@ func writeDiscoverKV(t *testing.T, f *os.File, key string, val any) { } } -func TestDiscoverModels(t *testing.T) { +func TestDiscoverModels_Good(t *testing.T) { dir := t.TempDir() // Create two valid GGUF model files. @@ -112,7 +112,7 @@ func TestDiscoverModels(t *testing.T) { assert.Greater(t, llama.FileSize, int64(0)) } -func TestDiscoverModels_EmptyDir(t *testing.T) { +func TestDiscoverModels_Good_EmptyDir(t *testing.T) { dir := t.TempDir() models, err := DiscoverModels(dir) @@ -120,7 +120,7 @@ func TestDiscoverModels_EmptyDir(t *testing.T) { assert.Empty(t, models) } -func TestDiscoverModels_NotFound(t *testing.T) { +func TestDiscoverModels_Good_NonExistentDir(t *testing.T) { // filepath.Glob returns nil, nil for a pattern matching no files, // even when the directory does not exist. models, err := DiscoverModels("/nonexistent/dir") @@ -128,7 +128,7 @@ func TestDiscoverModels_NotFound(t *testing.T) { assert.Empty(t, models) } -func TestDiscoverModels_SkipsCorruptFile(t *testing.T) { +func TestDiscoverModels_Ugly_SkipsCorruptFile(t *testing.T) { dir := t.TempDir() // Create a valid GGUF file. diff --git a/go.mod b/go.mod index 12d9d63..027a377 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,10 @@ require ( forge.lthn.ai/core/go-log v0.0.4 ) -require github.com/kr/text v0.2.0 // indirect +require ( + dappco.re/go/core v0.8.0-alpha.1 // indirect + github.com/kr/text v0.2.0 // indirect +) require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/go.sum b/go.sum index f55559e..1e66859 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk= +dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A= forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0= forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= diff --git a/internal/gguf/gguf.go b/internal/gguf/gguf.go index 28a290e..9018dbe 100644 --- a/internal/gguf/gguf.go +++ b/internal/gguf/gguf.go @@ -72,6 +72,10 @@ var fileTypeNames = map[uint32]string{ // FileTypeName returns a human-readable name for a GGML quantisation file type. // Unknown types return "type_N" where N is the numeric value. +// +// gguf.FileTypeName(15) // "Q4_K_M" +// gguf.FileTypeName(7) // "Q8_0" +// gguf.FileTypeName(1) // "F16" func FileTypeName(ft uint32) string { if name, ok := fileTypeNames[ft]; ok { return name @@ -81,6 +85,9 @@ func FileTypeName(ft uint32) string { // ReadMetadata reads the GGUF header from the file at path and returns the // extracted metadata. Only metadata KV pairs are read; tensor data is not loaded. +// +// meta, err := gguf.ReadMetadata("/data/model.gguf") +// fmt.Printf("arch=%s ctx=%d quant=%s", meta.Architecture, meta.ContextLength, gguf.FileTypeName(meta.FileType)) func ReadMetadata(path string) (Metadata, error) { f, err := os.Open(path) if err != nil { @@ -123,15 +130,15 @@ func ReadMetadata(path string) (Metadata, error) { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } } else { - var tc, kc uint32 - if err := binary.Read(r, binary.LittleEndian, &tc); err != nil { + var tensorCount32, kvCount32 uint32 + if err := binary.Read(r, binary.LittleEndian, &tensorCount32); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err) } - if err := binary.Read(r, binary.LittleEndian, &kc); err != nil { + if err := binary.Read(r, binary.LittleEndian, &kvCount32); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } - tensorCount = uint64(tc) - kvCount = uint64(kc) + tensorCount = uint64(tensorCount32) + kvCount = uint64(kvCount32) } _ = tensorCount // we only read metadata KVs diff --git a/internal/gguf/gguf_test.go b/internal/gguf/gguf_test.go index 5afb7b0..ecc136b 100644 --- a/internal/gguf/gguf_test.go +++ b/internal/gguf/gguf_test.go @@ -109,7 +109,7 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string { return path } -func TestReadMetadata_Gemma3(t *testing.T) { +func TestReadMetadata_Good_Gemma3(t *testing.T) { path := writeTestGGUFOrdered(t, [][2]any{ {"general.architecture", "gemma3"}, {"general.name", "Test Gemma3 1B"}, @@ -131,7 +131,7 @@ func TestReadMetadata_Gemma3(t *testing.T) { assert.Greater(t, m.FileSize, int64(0)) } -func TestReadMetadata_Llama(t *testing.T) { +func TestReadMetadata_Good_Llama(t *testing.T) { path := writeTestGGUFOrdered(t, [][2]any{ {"general.architecture", "llama"}, {"general.name", "Test Llama 8B"}, @@ -153,7 +153,7 @@ func TestReadMetadata_Llama(t *testing.T) { assert.Greater(t, m.FileSize, int64(0)) } -func TestReadMetadata_ArchAfterContextLength(t *testing.T) { +func TestReadMetadata_Ugly_ArchAfterContextLength(t *testing.T) { // Architecture key comes AFTER the arch-specific keys. // The parser must handle deferred resolution of arch-prefixed keys. path := writeTestGGUFOrdered(t, [][2]any{ @@ -174,7 +174,7 @@ func TestReadMetadata_ArchAfterContextLength(t *testing.T) { assert.Equal(t, uint32(32), m.BlockCount) } -func TestReadMetadata_InvalidMagic(t *testing.T) { +func TestReadMetadata_Bad_InvalidMagic(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "notgguf.bin") @@ -186,12 +186,12 @@ func TestReadMetadata_InvalidMagic(t *testing.T) { assert.Contains(t, err.Error(), "invalid magic") } -func TestReadMetadata_FileNotFound(t *testing.T) { +func TestReadMetadata_Bad_FileNotFound(t *testing.T) { _, err := ReadMetadata("/nonexistent/path/model.gguf") require.Error(t, err) } -func TestFileTypeName(t *testing.T) { +func TestFileTypeName_Good(t *testing.T) { assert.Equal(t, "Q4_K_M", FileTypeName(15)) assert.Equal(t, "Q5_K_M", FileTypeName(17)) assert.Equal(t, "Q8_0", FileTypeName(7)) @@ -199,7 +199,7 @@ func TestFileTypeName(t *testing.T) { assert.Equal(t, "type_999", FileTypeName(999)) } -func TestReadMetadata_V2(t *testing.T) { +func TestReadMetadata_Good_V2(t *testing.T) { // GGUF v2 uses uint32 for tensor and KV counts (instead of uint64 in v3). path := writeTestGGUFV2(t, [][2]any{ {"general.architecture", "llama"}, @@ -219,7 +219,7 @@ func TestReadMetadata_V2(t *testing.T) { assert.Equal(t, uint32(16), m.BlockCount) } -func TestReadMetadata_UnsupportedVersion(t *testing.T) { +func TestReadMetadata_Bad_UnsupportedVersion(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "bad_version.gguf") @@ -235,7 +235,7 @@ func TestReadMetadata_UnsupportedVersion(t *testing.T) { assert.Contains(t, err.Error(), "unsupported GGUF version") } -func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) { +func TestReadMetadata_Ugly_SkipsUnknownValueTypes(t *testing.T) { // Tests skipValue for uint8, int16, float32, uint64, bool, and array types. // These are stored under uninteresting keys so ReadMetadata skips them. dir := t.TempDir() @@ -283,7 +283,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) { b8 := make([]byte, 8) binary.LittleEndian.PutUint64(b8, 3) // count: 3 arrBuf = append(arrBuf, b8...) - arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values + arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values writeRawKV(t, f, "custom.array_val", 9, arrBuf) // 7-8. Interesting keys to verify parsing continued correctly. @@ -299,7 +299,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) { assert.Equal(t, "Skip Test Model", m.Name) } -func TestReadMetadata_Uint64ContextLength(t *testing.T) { +func TestReadMetadata_Ugly_Uint64ContextLength(t *testing.T) { // context_length stored as uint64 that fits in uint32 — readTypedValue // should downcast it to uint32. path := writeTestGGUFOrdered(t, [][2]any{ @@ -315,7 +315,7 @@ func TestReadMetadata_Uint64ContextLength(t *testing.T) { assert.Equal(t, uint32(32), m.BlockCount) } -func TestReadMetadata_TruncatedFile(t *testing.T) { +func TestReadMetadata_Bad_TruncatedFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "truncated.gguf") @@ -330,7 +330,7 @@ func TestReadMetadata_TruncatedFile(t *testing.T) { assert.Contains(t, err.Error(), "reading version") } -func TestReadMetadata_SkipsStringValue(t *testing.T) { +func TestReadMetadata_Ugly_SkipsStringValue(t *testing.T) { // Tests skipValue for string type (type 8) on an uninteresting key. dir := t.TempDir() path := filepath.Join(dir, "skip_string.gguf") diff --git a/internal/llamacpp/client.go b/internal/llamacpp/client.go index 2fb0c11..d91970c 100644 --- a/internal/llamacpp/client.go +++ b/internal/llamacpp/client.go @@ -60,8 +60,14 @@ type completionChunkResponse struct { } // ChatComplete sends a streaming chat completion request to /v1/chat/completions. -// It returns an iterator over text chunks and a function that returns any error -// that occurred during the request or while reading the stream. +// Returns an iterator over text chunks and an error accessor called after ranging. +// +// chunks, errFn := client.ChatComplete(ctx, llamacpp.ChatRequest{ +// Messages: []llamacpp.ChatMessage{{Role: "user", Content: "Hello"}}, +// MaxTokens: 128, Temperature: 0.7, +// }) +// for text := range chunks { fmt.Print(text) } +// if err := errFn(); err != nil { /* handle */ } func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) { req.Stream = true @@ -125,8 +131,13 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st } // Complete sends a streaming completion request to /v1/completions. -// It returns an iterator over text chunks and a function that returns any error -// that occurred during the request or while reading the stream. +// Returns an iterator over text chunks and an error accessor called after ranging. +// +// chunks, errFn := client.Complete(ctx, llamacpp.CompletionRequest{ +// Prompt: "The capital of France is", MaxTokens: 16, Temperature: 0.0, +// }) +// for text := range chunks { fmt.Print(text) } +// if err := errFn(); err != nil { /* handle */ } func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) { req.Stream = true diff --git a/internal/llamacpp/client_test.go b/internal/llamacpp/client_test.go index 20b49c7..5e85b01 100644 --- a/internal/llamacpp/client_test.go +++ b/internal/llamacpp/client_test.go @@ -27,7 +27,7 @@ func sseLines(w http.ResponseWriter, lines []string) { } } -func TestChatComplete_Streaming(t *testing.T) { +func TestChatComplete_Good_Streaming(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v1/chat/completions", r.URL.Path) assert.Equal(t, "POST", r.Method) @@ -56,7 +56,7 @@ func TestChatComplete_Streaming(t *testing.T) { assert.Equal(t, []string{"Hello", " world"}, got) } -func TestChatComplete_EmptyResponse(t *testing.T) { +func TestChatComplete_Good_EmptyResponse(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sseLines(w, []string{"[DONE]"}) })) @@ -77,7 +77,7 @@ func TestChatComplete_EmptyResponse(t *testing.T) { assert.Empty(t, got) } -func TestChatComplete_HTTPError(t *testing.T) { +func TestChatComplete_Bad_HTTPError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal server error", http.StatusInternalServerError) })) @@ -100,7 +100,7 @@ func TestChatComplete_HTTPError(t *testing.T) { assert.Contains(t, err.Error(), "500") } -func TestChatComplete_ContextCancelled(t *testing.T) { +func TestChatComplete_Ugly_ContextCancelled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -140,7 +140,7 @@ func TestChatComplete_ContextCancelled(t *testing.T) { assert.Equal(t, []string{"Hello"}, got) } -func TestComplete_Streaming(t *testing.T) { +func TestComplete_Good_Streaming(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v1/completions", r.URL.Path) assert.Equal(t, "POST", r.Method) @@ -170,7 +170,7 @@ func TestComplete_Streaming(t *testing.T) { assert.Equal(t, []string{"Once", " upon", " a time"}, got) } -func TestComplete_HTTPError(t *testing.T) { +func TestComplete_Bad_HTTPError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad request", http.StatusBadRequest) })) diff --git a/internal/llamacpp/health.go b/internal/llamacpp/health.go index 33ec57b..ecd422c 100644 --- a/internal/llamacpp/health.go +++ b/internal/llamacpp/health.go @@ -18,6 +18,11 @@ type Client struct { } // NewClient creates a client for the llama-server at the given base URL. +// +// client := llamacpp.NewClient("http://127.0.0.1:8080") +// if err := client.Health(ctx); err != nil { +// // server not ready +// } func NewClient(baseURL string) *Client { return &Client{ baseURL: strings.TrimRight(baseURL, "/"), @@ -30,6 +35,11 @@ type healthResponse struct { } // Health checks whether the llama-server is ready to accept requests. +// Returns nil when the server responds with {"status":"ok"}. +// +// if err := client.Health(ctx); err != nil { +// log.Printf("server not ready: %v", err) +// } func (c *Client) Health(ctx context.Context) error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil) if err != nil { diff --git a/internal/llamacpp/health_test.go b/internal/llamacpp/health_test.go index 38affcf..bb42e56 100644 --- a/internal/llamacpp/health_test.go +++ b/internal/llamacpp/health_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestHealth_OK(t *testing.T) { +func TestHealth_Good(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/health", r.URL.Path) w.Header().Set("Content-Type", "application/json") @@ -23,7 +23,7 @@ func TestHealth_OK(t *testing.T) { require.NoError(t, err) } -func TestHealth_NotReady(t *testing.T) { +func TestHealth_Bad_NotReady(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"status":"loading model"}`)) @@ -35,7 +35,7 @@ func TestHealth_NotReady(t *testing.T) { assert.ErrorContains(t, err, "not ready") } -func TestHealth_Loading(t *testing.T) { +func TestHealth_Bad_Loading(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusServiceUnavailable) @@ -48,7 +48,7 @@ func TestHealth_Loading(t *testing.T) { assert.ErrorContains(t, err, "503") } -func TestHealth_ServerDown(t *testing.T) { +func TestHealth_Bad_ServerDown(t *testing.T) { c := NewClient("http://127.0.0.1:1") // nothing listening err := c.Health(context.Background()) assert.Error(t, err) diff --git a/model.go b/model.go index 5ae5c47..98643d1 100644 --- a/model.go +++ b/model.go @@ -17,7 +17,7 @@ import ( // rocmModel implements inference.TextModel using a llama-server subprocess. type rocmModel struct { - srv *server + server *server modelType string modelInfo inference.ModelInfo @@ -32,24 +32,24 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen m.lastErr = nil m.mu.Unlock() - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return func(yield func(inference.Token) bool) {} } - cfg := inference.ApplyGenerateOpts(opts) + config := inference.ApplyGenerateOpts(opts) req := llamacpp.CompletionRequest{ Prompt: prompt, - MaxTokens: cfg.MaxTokens, - Temperature: cfg.Temperature, - TopK: cfg.TopK, - TopP: cfg.TopP, - RepeatPenalty: cfg.RepeatPenalty, + MaxTokens: config.MaxTokens, + Temperature: config.Temperature, + TopK: config.TopK, + TopP: config.TopP, + RepeatPenalty: config.RepeatPenalty, } start := time.Now() - chunks, errFn := m.srv.client.Complete(ctx, req) + chunks, errFn := m.server.client.Complete(ctx, req) return func(yield func(inference.Token) bool) { var count int @@ -75,12 +75,12 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts m.lastErr = nil m.mu.Unlock() - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return func(yield func(inference.Token) bool) {} } - cfg := inference.ApplyGenerateOpts(opts) + config := inference.ApplyGenerateOpts(opts) chatMsgs := make([]llamacpp.ChatMessage, len(messages)) for i, msg := range messages { @@ -92,15 +92,15 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts req := llamacpp.ChatRequest{ Messages: chatMsgs, - MaxTokens: cfg.MaxTokens, - Temperature: cfg.Temperature, - TopK: cfg.TopK, - TopP: cfg.TopP, - RepeatPenalty: cfg.RepeatPenalty, + MaxTokens: config.MaxTokens, + Temperature: config.Temperature, + TopK: config.TopK, + TopP: config.TopP, + RepeatPenalty: config.RepeatPenalty, } start := time.Now() - chunks, errFn := m.srv.client.ChatComplete(ctx, req) + chunks, errFn := m.server.client.ChatComplete(ctx, req) return func(yield func(inference.Token) bool) { var count int @@ -124,7 +124,7 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts // Each prompt gets a single-token completion (max_tokens=1, temperature=0). // llama-server has no native classify endpoint, so this simulates it. func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) { - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return nil, m.Err() } @@ -143,7 +143,7 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe Temperature: 0, } - chunks, errFn := m.srv.client.Complete(ctx, req) + chunks, errFn := m.server.client.Complete(ctx, req) var text strings.Builder for chunk := range chunks { text.WriteString(chunk) @@ -164,12 +164,12 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe // BatchGenerate runs batched autoregressive generation via llama-server. // Each prompt is decoded sequentially up to MaxTokens. func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) { - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return nil, m.Err() } - cfg := inference.ApplyGenerateOpts(opts) + config := inference.ApplyGenerateOpts(opts) start := time.Now() results := make([]inference.BatchResult, len(prompts)) var totalGenerated int @@ -182,14 +182,14 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts .. req := llamacpp.CompletionRequest{ Prompt: prompt, - MaxTokens: cfg.MaxTokens, - Temperature: cfg.Temperature, - TopK: cfg.TopK, - TopP: cfg.TopP, - RepeatPenalty: cfg.RepeatPenalty, + MaxTokens: config.MaxTokens, + Temperature: config.Temperature, + TopK: config.TopK, + TopP: config.TopP, + RepeatPenalty: config.RepeatPenalty, } - chunks, errFn := m.srv.client.Complete(ctx, req) + chunks, errFn := m.server.client.Complete(ctx, req) var tokens []inference.Token for text := range chunks { tokens = append(tokens, inference.Token{Text: text}) @@ -227,15 +227,15 @@ func (m *rocmModel) Err() error { // Close releases the llama-server subprocess and all associated resources. func (m *rocmModel) Close() error { - return m.srv.stop() + return m.server.stop() } // setServerExitErr stores an appropriate error when the server is dead. func (m *rocmModel) setServerExitErr() { m.mu.Lock() defer m.mu.Unlock() - if m.srv.exitErr != nil { - m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr) + if m.server.exitErr != nil { + m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.server.exitErr) } else { m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil) } @@ -248,7 +248,7 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco decode := now.Sub(decodeStart) prefill := total - decode - met := inference.GenerateMetrics{ + metrics := inference.GenerateMetrics{ PromptTokens: promptTokens, GeneratedTokens: generatedTokens, PrefillDuration: prefill, @@ -256,19 +256,19 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco TotalDuration: total, } if prefill > 0 && promptTokens > 0 { - met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds() + metrics.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds() } if decode > 0 && generatedTokens > 0 { - met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds() + metrics.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds() } // Try to get VRAM stats — best effort. if vram, err := GetVRAMInfo(); err == nil { - met.PeakMemoryBytes = vram.Used - met.ActiveMemoryBytes = vram.Used + metrics.PeakMemoryBytes = vram.Used + metrics.ActiveMemoryBytes = vram.Used } m.mu.Lock() - m.metrics = met + m.metrics = metrics m.mu.Unlock() } diff --git a/register_rocm.go b/register_rocm.go index f7b6e3b..24ed8c4 100644 --- a/register_rocm.go +++ b/register_rocm.go @@ -8,5 +8,10 @@ func init() { inference.Register(&rocmBackend{}) } -// ROCmAvailable reports whether ROCm GPU inference is available. +// ROCmAvailable reports whether ROCm GPU inference is available on this machine. +// Returns true only on linux/amd64 with /dev/kfd present and llama-server findable. +// +// if rocm.ROCmAvailable() { +// m, err := inference.LoadModel("/data/model.gguf") +// } func ROCmAvailable() bool { return true } diff --git a/rocm.go b/rocm.go index bea7178..4518ecd 100644 --- a/rocm.go +++ b/rocm.go @@ -25,6 +25,9 @@ package rocm // VRAMInfo reports GPU video memory usage in bytes. +// +// info, err := rocm.GetVRAMInfo() +// fmt.Printf("VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024)) type VRAMInfo struct { Total uint64 Used uint64 @@ -32,6 +35,11 @@ type VRAMInfo struct { } // ModelInfo describes a GGUF model file discovered on disk. +// +// models, _ := rocm.DiscoverModels("/data/lem/gguf") +// for _, m := range models { +// fmt.Printf("%s (%s %s, ctx=%d)\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen) +// } type ModelInfo struct { Path string // full path to .gguf file Architecture string // GGUF architecture (e.g. "gemma3", "llama", "qwen2") diff --git a/rocm_stub.go b/rocm_stub.go index 0947fe7..87ec109 100644 --- a/rocm_stub.go +++ b/rocm_stub.go @@ -4,8 +4,12 @@ package rocm import coreerr "forge.lthn.ai/core/go-log" -// ROCmAvailable reports whether ROCm GPU inference is available. +// ROCmAvailable reports whether ROCm GPU inference is available on this machine. // Returns false on non-Linux or non-amd64 platforms. +// +// if rocm.ROCmAvailable() { +// m, err := inference.LoadModel("/data/model.gguf") +// } func ROCmAvailable() bool { return false } // GetVRAMInfo is not available on non-Linux/non-amd64 platforms. diff --git a/server_test.go b/server_test.go index 0ecbc57..9d4e527 100644 --- a/server_test.go +++ b/server_test.go @@ -14,34 +14,34 @@ import ( "github.com/stretchr/testify/require" ) -func TestFindLlamaServer_InPATH(t *testing.T) { +func TestFindLlamaServer_Good_InPATH(t *testing.T) { // llama-server is at /usr/local/bin/llama-server on this machine. path, err := findLlamaServer() require.NoError(t, err) assert.Contains(t, path, "llama-server") } -func TestFindLlamaServer_EnvOverride(t *testing.T) { +func TestFindLlamaServer_Good_EnvOverride(t *testing.T) { t.Setenv("ROCM_LLAMA_SERVER_PATH", "/usr/local/bin/llama-server") path, err := findLlamaServer() require.NoError(t, err) assert.Equal(t, "/usr/local/bin/llama-server", path) } -func TestFindLlamaServer_EnvNotFound(t *testing.T) { +func TestFindLlamaServer_Bad_EnvPathMissing(t *testing.T) { t.Setenv("ROCM_LLAMA_SERVER_PATH", "/nonexistent/llama-server") _, err := findLlamaServer() assert.ErrorContains(t, err, "not found") } -func TestFreePort(t *testing.T) { +func TestFreePort_Good(t *testing.T) { port, err := freePort() require.NoError(t, err) assert.Greater(t, port, 0) assert.Less(t, port, 65536) } -func TestFreePort_UniquePerCall(t *testing.T) { +func TestFreePort_Good_UniquePerCall(t *testing.T) { p1, err := freePort() require.NoError(t, err) p2, err := freePort() @@ -50,7 +50,7 @@ func TestFreePort_UniquePerCall(t *testing.T) { _ = p2 } -func TestServerEnv_HIPVisibleDevices(t *testing.T) { +func TestServerEnv_Good_SetsHIPVisibleDevices(t *testing.T) { env := serverEnv() var hipVals []string for _, e := range env { @@ -61,7 +61,7 @@ func TestServerEnv_HIPVisibleDevices(t *testing.T) { assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals) } -func TestServerEnv_FiltersExistingHIP(t *testing.T) { +func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) { t.Setenv("HIP_VISIBLE_DEVICES", "1") env := serverEnv() var hipVals []string @@ -73,7 +73,7 @@ func TestServerEnv_FiltersExistingHIP(t *testing.T) { assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals) } -func TestAvailable(t *testing.T) { +func TestAvailable_Good(t *testing.T) { b := &rocmBackend{} if _, err := os.Stat("/dev/kfd"); err != nil { t.Skip("no ROCm hardware") @@ -81,27 +81,26 @@ func TestAvailable(t *testing.T) { assert.True(t, b.Available()) } - -func TestServerAlive_Running(t *testing.T) { +func TestServerAlive_Good_Running(t *testing.T) { s := &server{exited: make(chan struct{})} assert.True(t, s.alive()) } -func TestServerAlive_Exited(t *testing.T) { +func TestServerAlive_Good_Exited(t *testing.T) { exited := make(chan struct{}) close(exited) s := &server{exited: exited, exitErr: coreerr.E("test", "process killed", nil)} assert.False(t, s.alive()) } -func TestGenerate_ServerDead(t *testing.T) { +func TestGenerate_Bad_ServerDead(t *testing.T) { exited := make(chan struct{}) close(exited) s := &server{ exited: exited, exitErr: coreerr.E("test", "process killed", nil), } - m := &rocmModel{srv: s} + m := &rocmModel{server: s} var count int for range m.Generate(context.Background(), "hello") { @@ -111,7 +110,7 @@ func TestGenerate_ServerDead(t *testing.T) { assert.ErrorContains(t, m.Err(), "server has exited") } -func TestStartServer_RetriesOnProcessExit(t *testing.T) { +func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) { // /bin/false starts successfully but exits immediately with code 1. // startServer should retry up to 3 times, then fail. _, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0, 0) @@ -119,14 +118,14 @@ func TestStartServer_RetriesOnProcessExit(t *testing.T) { assert.Contains(t, err.Error(), "failed after 3 attempts") } -func TestChat_ServerDead(t *testing.T) { +func TestChat_Bad_ServerDead(t *testing.T) { exited := make(chan struct{}) close(exited) s := &server{ exited: exited, exitErr: coreerr.E("test", "process killed", nil), } - m := &rocmModel{srv: s} + m := &rocmModel{server: s} msgs := []inference.Message{{Role: "user", Content: "hello"}} var count int diff --git a/vram.go b/vram.go index 9f6d1da..99a5c5d 100644 --- a/vram.go +++ b/vram.go @@ -17,6 +17,10 @@ import ( // // Note: total and used are read non-atomically from sysfs; transient // inconsistencies are possible under heavy allocation churn. +// +// info, err := rocm.GetVRAMInfo() +// fmt.Printf("VRAM: %d MiB used / %d MiB total (free: %d MiB)", +// info.Used/(1024*1024), info.Total/(1024*1024), info.Free/(1024*1024)) func GetVRAMInfo() (VRAMInfo, error) { cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total") if err != nil { diff --git a/vram_test.go b/vram_test.go index 36c8eca..325dfbc 100644 --- a/vram_test.go +++ b/vram_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestReadSysfsUint64(t *testing.T) { +func TestReadSysfsUint64_Good(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test_value") require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644)) @@ -21,12 +21,12 @@ func TestReadSysfsUint64(t *testing.T) { assert.Equal(t, uint64(17163091968), val) } -func TestReadSysfsUint64_NotFound(t *testing.T) { +func TestReadSysfsUint64_Bad_NotFound(t *testing.T) { _, err := readSysfsUint64("/nonexistent/path") assert.Error(t, err) } -func TestReadSysfsUint64_InvalidContent(t *testing.T) { +func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "bad_value") require.NoError(t, os.WriteFile(path, []byte("not-a-number\n"), 0644)) @@ -35,7 +35,7 @@ func TestReadSysfsUint64_InvalidContent(t *testing.T) { assert.Error(t, err) } -func TestReadSysfsUint64_EmptyFile(t *testing.T) { +func TestReadSysfsUint64_Bad_EmptyFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "empty_value") require.NoError(t, os.WriteFile(path, []byte(""), 0644)) @@ -44,7 +44,7 @@ func TestReadSysfsUint64_EmptyFile(t *testing.T) { assert.Error(t, err) } -func TestGetVRAMInfo(t *testing.T) { +func TestGetVRAMInfo_Good(t *testing.T) { info, err := GetVRAMInfo() if err != nil { t.Skipf("no VRAM sysfs info available: %v", err)