refactor(mlx): drop mlx build tag, auto-enable on darwin/arm64
Remove the manual -tags mlx requirement. MLX is now automatically compiled on darwin/arm64 via build constraints. Stubs remain for other platforms. No functional change. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
d0cbd5065e
commit
92c6282d50
19 changed files with 655 additions and 153 deletions
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package ml
|
||||
|
||||
|
|
@ -16,9 +16,9 @@ import (
|
|||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// MLXBackend implements Backend for native Metal inference via mlx-c.
|
||||
// MLXBackend implements Backend and StreamingBackend for native Metal inference.
|
||||
type MLXBackend struct {
|
||||
model *model.GemmaModel
|
||||
model model.Model
|
||||
tok *tokenizer.Tokenizer
|
||||
caches []cache.Cache
|
||||
sampler sample.Sampler
|
||||
|
|
@ -26,6 +26,9 @@ type MLXBackend struct {
|
|||
modelBytes uint64 // model size at load time, for memory budget
|
||||
}
|
||||
|
||||
// Compile-time check that MLXBackend satisfies StreamingBackend.
|
||||
var _ StreamingBackend = (*MLXBackend)(nil)
|
||||
|
||||
// NewMLXBackend loads a model from a safetensors directory and creates
|
||||
// a native Metal inference backend.
|
||||
func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||||
|
|
@ -34,13 +37,12 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
|||
}
|
||||
|
||||
slog.Info("mlx: loading model", "path", modelPath)
|
||||
m, err := model.LoadGemma3(modelPath)
|
||||
m, err := model.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mlx: load model: %w", err)
|
||||
}
|
||||
|
||||
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
|
||||
// This prevents runaway memory growth from killing the system.
|
||||
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
|
||||
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
|
||||
|
||||
|
|
@ -54,31 +56,27 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
|||
model: m,
|
||||
tok: m.Tokenizer(),
|
||||
caches: m.NewCache(),
|
||||
sampler: sample.New(0.1, 0, 0, 0), // default low temp
|
||||
sampler: sample.New(0.1, 0, 0, 0),
|
||||
modelBytes: mlx.GetActiveMemory(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate produces text from a prompt using native Metal inference.
|
||||
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
|
||||
// generate is the core token generation loop. If cb is non-nil, each token's
|
||||
// text is sent to it (streaming mode). Returns the full output text.
|
||||
func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, cb TokenCallback) (string, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Reset caches for new generation
|
||||
for _, c := range b.caches {
|
||||
c.Reset()
|
||||
}
|
||||
|
||||
// Set up sampler based on opts
|
||||
temp := float32(opts.Temperature)
|
||||
if temp == 0 {
|
||||
temp = 0.1
|
||||
}
|
||||
sampler := sample.New(temp, 0, 0, 0)
|
||||
|
||||
// Tokenize
|
||||
formatted := tokenizer.FormatGemmaPrompt(prompt)
|
||||
tokens := b.tok.Encode(formatted)
|
||||
input := mlx.FromValues(tokens, 1, len(tokens))
|
||||
|
||||
maxTokens := opts.MaxTokens
|
||||
|
|
@ -86,8 +84,6 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
|||
maxTokens = 2048
|
||||
}
|
||||
|
||||
// Generation loop — force Go GC every 4 tokens so finalizers release
|
||||
// intermediate C array handles that Go GC cannot see as memory pressure.
|
||||
var output []int32
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
select {
|
||||
|
|
@ -110,20 +106,58 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
|||
output = append(output, nextToken)
|
||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||
|
||||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||||
// Stream the token text to the callback
|
||||
if cb != nil {
|
||||
tokenText := b.tok.Decode([]int32{nextToken})
|
||||
if err := cb(tokenText); err != nil {
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
return b.tok.Decode(output), err
|
||||
}
|
||||
}
|
||||
|
||||
if i%4 == 3 {
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup between requests
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
b.checkMemory()
|
||||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
||||
// Generate produces text from a prompt using native Metal inference.
|
||||
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
|
||||
formatted := formatPrompt(b.model.ModelType(), prompt)
|
||||
tokens := b.tok.Encode(formatted)
|
||||
return b.generate(ctx, tokens, opts, nil)
|
||||
}
|
||||
|
||||
// Chat formats messages and generates a response.
|
||||
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
||||
prompt := formatChat(b.model.ModelType(), messages)
|
||||
tokens := b.tok.Encode(prompt)
|
||||
return b.generate(ctx, tokens, opts, nil)
|
||||
}
|
||||
|
||||
// GenerateStream streams tokens from a single prompt via the callback.
|
||||
func (b *MLXBackend) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error {
|
||||
formatted := formatPrompt(b.model.ModelType(), prompt)
|
||||
tokens := b.tok.Encode(formatted)
|
||||
_, err := b.generate(ctx, tokens, opts, cb)
|
||||
return err
|
||||
}
|
||||
|
||||
// ChatStream streams tokens from a chat conversation via the callback.
|
||||
func (b *MLXBackend) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error {
|
||||
prompt := formatChat(b.model.ModelType(), messages)
|
||||
tokens := b.tok.Encode(prompt)
|
||||
_, err := b.generate(ctx, tokens, opts, cb)
|
||||
return err
|
||||
}
|
||||
|
||||
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
|
||||
func lastPosition(logits *mlx.Array) *mlx.Array {
|
||||
shape := logits.Shape()
|
||||
|
|
@ -137,9 +171,49 @@ func lastPosition(logits *mlx.Array) *mlx.Array {
|
|||
return logits
|
||||
}
|
||||
|
||||
// Chat formats messages and generates a response.
|
||||
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
||||
// Format as Gemma chat
|
||||
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
|
||||
func (b *MLXBackend) checkMemory() {
|
||||
active := mlx.GetActiveMemory()
|
||||
budget := b.modelBytes * 3
|
||||
if active > budget {
|
||||
slog.Warn("mlx: memory over budget, forcing cleanup",
|
||||
"active_mb", active/1024/1024,
|
||||
"model_mb", b.modelBytes/1024/1024,
|
||||
"peak_mb", mlx.GetPeakMemory()/1024/1024,
|
||||
)
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the backend identifier.
|
||||
func (b *MLXBackend) Name() string { return "mlx" }
|
||||
|
||||
// Available reports whether Metal GPU is ready.
|
||||
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }
|
||||
|
||||
// formatPrompt wraps a raw prompt in the model's chat template for single-turn generation.
|
||||
func formatPrompt(modelType, prompt string) string {
|
||||
switch modelType {
|
||||
case "qwen3":
|
||||
return fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n", prompt)
|
||||
default:
|
||||
return tokenizer.FormatGemmaPrompt(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// formatChat builds a multi-turn chat prompt from messages using the model's template.
|
||||
func formatChat(modelType string, messages []Message) string {
|
||||
switch modelType {
|
||||
case "qwen3":
|
||||
return formatQwen3Chat(messages)
|
||||
default:
|
||||
return formatGemmaChat(messages)
|
||||
}
|
||||
}
|
||||
|
||||
func formatGemmaChat(messages []Message) string {
|
||||
var prompt string
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
|
|
@ -152,83 +226,21 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
|||
}
|
||||
}
|
||||
prompt += "<start_of_turn>model\n"
|
||||
|
||||
// Use raw prompt (already formatted)
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
for _, c := range b.caches {
|
||||
c.Reset()
|
||||
return prompt
|
||||
}
|
||||
|
||||
temp := float32(opts.Temperature)
|
||||
if temp == 0 {
|
||||
temp = 0.1
|
||||
}
|
||||
sampler := sample.New(temp, 0, 0, 0)
|
||||
|
||||
tokens := b.tok.Encode(prompt)
|
||||
input := mlx.FromValues(tokens, 1, len(tokens))
|
||||
|
||||
maxTokens := opts.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 2048
|
||||
}
|
||||
|
||||
var output []int32
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
return b.tok.Decode(output), ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logits := b.model.Forward(input, b.caches)
|
||||
logits = lastPosition(logits)
|
||||
next := sampler.Sample(logits)
|
||||
mlx.Materialize(next)
|
||||
|
||||
nextToken := int32(next.Int())
|
||||
if nextToken == b.tok.EOSToken() {
|
||||
break
|
||||
}
|
||||
output = append(output, nextToken)
|
||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||
|
||||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||||
if i%4 == 3 {
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
func formatQwen3Chat(messages []Message) string {
|
||||
var prompt string
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
prompt += fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content)
|
||||
case "user":
|
||||
prompt += fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content)
|
||||
case "assistant":
|
||||
prompt += fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup between requests
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
b.checkMemory()
|
||||
return b.tok.Decode(output), nil
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
return prompt
|
||||
}
|
||||
|
||||
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
|
||||
func (b *MLXBackend) checkMemory() {
|
||||
active := mlx.GetActiveMemory()
|
||||
budget := b.modelBytes * 3 // 3× model size = danger zone
|
||||
if active > budget {
|
||||
slog.Warn("mlx: memory over budget, forcing cleanup",
|
||||
"active_mb", active/1024/1024,
|
||||
"model_mb", b.modelBytes/1024/1024,
|
||||
"peak_mb", mlx.GetPeakMemory()/1024/1024,
|
||||
)
|
||||
runtime.GC()
|
||||
runtime.GC() // double GC to run finalizers
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the backend identifier.
|
||||
func (b *MLXBackend) Name() string { return "mlx" }
|
||||
|
||||
// Available reports whether Metal GPU is ready.
|
||||
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
2
mlx/cache/cache.go
vendored
2
mlx/cache/cache.go
vendored
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
// Package cache provides KV cache implementations for transformer inference.
|
||||
package cache
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||
//
|
||||
|
|
@ -6,9 +6,9 @@
|
|||
//
|
||||
// cd pkg/mlx && go generate ./...
|
||||
//
|
||||
// Build with MLX enabled:
|
||||
// Build (MLX is auto-enabled on darwin/arm64):
|
||||
//
|
||||
// go build -tags mlx -o core .
|
||||
// go build -o core .
|
||||
package mlx
|
||||
|
||||
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build !(darwin && arm64 && mlx)
|
||||
//go:build !(darwin && arm64)
|
||||
|
||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||
// This stub file is used on non-darwin/non-arm64 platforms or when the
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
// Package model provides transformer model architectures for MLX inference.
|
||||
package model
|
||||
|
|
@ -16,12 +16,6 @@ import (
|
|||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// QuantizationConfig holds quantization parameters from config.json.
|
||||
type QuantizationConfig struct {
|
||||
GroupSize int `json:"group_size"`
|
||||
Bits int `json:"bits"`
|
||||
}
|
||||
|
||||
// TextConfig holds Gemma 3 text model configuration.
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
|
|
@ -168,17 +162,6 @@ func parseConfig(data []byte) (*TextConfig, error) {
|
|||
return &cfg, nil
|
||||
}
|
||||
|
||||
// resolveWeight looks up a weight with optional "language_model." prefix.
|
||||
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
||||
if w, ok := weights[name]; ok {
|
||||
return w
|
||||
}
|
||||
if w, ok := weights["language_model."+name]; ok {
|
||||
return w
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadGemma3 loads a Gemma 3 text model from a directory.
|
||||
func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
|
|
@ -428,3 +411,6 @@ func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
|
|||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *GemmaModel) ModelType() string { return "gemma3" }
|
||||
|
|
|
|||
74
mlx/model/model.go
Normal file
74
mlx/model/model.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package model provides transformer model architectures for MLX inference.
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the common interface for all transformer model architectures.
|
||||
type Model interface {
|
||||
// Forward runs the model forward pass on token IDs with KV caches.
|
||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
|
||||
// NewCache creates per-layer KV caches for generation.
|
||||
NewCache() []cache.Cache
|
||||
|
||||
// NumLayers returns the number of transformer layers.
|
||||
NumLayers() int
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
}
|
||||
|
||||
// QuantizationConfig holds quantization parameters from config.json.
|
||||
type QuantizationConfig struct {
|
||||
GroupSize int `json:"group_size"`
|
||||
Bits int `json:"bits"`
|
||||
}
|
||||
|
||||
// resolveWeight looks up a weight with optional "language_model." prefix.
|
||||
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
||||
if w, ok := weights[name]; ok {
|
||||
return w
|
||||
}
|
||||
if w, ok := weights["language_model."+name]; ok {
|
||||
return w
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadModel auto-detects the model architecture from config.json and loads it.
|
||||
func LoadModel(modelPath string) (Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model: load config: %w", err)
|
||||
}
|
||||
|
||||
var probe struct {
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, fmt.Errorf("model: parse model_type: %w", err)
|
||||
}
|
||||
|
||||
switch probe.ModelType {
|
||||
case "qwen3":
|
||||
return LoadQwen3(modelPath)
|
||||
case "gemma3", "gemma2":
|
||||
return LoadGemma3(modelPath)
|
||||
default:
|
||||
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)
|
||||
}
|
||||
}
|
||||
305
mlx/model/qwen3.go
Normal file
305
mlx/model/qwen3.go
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// Qwen3Config holds Qwen 3 model configuration.
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
|
||||
Quantization *QuantizationConfig `json:"-"`
|
||||
Scale float32 `json:"-"` // 1/sqrt(head_dim)
|
||||
}
|
||||
|
||||
// Qwen3Model is the Qwen 3 text model.
|
||||
type Qwen3Model struct {
|
||||
EmbedTokens *mlx.Embedding
|
||||
Layers []*Qwen3DecoderLayer
|
||||
Norm *mlx.RMSNormModule
|
||||
Output *mlx.Linear
|
||||
|
||||
Tok *tokenizer.Tokenizer
|
||||
Cfg *Qwen3Config
|
||||
}
|
||||
|
||||
// Qwen3DecoderLayer is a single transformer block.
|
||||
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
|
||||
type Qwen3DecoderLayer struct {
|
||||
InputNorm *mlx.RMSNormModule // Pre-attention norm
|
||||
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
|
||||
Attention *Qwen3Attention
|
||||
MLP *Qwen3MLP
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
|
||||
type Qwen3Attention struct {
|
||||
QProj *mlx.Linear
|
||||
KProj *mlx.Linear
|
||||
VProj *mlx.Linear
|
||||
OProj *mlx.Linear
|
||||
QNorm *mlx.RMSNormModule
|
||||
KNorm *mlx.RMSNormModule
|
||||
}
|
||||
|
||||
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
|
||||
type Qwen3MLP struct {
|
||||
GateProj *mlx.Linear
|
||||
UpProj *mlx.Linear
|
||||
DownProj *mlx.Linear
|
||||
}
|
||||
|
||||
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
|
||||
var cfg Qwen3Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Top-level quantization
|
||||
var wrapper struct {
|
||||
Quantization *QuantizationConfig `json:"quantization"`
|
||||
}
|
||||
json.Unmarshal(data, &wrapper)
|
||||
cfg.Quantization = wrapper.Quantization
|
||||
|
||||
// Compute scale
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
// Defaults
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 151936
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
|
||||
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3: load config: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := parseQwen3Config(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3: parse config: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Load weights from all safetensors files
|
||||
weights := make(map[string]*mlx.Array)
|
||||
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
|
||||
for _, path := range matches {
|
||||
for name, arr := range mlx.LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Quantization setup
|
||||
q := cfg.Quantization
|
||||
if q != nil {
|
||||
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||
}
|
||||
linear := func(prefix string) *mlx.Linear {
|
||||
weight := w(prefix + ".weight")
|
||||
scales := w(prefix + ".scales")
|
||||
biases := w(prefix + ".biases")
|
||||
bias := w(prefix + ".bias")
|
||||
if scales != nil && q != nil {
|
||||
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
|
||||
}
|
||||
return mlx.NewLinear(weight, bias)
|
||||
}
|
||||
|
||||
// Embedding
|
||||
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||
embed.Scales = embedScales
|
||||
embed.Biases = w("model.embed_tokens.biases")
|
||||
embed.GroupSize = q.GroupSize
|
||||
embed.Bits = q.Bits
|
||||
}
|
||||
|
||||
m := &Qwen3Model{
|
||||
EmbedTokens: embed,
|
||||
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
p := fmt.Sprintf("model.layers.%d", i)
|
||||
m.Layers[i] = &Qwen3DecoderLayer{
|
||||
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
|
||||
Attention: &Qwen3Attention{
|
||||
QProj: linear(p + ".self_attn.q_proj"),
|
||||
KProj: linear(p + ".self_attn.k_proj"),
|
||||
VProj: linear(p + ".self_attn.v_proj"),
|
||||
OProj: linear(p + ".self_attn.o_proj"),
|
||||
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
|
||||
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
|
||||
},
|
||||
MLP: &Qwen3MLP{
|
||||
GateProj: linear(p + ".mlp.gate_proj"),
|
||||
UpProj: linear(p + ".mlp.up_proj"),
|
||||
DownProj: linear(p + ".mlp.down_proj"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate
|
||||
lmHeadWeight := w("lm_head.weight")
|
||||
if lmHeadWeight != nil {
|
||||
lmHeadScales := w("lm_head.scales")
|
||||
if lmHeadScales != nil && q != nil {
|
||||
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
} else {
|
||||
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||
}
|
||||
} else {
|
||||
m.Output = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
// Materialise all weights onto Metal
|
||||
var allArrays []*mlx.Array
|
||||
for _, a := range weights {
|
||||
allArrays = append(allArrays, a)
|
||||
}
|
||||
mlx.Materialize(allArrays...)
|
||||
|
||||
slog.Info("qwen3: model loaded",
|
||||
"layers", cfg.NumHiddenLayers,
|
||||
"hidden", cfg.HiddenSize,
|
||||
"heads", cfg.NumAttentionHeads,
|
||||
"kv_heads", cfg.NumKeyValueHeads,
|
||||
"head_dim", cfg.HeadDim,
|
||||
"vocab", cfg.VocabSize,
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the Qwen 3 forward pass.
|
||||
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
|
||||
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||
}
|
||||
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||
// Pre-attention norm → attention → residual add
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, cfg)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
// Pre-MLP norm → MLP → residual add
|
||||
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K RMS normalization (Qwen 3 has this)
|
||||
q = a.QNorm.Forward(q, cfg.RMSNormEps)
|
||||
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
||||
|
||||
// RoPE — single theta for all layers (no sliding window)
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update KV cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// GQA: repeat K/V heads to match Q heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = mlx.RepeatKV(k, repeatFactor)
|
||||
v = mlx.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
|
||||
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
|
||||
func (m *Qwen3Model) NewCache() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// NumLayers returns the number of transformer layers.
|
||||
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *Qwen3Model) ModelType() string { return "qwen3" }
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
14
mlx/ops.go
14
mlx/ops.go
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
@ -68,6 +68,18 @@ func Exp(a *Array) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
// Sigmoid returns element-wise 1/(1+exp(-a)).
|
||||
func Sigmoid(a *Array) *Array {
|
||||
out := New("SIGMOID", a)
|
||||
C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
|
||||
func SiLU(a *Array) *Array {
|
||||
return Mul(a, Sigmoid(a))
|
||||
}
|
||||
|
||||
// Tanh returns element-wise tanh(a).
|
||||
func Tanh(a *Array) *Array {
|
||||
out := New("TANH", a)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
// Package sample provides composable token sampling strategies.
|
||||
package sample
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
//go:build darwin && arm64
|
||||
|
||||
// Package tokenizer provides BPE/SentencePiece tokenization for Gemma models.
|
||||
// Package tokenizer provides BPE tokenization for transformer models.
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
|
|
@ -19,6 +19,11 @@ type Tokenizer struct {
|
|||
|
||||
bosToken int32
|
||||
eosToken int32
|
||||
|
||||
// GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.)
|
||||
isGPT2BPE bool
|
||||
gpt2Decoder map[rune]byte // Unicode char → original byte
|
||||
gpt2Encoder map[byte]rune // original byte → Unicode char
|
||||
}
|
||||
|
||||
type mergePair struct {
|
||||
|
|
@ -71,7 +76,6 @@ func Load(path string) (*Tokenizer, error) {
|
|||
|
||||
// Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats
|
||||
if len(tj.Model.Merges) > 0 {
|
||||
// Try array-of-strings first
|
||||
var stringMerges []string
|
||||
if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil {
|
||||
for rank, merge := range stringMerges {
|
||||
|
|
@ -81,7 +85,6 @@ func Load(path string) (*Tokenizer, error) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
// Try array-of-arrays: [["a","b"], ...]
|
||||
var arrayMerges [][]string
|
||||
if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil {
|
||||
for rank, pair := range arrayMerges {
|
||||
|
|
@ -102,37 +105,77 @@ func Load(path string) (*Tokenizer, error) {
|
|||
t.invVocab[tok.ID] = tok.Content
|
||||
}
|
||||
|
||||
// Set BOS/EOS
|
||||
// Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space)
|
||||
if _, ok := t.vocab["Ġ"]; ok {
|
||||
t.isGPT2BPE = true
|
||||
t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps()
|
||||
}
|
||||
|
||||
// Set BOS/EOS — detect model family from special tokens
|
||||
if id, ok := t.special["<bos>"]; ok {
|
||||
t.bosToken = id
|
||||
}
|
||||
if id, ok := t.special["<eos>"]; ok {
|
||||
t.eosToken = id
|
||||
}
|
||||
// Gemma: <end_of_turn> is the generation stop token
|
||||
if id, ok := t.special["<end_of_turn>"]; ok {
|
||||
t.eosToken = id // Gemma uses end_of_turn as EOS
|
||||
t.eosToken = id
|
||||
}
|
||||
// Qwen3: <|im_end|> is the generation stop token
|
||||
if id, ok := t.special["<|im_end|>"]; ok {
|
||||
t.eosToken = id
|
||||
}
|
||||
// Qwen3 BOS: <|im_start|>
|
||||
if id, ok := t.special["<|im_start|>"]; ok {
|
||||
t.bosToken = id
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps.
|
||||
// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars
|
||||
// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves;
|
||||
// everything else (0-32, 127-160, 173) maps to U+0100 onwards.
|
||||
func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
|
||||
encoder = make(map[byte]rune, 256)
|
||||
decoder = make(map[rune]byte, 256)
|
||||
|
||||
// Self-mapping ranges: printable ASCII + Latin-1 Supplement
|
||||
// Use int loop variable to avoid byte overflow at 255.
|
||||
selfMap := func(lo, hi int) {
|
||||
for b := lo; b <= hi; b++ {
|
||||
encoder[byte(b)] = rune(b)
|
||||
decoder[rune(b)] = byte(b)
|
||||
}
|
||||
}
|
||||
selfMap(33, 126) // ! through ~
|
||||
selfMap(161, 172) // ¡ through ¬
|
||||
selfMap(174, 255) // ® through ÿ
|
||||
|
||||
// Non-self-mapping: control chars, space, DEL, and gaps
|
||||
n := 0
|
||||
for b := 0; b < 256; b++ {
|
||||
if _, ok := encoder[byte(b)]; !ok {
|
||||
r := rune(256 + n)
|
||||
encoder[byte(b)] = r
|
||||
decoder[r] = byte(b)
|
||||
n++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Encode converts text to token IDs. Prepends BOS token.
|
||||
func (t *Tokenizer) Encode(text string) []int32 {
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
// Simple BPE encoding — split into characters then merge
|
||||
// This is a simplified version. Full implementation handles
|
||||
// Unicode, byte fallback, and efficient BPE merging.
|
||||
chars := []string{}
|
||||
for _, r := range text {
|
||||
s := string(r)
|
||||
if s == " " {
|
||||
s = "▁" // SentencePiece space marker
|
||||
}
|
||||
chars = append(chars, s)
|
||||
if t.isGPT2BPE {
|
||||
return t.encodeGPT2(text)
|
||||
}
|
||||
|
||||
// Check for special tokens first
|
||||
// SentencePiece style encoding
|
||||
remaining := text
|
||||
for remaining != "" {
|
||||
found := false
|
||||
|
|
@ -145,7 +188,6 @@ func (t *Tokenizer) Encode(text string) []int32 {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
// Encode character by character (simplified BPE)
|
||||
r := []rune(remaining)
|
||||
ch := "▁" + string(r[0])
|
||||
if id, ok := t.vocab[ch]; ok {
|
||||
|
|
@ -160,24 +202,95 @@ func (t *Tokenizer) Encode(text string) []int32 {
|
|||
return tokens
|
||||
}
|
||||
|
||||
// encodeGPT2 encodes text using GPT-2 byte-level BPE.
|
||||
func (t *Tokenizer) encodeGPT2(text string) []int32 {
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
// Convert text bytes to GPT-2 Unicode representation
|
||||
var encoded strings.Builder
|
||||
for _, b := range []byte(text) {
|
||||
if r, ok := t.gpt2Encoder[b]; ok {
|
||||
encoded.WriteRune(r)
|
||||
}
|
||||
}
|
||||
gpt2Text := encoded.String()
|
||||
|
||||
// Scan for special tokens and regular text
|
||||
remaining := gpt2Text
|
||||
for remaining != "" {
|
||||
// Check special tokens (these are stored as-is, not byte-encoded)
|
||||
found := false
|
||||
for tok, id := range t.special {
|
||||
// Special tokens in GPT-2 tokenizers are stored in their original form
|
||||
// Convert the special token to GPT-2 encoding for matching
|
||||
var encTok strings.Builder
|
||||
for _, b := range []byte(tok) {
|
||||
if r, ok := t.gpt2Encoder[b]; ok {
|
||||
encTok.WriteRune(r)
|
||||
}
|
||||
}
|
||||
encStr := encTok.String()
|
||||
if strings.HasPrefix(remaining, encStr) {
|
||||
tokens = append(tokens, id)
|
||||
remaining = remaining[len(encStr):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Character-by-character lookup (simplified BPE)
|
||||
r := []rune(remaining)
|
||||
ch := string(r[0])
|
||||
if id, ok := t.vocab[ch]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
remaining = string(r[1:])
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to text.
|
||||
func (t *Tokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
if text, ok := t.invVocab[id]; ok {
|
||||
// Replace SentencePiece space marker
|
||||
text = strings.ReplaceAll(text, "▁", " ")
|
||||
// Skip special tokens in decode output
|
||||
if _, isSpecial := t.special[text]; isSpecial {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(text)
|
||||
}
|
||||
}
|
||||
result := sb.String()
|
||||
// Trim leading space from SentencePiece encoding
|
||||
raw := sb.String()
|
||||
|
||||
if t.isGPT2BPE {
|
||||
return t.decodeGPT2Bytes(raw)
|
||||
}
|
||||
|
||||
// SentencePiece style
|
||||
result := strings.ReplaceAll(raw, "▁", " ")
|
||||
if strings.HasPrefix(result, " ") {
|
||||
result = result[1:]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
|
||||
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
|
||||
var buf []byte
|
||||
for _, r := range s {
|
||||
if b, ok := t.gpt2Decoder[r]; ok {
|
||||
buf = append(buf, b)
|
||||
} else {
|
||||
// Non-mapped runes pass through as UTF-8
|
||||
buf = append(buf, []byte(string(r))...)
|
||||
}
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// BOSToken returns the beginning-of-sequence token ID.
|
||||
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue