diff --git a/model.go b/model.go new file mode 100644 index 0000000..bc07d07 --- /dev/null +++ b/model.go @@ -0,0 +1,102 @@ +//go:build linux && amd64 + +package rocm + +import ( + "context" + "iter" + "sync" + + "forge.lthn.ai/core/go-inference" + "forge.lthn.ai/core/go-rocm/internal/llamacpp" +) + +// rocmModel implements inference.TextModel using a llama-server subprocess. +type rocmModel struct { + srv *server + modelType string + + mu sync.Mutex + lastErr error +} + +// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint. +func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + cfg := inference.ApplyGenerateOpts(opts) + + req := llamacpp.CompletionRequest{ + Prompt: prompt, + MaxTokens: cfg.MaxTokens, + Temperature: cfg.Temperature, + TopK: cfg.TopK, + TopP: cfg.TopP, + RepeatPenalty: cfg.RepeatPenalty, + } + + chunks, errFn := m.srv.client.Complete(ctx, req) + + return func(yield func(inference.Token) bool) { + for text := range chunks { + if !yield(inference.Token{Text: text}) { + break + } + } + if err := errFn(); err != nil { + m.mu.Lock() + m.lastErr = err + m.mu.Unlock() + } + } +} + +// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint. +func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + cfg := inference.ApplyGenerateOpts(opts) + + chatMsgs := make([]llamacpp.ChatMessage, len(messages)) + for i, msg := range messages { + chatMsgs[i] = llamacpp.ChatMessage{ + Role: msg.Role, + Content: msg.Content, + } + } + + req := llamacpp.ChatRequest{ + Messages: chatMsgs, + MaxTokens: cfg.MaxTokens, + Temperature: cfg.Temperature, + TopK: cfg.TopK, + TopP: cfg.TopP, + RepeatPenalty: cfg.RepeatPenalty, + } + + chunks, errFn := m.srv.client.ChatComplete(ctx, req) + + return func(yield func(inference.Token) bool) { + for text := range chunks { + if !yield(inference.Token{Text: text}) { + break + } + } + if err := errFn(); err != nil { + m.mu.Lock() + m.lastErr = err + m.mu.Unlock() + } + } +} + +// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3"). +func (m *rocmModel) ModelType() string { return m.modelType } + +// Err returns the error from the last Generate/Chat call, if any. +func (m *rocmModel) Err() error { + m.mu.Lock() + defer m.mu.Unlock() + return m.lastErr +} + +// Close releases the llama-server subprocess and all associated resources. +func (m *rocmModel) Close() error { + return m.srv.stop() +}