refactor: apply go fix modernizers for Go 1.26
Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
fc27c2cd27
commit
5004ac258a
12 changed files with 746 additions and 75 deletions
|
|
@ -6,7 +6,8 @@ import "fmt"
|
|||
|
||||
// LoadConfig holds configuration applied during model loading.
|
||||
type LoadConfig struct {
|
||||
ContextLen int // Context window size (0 = model default, unbounded KV cache)
|
||||
ContextLen int // Context window size (0 = model default, unbounded KV cache)
|
||||
AdapterPath string // Path to LoRA adapter directory (empty = no adapter)
|
||||
}
|
||||
|
||||
// LoadAndInit initialises Metal and loads a model from the given path.
|
||||
|
|
@ -22,8 +23,15 @@ func LoadAndInit(path string, cfg ...LoadConfig) (*Model, error) {
|
|||
tokenizer: im.Tokenizer(),
|
||||
modelType: im.ModelType(),
|
||||
}
|
||||
if len(cfg) > 0 && cfg[0].ContextLen > 0 {
|
||||
m.contextLen = cfg[0].ContextLen
|
||||
if len(cfg) > 0 {
|
||||
if cfg[0].ContextLen > 0 {
|
||||
m.contextLen = cfg[0].ContextLen
|
||||
}
|
||||
if cfg[0].AdapterPath != "" {
|
||||
if err := applyLoadedLoRA(im, cfg[0].AdapterPath); err != nil {
|
||||
return nil, fmt.Errorf("metal: load adapter: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -86,7 +87,7 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf
|
|||
|
||||
// Gather logits at each prompt's last real token position and sample.
|
||||
sortedResults := make([]ClassifyResult, N)
|
||||
for si := int32(0); si < N; si++ {
|
||||
for si := range N {
|
||||
lastPos := sortedLengths[si] - 1
|
||||
|
||||
// Extract [1, vocab] at position lastPos for this batch element.
|
||||
|
|
@ -120,11 +121,11 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf
|
|||
|
||||
totalDur := time.Since(totalStart)
|
||||
m.lastMetrics = Metrics{
|
||||
PromptTokens: totalPromptTokens,
|
||||
GeneratedTokens: int(N), // One token sampled per prompt
|
||||
PrefillDuration: totalDur,
|
||||
TotalDuration: totalDur,
|
||||
PeakMemoryBytes: GetPeakMemory(),
|
||||
PromptTokens: totalPromptTokens,
|
||||
GeneratedTokens: int(N), // One token sampled per prompt
|
||||
PrefillDuration: totalDur,
|
||||
TotalDuration: totalDur,
|
||||
PeakMemoryBytes: GetPeakMemory(),
|
||||
ActiveMemoryBytes: GetActiveMemory(),
|
||||
}
|
||||
if totalDur > 0 {
|
||||
|
|
@ -227,7 +228,7 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
|
|||
nextIDs := make([]int32, N)
|
||||
allFinished := true
|
||||
|
||||
for si := int32(0); si < N; si++ {
|
||||
for si := range N {
|
||||
if states[si].finished {
|
||||
nextIDs[si] = 0 // pad
|
||||
continue
|
||||
|
|
@ -259,11 +260,8 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
|
|||
states[si].finished = true
|
||||
continue
|
||||
}
|
||||
for _, stop := range cfg.StopTokens {
|
||||
if id == stop {
|
||||
states[si].finished = true
|
||||
break
|
||||
}
|
||||
if slices.Contains(cfg.StopTokens, id) {
|
||||
states[si].finished = true
|
||||
}
|
||||
if !states[si].finished {
|
||||
text := m.tokenizer.DecodeToken(id)
|
||||
|
|
@ -299,12 +297,12 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
|
|||
totalDur := time.Since(totalStart)
|
||||
decodeDur := totalDur - prefillDur
|
||||
m.lastMetrics = Metrics{
|
||||
PromptTokens: totalPromptTokens,
|
||||
GeneratedTokens: totalGenerated,
|
||||
PrefillDuration: prefillDur,
|
||||
DecodeDuration: decodeDur,
|
||||
TotalDuration: totalDur,
|
||||
PeakMemoryBytes: GetPeakMemory(),
|
||||
PromptTokens: totalPromptTokens,
|
||||
GeneratedTokens: totalGenerated,
|
||||
PrefillDuration: prefillDur,
|
||||
DecodeDuration: decodeDur,
|
||||
TotalDuration: totalDur,
|
||||
PeakMemoryBytes: GetPeakMemory(),
|
||||
ActiveMemoryBytes: GetActiveMemory(),
|
||||
}
|
||||
if prefillDur > 0 {
|
||||
|
|
@ -324,11 +322,11 @@ func buildBatchMask(N, L int32, promptLens []int32) *Array {
|
|||
negInf := float32(math.Inf(-1))
|
||||
data := make([]float32, int(N)*int(L)*int(L))
|
||||
|
||||
for b := int32(0); b < N; b++ {
|
||||
for b := range N {
|
||||
pLen := promptLens[b]
|
||||
base := int(b) * int(L) * int(L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
for i := range L {
|
||||
for j := range L {
|
||||
if j <= i && j < pLen {
|
||||
data[base+int(i)*int(L)+int(j)] = 0
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -182,9 +183,7 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
return nil, fmt.Errorf("gemma3: no .safetensors files found in %s", modelPath)
|
||||
}
|
||||
for _, path := range matches {
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
maps.Insert(weights, LoadSafetensors(path))
|
||||
if err := lastError(); err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load weights %s: %w", filepath.Base(path), err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -46,11 +48,11 @@ type Metrics struct {
|
|||
|
||||
// Model wraps a loaded transformer model for text generation.
|
||||
type Model struct {
|
||||
model InternalModel
|
||||
tokenizer *Tokenizer
|
||||
modelType string
|
||||
contextLen int // 0 = unbounded (model default)
|
||||
lastErr error
|
||||
model InternalModel
|
||||
tokenizer *Tokenizer
|
||||
modelType string
|
||||
contextLen int // 0 = unbounded (model default)
|
||||
lastErr error
|
||||
lastMetrics Metrics
|
||||
}
|
||||
|
||||
|
|
@ -145,12 +147,12 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
decodeDur := time.Since(totalStart) - prefillDur
|
||||
totalDur := time.Since(totalStart)
|
||||
m.lastMetrics = Metrics{
|
||||
PromptTokens: promptLen,
|
||||
GeneratedTokens: genCount,
|
||||
PrefillDuration: prefillDur,
|
||||
DecodeDuration: decodeDur,
|
||||
TotalDuration: totalDur,
|
||||
PeakMemoryBytes: GetPeakMemory(),
|
||||
PromptTokens: promptLen,
|
||||
GeneratedTokens: genCount,
|
||||
PrefillDuration: prefillDur,
|
||||
DecodeDuration: decodeDur,
|
||||
TotalDuration: totalDur,
|
||||
PeakMemoryBytes: GetPeakMemory(),
|
||||
ActiveMemoryBytes: GetActiveMemory(),
|
||||
}
|
||||
if prefillDur > 0 {
|
||||
|
|
@ -205,10 +207,8 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
if id == m.tokenizer.EOSToken() {
|
||||
return
|
||||
}
|
||||
for _, stop := range cfg.StopTokens {
|
||||
if id == stop {
|
||||
return
|
||||
}
|
||||
if slices.Contains(cfg.StopTokens, id) {
|
||||
return
|
||||
}
|
||||
|
||||
genCount++
|
||||
|
|
@ -286,44 +286,45 @@ func (m *Model) formatChat(messages []ChatMessage) string {
|
|||
case "llama":
|
||||
return formatLlamaChat(messages)
|
||||
default:
|
||||
var s string
|
||||
var s strings.Builder
|
||||
for _, msg := range messages {
|
||||
s += msg.Content + "\n"
|
||||
s.WriteString(msg.Content + "\n")
|
||||
}
|
||||
return s
|
||||
return s.String()
|
||||
}
|
||||
}
|
||||
|
||||
func formatGemmaChat(messages []ChatMessage) string {
|
||||
var s string
|
||||
var s strings.Builder
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
s += "<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n"
|
||||
s.WriteString("<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n")
|
||||
case "user":
|
||||
s += "<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n"
|
||||
s.WriteString("<start_of_turn>user\n" + msg.Content + "<end_of_turn>\n")
|
||||
case "assistant":
|
||||
s += "<start_of_turn>model\n" + msg.Content + "<end_of_turn>\n"
|
||||
s.WriteString("<start_of_turn>model\n" + msg.Content + "<end_of_turn>\n")
|
||||
}
|
||||
}
|
||||
s += "<start_of_turn>model\n"
|
||||
return s
|
||||
s.WriteString("<start_of_turn>model\n")
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func formatQwenChat(messages []ChatMessage) string {
|
||||
var s string
|
||||
var s strings.Builder
|
||||
for _, msg := range messages {
|
||||
s += "<|im_start|>" + msg.Role + "\n" + msg.Content + "<|im_end|>\n"
|
||||
s.WriteString("<|im_start|>" + msg.Role + "\n" + msg.Content + "<|im_end|>\n")
|
||||
}
|
||||
s += "<|im_start|>assistant\n"
|
||||
return s
|
||||
s.WriteString("<|im_start|>assistant\n")
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func formatLlamaChat(messages []ChatMessage) string {
|
||||
s := "<|begin_of_text|>"
|
||||
var s strings.Builder
|
||||
s.WriteString("<|begin_of_text|>")
|
||||
for _, msg := range messages {
|
||||
s += "<|start_header_id|>" + msg.Role + "<|end_header_id|>\n\n" + msg.Content + "<|eot_id|>"
|
||||
s.WriteString("<|start_header_id|>" + msg.Role + "<|end_header_id|>\n\n" + msg.Content + "<|eot_id|>")
|
||||
}
|
||||
s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
return s
|
||||
s.WriteString("<|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||
return s.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,9 +9,14 @@ package metal
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
|
@ -212,6 +217,238 @@ func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
// --- Adapter Loading (Inference) ---
|
||||
|
||||
// adapterConfig holds the metadata from adapter_config.json produced by mlx-lm training.
|
||||
type adapterConfig struct {
|
||||
Rank int `json:"rank"`
|
||||
Alpha float32 `json:"alpha"`
|
||||
NumLayers int `json:"num_layers"`
|
||||
TargetKeys []string `json:"lora_layers"` // e.g. ["self_attn.q_proj", "self_attn.v_proj"]
|
||||
}
|
||||
|
||||
// parseAdapterConfig reads and parses an adapter_config.json file.
|
||||
func parseAdapterConfig(path string) (*adapterConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read adapter_config.json: %w", err)
|
||||
}
|
||||
var cfg adapterConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse adapter_config.json: %w", err)
|
||||
}
|
||||
// Apply defaults matching mlx-lm conventions.
|
||||
if cfg.Rank == 0 {
|
||||
cfg.Rank = 8
|
||||
}
|
||||
if cfg.Alpha == 0 {
|
||||
cfg.Alpha = float32(cfg.Rank) * 2 // mlx-lm default: alpha = 2 * rank
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// loadAdapterWeights loads all safetensors files from an adapter directory into a flat weight map.
|
||||
func loadAdapterWeights(dir string) (map[string]*Array, error) {
|
||||
matches, err := filepath.Glob(filepath.Join(dir, "*.safetensors"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("glob adapter safetensors: %w", err)
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("no .safetensors files found in %s", dir)
|
||||
}
|
||||
|
||||
weights := make(map[string]*Array)
|
||||
for _, path := range matches {
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
if err := lastError(); err != nil {
|
||||
return nil, fmt.Errorf("load adapter weights %s: %w", filepath.Base(path), err)
|
||||
}
|
||||
}
|
||||
return weights, nil
|
||||
}
|
||||
|
||||
// resolveLinear returns the *Linear for a given projection path within a model.
|
||||
// projPath is e.g. "self_attn.q_proj" and the function resolves layer index + field.
|
||||
func resolveLinear(model InternalModel, layerIdx int, projPath string) *Linear {
|
||||
switch m := model.(type) {
|
||||
case *Qwen3Model:
|
||||
if layerIdx >= len(m.Layers) {
|
||||
return nil
|
||||
}
|
||||
layer := m.Layers[layerIdx]
|
||||
switch projPath {
|
||||
case "self_attn.q_proj":
|
||||
return layer.Attention.QProj
|
||||
case "self_attn.k_proj":
|
||||
return layer.Attention.KProj
|
||||
case "self_attn.v_proj":
|
||||
return layer.Attention.VProj
|
||||
case "self_attn.o_proj":
|
||||
return layer.Attention.OProj
|
||||
}
|
||||
case *GemmaModel:
|
||||
if layerIdx >= len(m.Layers) {
|
||||
return nil
|
||||
}
|
||||
layer := m.Layers[layerIdx]
|
||||
switch projPath {
|
||||
case "self_attn.q_proj":
|
||||
return layer.Attention.QProj
|
||||
case "self_attn.k_proj":
|
||||
return layer.Attention.KProj
|
||||
case "self_attn.v_proj":
|
||||
return layer.Attention.VProj
|
||||
case "self_attn.o_proj":
|
||||
return layer.Attention.OProj
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseLoRAWeightName extracts the layer index, projection path, and A/B suffix
|
||||
// from an adapter weight name. Returns (-1, "", "") if the name is not a recognised
|
||||
// LoRA weight.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// "layers.0.self_attn.q_proj.lora_a" → (0, "self_attn.q_proj", "lora_a")
|
||||
// "model.layers.12.self_attn.v_proj.lora_b" → (12, "self_attn.v_proj", "lora_b")
|
||||
func parseLoRAWeightName(name string) (layerIdx int, projPath, suffix string) {
|
||||
// Strip optional "model." prefix.
|
||||
name = strings.TrimPrefix(name, "model.")
|
||||
|
||||
// Must start with "layers.{N}."
|
||||
if !strings.HasPrefix(name, "layers.") {
|
||||
return -1, "", ""
|
||||
}
|
||||
|
||||
// Must end with ".lora_a" or ".lora_b".
|
||||
if strings.HasSuffix(name, ".lora_a") {
|
||||
suffix = "lora_a"
|
||||
} else if strings.HasSuffix(name, ".lora_b") {
|
||||
suffix = "lora_b"
|
||||
} else {
|
||||
return -1, "", ""
|
||||
}
|
||||
|
||||
// Remove "layers." prefix and ".lora_X" suffix.
|
||||
inner := name[len("layers."):]
|
||||
inner = inner[:len(inner)-len("."+suffix)]
|
||||
|
||||
// Split off the layer index.
|
||||
dotIdx := strings.Index(inner, ".")
|
||||
if dotIdx < 0 {
|
||||
return -1, "", ""
|
||||
}
|
||||
idxStr := inner[:dotIdx]
|
||||
projPath = inner[dotIdx+1:]
|
||||
|
||||
var idx int
|
||||
if _, err := fmt.Sscanf(idxStr, "%d", &idx); err != nil {
|
||||
return -1, "", ""
|
||||
}
|
||||
|
||||
return idx, projPath, suffix
|
||||
}
|
||||
|
||||
// applyLoadedLoRA loads a trained LoRA adapter from disk and injects it into the model
|
||||
// for inference. The adapter weights are frozen (no gradients needed).
|
||||
func applyLoadedLoRA(model InternalModel, adapterDir string) error {
|
||||
// Step 1: Read adapter configuration.
|
||||
cfg, err := parseAdapterConfig(filepath.Join(adapterDir, "adapter_config.json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 2: Load adapter safetensors.
|
||||
weights, err := loadAdapterWeights(adapterDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Materialise all adapter weights onto Metal.
|
||||
var allArrays []*Array
|
||||
for _, a := range weights {
|
||||
allArrays = append(allArrays, a)
|
||||
}
|
||||
Materialize(allArrays...)
|
||||
|
||||
// Step 3: Group weights by (layerIdx, projPath) and inject LoRA.
|
||||
type loraKey struct {
|
||||
layerIdx int
|
||||
projPath string
|
||||
}
|
||||
type loraPair struct {
|
||||
a *Array
|
||||
b *Array
|
||||
}
|
||||
pairs := make(map[loraKey]*loraPair)
|
||||
|
||||
for name, arr := range weights {
|
||||
layerIdx, projPath, suffix := parseLoRAWeightName(name)
|
||||
if layerIdx < 0 {
|
||||
slog.Warn("adapter: skipping unrecognised weight", "name", name)
|
||||
continue
|
||||
}
|
||||
key := loraKey{layerIdx, projPath}
|
||||
pair, ok := pairs[key]
|
||||
if !ok {
|
||||
pair = &loraPair{}
|
||||
pairs[key] = pair
|
||||
}
|
||||
switch suffix {
|
||||
case "lora_a":
|
||||
pair.a = arr
|
||||
case "lora_b":
|
||||
pair.b = arr
|
||||
}
|
||||
}
|
||||
|
||||
scale := cfg.Alpha / float32(cfg.Rank)
|
||||
injected := 0
|
||||
|
||||
for key, pair := range pairs {
|
||||
if pair.a == nil || pair.b == nil {
|
||||
slog.Warn("adapter: incomplete LoRA pair, skipping",
|
||||
"layer", key.layerIdx, "proj", key.projPath)
|
||||
continue
|
||||
}
|
||||
|
||||
linear := resolveLinear(model, key.layerIdx, key.projPath)
|
||||
if linear == nil {
|
||||
slog.Warn("adapter: target layer not found, skipping",
|
||||
"layer", key.layerIdx, "proj", key.projPath)
|
||||
continue
|
||||
}
|
||||
|
||||
lora := &LoRALinear{
|
||||
Base: linear,
|
||||
A: pair.a,
|
||||
B: pair.b,
|
||||
Scale: scale,
|
||||
Rank: cfg.Rank,
|
||||
Alpha: cfg.Alpha,
|
||||
}
|
||||
linear.LoRA = lora
|
||||
injected++
|
||||
}
|
||||
|
||||
if injected == 0 {
|
||||
return fmt.Errorf("no LoRA layers injected from %s", adapterDir)
|
||||
}
|
||||
|
||||
slog.Info("adapter loaded",
|
||||
"path", adapterDir,
|
||||
"rank", cfg.Rank,
|
||||
"alpha", cfg.Alpha,
|
||||
"scale", scale,
|
||||
"layers_injected", injected,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- SaveSafetensors ---
|
||||
|
||||
// SaveSafetensors saves a map of named arrays to a .safetensors file.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ package metal
|
|||
import (
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
@ -320,3 +321,430 @@ func TestDefaultLoRAConfig(t *testing.T) {
|
|||
t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys)
|
||||
}
|
||||
}
|
||||
|
||||
// --- parseLoRAWeightName ---
|
||||
|
||||
func TestParseLoRAWeightName_Good(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantIdx int
|
||||
wantProj string
|
||||
wantSuf string
|
||||
}{
|
||||
{
|
||||
"standard_lora_a",
|
||||
"layers.0.self_attn.q_proj.lora_a",
|
||||
0, "self_attn.q_proj", "lora_a",
|
||||
},
|
||||
{
|
||||
"standard_lora_b",
|
||||
"layers.5.self_attn.v_proj.lora_b",
|
||||
5, "self_attn.v_proj", "lora_b",
|
||||
},
|
||||
{
|
||||
"with_model_prefix",
|
||||
"model.layers.12.self_attn.q_proj.lora_a",
|
||||
12, "self_attn.q_proj", "lora_a",
|
||||
},
|
||||
{
|
||||
"k_proj",
|
||||
"layers.3.self_attn.k_proj.lora_b",
|
||||
3, "self_attn.k_proj", "lora_b",
|
||||
},
|
||||
{
|
||||
"o_proj",
|
||||
"layers.7.self_attn.o_proj.lora_a",
|
||||
7, "self_attn.o_proj", "lora_a",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idx, proj, suf := parseLoRAWeightName(tt.input)
|
||||
if idx != tt.wantIdx {
|
||||
t.Errorf("layerIdx = %d, want %d", idx, tt.wantIdx)
|
||||
}
|
||||
if proj != tt.wantProj {
|
||||
t.Errorf("projPath = %q, want %q", proj, tt.wantProj)
|
||||
}
|
||||
if suf != tt.wantSuf {
|
||||
t.Errorf("suffix = %q, want %q", suf, tt.wantSuf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLoRAWeightName_Bad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"no_lora_suffix", "layers.0.self_attn.q_proj.weight"},
|
||||
{"no_layers_prefix", "self_attn.q_proj.lora_a"},
|
||||
{"empty", ""},
|
||||
{"just_layers", "layers."},
|
||||
{"no_dot_after_idx", "layers.0lora_a"},
|
||||
{"non_numeric_idx", "layers.abc.self_attn.q_proj.lora_a"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idx, _, _ := parseLoRAWeightName(tt.input)
|
||||
if idx != -1 {
|
||||
t.Errorf("expected -1 for %q, got %d", tt.input, idx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- parseAdapterConfig ---
|
||||
|
||||
func TestParseAdapterConfig_Good(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := `{
|
||||
"rank": 16,
|
||||
"alpha": 32.0,
|
||||
"num_layers": 4,
|
||||
"lora_layers": ["self_attn.q_proj", "self_attn.v_proj"]
|
||||
}`
|
||||
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(cfg), 0644)
|
||||
|
||||
parsed, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("parseAdapterConfig: %v", err)
|
||||
}
|
||||
if parsed.Rank != 16 {
|
||||
t.Errorf("Rank = %d, want 16", parsed.Rank)
|
||||
}
|
||||
if parsed.Alpha != 32.0 {
|
||||
t.Errorf("Alpha = %f, want 32.0", parsed.Alpha)
|
||||
}
|
||||
if parsed.NumLayers != 4 {
|
||||
t.Errorf("NumLayers = %d, want 4", parsed.NumLayers)
|
||||
}
|
||||
if len(parsed.TargetKeys) != 2 {
|
||||
t.Errorf("TargetKeys = %v, want 2 entries", parsed.TargetKeys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAdapterConfig_Good_Defaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Minimal config — rank and alpha should get defaults.
|
||||
cfg := `{}`
|
||||
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(cfg), 0644)
|
||||
|
||||
parsed, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("parseAdapterConfig: %v", err)
|
||||
}
|
||||
if parsed.Rank != 8 {
|
||||
t.Errorf("default Rank = %d, want 8", parsed.Rank)
|
||||
}
|
||||
if parsed.Alpha != 16.0 {
|
||||
t.Errorf("default Alpha = %f, want 16.0 (2 * rank)", parsed.Alpha)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAdapterConfig_Bad_MissingFile(t *testing.T) {
|
||||
_, err := parseAdapterConfig("/nonexistent/adapter_config.json")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAdapterConfig_Bad_InvalidJSON(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte("{broken"), 0644)
|
||||
|
||||
_, err := parseAdapterConfig(filepath.Join(dir, "adapter_config.json"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- loadAdapterWeights ---
|
||||
|
||||
func TestLoadAdapterWeights_Bad_NoFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
_, err := loadAdapterWeights(dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for directory with no safetensors files")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAdapterWeights_Good(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Save a small adapter file.
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||
b := FromValues([]float32{5, 6, 7, 8}, 2, 2)
|
||||
Materialize(a, b)
|
||||
|
||||
err := SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{
|
||||
"layers.0.self_attn.q_proj.lora_a": a,
|
||||
"layers.0.self_attn.q_proj.lora_b": b,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SaveSafetensors: %v", err)
|
||||
}
|
||||
|
||||
weights, err := loadAdapterWeights(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("loadAdapterWeights: %v", err)
|
||||
}
|
||||
if len(weights) != 2 {
|
||||
t.Errorf("loaded %d weights, want 2", len(weights))
|
||||
}
|
||||
if _, ok := weights["layers.0.self_attn.q_proj.lora_a"]; !ok {
|
||||
t.Error("missing lora_a weight")
|
||||
}
|
||||
if _, ok := weights["layers.0.self_attn.q_proj.lora_b"]; !ok {
|
||||
t.Error("missing lora_b weight")
|
||||
}
|
||||
}
|
||||
|
||||
// --- applyLoadedLoRA integration ---
|
||||
|
||||
func TestApplyLoadedLoRA_Good_SaveAndReload(t *testing.T) {
|
||||
// Create a simple base Linear layer and save LoRA weights for it,
|
||||
// then load them back with applyLoadedLoRA.
|
||||
|
||||
// Create a small "model" with 1 layer and known dimensions.
|
||||
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
linear := NewLinear(w, nil)
|
||||
|
||||
// Train a LoRA on this linear, then save.
|
||||
lora := NewLoRALinear(linear, 4, 8.0)
|
||||
// Set A and B to non-zero values so we can verify they load correctly.
|
||||
newA := FromValues([]float32{
|
||||
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8,
|
||||
0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6,
|
||||
1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4,
|
||||
2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2,
|
||||
}, 4, 8) // [rank=4, in=8]
|
||||
newB := FromValues([]float32{
|
||||
0.1, 0.2, 0.3, 0.4,
|
||||
0.5, 0.6, 0.7, 0.8,
|
||||
0.9, 1.0, 1.1, 1.2,
|
||||
1.3, 1.4, 1.5, 1.6,
|
||||
}, 4, 4) // [out=4, rank=4]
|
||||
Materialize(newA, newB)
|
||||
lora.A = newA
|
||||
lora.B = newB
|
||||
|
||||
// Save the adapter weights.
|
||||
adapterDir := t.TempDir()
|
||||
err := SaveSafetensors(filepath.Join(adapterDir, "adapters.safetensors"), map[string]*Array{
|
||||
"layers.0.self_attn.q_proj.lora_a": lora.A,
|
||||
"layers.0.self_attn.q_proj.lora_b": lora.B,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SaveSafetensors: %v", err)
|
||||
}
|
||||
|
||||
// Write adapter_config.json.
|
||||
configJSON := `{"rank": 4, "alpha": 8.0, "num_layers": 1, "lora_layers": ["self_attn.q_proj"]}`
|
||||
os.WriteFile(filepath.Join(adapterDir, "adapter_config.json"), []byte(configJSON), 0644)
|
||||
|
||||
// Now create a fresh linear with the same base weights (no LoRA).
|
||||
linear2 := NewLinear(w, nil)
|
||||
if linear2.LoRA != nil {
|
||||
t.Fatal("fresh linear should not have LoRA")
|
||||
}
|
||||
|
||||
// Build a minimal model for resolveLinear to work.
|
||||
qwen := &Qwen3Model{
|
||||
Layers: []*Qwen3DecoderLayer{
|
||||
{
|
||||
Attention: &Qwen3Attention{
|
||||
QProj: linear2,
|
||||
KProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
|
||||
VProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
|
||||
OProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Apply the loaded adapter.
|
||||
err = applyLoadedLoRA(qwen, adapterDir)
|
||||
if err != nil {
|
||||
t.Fatalf("applyLoadedLoRA: %v", err)
|
||||
}
|
||||
|
||||
// Verify LoRA was injected.
|
||||
if linear2.LoRA == nil {
|
||||
t.Fatal("LoRA should have been injected into q_proj")
|
||||
}
|
||||
|
||||
// Verify rank and scale.
|
||||
if linear2.LoRA.Rank != 4 {
|
||||
t.Errorf("Rank = %d, want 4", linear2.LoRA.Rank)
|
||||
}
|
||||
expectedScale := float32(8.0) / float32(4) // alpha / rank = 2.0
|
||||
if math.Abs(float64(linear2.LoRA.Scale-expectedScale)) > 1e-5 {
|
||||
t.Errorf("Scale = %f, want %f", linear2.LoRA.Scale, expectedScale)
|
||||
}
|
||||
|
||||
// Verify the loaded A weights match what we saved.
|
||||
Materialize(linear2.LoRA.A, linear2.LoRA.B)
|
||||
loadedA := linear2.LoRA.A.Floats()
|
||||
origA := newA.Floats()
|
||||
if len(loadedA) != len(origA) {
|
||||
t.Fatalf("A size mismatch: %d vs %d", len(loadedA), len(origA))
|
||||
}
|
||||
for i := range origA {
|
||||
if math.Abs(float64(loadedA[i]-origA[i])) > 1e-5 {
|
||||
t.Errorf("A[%d] = %f, want %f", i, loadedA[i], origA[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the loaded B weights match.
|
||||
loadedB := linear2.LoRA.B.Floats()
|
||||
origB := newB.Floats()
|
||||
if len(loadedB) != len(origB) {
|
||||
t.Fatalf("B size mismatch: %d vs %d", len(loadedB), len(origB))
|
||||
}
|
||||
for i := range origB {
|
||||
if math.Abs(float64(loadedB[i]-origB[i])) > 1e-5 {
|
||||
t.Errorf("B[%d] = %f, want %f", i, loadedB[i], origB[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyLoadedLoRA_Bad_MissingConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Write safetensors but no config.
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||
Materialize(a)
|
||||
SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{"x": a})
|
||||
|
||||
qwen := &Qwen3Model{Layers: []*Qwen3DecoderLayer{}}
|
||||
err := applyLoadedLoRA(qwen, dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing adapter_config.json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyLoadedLoRA_Bad_MissingSafetensors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Write config but no safetensors.
|
||||
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(`{"rank": 8}`), 0644)
|
||||
|
||||
qwen := &Qwen3Model{Layers: []*Qwen3DecoderLayer{}}
|
||||
err := applyLoadedLoRA(qwen, dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing safetensors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyLoadedLoRA_Bad_NoMatchingLayers(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "adapter_config.json"), []byte(`{"rank": 4, "alpha": 8.0}`), 0644)
|
||||
|
||||
// Save weights that reference layer 99 (which won't exist).
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||
b := FromValues([]float32{5, 6, 7, 8}, 2, 2)
|
||||
Materialize(a, b)
|
||||
SaveSafetensors(filepath.Join(dir, "adapters.safetensors"), map[string]*Array{
|
||||
"layers.99.self_attn.q_proj.lora_a": a,
|
||||
"layers.99.self_attn.q_proj.lora_b": b,
|
||||
})
|
||||
|
||||
qwen := &Qwen3Model{
|
||||
Layers: []*Qwen3DecoderLayer{
|
||||
{
|
||||
Attention: &Qwen3Attention{
|
||||
QProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := applyLoadedLoRA(qwen, dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no layers are injected")
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyLoadedLoRA_Good_ForwardProducesOutput validates that a model with a
|
||||
// loaded LoRA adapter produces different output than the base model alone.
|
||||
func TestApplyLoadedLoRA_Good_ForwardProducesOutput(t *testing.T) {
|
||||
// Create base linear [4, 8].
|
||||
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
linear := NewLinear(w, nil)
|
||||
|
||||
// Compute base output.
|
||||
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
|
||||
Materialize(x)
|
||||
baseOut := linear.Forward(x)
|
||||
Materialize(baseOut)
|
||||
baseFloats := baseOut.Floats()
|
||||
|
||||
// Create and save non-trivial adapter weights.
|
||||
rank := 4
|
||||
loraA := RandomNormal(0, 0.1, []int32{int32(rank), 8}, DTypeFloat32)
|
||||
loraB := RandomNormal(0, 0.1, []int32{4, int32(rank)}, DTypeFloat32)
|
||||
Materialize(loraA, loraB)
|
||||
|
||||
adapterDir := t.TempDir()
|
||||
SaveSafetensors(filepath.Join(adapterDir, "adapters.safetensors"), map[string]*Array{
|
||||
"layers.0.self_attn.q_proj.lora_a": loraA,
|
||||
"layers.0.self_attn.q_proj.lora_b": loraB,
|
||||
})
|
||||
os.WriteFile(filepath.Join(adapterDir, "adapter_config.json"),
|
||||
[]byte(`{"rank": 4, "alpha": 8.0}`), 0644)
|
||||
|
||||
// Build a model and apply adapter.
|
||||
qwen := &Qwen3Model{
|
||||
Layers: []*Qwen3DecoderLayer{
|
||||
{
|
||||
Attention: &Qwen3Attention{
|
||||
QProj: linear,
|
||||
KProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
|
||||
VProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
|
||||
OProj: NewLinear(RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32), nil),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := applyLoadedLoRA(qwen, adapterDir)
|
||||
if err != nil {
|
||||
t.Fatalf("applyLoadedLoRA: %v", err)
|
||||
}
|
||||
|
||||
// Now forward should go through LoRA path.
|
||||
loraOut := linear.Forward(x)
|
||||
Materialize(loraOut)
|
||||
loraFloats := loraOut.Floats()
|
||||
|
||||
// Outputs should differ since B is non-zero.
|
||||
allSame := true
|
||||
for i := range baseFloats {
|
||||
if math.Abs(float64(baseFloats[i]-loraFloats[i])) > 1e-6 {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allSame {
|
||||
t.Error("expected LoRA output to differ from base output with non-zero B weights")
|
||||
}
|
||||
}
|
||||
|
||||
// --- LoadAndInit with adapter ---
|
||||
|
||||
func TestLoadAndInit_Bad_AdapterMissing(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeMinimalConfig(t, dir, "qwen3")
|
||||
writeMinimalTokenizer(t, dir)
|
||||
|
||||
// Create a minimal safetensors file so model loading proceeds.
|
||||
// The adapter path doesn't exist, so it should fail at the adapter step.
|
||||
_, err := LoadAndInit(dir, LoadConfig{AdapterPath: "/nonexistent/adapter"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing adapter")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ func TestAdamW_BasicStep(t *testing.T) {
|
|||
|
||||
opt := NewAdamW(0.1)
|
||||
|
||||
for i := 0; i < 300; i++ {
|
||||
for i := range 300 {
|
||||
// Gradient of x^2 is 2x
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
p := inputs[0]
|
||||
|
|
@ -48,7 +48,7 @@ func TestAdamW_MultiParam(t *testing.T) {
|
|||
|
||||
opt := NewAdamW(0.1)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))}
|
||||
}
|
||||
|
|
@ -85,7 +85,7 @@ func TestAdamW_WeightDecay(t *testing.T) {
|
|||
zeroGrad := FromValue(float32(0.0))
|
||||
Materialize(zeroGrad)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
updated := opt.Step([]*Array{x}, []*Array{zeroGrad})
|
||||
x = updated[0]
|
||||
Materialize(x)
|
||||
|
|
@ -137,7 +137,7 @@ func TestAdamW_WithLoRA(t *testing.T) {
|
|||
|
||||
var initialLoss, finalLoss float64
|
||||
|
||||
for step := 0; step < 50; step++ {
|
||||
for step := range 50 {
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
lora.A = inputs[0]
|
||||
lora.B = inputs[1]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -44,10 +45,10 @@ type Qwen3Model struct {
|
|||
// Qwen3DecoderLayer is a single transformer block.
|
||||
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
|
||||
type Qwen3DecoderLayer struct {
|
||||
InputNorm *RMSNormModule // Pre-attention norm
|
||||
InputNorm *RMSNormModule // Pre-attention norm
|
||||
PostAttnNorm *RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
|
||||
Attention *Qwen3Attention
|
||||
MLP *Qwen3MLP
|
||||
Attention *Qwen3Attention
|
||||
MLP *Qwen3MLP
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
|
||||
|
|
@ -136,9 +137,7 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
return nil, fmt.Errorf("qwen3: no .safetensors files found in %s", modelPath)
|
||||
}
|
||||
for _, path := range matches {
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
maps.Insert(weights, LoadSafetensors(path))
|
||||
if err := lastError(); err != nil {
|
||||
return nil, fmt.Errorf("qwen3: load weights %s: %w", filepath.Base(path), err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
|
|||
|
||||
// Non-self-mapping: control chars, space, DEL, and gaps
|
||||
n := 0
|
||||
for b := 0; b < 256; b++ {
|
||||
for b := range 256 {
|
||||
if _, ok := encoder[byte(b)]; !ok {
|
||||
r := rune(256 + n)
|
||||
encoder[byte(b)] = r
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ func TestBuildGPT2ByteMaps(t *testing.T) {
|
|||
}
|
||||
|
||||
// Round-trip: every byte should survive encode → decode
|
||||
for b := 0; b < 256; b++ {
|
||||
for b := range 256 {
|
||||
r := encoder[byte(b)]
|
||||
got := decoder[r]
|
||||
if got != byte(b) {
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ func TestLoRA_EndToEnd(t *testing.T) {
|
|||
var initialLoss, finalLoss float64
|
||||
const numSteps = 5
|
||||
|
||||
for step := 0; step < numSteps; step++ {
|
||||
for step := range numSteps {
|
||||
// Fresh caches each step (stateful — can't reuse across gradient calls).
|
||||
caches := gemma.NewCache()
|
||||
|
||||
|
|
@ -224,7 +224,7 @@ func TestLoRA_GradientCheckpointing(t *testing.T) {
|
|||
var initialLoss, finalLoss float64
|
||||
const numSteps = 3
|
||||
|
||||
for step := 0; step < numSteps; step++ {
|
||||
for step := range numSteps {
|
||||
caches := gemma.NewCache()
|
||||
|
||||
// Wrap the model forward pass in Checkpoint to recompute activations
|
||||
|
|
@ -326,7 +326,7 @@ func TestLoRA_MixedPrecision(t *testing.T) {
|
|||
var initialLoss, finalLoss float64
|
||||
const numSteps = 5
|
||||
|
||||
for step := 0; step < numSteps; step++ {
|
||||
for step := range numSteps {
|
||||
caches := gemma.NewCache()
|
||||
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,8 @@ func (b *metalBackend) LoadModel(path string, opts ...inference.LoadOption) (inf
|
|||
slog.Warn("mlx: GPULayers=0 ignored — Metal always uses full GPU offload")
|
||||
}
|
||||
m, err := metal.LoadAndInit(path, metal.LoadConfig{
|
||||
ContextLen: cfg.ContextLen,
|
||||
ContextLen: cfg.ContextLen,
|
||||
AdapterPath: cfg.AdapterPath,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue