diff --git a/backend.go b/backend.go index 27b354a..12b3fad 100644 --- a/backend.go +++ b/backend.go @@ -5,6 +5,7 @@ package rocm import ( "fmt" "os" + "strings" "forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-rocm/internal/gguf" @@ -54,8 +55,43 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe return nil, err } + // Map quantisation file type to bit width. + quantBits := 0 + quantGroup := 0 + ftName := gguf.FileTypeName(meta.FileType) + switch { + case strings.HasPrefix(ftName, "Q4_"): + quantBits = 4 + quantGroup = 32 + case strings.HasPrefix(ftName, "Q5_"): + quantBits = 5 + quantGroup = 32 + case strings.HasPrefix(ftName, "Q8_"): + quantBits = 8 + quantGroup = 32 + case strings.HasPrefix(ftName, "Q2_"): + quantBits = 2 + quantGroup = 16 + case strings.HasPrefix(ftName, "Q3_"): + quantBits = 3 + quantGroup = 32 + case strings.HasPrefix(ftName, "Q6_"): + quantBits = 6 + quantGroup = 64 + case ftName == "F16": + quantBits = 16 + case ftName == "F32": + quantBits = 32 + } + return &rocmModel{ srv: srv, modelType: meta.Architecture, + modelInfo: inference.ModelInfo{ + Architecture: meta.Architecture, + NumLayers: int(meta.BlockCount), + QuantBits: quantBits, + QuantGroup: quantGroup, + }, }, nil } diff --git a/go.sum b/go.sum index 309ca2e..072f3cd 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,11 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= diff --git a/model.go b/model.go index 5fb1dfb..d01e08f 100644 --- a/model.go +++ b/model.go @@ -6,7 +6,9 @@ import ( "context" "fmt" "iter" + "strings" "sync" + "time" "forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-rocm/internal/llamacpp" @@ -16,9 +18,11 @@ import ( type rocmModel struct { srv *server modelType string + modelInfo inference.ModelInfo mu sync.Mutex lastErr error + metrics inference.GenerateMetrics } // Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint. @@ -28,13 +32,7 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen m.mu.Unlock() if !m.srv.alive() { - m.mu.Lock() - if m.srv.exitErr != nil { - m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr) - } else { - m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly") - } - m.mu.Unlock() + m.setServerExitErr() return func(yield func(inference.Token) bool) {} } @@ -49,10 +47,14 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen RepeatPenalty: cfg.RepeatPenalty, } + start := time.Now() chunks, errFn := m.srv.client.Complete(ctx, req) return func(yield func(inference.Token) bool) { + var count int + decodeStart := time.Now() for text := range chunks { + count++ if !yield(inference.Token{Text: text}) { break } @@ -62,6 +64,7 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen m.lastErr = err m.mu.Unlock() } + m.recordMetrics(0, count, start, decodeStart) } } @@ -72,13 +75,7 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts m.mu.Unlock() if !m.srv.alive() { - m.mu.Lock() - if m.srv.exitErr != nil { - m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr) - } else { - m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly") - } - m.mu.Unlock() + m.setServerExitErr() return func(yield func(inference.Token) bool) {} } @@ -101,10 +98,14 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts RepeatPenalty: cfg.RepeatPenalty, } + start := time.Now() chunks, errFn := m.srv.client.ChatComplete(ctx, req) return func(yield func(inference.Token) bool) { + var count int + decodeStart := time.Now() for text := range chunks { + count++ if !yield(inference.Token{Text: text}) { break } @@ -114,12 +115,108 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts m.lastErr = err m.mu.Unlock() } + m.recordMetrics(0, count, start, decodeStart) } } +// Classify runs batched prefill-only inference via llama-server. +// Each prompt gets a single-token completion (max_tokens=1, temperature=0). +// llama-server has no native classify endpoint, so this simulates it. +func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + if !m.srv.alive() { + m.setServerExitErr() + return nil, m.Err() + } + + start := time.Now() + results := make([]inference.ClassifyResult, len(prompts)) + + for i, prompt := range prompts { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + req := llamacpp.CompletionRequest{ + Prompt: prompt, + MaxTokens: 1, + Temperature: 0, + } + + chunks, errFn := m.srv.client.Complete(ctx, req) + var text strings.Builder + for chunk := range chunks { + text.WriteString(chunk) + } + if err := errFn(); err != nil { + return nil, fmt.Errorf("rocm: classify prompt %d: %w", i, err) + } + + results[i] = inference.ClassifyResult{ + Token: inference.Token{Text: text.String()}, + } + } + + m.recordMetrics(len(prompts), len(prompts), start, start) + return results, nil +} + +// BatchGenerate runs batched autoregressive generation via llama-server. +// Each prompt is decoded sequentially up to MaxTokens. +func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) { + if !m.srv.alive() { + m.setServerExitErr() + return nil, m.Err() + } + + cfg := inference.ApplyGenerateOpts(opts) + start := time.Now() + results := make([]inference.BatchResult, len(prompts)) + var totalGenerated int + + for i, prompt := range prompts { + if ctx.Err() != nil { + results[i].Err = ctx.Err() + continue + } + + req := llamacpp.CompletionRequest{ + Prompt: prompt, + MaxTokens: cfg.MaxTokens, + Temperature: cfg.Temperature, + TopK: cfg.TopK, + TopP: cfg.TopP, + RepeatPenalty: cfg.RepeatPenalty, + } + + chunks, errFn := m.srv.client.Complete(ctx, req) + var tokens []inference.Token + for text := range chunks { + tokens = append(tokens, inference.Token{Text: text}) + } + if err := errFn(); err != nil { + results[i].Err = fmt.Errorf("rocm: batch prompt %d: %w", i, err) + } + results[i].Tokens = tokens + totalGenerated += len(tokens) + } + + m.recordMetrics(len(prompts), totalGenerated, start, start) + return results, nil +} + // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3"). func (m *rocmModel) ModelType() string { return m.modelType } +// Info returns metadata about the loaded model. +func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo } + +// Metrics returns performance metrics from the last inference operation. +func (m *rocmModel) Metrics() inference.GenerateMetrics { + m.mu.Lock() + defer m.mu.Unlock() + return m.metrics +} + // Err returns the error from the last Generate/Chat call, if any. func (m *rocmModel) Err() error { m.mu.Lock() @@ -131,3 +228,46 @@ func (m *rocmModel) Err() error { func (m *rocmModel) Close() error { return m.srv.stop() } + +// setServerExitErr stores an appropriate error when the server is dead. +func (m *rocmModel) setServerExitErr() { + m.mu.Lock() + defer m.mu.Unlock() + if m.srv.exitErr != nil { + m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr) + } else { + m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly") + } +} + +// recordMetrics captures timing data from an inference operation. +func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, decodeStart time.Time) { + now := time.Now() + total := now.Sub(start) + decode := now.Sub(decodeStart) + prefill := total - decode + + met := inference.GenerateMetrics{ + PromptTokens: promptTokens, + GeneratedTokens: generatedTokens, + PrefillDuration: prefill, + DecodeDuration: decode, + TotalDuration: total, + } + if prefill > 0 && promptTokens > 0 { + met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds() + } + if decode > 0 && generatedTokens > 0 { + met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds() + } + + // Try to get VRAM stats — best effort. + if vram, err := GetVRAMInfo(); err == nil { + met.PeakMemoryBytes = vram.Used + met.ActiveMemoryBytes = vram.Used + } + + m.mu.Lock() + m.metrics = met + m.mu.Unlock() +}