125 lines
3 KiB
Go
125 lines
3 KiB
Go
//go:build linux && amd64
|
|
|
|
package rocm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"iter"
|
|
"sync"
|
|
|
|
"forge.lthn.ai/core/go-inference"
|
|
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
|
)
|
|
|
|
// rocmModel implements inference.TextModel using a llama-server subprocess.
|
|
type rocmModel struct {
|
|
srv *server
|
|
modelType string
|
|
|
|
mu sync.Mutex
|
|
lastErr error
|
|
}
|
|
|
|
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
|
|
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
|
if !m.srv.alive() {
|
|
m.mu.Lock()
|
|
if m.srv.exitErr != nil {
|
|
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
|
} else {
|
|
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
|
|
}
|
|
m.mu.Unlock()
|
|
return func(yield func(inference.Token) bool) {}
|
|
}
|
|
|
|
cfg := inference.ApplyGenerateOpts(opts)
|
|
|
|
req := llamacpp.CompletionRequest{
|
|
Prompt: prompt,
|
|
MaxTokens: cfg.MaxTokens,
|
|
Temperature: cfg.Temperature,
|
|
TopK: cfg.TopK,
|
|
TopP: cfg.TopP,
|
|
RepeatPenalty: cfg.RepeatPenalty,
|
|
}
|
|
|
|
chunks, errFn := m.srv.client.Complete(ctx, req)
|
|
|
|
return func(yield func(inference.Token) bool) {
|
|
for text := range chunks {
|
|
if !yield(inference.Token{Text: text}) {
|
|
break
|
|
}
|
|
}
|
|
if err := errFn(); err != nil {
|
|
m.mu.Lock()
|
|
m.lastErr = err
|
|
m.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
|
|
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
|
if !m.srv.alive() {
|
|
m.mu.Lock()
|
|
if m.srv.exitErr != nil {
|
|
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
|
} else {
|
|
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
|
|
}
|
|
m.mu.Unlock()
|
|
return func(yield func(inference.Token) bool) {}
|
|
}
|
|
|
|
cfg := inference.ApplyGenerateOpts(opts)
|
|
|
|
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
|
|
for i, msg := range messages {
|
|
chatMsgs[i] = llamacpp.ChatMessage{
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
}
|
|
}
|
|
|
|
req := llamacpp.ChatRequest{
|
|
Messages: chatMsgs,
|
|
MaxTokens: cfg.MaxTokens,
|
|
Temperature: cfg.Temperature,
|
|
TopK: cfg.TopK,
|
|
TopP: cfg.TopP,
|
|
RepeatPenalty: cfg.RepeatPenalty,
|
|
}
|
|
|
|
chunks, errFn := m.srv.client.ChatComplete(ctx, req)
|
|
|
|
return func(yield func(inference.Token) bool) {
|
|
for text := range chunks {
|
|
if !yield(inference.Token{Text: text}) {
|
|
break
|
|
}
|
|
}
|
|
if err := errFn(); err != nil {
|
|
m.mu.Lock()
|
|
m.lastErr = err
|
|
m.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
|
|
func (m *rocmModel) ModelType() string { return m.modelType }
|
|
|
|
// Err returns the error from the last Generate/Chat call, if any.
|
|
func (m *rocmModel) Err() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.lastErr
|
|
}
|
|
|
|
// Close releases the llama-server subprocess and all associated resources.
|
|
func (m *rocmModel) Close() error {
|
|
return m.srv.stop()
|
|
}
|