refactor: apply go fix modernizers for Go 1.26

Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-22 21:00:16 +00:00
parent fc27c2cd27
commit 5004ac258a
12 changed files with 746 additions and 75 deletions

View file

@ -6,7 +6,8 @@ import "fmt"
// LoadConfig holds configuration applied during model loading.
type LoadConfig struct {
ContextLen int // Context window size (0 = model default, unbounded KV cache)
ContextLen int // Context window size (0 = model default, unbounded KV cache)
AdapterPath string // Path to LoRA adapter directory (empty = no adapter)
}
// LoadAndInit initialises Metal and loads a model from the given path.
@ -22,8 +23,15 @@ func LoadAndInit(path string, cfg ...LoadConfig) (*Model, error) {
tokenizer: im.Tokenizer(),
modelType: im.ModelType(),
}
if len(cfg) > 0 && cfg[0].ContextLen > 0 {
m.contextLen = cfg[0].ContextLen
if len(cfg) > 0 {
if cfg[0].ContextLen > 0 {
m.contextLen = cfg[0].ContextLen
}
if cfg[0].AdapterPath != "" {
if err := applyLoadedLoRA(im, cfg[0].AdapterPath); err != nil {
return nil, fmt.Errorf("metal: load adapter: %w", err)
}
}
}
return m, nil
}

View file

@ -6,6 +6,7 @@ import (
"context"
"fmt"
"math"
"slices"
"sort"
"time"
)
@ -86,7 +87,7 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf
// Gather logits at each prompt's last real token position and sample.
sortedResults := make([]ClassifyResult, N)
for si := int32(0); si < N; si++ {
for si := range N {
lastPos := sortedLengths[si] - 1
// Extract [1, vocab] at position lastPos for this batch element.
@ -120,11 +121,11 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf
totalDur := time.Since(totalStart)
m.lastMetrics = Metrics{
PromptTokens: totalPromptTokens,
GeneratedTokens: int(N), // One token sampled per prompt
PrefillDuration: totalDur,
TotalDuration: totalDur,
PeakMemoryBytes: GetPeakMemory(),
PromptTokens: totalPromptTokens,
GeneratedTokens: int(N), // One token sampled per prompt
PrefillDuration: totalDur,
TotalDuration: totalDur,
PeakMemoryBytes: GetPeakMemory(),
ActiveMemoryBytes: GetActiveMemory(),
}
if totalDur > 0 {
@ -227,7 +228,7 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
nextIDs := make([]int32, N)
allFinished := true
for si := int32(0); si < N; si++ {
for si := range N {
if states[si].finished {
nextIDs[si] = 0 // pad
continue
@ -259,11 +260,8 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
states[si].finished = true
continue
}
for _, stop := range cfg.StopTokens {
if id == stop {
states[si].finished = true
break
}
if slices.Contains(cfg.StopTokens, id) {
states[si].finished = true
}
if !states[si].finished {
text := m.tokenizer.DecodeToken(id)
@ -299,12 +297,12 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
totalDur := time.Since(totalStart)
decodeDur := totalDur - prefillDur
m.lastMetrics = Metrics{
PromptTokens: totalPromptTokens,
GeneratedTokens: totalGenerated,
PrefillDuration: prefillDur,
DecodeDuration: decodeDur,
TotalDuration: totalDur,
PeakMemoryBytes: GetPeakMemory(),
PromptTokens: totalPromptTokens,
GeneratedTokens: totalGenerated,
PrefillDuration: prefillDur,
DecodeDuration: decodeDur,
TotalDuration: totalDur,
PeakMemoryBytes: GetPeakMemory(),
ActiveMemoryBytes: GetActiveMemory(),
}
if prefillDur > 0 {
@ -324,11 +322,11 @@ func buildBatchMask(N, L int32, promptLens []int32) *Array {
negInf := float32(math.Inf(-1))
data := make([]float32, int(N)*int(L)*int(L))
for b := int32(0); b < N; b++ {
for b := range N {
pLen := promptLens[b]
base := int(b) * int(L) * int(L)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
for i := range L {
for j := range L {
if j <= i && j < pLen {
data[base+int(i)*int(L)+int(j)] = 0
} else {

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"log/slog"
"maps"
"math"
"os"
"path/filepath"
@ -182,9 +183,7 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
return nil, fmt.Errorf("gemma3: no .safetensors files found in %s", modelPath)
}
for _, path := range matches {
for name, arr := range LoadSafetensors(path) {
weights[name] = arr
}
maps.Insert(weights, LoadSafetensors(path))
if err := lastError(); err != nil {
return nil, fmt.Errorf("gemma3: load weights %s: %w", filepath.Base(path), err)
}

View file

@ -6,6 +6,8 @@ import (
"context"
"fmt"
"iter"
"slices"
"strings"
"time"
)
@ -46,11 +48,11 @@ type Metrics struct {
// Model wraps a loaded transformer model for text generation.
type Model struct {
model InternalModel
tokenizer *Tokenizer
modelType string
contextLen int // 0 = unbounded (model default)
lastErr error
model InternalModel
tokenizer *Tokenizer
modelType string
contextLen int // 0 = unbounded (model default)
lastErr error
lastMetrics Metrics
}
@ -145,12 +147,12 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
decodeDur := time.Since(totalStart) - prefillDur
totalDur := time.Since(totalStart)
m.lastMetrics = Metrics{
PromptTokens: promptLen,
GeneratedTokens: genCount,
PrefillDuration: prefillDur,
DecodeDuration: decodeDur,
TotalDuration: totalDur,
PeakMemoryBytes: GetPeakMemory(),
PromptTokens: promptLen,
GeneratedTokens: genCount,
PrefillDuration: prefillDur,
DecodeDuration: decodeDur,
TotalDuration: totalDur,
PeakMemoryBytes: GetPeakMemory(),
ActiveMemoryBytes: GetActiveMemory(),
}
if prefillDur > 0 {
@ -205,10 +207,8 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
if id == m.tokenizer.EOSToken() {
return
}
for _, stop := range cfg.StopTokens {
if id == stop {
return
}
if slices.Contains(cfg.StopTokens, id) {
return
}
genCount++
@ -286,44 +286,45 @@ func (m *Model) formatChat(messages []ChatMessage) string {
case "llama":
return formatLlamaChat(messages)
default:
var s string
var s strings.Builder
for _, msg := range messages {
s += msg.Content + "\n"
s.WriteString(msg.Content + "\n")
}
return s
return s.String()
}
}
func formatGemmaChat(messages []ChatMessage) string {
var s string
var s strings.Builder
for _, msg := range messages {
switch msg.Role {
case "system":
s += "<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n"
s.WriteString("<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n")
case "user":
s += "<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n"
s.WriteString("<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n")
case "assistant":
s += "<start_of_turn>model\n" + msg.Content + "<end_of_turn>\n"
s.WriteString("<start_of_turn>model\n" + msg.Content + "<end_of_turn>\n")
}
}
s += "<start_of_turn>model\n"
return s
s.WriteString("<start_of_turn>model\n")
return s.String()
}
func formatQwenChat(messages []ChatMessage) string {
var s string
var s strings.Builder
for _, msg := range messages {
s += "<|im_start|>" + msg.Role + "\n" + msg.Content + "<|im_end|>\n"
s.WriteString("<|im_start|>" + msg.Role + "\n" + msg.Content + "<|im_end|>\n")
}
s += "<|im_start|>assistant\n"
return s
s.WriteString("<|im_start|>assistant\n")
return s.String()
}
func formatLlamaChat(messages []ChatMessage) string {
s := "<|begin_of_text|>"
var s strings.Builder
s.WriteString("<|begin_of_text|>")
for _, msg := range messages {
s += "<|start_header_id|>" + msg.Role + "<|end_header_id|>\n\n" + msg.Content + "<|eot_id|>"
s.WriteString("<|start_header_id|>" + msg.Role + "<|end_header_id|>\n\n" + msg.Content + "<|eot_id|>")
}
s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
return s
s.WriteString("<|start_header_id|>assistant<|end_header_id|>\n\n")
return s.String()
}

View file

@ -9,9 +9,14 @@ package metal
import "C"
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"sort"
"strings"
"unsafe"
)
@ -212,6 +217,238 @@ func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
return out
}
// --- Adapter Loading (Inference) ---
// adapterConfig holds the metadata from adapter_config.json produced by mlx-lm training.
type adapterConfig struct {
Rank int `json:"rank"`
Alpha float32 `json:"alpha"`
NumLayers int `json:"num_layers"`
TargetKeys []string `json:"lora_layers"` // e.g. ["self_attn.q_proj", "self_attn.v_proj"]
}
// parseAdapterConfig reads and parses an adapter_config.json file.
func parseAdapterConfig(path string) (*adapterConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read adapter_config.json: %w", err)
}
var cfg adapterConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse adapter_config.json: %w", err)
}
// Apply defaults matching mlx-lm conventions.
if cfg.Rank == 0 {
cfg.Rank = 8
}
if cfg.Alpha == 0 {
cfg.Alpha = float32(cfg.Rank) * 2 // mlx-lm default: alpha = 2 * rank
}
return &cfg, nil
}
// loadAdapterWeights loads all safetensors files from an adapter directory into a flat weight map.
func loadAdapterWeights(dir string) (map[string]*Array, error) {
matches, err := filepath.Glob(filepath.Join(dir, "*.safetensors"))
if err != nil {
return nil, fmt.Errorf("glob adapter safetensors: %w", err)
}
if len(matches) == 0 {
return nil, fmt.Errorf("no .safetensors files found in %s", dir)
}
weights := make(map[string]*Array)
for _, path := range matches {
for name, arr := range LoadSafetensors(path) {
weights[name] = arr
}
if err := lastError(); err != nil {
return nil, fmt.Errorf("load adapter weights %s: %w", filepath.Base(path), err)
}
}
return weights, nil
}
// resolveLinear returns the *Linear for a given projection path within a model.
// projPath is e.g. "self_attn.q_proj" and the function resolves layer index + field.
func resolveLinear(model InternalModel, layerIdx int, projPath string) *Linear {
switch m := model.(type) {
case *Qwen3Model:
if layerIdx >= len(m.Layers) {
return nil
}
layer := m.Layers[layerIdx]
switch projPath {
case "self_attn.q_proj":
return layer.Attention.QProj
case "self_attn.k_proj":
return layer.Attention.KProj
case "self_attn.v_proj":
return layer.Attention.VProj
case "self_attn.o_proj":
return layer.Attention.OProj
}
case *GemmaModel:
if layerIdx >= len(m.Layers) {
return nil
}
layer := m.Layers[layerIdx]
switch projPath {
case "self_attn.q_proj":
return layer.Attention.QProj
case "self_attn.k_proj":
return layer.Attention.KProj
case "self_attn.v_proj":
return layer.Attention.VProj
case "self_attn.o_proj":
return layer.Attention.OProj
}
}
return nil
}
// parseLoRAWeightName extracts the layer index, projection path, and A/B suffix
// from an adapter weight name. Returns (-1, "", "") if the name is not a recognised
// LoRA weight.
//
// Examples:
//
// "layers.0.self_attn.q_proj.lora_a" → (0, "self_attn.q_proj", "lora_a")
// "model.layers.12.self_attn.v_proj.lora_b" → (12, "self_attn.v_proj", "lora_b")
func parseLoRAWeightName(name string) (layerIdx int, projPath, suffix string) {
// Strip optional "model." prefix.
name = strings.TrimPrefix(name, "model.")
// Must start with "layers.{N}."
if !strings.HasPrefix(name, "layers.") {
return -1, "", ""
}
// Must end with ".lora_a" or ".lora_b".
if strings.HasSuffix(name, ".lora_a") {
suffix = "lora_a"
} else if strings.HasSuffix(name, ".lora_b") {
suffix = "lora_b"
} else {
return -1, "", ""
}
// Remove "layers." prefix and ".lora_X" suffix.
inner := name[len("layers."):]
inner = inner[:len(inner)-len("."+suffix)]
// Split off the layer index.
dotIdx := strings.Index(inner, ".")
if dotIdx < 0 {
return -1, "", ""
}
idxStr := inner[:dotIdx]
projPath = inner[dotIdx+1:]
var idx int
if _, err := fmt.Sscanf(idxStr, "%d", &idx); err != nil {
return -1, "", ""
}
return idx, projPath, suffix
}
// applyLoadedLoRA loads a trained LoRA adapter from disk and injects it into the model
// for inference. The adapter weights are frozen (no gradients needed).
func applyLoadedLoRA(model InternalModel, adapterDir string) error {
// Step 1: Read adapter configuration.
cfg, err := parseAdapterConfig(filepath.Join(adapterDir, "adapter_config.json"))
if err != nil {
return err
}
// Step 2: Load adapter safetensors.
weights, err := loadAdapterWeights(adapterDir)
if err != nil {
return err
}
// Materialise all adapter weights onto Metal.
var allArrays []*Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
Materialize(allArrays...)
// Step 3: Group weights by (layerIdx, projPath) and inject LoRA.
type loraKey struct {
layerIdx int
projPath string
}
type loraPair struct {
a *Array
b *Array
}
pairs := make(map[loraKey]*loraPair)
for name, arr := range weights {
layerIdx, projPath, suffix := parseLoRAWeightName(name)
if layerIdx < 0 {
slog.Warn("adapter: skipping unrecognised weight", "name", name)
continue
}
key := loraKey{layerIdx, projPath}
pair, ok := pairs[key]
if !ok {
pair = &loraPair{}
pairs[key] = pair
}
switch suffix {
case "lora_a":
pair.a = arr
case "lora_b":
pair.b = arr
}
}
scale := cfg.Alpha / float32(cfg.Rank)
injected := 0
for key, pair := range pairs {
if pair.a == nil || pair.b == nil {
slog.Warn("adapter: incomplete LoRA pair, skipping",
"layer", key.layerIdx, "proj", key.projPath)
continue
}
linear := resolveLinear(model, key.layerIdx, key.projPath)
if linear == nil {
slog.Warn("adapter: target layer not found, skipping",
"layer", key.layerIdx, "proj", key.projPath)
continue
}
lora := &LoRALinear{
Base: linear,
A: pair.a,
B: pair.b,
Scale: scale,
Rank: cfg.Rank,
Alpha: cfg.Alpha,
}
linear.LoRA = lora
injected++
}
if injected == 0 {
return fmt.Errorf("no LoRA layers injected from %s", adapterDir)
}
slog.Info("adapter loaded",
"path", adapterDir,
"rank", cfg.Rank,
"alpha", cfg.Alpha,
"scale", scale,
"layers_injected", injected,
)
return nil
}
// --- SaveSafetensors ---
// SaveSafetensors saves a map of named arrays to a .safetensors file.

View file

@ -5,6 +5,7 @@ package metal
import (
"math"
"os"
"path/filepath"
"testing"
)
@ -320,3 +321,430 @@ func TestDefaultLoRAConfig(t *testing.T) {
t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys)
}
}
// --- parseLoRAWeightName ---
func TestParseLoRAWeightName_Good(t *testing.T) {
tests := []struct {
name string
input string
wantIdx int
wantProj string
wantSuf string
}{
{
"standard_lora_a",
"layers.0.self_attn.q_proj.lora_a",
0, "self_attn.q_proj", "lora_a",
},
{
"standard_lora_b",
"layers.5.self_attn.v_proj.lora_b",
5, "self_attn.v_proj", "lora_b",
},
{
"with_model_prefix",
"model.layers.12.self_attn.q_proj.lora_a",
12, "self_attn.q_proj", "lora_a",
},
{
"k_proj",
"layers.3.self_attn.k_proj.lora_b",
3, "self_attn.k_proj", "lora_b",
},
{
"o_proj",
"layers.7.self_attn.o_proj.lora_a",
7, "self_attn.o_proj", "lora_a",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idx, proj, suf := parseLoRAWeightName(tt.input)
if idx != tt.wantIdx {
t.Errorf("layerIdx = %d, want %d", idx, tt.wantIdx)
}
if proj != tt.wantProj {
t.Errorf("projPath = %q, want %q", proj, tt.wantProj)
}
if suf != tt.wantSuf {
t.Errorf("suffix = %q, want %q", suf, tt.wantSuf)
}
})
}
}
func TestParseLoRAWeightName_Bad(t *testing.T) {
tests := []struct {
name string
input string
}{
{"no_lora_suffix", "layers.0.self_attn.q_proj.weight"},
{"no_layers_prefix", "self_attn.q_proj.lora_a"},
{"empty", ""},
{"just_layers", "layers."},
{"no_dot_after_idx", "layers.0lora_a"},
{"non_numeric_idx", "layers.abc.self_attn.q_proj.lora_a"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idx, _, _ := parseLoRAWeightName(tt.input)
if idx != -1 {
t.Errorf("expected -1 for %q, got %d", tt.input, idx)
}
})
}
}
// --- parseAdapterConfig ---
func TestParseAdapterConfig_Good(t *testing.T) {
dir := t.TempDir()
cfg := `{
"rank": 16,
"alpha": 32.0,
"num_layers": 4,
"lora_layers": ["self_attn.q_proj", "self_attn.v_proj"]
}`
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(cfg), 0644)
parsed, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json"))
if err != nil {
t.Fatalf("parseAdapterConfig: %v", err)
}
if parsed.Rank != 16 {
t.Errorf("Rank = %d, want 16", parsed.Rank)
}
if parsed.Alpha != 32.0 {
t.Errorf("Alpha = %f, want 32.0", parsed.Alpha)
}
if parsed.NumLayers != 4 {
t.Errorf("NumLayers = %d, want 4", parsed.NumLayers)
}
if len(parsed.TargetKeys) != 2 {
t.Errorf("TargetKeys = %v, want 2 entries", parsed.TargetKeys)
}
}
func TestParseAdapterConfig_Good_Defaults(t *testing.T) {
dir := t.TempDir()
// Minimal config — rank and alpha should get defaults.
cfg := `{}`
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(cfg), 0644)
parsed, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json"))
if err != nil {
t.Fatalf("parseAdapterConfig: %v", err)
}
if parsed.Rank != 8 {
t.Errorf("default Rank = %d, want 8", parsed.Rank)
}
if parsed.Alpha != 16.0 {
t.Errorf("default Alpha = %f, want 16.0 (2 * rank)", parsed.Alpha)
}
}
func TestParseAdapterConfig_Bad_MissingFile(t *testing.T) {
_, err := parseAdapterConfig("/nonexistent/adapter_config.json")
if err == nil {
t.Fatal("expected error for missing file")
}
}
func TestParseAdapterConfig_Bad_InvalidJSON(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte("{broken"), 0644)
_, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json"))
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- loadAdapterWeights ---
func TestLoadAdapterWeights_Bad_NoFiles(t *testing.T) {
dir := t.TempDir()
_, err := loadAdapterWeights(dir)
if err == nil {
t.Fatal("expected error for directory with no safetensors files")
}
}
func TestLoadAdapterWeights_Good(t *testing.T) {
dir := t.TempDir()
// Save a small adapter file.
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
b := FromValues([]float32{5, 6, 7, 8}, 2, 2)
Materialize(a, b)
err := SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{
"layers.0.self_attn.q_proj.lora_a": a,
"layers.0.self_attn.q_proj.lora_b": b,
})
if err != nil {
t.Fatalf("SaveSafetensors: %v", err)
}
weights, err := loadAdapterWeights(dir)
if err != nil {
t.Fatalf("loadAdapterWeights: %v", err)
}
if len(weights) != 2 {
t.Errorf("loaded %d weights, want 2", len(weights))
}
if _, ok := weights["layers.0.self_attn.q_proj.lora_a"]; !ok {
t.Error("missing lora_a weight")
}
if _, ok := weights["layers.0.self_attn.q_proj.lora_b"]; !ok {
t.Error("missing lora_b weight")
}
}
// --- applyLoadedLoRA integration ---
func TestApplyLoadedLoRA_Good_SaveAndReload(t *testing.T) {
// Create a simple base Linear layer and save LoRA weights for it,
// then load them back with applyLoadedLoRA.
// Create a small "model" with 1 layer and known dimensions.
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
Materialize(w)
linear := NewLinear(w, nil)
// Train a LoRA on this linear, then save.
lora := NewLoRALinear(linear, 4, 8.0)
// Set A and B to non-zero values so we can verify they load correctly.
newA := FromValues([]float32{
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8,
0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6,
1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4,
2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2,
}, 4, 8) // [rank=4, in=8]
newB := FromValues([]float32{
0.1, 0.2, 0.3, 0.4,
0.5, 0.6, 0.7, 0.8,
0.9, 1.0, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6,
}, 4, 4) // [out=4, rank=4]
Materialize(newA, newB)
lora.A = newA
lora.B = newB
// Save the adapter weights.
adapterDir := t.TempDir()
err := SaveSafetensors(filepath.Join(adapterDir, "adapters.safetensors"), map[string]*Array{
"layers.0.self_attn.q_proj.lora_a": lora.A,
"layers.0.self_attn.q_proj.lora_b": lora.B,
})
if err != nil {
t.Fatalf("SaveSafetensors: %v", err)
}
// Write adapter_config.json.
configJSON := `{"rank": 4, "alpha": 8.0, "num_layers": 1, "lora_layers": ["self_attn.q_proj"]}`
os.WriteFile(filepath.Join(adapterDir, "adapter_config.json"), []byte(configJSON), 0644)
// Now create a fresh linear with the same base weights (no LoRA).
linear2 := NewLinear(w, nil)
if linear2.LoRA != nil {
t.Fatal("fresh linear should not have LoRA")
}
// Build a minimal model for resolveLinear to work.
qwen := &Qwen3Model{
Layers: []*Qwen3DecoderLayer{
{
Attention: &Qwen3Attention{
QProj: linear2,
KProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
VProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
OProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
},
},
},
}
// Apply the loaded adapter.
err = applyLoadedLoRA(qwen, adapterDir)
if err != nil {
t.Fatalf("applyLoadedLoRA: %v", err)
}
// Verify LoRA was injected.
if linear2.LoRA == nil {
t.Fatal("LoRA should have been injected into q_proj")
}
// Verify rank and scale.
if linear2.LoRA.Rank != 4 {
t.Errorf("Rank = %d, want 4", linear2.LoRA.Rank)
}
expectedScale := float32(8.0) / float32(4) // alpha / rank = 2.0
if math.Abs(float64(linear2.LoRA.Scale-expectedScale)) > 1e-5 {
t.Errorf("Scale = %f, want %f", linear2.LoRA.Scale, expectedScale)
}
// Verify the loaded A weights match what we saved.
Materialize(linear2.LoRA.A, linear2.LoRA.B)
loadedA := linear2.LoRA.A.Floats()
origA := newA.Floats()
if len(loadedA) != len(origA) {
t.Fatalf("A size mismatch: %d vs %d", len(loadedA), len(origA))
}
for i := range origA {
if math.Abs(float64(loadedA[i]-origA[i])) > 1e-5 {
t.Errorf("A[%d] = %f, want %f", i, loadedA[i], origA[i])
break
}
}
// Verify the loaded B weights match.
loadedB := linear2.LoRA.B.Floats()
origB := newB.Floats()
if len(loadedB) != len(origB) {
t.Fatalf("B size mismatch: %d vs %d", len(loadedB), len(origB))
}
for i := range origB {
if math.Abs(float64(loadedB[i]-origB[i])) > 1e-5 {
t.Errorf("B[%d] = %f, want %f", i, loadedB[i], origB[i])
break
}
}
}
func TestApplyLoadedLoRA_Bad_MissingConfig(t *testing.T) {
dir := t.TempDir()
// Write safetensors but no config.
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
Materialize(a)
SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{"x": a})
qwen := &Qwen3Model{Layers: []*Qwen3DecoderLayer{}}
err := applyLoadedLoRA(qwen, dir)
if err == nil {
t.Fatal("expected error for missing adapter_config.json")
}
}
func TestApplyLoadedLoRA_Bad_MissingSafetensors(t *testing.T) {
dir := t.TempDir()
// Write config but no safetensors.
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(`{"rank": 8}`), 0644)
qwen := &Qwen3Model{Layers: []*Qwen3DecoderLayer{}}
err := applyLoadedLoRA(qwen, dir)
if err == nil {
t.Fatal("expected error for missing safetensors")
}
}
func TestApplyLoadedLoRA_Bad_NoMatchingLayers(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(`{"rank": 4, "alpha": 8.0}`), 0644)
// Save weights that reference layer 99 (which won't exist).
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
b := FromValues([]float32{5, 6, 7, 8}, 2, 2)
Materialize(a, b)
SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{
"layers.99.self_attn.q_proj.lora_a": a,
"layers.99.self_attn.q_proj.lora_b": b,
})
qwen := &Qwen3Model{
Layers: []*Qwen3DecoderLayer{
{
Attention: &Qwen3Attention{
QProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
},
},
},
}
err := applyLoadedLoRA(qwen, dir)
if err == nil {
t.Fatal("expected error when no layers are injected")
}
}
// TestApplyLoadedLoRA_Good_ForwardProducesOutput validates that a model with a
// loaded LoRA adapter produces different output than the base model alone.
func TestApplyLoadedLoRA_Good_ForwardProducesOutput(t *testing.T) {
// Create base linear [4, 8].
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
Materialize(w)
linear := NewLinear(w, nil)
// Compute base output.
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
Materialize(x)
baseOut := linear.Forward(x)
Materialize(baseOut)
baseFloats := baseOut.Floats()
// Create and save non-trivial adapter weights.
rank := 4
loraA := RandomNormal(0, 0.1, []int32{int32(rank), 8}, DTypeFloat32)
loraB := RandomNormal(0, 0.1, []int32{4, int32(rank)}, DTypeFloat32)
Materialize(loraA, loraB)
adapterDir := t.TempDir()
SaveSafetensors(filepath.Join(adapterDir, "adapters.safetensors"), map[string]*Array{
"layers.0.self_attn.q_proj.lora_a": loraA,
"layers.0.self_attn.q_proj.lora_b": loraB,
})
os.WriteFile(filepath.Join(adapterDir, "adapter_config.json"),
[]byte(`{"rank": 4, "alpha": 8.0}`), 0644)
// Build a model and apply adapter.
qwen := &Qwen3Model{
Layers: []*Qwen3DecoderLayer{
{
Attention: &Qwen3Attention{
QProj: linear,
KProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
VProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
OProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
},
},
},
}
err := applyLoadedLoRA(qwen, adapterDir)
if err != nil {
t.Fatalf("applyLoadedLoRA: %v", err)
}
// Now forward should go through LoRA path.
loraOut := linear.Forward(x)
Materialize(loraOut)
loraFloats := loraOut.Floats()
// Outputs should differ since B is non-zero.
allSame := true
for i := range baseFloats {
if math.Abs(float64(baseFloats[i]-loraFloats[i])) > 1e-6 {
allSame = false
break
}
}
if allSame {
t.Error("expected LoRA output to differ from base output with non-zero B weights")
}
}
// --- LoadAndInit with adapter ---
func TestLoadAndInit_Bad_AdapterMissing(t *testing.T) {
dir := t.TempDir()
writeMinimalConfig(t, dir, "qwen3")
writeMinimalTokenizer(t, dir)
// Create a minimal safetensors file so model loading proceeds.
// The adapter path doesn't exist, so it should fail at the adapter step.
_, err := LoadAndInit(dir, LoadConfig{AdapterPath: "/nonexistent/adapter"})
if err == nil {
t.Fatal("expected error for missing adapter")
}
}

View file

@ -14,7 +14,7 @@ func TestAdamW_BasicStep(t *testing.T) {
opt := NewAdamW(0.1)
for i := 0; i < 300; i++ {
for i := range 300 {
// Gradient of x^2 is 2x
lossFn := func(inputs []*Array) []*Array {
p := inputs[0]
@ -48,7 +48,7 @@ func TestAdamW_MultiParam(t *testing.T) {
opt := NewAdamW(0.1)
for i := 0; i < 100; i++ {
for i := range 100 {
lossFn := func(inputs []*Array) []*Array {
return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))}
}
@ -85,7 +85,7 @@ func TestAdamW_WeightDecay(t *testing.T) {
zeroGrad := FromValue(float32(0.0))
Materialize(zeroGrad)
for i := 0; i < 10; i++ {
for range 10 {
updated := opt.Step([]*Array{x}, []*Array{zeroGrad})
x = updated[0]
Materialize(x)
@ -137,7 +137,7 @@ func TestAdamW_WithLoRA(t *testing.T) {
var initialLoss, finalLoss float64
for step := 0; step < 50; step++ {
for step := range 50 {
lossFn := func(inputs []*Array) []*Array {
lora.A = inputs[0]
lora.B = inputs[1]

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"log/slog"
"maps"
"math"
"os"
"path/filepath"
@ -44,10 +45,10 @@ type Qwen3Model struct {
// Qwen3DecoderLayer is a single transformer block.
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
type Qwen3DecoderLayer struct {
InputNorm *RMSNormModule // Pre-attention norm
InputNorm *RMSNormModule // Pre-attention norm
PostAttnNorm *RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
Attention *Qwen3Attention
MLP *Qwen3MLP
Attention *Qwen3Attention
MLP *Qwen3MLP
}
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
@ -136,9 +137,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
return nil, fmt.Errorf("qwen3: no .safetensors files found in %s", modelPath)
}
for _, path := range matches {
for name, arr := range LoadSafetensors(path) {
weights[name] = arr
}
maps.Insert(weights, LoadSafetensors(path))
if err := lastError(); err != nil {
return nil, fmt.Errorf("qwen3: load weights %s: %w", filepath.Base(path), err)
}

View file

@ -173,7 +173,7 @@ func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
// Non-self-mapping: control chars, space, DEL, and gaps
n := 0
for b := 0; b < 256; b++ {
for b := range 256 {
if _, ok := encoder[byte(b)]; !ok {
r := rune(256 + n)
encoder[byte(b)] = r

View file

@ -239,7 +239,7 @@ func TestBuildGPT2ByteMaps(t *testing.T) {
}
// Round-trip: every byte should survive encode → decode
for b := 0; b < 256; b++ {
for b := range 256 {
r := encoder[byte(b)]
got := decoder[r]
if got != byte(b) {

View file

@ -75,7 +75,7 @@ func TestLoRA_EndToEnd(t *testing.T) {
var initialLoss, finalLoss float64
const numSteps = 5
for step := 0; step < numSteps; step++ {
for step := range numSteps {
// Fresh caches each step (stateful — can't reuse across gradient calls).
caches := gemma.NewCache()
@ -224,7 +224,7 @@ func TestLoRA_GradientCheckpointing(t *testing.T) {
var initialLoss, finalLoss float64
const numSteps = 3
for step := 0; step < numSteps; step++ {
for step := range numSteps {
caches := gemma.NewCache()
// Wrap the model forward pass in Checkpoint to recompute activations
@ -326,7 +326,7 @@ func TestLoRA_MixedPrecision(t *testing.T) {
var initialLoss, finalLoss float64
const numSteps = 5
for step := 0; step < numSteps; step++ {
for step := range numSteps {
caches := gemma.NewCache()
lossFn := func(inputs []*Array) []*Array {

View file

@ -63,7 +63,8 @@ func (b *metalBackend) LoadModel(path string, opts ...inference.LoadOption) (inf
slog.Warn("mlx: GPULayers=0 ignored — Metal always uses full GPU offload")
}
m, err := metal.LoadAndInit(path, metal.LoadConfig{
ContextLen: cfg.ContextLen,
ContextLen: cfg.ContextLen,
AdapterPath: cfg.AdapterPath,
})
if err != nil {
return nil, err