go-rocm/model.go
Claude 41b34b6779
feat(ax): apply RFC-025 AX compliance review
Principle 1 — Predictable Names:
- rocmModel.srv → rocmModel.server (struct field)
- recordMetrics: met → metrics (local var)
- backend.go/model.go: cfg → config (local vars)
- gguf.go: tc/kc → tensorCount32/kvCount32 (v2 count reads)

Principle 2 — Comments as Usage Examples:
- Added concrete usage examples to all exported functions:
  VRAMInfo, ModelInfo, DiscoverModels, GetVRAMInfo,
  ROCmAvailable, LoadModel, Available, NewClient, Health,
  ChatComplete, Complete, ReadMetadata, FileTypeName

Principle 5 — Test naming (_Good/_Bad/_Ugly):
- All test functions renamed to AX-7 convention across:
  discover_test.go, vram_test.go, server_test.go,
  internal/gguf/gguf_test.go, internal/llamacpp/client_test.go,
  internal/llamacpp/health_test.go

Also: fix go.sum missing entry for dappco.re/go/core transitive dep
(pulled in by go-inference replace directive).

All tests pass: go test ./... -short -count=1

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 07:33:47 +01:00

274 lines
7.1 KiB
Go

//go:build linux && amd64
package rocm
import (
"context"
"fmt"
"iter"
"strings"
"sync"
"time"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
)
// rocmModel implements inference.TextModel using a llama-server subprocess.
type rocmModel struct {
server *server
modelType string
modelInfo inference.ModelInfo
mu sync.Mutex
lastErr error
metrics inference.GenerateMetrics
}
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
m.mu.Lock()
m.lastErr = nil
m.mu.Unlock()
if !m.server.alive() {
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
config := inference.ApplyGenerateOpts(opts)
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopK: config.TopK,
TopP: config.TopP,
RepeatPenalty: config.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.server.client.Complete(ctx, req)
return func(yield func(inference.Token) bool) {
var count int
decodeStart := time.Now()
for text := range chunks {
count++
if !yield(inference.Token{Text: text}) {
break
}
}
if err := errFn(); err != nil {
m.mu.Lock()
m.lastErr = err
m.mu.Unlock()
}
m.recordMetrics(0, count, start, decodeStart)
}
}
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
m.mu.Lock()
m.lastErr = nil
m.mu.Unlock()
if !m.server.alive() {
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
config := inference.ApplyGenerateOpts(opts)
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
for i, msg := range messages {
chatMsgs[i] = llamacpp.ChatMessage{
Role: msg.Role,
Content: msg.Content,
}
}
req := llamacpp.ChatRequest{
Messages: chatMsgs,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopK: config.TopK,
TopP: config.TopP,
RepeatPenalty: config.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.server.client.ChatComplete(ctx, req)
return func(yield func(inference.Token) bool) {
var count int
decodeStart := time.Now()
for text := range chunks {
count++
if !yield(inference.Token{Text: text}) {
break
}
}
if err := errFn(); err != nil {
m.mu.Lock()
m.lastErr = err
m.mu.Unlock()
}
m.recordMetrics(0, count, start, decodeStart)
}
}
// Classify runs batched prefill-only inference via llama-server.
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
// llama-server has no native classify endpoint, so this simulates it.
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
if !m.server.alive() {
m.setServerExitErr()
return nil, m.Err()
}
start := time.Now()
results := make([]inference.ClassifyResult, len(prompts))
for i, prompt := range prompts {
if ctx.Err() != nil {
return nil, ctx.Err()
}
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: 1,
Temperature: 0,
}
chunks, errFn := m.server.client.Complete(ctx, req)
var text strings.Builder
for chunk := range chunks {
text.WriteString(chunk)
}
if err := errFn(); err != nil {
return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err)
}
results[i] = inference.ClassifyResult{
Token: inference.Token{Text: text.String()},
}
}
m.recordMetrics(len(prompts), len(prompts), start, start)
return results, nil
}
// BatchGenerate runs batched autoregressive generation via llama-server.
// Each prompt is decoded sequentially up to MaxTokens.
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
if !m.server.alive() {
m.setServerExitErr()
return nil, m.Err()
}
config := inference.ApplyGenerateOpts(opts)
start := time.Now()
results := make([]inference.BatchResult, len(prompts))
var totalGenerated int
for i, prompt := range prompts {
if ctx.Err() != nil {
results[i].Err = ctx.Err()
continue
}
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: config.MaxTokens,
Temperature: config.Temperature,
TopK: config.TopK,
TopP: config.TopP,
RepeatPenalty: config.RepeatPenalty,
}
chunks, errFn := m.server.client.Complete(ctx, req)
var tokens []inference.Token
for text := range chunks {
tokens = append(tokens, inference.Token{Text: text})
}
if err := errFn(); err != nil {
results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err)
}
results[i].Tokens = tokens
totalGenerated += len(tokens)
}
m.recordMetrics(len(prompts), totalGenerated, start, start)
return results, nil
}
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
func (m *rocmModel) ModelType() string { return m.modelType }
// Info returns metadata about the loaded model.
func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo }
// Metrics returns performance metrics from the last inference operation.
func (m *rocmModel) Metrics() inference.GenerateMetrics {
m.mu.Lock()
defer m.mu.Unlock()
return m.metrics
}
// Err returns the error from the last Generate/Chat call, if any.
func (m *rocmModel) Err() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastErr
}
// Close releases the llama-server subprocess and all associated resources.
func (m *rocmModel) Close() error {
return m.server.stop()
}
// setServerExitErr stores an appropriate error when the server is dead.
func (m *rocmModel) setServerExitErr() {
m.mu.Lock()
defer m.mu.Unlock()
if m.server.exitErr != nil {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.server.exitErr)
} else {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil)
}
}
// recordMetrics captures timing data from an inference operation.
func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, decodeStart time.Time) {
now := time.Now()
total := now.Sub(start)
decode := now.Sub(decodeStart)
prefill := total - decode
metrics := inference.GenerateMetrics{
PromptTokens: promptTokens,
GeneratedTokens: generatedTokens,
PrefillDuration: prefill,
DecodeDuration: decode,
TotalDuration: total,
}
if prefill > 0 && promptTokens > 0 {
metrics.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
}
if decode > 0 && generatedTokens > 0 {
metrics.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
}
// Try to get VRAM stats — best effort.
if vram, err := GetVRAMInfo(); err == nil {
metrics.PeakMemoryBytes = vram.Used
metrics.ActiveMemoryBytes = vram.Used
}
m.mu.Lock()
m.metrics = metrics
m.mu.Unlock()
}