From 754d6e2f934f95bd3af9bef269da19a6e01210c9 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 20:59:37 +0000 Subject: [PATCH] =?UTF-8?q?fix(metal):=20error=20handling=20audit=20?= =?UTF-8?q?=E2=80=94=20propagate=20MLX=20errors=20instead=20of=20swallowin?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace checkError() log+swallow with lastError() that returns real MLX error messages. Add Eval/EvalAsync as error-returning variants of Materialize. Generate loop now propagates GPU errors via model.Err(). LoadAllSafetensors returns (map, error). Model loaders check lastError() after safetensors load. 180 tests passing. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- FINDINGS.md | 39 ++++++++++++++++ TODO.md | 2 +- internal/metal/error_test.go | 46 +++++++++++++++++++ internal/metal/gemma3.go | 3 ++ internal/metal/generate.go | 16 +++++-- internal/metal/grad.go | 18 +++++--- internal/metal/io.go | 19 ++++++-- internal/metal/lora.go | 6 ++- internal/metal/lora_test.go | 10 ++++- internal/metal/metal.go | 86 ++++++++++++++++++++++++++---------- internal/metal/qwen3.go | 3 ++ 11 files changed, 207 insertions(+), 41 deletions(-) create mode 100644 internal/metal/error_test.go diff --git a/FINDINGS.md b/FINDINGS.md index 05b6835..f83ad9e 100644 --- a/FINDINGS.md +++ b/FINDINGS.md @@ -415,3 +415,42 @@ Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses 3. **SDPA efficient**: 32-head seq=512 attention at 1.8ms is practical for real-time inference. 4. **Sampling overhead**: Full chain (TopP+MinP+TopK) adds ~560μs over greedy — acceptable per token. 5. **Linear layer**: Single-token forward through 2048→2048 at 181μs suggests ~5500 layers/sec ceiling for per-token decode. + +--- + +## 2026-02-19: Error Handling Audit — COMPLETED + +### What changed + +The old `checkError()` function logged errors via `slog.Error` and swallowed them. The mlx-c error handler (`mlx_go_error_handler`) stored the error string in a C static variable, but no Go code read it back as an error value. + +### New error model + +1. **`lastError() error`** — reads and clears the C-level error string. Returns `fmt.Errorf("mlx: %s", msg)` or nil. All callers now get real MLX error messages instead of generic "failed" strings. + +2. **`Eval(...*Array) error`** — error-returning variant of Materialize. Checks `mlx_eval` return code and surfaces errors through `lastError()`. The key error surface for GPU computation failures (OOM, invalid graph, etc.). + +3. **`EvalAsync(...*Array) error`** — same for async evaluation. + +4. **`Materialize()` unchanged** — delegates to `Eval()`, logs errors. 237 existing callsites (mostly tests) unchanged. + +### Error propagation paths + +| Path | Before | After | +|------|--------|-------| +| Generate loop | Silent failure | `Eval()` → `m.lastErr` → `model.Err()` | +| Safetensors load | Silent zero results | `LoadAllSafetensors` returns error; `LoadSafetensors` checks rc | +| Model load (gemma3/qwen3) | Missing weights panic | Check `lastError()` after each safetensors file | +| VJP/JVP/ValueAndGrad | Generic "vjp failed" | Real MLX error message via `lastError()` | +| LoRA save | Generic "save failed" | Real MLX error message via `lastError()` | + +### C-level changes + +- Removed `fprintf(stderr, ...)` from error handler — errors now surface through Go, not stderr +- Added `get_and_clear_last_error()` — reads and atomically clears the stored error string +- Old `get_last_error()` removed (never cleared, so stale errors persisted across operations) + +### Test results + +- 180 tests passing (176 existing + 4 new error handling tests) +- New tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file diff --git a/TODO.md b/TODO.md index a14fbfa..9f0ae16 100644 --- a/TODO.md +++ b/TODO.md @@ -38,7 +38,7 @@ Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md` - [x] **All CGO moved to `internal/metal/`** — 19 source files, 10 test files, 148 tests passing. - [x] **Public API: `TextModel`, `Backend`, functional options** — Clean root package, compiles on all platforms. - [x] **Integration tests** — 7 tests for public API (backend registration, options, LoadModel paths). -- [ ] **Error handling audit** — `checkError()` still logs + swallows. Needs conversion to error returns (return code 0/1 + stored error string pattern confirmed by CLion Claude). Low priority — existing behaviour, not a regression. +- [x] **Error handling audit** — ✅ `checkError()` replaced with `lastError() error` (reads + clears C-level error string). Added `Eval(...*Array) error` and `EvalAsync(...*Array) error` as error-returning variants of Materialize. Generate loop propagates errors via `m.lastErr`. `LoadAllSafetensors` returns `(map, error)`. Model loaders (gemma3, qwen3) check `lastError()` after safetensors load. grad.go/lora.go now surface real MLX error messages. 4 new tests in error_test.go. - [ ] **Memory management — deterministic cleanup** — `Close()` stub in place. CLion Claude confirmed `mlx_array_free()` is safe on graph-referenced arrays (refcounted via shared_ptr). Double-free is UB. Can now implement per-step cleanup. - [ ] **Documentation** — Public API has godoc but needs examples for common workflows. diff --git a/internal/metal/error_test.go b/internal/metal/error_test.go new file mode 100644 index 0000000..ba26906 --- /dev/null +++ b/internal/metal/error_test.go @@ -0,0 +1,46 @@ +//go:build darwin && arm64 + +package metal + +import ( + "testing" +) + +func TestEval_Success(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 3) + b := FromValues([]float32{4, 5, 6}, 3) + c := Add(a, b) + + if err := Eval(c); err != nil { + t.Fatalf("Eval should succeed: %v", err) + } + + got := c.Floats() + want := []float32{5, 7, 9} + for i := range got { + if got[i] != want[i] { + t.Errorf("got[%d] = %f, want %f", i, got[i], want[i]) + } + } +} + +func TestEval_NilArray(t *testing.T) { + // Eval should handle nil arrays gracefully. + if err := Eval(nil); err != nil { + t.Fatalf("Eval(nil) should not error: %v", err) + } +} + +func TestLastError_NoError(t *testing.T) { + // When no error has occurred, lastError should return nil. + if err := lastError(); err != nil { + t.Errorf("lastError should be nil when no error occurred, got: %v", err) + } +} + +func TestLoadAllSafetensors_MissingFile(t *testing.T) { + _, err := LoadAllSafetensors("/nonexistent/path/model.safetensors") + if err == nil { + t.Fatal("LoadAllSafetensors should fail for missing file") + } +} diff --git a/internal/metal/gemma3.go b/internal/metal/gemma3.go index 6429fae..9da21fc 100644 --- a/internal/metal/gemma3.go +++ b/internal/metal/gemma3.go @@ -182,6 +182,9 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) { for name, arr := range LoadSafetensors(path) { weights[name] = arr } + if err := lastError(); err != nil { + return nil, fmt.Errorf("gemma3: load weights %s: %w", filepath.Base(path), err) + } } // Helper to resolve weight with language_model. prefix fallback diff --git a/internal/metal/generate.go b/internal/metal/generate.go index 3b6cd6a..7c090e5 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -4,6 +4,7 @@ package metal import ( "context" + "fmt" "iter" ) @@ -68,7 +69,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) input := FromValues(tokens, len(tokens)) input = Reshape(input, 1, int32(len(tokens))) logits := m.model.Forward(input, caches) - Materialize(logits) + if err := Eval(logits); err != nil { + m.lastErr = fmt.Errorf("prefill: %w", err) + return + } for i := 0; i < cfg.MaxTokens; i++ { select { @@ -82,7 +86,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1))) lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2))) next := sampler.Sample(lastPos) - Materialize(next) + if err := Eval(next); err != nil { + m.lastErr = fmt.Errorf("sample step %d: %w", i, err) + return + } id := int32(next.Int()) @@ -105,7 +112,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) nextInput := FromValues([]int32{id}, 1) nextInput = Reshape(nextInput, 1, 1) logits = m.model.Forward(nextInput, caches) - Materialize(logits) + if err := Eval(logits); err != nil { + m.lastErr = fmt.Errorf("decode step %d: %w", i, err) + return + } } } } diff --git a/internal/metal/grad.go b/internal/metal/grad.go index 55a588c..0199cce 100644 --- a/internal/metal/grad.go +++ b/internal/metal/grad.go @@ -101,8 +101,10 @@ func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (out rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec) if rc != 0 { - checkError() - return nil, nil, fmt.Errorf("mlx: vjp failed") + if err := lastError(); err != nil { + return nil, nil, err + } + return nil, nil, fmt.Errorf("mlx: vjp failed (rc=%d)", rc) } outputs = vectorToArrays(outVec) @@ -143,8 +145,10 @@ func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outpu rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec) if rc != 0 { - checkError() - return nil, nil, fmt.Errorf("mlx: jvp failed") + if err := lastError(); err != nil { + return nil, nil, err + } + return nil, nil, fmt.Errorf("mlx: jvp failed (rc=%d)", rc) } outputs = vectorToArrays(outVec) @@ -209,8 +213,10 @@ func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err e rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec) if rc != 0 { - checkError() - return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed") + if err := lastError(); err != nil { + return nil, nil, err + } + return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed (rc=%d)", rc) } values = vectorToArrays(valVec) diff --git a/internal/metal/io.go b/internal/metal/io.go index 8698dc9..73f08b6 100644 --- a/internal/metal/io.go +++ b/internal/metal/io.go @@ -9,6 +9,7 @@ package metal import "C" import ( + "fmt" "iter" "runtime" "unsafe" @@ -16,6 +17,7 @@ import ( // LoadSafetensors loads tensors from a .safetensors file, returning an iterator // over (name, array) pairs. Tensors are loaded lazily on the CPU stream. +// Use [LoadAllSafetensors] for an error-returning variant. func LoadSafetensors(path string) iter.Seq2[string, *Array] { Init() return func(yield func(string, *Array) bool) { @@ -31,7 +33,11 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] { cpu := C.mlx_default_cpu_stream_new() defer C.mlx_stream_free(cpu) - C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu) + rc := C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu) + if rc != 0 { + // Error will surface via lastError(); caller iterates zero tensors. + return + } it := C.mlx_map_string_to_array_iterator_new(string2array) defer C.mlx_map_string_to_array_iterator_free(it) @@ -54,10 +60,17 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] { } // LoadAllSafetensors loads all tensors from a .safetensors file into a map. -func LoadAllSafetensors(path string) map[string]*Array { +// Returns an error if the file cannot be loaded. +func LoadAllSafetensors(path string) (map[string]*Array, error) { tensors := make(map[string]*Array) for name, arr := range LoadSafetensors(path) { tensors[name] = arr } - return tensors + if len(tensors) == 0 { + if err := lastError(); err != nil { + return nil, err + } + return nil, fmt.Errorf("mlx: no tensors loaded from %s", path) + } + return tensors, nil } diff --git a/internal/metal/lora.go b/internal/metal/lora.go index 9f9a423..9f649ee 100644 --- a/internal/metal/lora.go +++ b/internal/metal/lora.go @@ -230,8 +230,10 @@ func SaveSafetensors(path string, weights map[string]*Array) error { rc := C.mlx_save_safetensors(cPath, cMap, cMeta) if rc != 0 { - checkError() - return fmt.Errorf("mlx: save safetensors failed: %s", path) + if err := lastError(); err != nil { + return err + } + return fmt.Errorf("mlx: save safetensors failed: %s (rc=%d)", path, rc) } return nil } diff --git a/internal/metal/lora_test.go b/internal/metal/lora_test.go index 10b16e8..64828ac 100644 --- a/internal/metal/lora_test.go +++ b/internal/metal/lora_test.go @@ -250,7 +250,10 @@ func TestSaveSafetensors(t *testing.T) { } // Load it back - loaded := LoadAllSafetensors(path) + loaded, err := LoadAllSafetensors(path) + if err != nil { + t.Fatalf("LoadAllSafetensors: %v", err) + } Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"]) aLoaded := loaded["layer.lora_a"].Floats() @@ -290,7 +293,10 @@ func TestLoRAAdapter_Save(t *testing.T) { } // Load and verify - loaded := LoadAllSafetensors(path) + loaded, err := LoadAllSafetensors(path) + if err != nil { + t.Fatalf("LoadAllSafetensors: %v", err) + } aKey := "model.layers.0.self_attn.q_proj.lora_a" bKey := "model.layers.0.self_attn.q_proj.lora_b" diff --git a/internal/metal/metal.go b/internal/metal/metal.go index d997407..818ceef 100644 --- a/internal/metal/metal.go +++ b/internal/metal/metal.go @@ -18,7 +18,6 @@ package metal static const char *last_mlx_error = NULL; static void mlx_go_error_handler(const char *msg, void *data) { - fprintf(stderr, "MLX ERROR: %s\n", msg); last_mlx_error = msg; } @@ -26,13 +25,16 @@ static void set_error_handler() { mlx_set_error_handler(&mlx_go_error_handler, NULL, NULL); } -static const char* get_last_error() { - return last_mlx_error; +static const char* get_and_clear_last_error() { + const char *msg = last_mlx_error; + last_mlx_error = NULL; + return msg; } */ import "C" import ( + "fmt" "log/slog" "sync" ) @@ -43,29 +45,24 @@ var initOnce sync.Once func Init() { initOnce.Do(func() { C.set_error_handler() - slog.Debug("mlx: initialized with Metal backend") + slog.Debug("mlx: initialised with Metal backend") }) } -// checkError logs the last MLX error if any occurred. -func checkError() { - if msg := C.get_last_error(); msg != nil { - slog.Error("mlx", "error", C.GoString(msg)) +// lastError reads and clears the most recent MLX-C error. +// Returns nil if no error occurred since the last call. +func lastError() error { + msg := C.get_and_clear_last_error() + if msg == nil { + return nil } + return fmt.Errorf("mlx: %s", C.GoString(msg)) } -// Materialize synchronously evaluates arrays, computing their values on the GPU. -// This is the MLX equivalent of forcing lazy computation to complete. -func Materialize(outputs ...*Array) { - doMaterialize(outputs, false) -} - -// MaterializeAsync queues arrays for asynchronous GPU evaluation. -func MaterializeAsync(outputs ...*Array) { - doMaterialize(outputs, true) -} - -func doMaterialize(outputs []*Array, async bool) { +// Eval synchronously evaluates arrays on the GPU and returns any error. +// This is the error-returning variant of Materialize. Use it in code paths +// that need to propagate errors (e.g. the generate loop). +func Eval(outputs ...*Array) error { Init() vector := C.mlx_vector_array_new() defer C.mlx_vector_array_free(vector) @@ -76,10 +73,51 @@ func doMaterialize(outputs []*Array, async bool) { } } - if async { - C.mlx_async_eval(vector) - } else { - C.mlx_eval(vector) + rc := C.mlx_eval(vector) + if rc != 0 { + if err := lastError(); err != nil { + return err + } + return fmt.Errorf("mlx: eval failed (rc=%d)", rc) + } + return nil +} + +// EvalAsync queues arrays for asynchronous GPU evaluation and returns any error. +func EvalAsync(outputs ...*Array) error { + Init() + vector := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vector) + + for _, output := range outputs { + if output != nil && output.Valid() { + C.mlx_vector_array_append_value(vector, output.ctx) + } + } + + rc := C.mlx_async_eval(vector) + if rc != 0 { + if err := lastError(); err != nil { + return err + } + return fmt.Errorf("mlx: async eval failed (rc=%d)", rc) + } + return nil +} + +// Materialize synchronously evaluates arrays, computing their values on the GPU. +// Errors are logged but not returned — use [Eval] when error propagation is needed. +func Materialize(outputs ...*Array) { + if err := Eval(outputs...); err != nil { + slog.Error("mlx: materialize", "error", err) + } +} + +// MaterializeAsync queues arrays for asynchronous GPU evaluation. +// Errors are logged but not returned — use [EvalAsync] when error propagation is needed. +func MaterializeAsync(outputs ...*Array) { + if err := EvalAsync(outputs...); err != nil { + slog.Error("mlx: materialize async", "error", err) } } diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go index e288d0a..33f32c5 100644 --- a/internal/metal/qwen3.go +++ b/internal/metal/qwen3.go @@ -122,6 +122,9 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { for name, arr := range LoadSafetensors(path) { weights[name] = arr } + if err := lastError(); err != nil { + return nil, fmt.Errorf("qwen3: load weights %s: %w", filepath.Base(path), err) + } } w := func(name string) *Array { return resolveWeight(weights, name) }