go-mlx/internal/metal/model.go
Snider 19c4823b04 feat(metal): add Llama 3 model support (Llama 3.1 8B validated)
Llama shares the Qwen3 loader (same decoder: pre-norm, SwiGLU, GQA).
Model type now detected from config.json model_type field instead of
weight-only heuristic. Llama 3 chat template and EOS token added.
Model tests now clear Metal GPU cache between runs.

Llama 3.1 8B Instruct 4-bit: 30 tok/s on M3 Ultra.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:06:43 +00:00

73 lines
2 KiB
Go

//go:build darwin && arm64
package metal
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
)
// InternalModel is the common interface for all transformer model architectures.
type InternalModel interface {
// Forward runs the model forward pass on token IDs with KV caches.
Forward(tokens *Array, caches []Cache) *Array
// NewCache creates per-layer KV caches for generation.
NewCache() []Cache
// NumLayers returns the number of transformer layers.
NumLayers() int
// Tokenizer returns the model's tokenizer.
Tokenizer() *Tokenizer
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
// Returns the adapter which holds references to all LoRA layers.
ApplyLoRA(cfg LoRAConfig) *LoRAAdapter
}
// QuantizationConfig holds quantization parameters from config.json.
type QuantizationConfig struct {
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
}
// resolveWeight looks up a weight with optional "language_model." prefix.
func resolveWeight(weights map[string]*Array, name string) *Array {
if w, ok := weights[name]; ok {
return w
}
if w, ok := weights["language_model."+name]; ok {
return w
}
return nil
}
// loadModel auto-detects the model architecture from config.json and loads it.
func loadModel(modelPath string) (InternalModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("model: load config: %w", err)
}
var probe struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("model: parse model_type: %w", err)
}
switch probe.ModelType {
case "qwen3", "qwen2", "llama":
return LoadQwen3(modelPath)
case "gemma3", "gemma3_text", "gemma2":
return LoadGemma3(modelPath)
default:
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)
}
}