go-mlx/internal/metal/generate.go
Snider 4d1bff3d78 refactor(api): clean root package — interfaces only, metal auto-registered
Root package now contains only:
- mlx.go: package doc + go:generate directives
- textmodel.go: TextModel, Token, Message interfaces
- options.go: GenerateOption, LoadOption functional options
- backend.go: Backend interface, Register/Get/Default/LoadModel
- register_metal.go: build-tagged init() + adapter + memory delegates
- mlx_stub.go: non-darwin fallback

internal/metal/ has its own Token, GenerateConfig, Model types.
register_metal.go adapts between the two via metalAdapter.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 20:04:19 +00:00

151 lines
3.5 KiB
Go

//go:build darwin && arm64
package metal
import (
"context"
"iter"
)
// Token represents a single generated token.
type Token struct {
ID int32
Text string
}
// ChatMessage represents a chat turn.
type ChatMessage struct {
Role string
Content string
}
// GenerateConfig holds generation parameters.
type GenerateConfig struct {
MaxTokens int
Temperature float32
TopK int
TopP float32
StopTokens []int32
}
// Model wraps a loaded transformer model for text generation.
type Model struct {
model InternalModel
tokenizer *Tokenizer
modelType string
lastErr error
}
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
func (m *Model) ModelType() string { return m.modelType }
// Err returns the error from the last Generate/Chat call, if any.
func (m *Model) Err() error { return m.lastErr }
// Close releases all resources.
func (m *Model) Close() error {
// TODO(task10): explicit Free() on model weights and caches
return nil
}
// Chat formats messages using the model's native template, then generates.
func (m *Model) Chat(ctx context.Context, messages []ChatMessage, cfg GenerateConfig) iter.Seq[Token] {
prompt := m.formatChat(messages)
return m.Generate(ctx, prompt, cfg)
}
// Generate streams tokens for the given prompt.
func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) iter.Seq[Token] {
m.lastErr = nil
return func(yield func(Token) bool) {
tokens := m.tokenizer.Encode(prompt)
caches := m.model.NewCache()
sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK)
// Prefill: process entire prompt
input := FromValues(tokens, len(tokens))
input = Reshape(input, 1, int32(len(tokens)))
logits := m.model.Forward(input, caches)
Materialize(logits)
for i := 0; i < cfg.MaxTokens; i++ {
select {
case <-ctx.Done():
m.lastErr = ctx.Err()
return
default:
}
// Sample from last position logits
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2)))
next := sampler.Sample(lastPos)
Materialize(next)
id := int32(next.Int())
// Check stop conditions
if id == m.tokenizer.EOSToken() {
return
}
for _, stop := range cfg.StopTokens {
if id == stop {
return
}
}
text := m.tokenizer.DecodeToken(id)
if !yield(Token{ID: id, Text: text}) {
return
}
// Next step input
nextInput := FromValues([]int32{id}, 1)
nextInput = Reshape(nextInput, 1, 1)
logits = m.model.Forward(nextInput, caches)
Materialize(logits)
}
}
}
// formatChat applies the model's native chat template.
func (m *Model) formatChat(messages []ChatMessage) string {
switch m.modelType {
case "gemma3":
return formatGemmaChat(messages)
case "qwen3":
return formatQwenChat(messages)
default:
var s string
for _, msg := range messages {
s += msg.Content + "\n"
}
return s
}
}
func formatGemmaChat(messages []ChatMessage) string {
var s string
for _, msg := range messages {
switch msg.Role {
case "system":
s += "<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n"
case "user":
s += "<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n"
case "assistant":
s += "<start_of_turn>model\n" + msg.Content + "<end_of_turn>\n"
}
}
s += "<start_of_turn>model\n"
return s
}
func formatQwenChat(messages []ChatMessage) string {
var s string
for _, msg := range messages {
s += "<|im_start|>" + msg.Role + "\n" + msg.Content + "<|im_end|>\n"
}
s += "<|im_start|>assistant\n"
return s
}