fix(metal): error handling audit — propagate MLX errors instead of swallowing

Replace checkError() log+swallow with lastError() that returns real MLX
error messages. Add Eval/EvalAsync as error-returning variants of
Materialize. Generate loop now propagates GPU errors via model.Err().
LoadAllSafetensors returns (map, error). Model loaders check lastError()
after safetensors load. 180 tests passing.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 20:59:37 +00:00
parent ff01175a62
commit 754d6e2f93
11 changed files with 207 additions and 41 deletions

View file

@ -415,3 +415,42 @@ Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses
3. **SDPA efficient**: 32-head seq=512 attention at 1.8ms is practical for real-time inference.
4. **Sampling overhead**: Full chain (TopP+MinP+TopK) adds ~560μs over greedy — acceptable per token.
5. **Linear layer**: Single-token forward through 2048→2048 at 181μs suggests ~5500 layers/sec ceiling for per-token decode.
---
## 2026-02-19: Error Handling Audit — COMPLETED
### What changed
The old `checkError()` function logged errors via `slog.Error` and swallowed them. The mlx-c error handler (`mlx_go_error_handler`) stored the error string in a C static variable, but no Go code read it back as an error value.
### New error model
1. **`lastError() error`** — reads and clears the C-level error string. Returns `fmt.Errorf("mlx: %s", msg)` or nil. All callers now get real MLX error messages instead of generic "failed" strings.
2. **`Eval(...*Array) error`** — error-returning variant of Materialize. Checks `mlx_eval` return code and surfaces errors through `lastError()`. The key error surface for GPU computation failures (OOM, invalid graph, etc.).
3. **`EvalAsync(...*Array) error`** — same for async evaluation.
4. **`Materialize()` unchanged** — delegates to `Eval()`, logs errors. 237 existing callsites (mostly tests) unchanged.
### Error propagation paths
| Path | Before | After |
|------|--------|-------|
| Generate loop | Silent failure | `Eval()``m.lastErr``model.Err()` |
| Safetensors load | Silent zero results | `LoadAllSafetensors` returns error; `LoadSafetensors` checks rc |
| Model load (gemma3/qwen3) | Missing weights panic | Check `lastError()` after each safetensors file |
| VJP/JVP/ValueAndGrad | Generic "vjp failed" | Real MLX error message via `lastError()` |
| LoRA save | Generic "save failed" | Real MLX error message via `lastError()` |
### C-level changes
- Removed `fprintf(stderr, ...)` from error handler — errors now surface through Go, not stderr
- Added `get_and_clear_last_error()` — reads and atomically clears the stored error string
- Old `get_last_error()` removed (never cleared, so stale errors persisted across operations)
### Test results
- 180 tests passing (176 existing + 4 new error handling tests)
- New tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file

View file

@ -38,7 +38,7 @@ Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md`
- [x] **All CGO moved to `internal/metal/`** — 19 source files, 10 test files, 148 tests passing.
- [x] **Public API: `TextModel`, `Backend`, functional options** — Clean root package, compiles on all platforms.
- [x] **Integration tests** — 7 tests for public API (backend registration, options, LoadModel paths).
- [ ] **Error handling audit**`checkError()` still logs + swallows. Needs conversion to error returns (return code 0/1 + stored error string pattern confirmed by CLion Claude). Low priority — existing behaviour, not a regression.
- [x] **Error handling audit** — ✅ `checkError()` replaced with `lastError() error` (reads + clears C-level error string). Added `Eval(...*Array) error` and `EvalAsync(...*Array) error` as error-returning variants of Materialize. Generate loop propagates errors via `m.lastErr`. `LoadAllSafetensors` returns `(map, error)`. Model loaders (gemma3, qwen3) check `lastError()` after safetensors load. grad.go/lora.go now surface real MLX error messages. 4 new tests in error_test.go.
- [ ] **Memory management — deterministic cleanup**`Close()` stub in place. CLion Claude confirmed `mlx_array_free()` is safe on graph-referenced arrays (refcounted via shared_ptr). Double-free is UB. Can now implement per-step cleanup.
- [ ] **Documentation** — Public API has godoc but needs examples for common workflows.

View file

@ -0,0 +1,46 @@
//go:build darwin && arm64
package metal
import (
"testing"
)
func TestEval_Success(t *testing.T) {
a := FromValues([]float32{1, 2, 3}, 3)
b := FromValues([]float32{4, 5, 6}, 3)
c := Add(a, b)
if err := Eval(c); err != nil {
t.Fatalf("Eval should succeed: %v", err)
}
got := c.Floats()
want := []float32{5, 7, 9}
for i := range got {
if got[i] != want[i] {
t.Errorf("got[%d] = %f, want %f", i, got[i], want[i])
}
}
}
func TestEval_NilArray(t *testing.T) {
// Eval should handle nil arrays gracefully.
if err := Eval(nil); err != nil {
t.Fatalf("Eval(nil) should not error: %v", err)
}
}
func TestLastError_NoError(t *testing.T) {
// When no error has occurred, lastError should return nil.
if err := lastError(); err != nil {
t.Errorf("lastError should be nil when no error occurred, got: %v", err)
}
}
func TestLoadAllSafetensors_MissingFile(t *testing.T) {
_, err := LoadAllSafetensors("/nonexistent/path/model.safetensors")
if err == nil {
t.Fatal("LoadAllSafetensors should fail for missing file")
}
}

View file

@ -182,6 +182,9 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
for name, arr := range LoadSafetensors(path) {
weights[name] = arr
}
if err := lastError(); err != nil {
return nil, fmt.Errorf("gemma3: load weights %s: %w", filepath.Base(path), err)
}
}
// Helper to resolve weight with language_model. prefix fallback

View file

@ -4,6 +4,7 @@ package metal
import (
"context"
"fmt"
"iter"
)
@ -68,7 +69,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
input := FromValues(tokens, len(tokens))
input = Reshape(input, 1, int32(len(tokens)))
logits := m.model.Forward(input, caches)
Materialize(logits)
if err := Eval(logits); err != nil {
m.lastErr = fmt.Errorf("prefill: %w", err)
return
}
for i := 0; i < cfg.MaxTokens; i++ {
select {
@ -82,7 +86,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2)))
next := sampler.Sample(lastPos)
Materialize(next)
if err := Eval(next); err != nil {
m.lastErr = fmt.Errorf("sample step %d: %w", i, err)
return
}
id := int32(next.Int())
@ -105,7 +112,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
nextInput := FromValues([]int32{id}, 1)
nextInput = Reshape(nextInput, 1, 1)
logits = m.model.Forward(nextInput, caches)
Materialize(logits)
if err := Eval(logits); err != nil {
m.lastErr = fmt.Errorf("decode step %d: %w", i, err)
return
}
}
}
}

View file

@ -101,8 +101,10 @@ func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (out
rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: vjp failed")
if err := lastError(); err != nil {
return nil, nil, err
}
return nil, nil, fmt.Errorf("mlx: vjp failed (rc=%d)", rc)
}
outputs = vectorToArrays(outVec)
@ -143,8 +145,10 @@ func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outpu
rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: jvp failed")
if err := lastError(); err != nil {
return nil, nil, err
}
return nil, nil, fmt.Errorf("mlx: jvp failed (rc=%d)", rc)
}
outputs = vectorToArrays(outVec)
@ -209,8 +213,10 @@ func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err e
rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed")
if err := lastError(); err != nil {
return nil, nil, err
}
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed (rc=%d)", rc)
}
values = vectorToArrays(valVec)

View file

@ -9,6 +9,7 @@ package metal
import "C"
import (
"fmt"
"iter"
"runtime"
"unsafe"
@ -16,6 +17,7 @@ import (
// LoadSafetensors loads tensors from a .safetensors file, returning an iterator
// over (name, array) pairs. Tensors are loaded lazily on the CPU stream.
// Use [LoadAllSafetensors] for an error-returning variant.
func LoadSafetensors(path string) iter.Seq2[string, *Array] {
Init()
return func(yield func(string, *Array) bool) {
@ -31,7 +33,11 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
rc := C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
if rc != 0 {
// Error will surface via lastError(); caller iterates zero tensors.
return
}
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
@ -54,10 +60,17 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
}
// LoadAllSafetensors loads all tensors from a .safetensors file into a map.
func LoadAllSafetensors(path string) map[string]*Array {
// Returns an error if the file cannot be loaded.
func LoadAllSafetensors(path string) (map[string]*Array, error) {
tensors := make(map[string]*Array)
for name, arr := range LoadSafetensors(path) {
tensors[name] = arr
}
return tensors
if len(tensors) == 0 {
if err := lastError(); err != nil {
return nil, err
}
return nil, fmt.Errorf("mlx: no tensors loaded from %s", path)
}
return tensors, nil
}

View file

@ -230,8 +230,10 @@ func SaveSafetensors(path string, weights map[string]*Array) error {
rc := C.mlx_save_safetensors(cPath, cMap, cMeta)
if rc != 0 {
checkError()
return fmt.Errorf("mlx: save safetensors failed: %s", path)
if err := lastError(); err != nil {
return err
}
return fmt.Errorf("mlx: save safetensors failed: %s (rc=%d)", path, rc)
}
return nil
}

View file

@ -250,7 +250,10 @@ func TestSaveSafetensors(t *testing.T) {
}
// Load it back
loaded := LoadAllSafetensors(path)
loaded, err := LoadAllSafetensors(path)
if err != nil {
t.Fatalf("LoadAllSafetensors: %v", err)
}
Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"])
aLoaded := loaded["layer.lora_a"].Floats()
@ -290,7 +293,10 @@ func TestLoRAAdapter_Save(t *testing.T) {
}
// Load and verify
loaded := LoadAllSafetensors(path)
loaded, err := LoadAllSafetensors(path)
if err != nil {
t.Fatalf("LoadAllSafetensors: %v", err)
}
aKey := "model.layers.0.self_attn.q_proj.lora_a"
bKey := "model.layers.0.self_attn.q_proj.lora_b"

View file

@ -18,7 +18,6 @@ package metal
static const char *last_mlx_error = NULL;
static void mlx_go_error_handler(const char *msg, void *data) {
fprintf(stderr, "MLX ERROR: %s\n", msg);
last_mlx_error = msg;
}
@ -26,13 +25,16 @@ static void set_error_handler() {
mlx_set_error_handler(&mlx_go_error_handler, NULL, NULL);
}
static const char* get_last_error() {
return last_mlx_error;
static const char* get_and_clear_last_error() {
const char *msg = last_mlx_error;
last_mlx_error = NULL;
return msg;
}
*/
import "C"
import (
"fmt"
"log/slog"
"sync"
)
@ -43,29 +45,24 @@ var initOnce sync.Once
func Init() {
initOnce.Do(func() {
C.set_error_handler()
slog.Debug("mlx: initialized with Metal backend")
slog.Debug("mlx: initialised with Metal backend")
})
}
// checkError logs the last MLX error if any occurred.
func checkError() {
if msg := C.get_last_error(); msg != nil {
slog.Error("mlx", "error", C.GoString(msg))
// lastError reads and clears the most recent MLX-C error.
// Returns nil if no error occurred since the last call.
func lastError() error {
msg := C.get_and_clear_last_error()
if msg == nil {
return nil
}
return fmt.Errorf("mlx: %s", C.GoString(msg))
}
// Materialize synchronously evaluates arrays, computing their values on the GPU.
// This is the MLX equivalent of forcing lazy computation to complete.
func Materialize(outputs ...*Array) {
doMaterialize(outputs, false)
}
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
func MaterializeAsync(outputs ...*Array) {
doMaterialize(outputs, true)
}
func doMaterialize(outputs []*Array, async bool) {
// Eval synchronously evaluates arrays on the GPU and returns any error.
// This is the error-returning variant of Materialize. Use it in code paths
// that need to propagate errors (e.g. the generate loop).
func Eval(outputs ...*Array) error {
Init()
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
@ -76,10 +73,51 @@ func doMaterialize(outputs []*Array, async bool) {
}
}
if async {
C.mlx_async_eval(vector)
} else {
C.mlx_eval(vector)
rc := C.mlx_eval(vector)
if rc != 0 {
if err := lastError(); err != nil {
return err
}
return fmt.Errorf("mlx: eval failed (rc=%d)", rc)
}
return nil
}
// EvalAsync queues arrays for asynchronous GPU evaluation and returns any error.
func EvalAsync(outputs ...*Array) error {
Init()
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output != nil && output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
rc := C.mlx_async_eval(vector)
if rc != 0 {
if err := lastError(); err != nil {
return err
}
return fmt.Errorf("mlx: async eval failed (rc=%d)", rc)
}
return nil
}
// Materialize synchronously evaluates arrays, computing their values on the GPU.
// Errors are logged but not returned — use [Eval] when error propagation is needed.
func Materialize(outputs ...*Array) {
if err := Eval(outputs...); err != nil {
slog.Error("mlx: materialize", "error", err)
}
}
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
// Errors are logged but not returned — use [EvalAsync] when error propagation is needed.
func MaterializeAsync(outputs ...*Array) {
if err := EvalAsync(outputs...); err != nil {
slog.Error("mlx: materialize async", "error", err)
}
}

View file

@ -122,6 +122,9 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
for name, arr := range LoadSafetensors(path) {
weights[name] = arr
}
if err := lastError(); err != nil {
return nil, fmt.Errorf("qwen3: load weights %s: %w", filepath.Base(path), err)
}
}
w := func(name string) *Array { return resolveWeight(weights, name) }