fix(metal): error handling audit — propagate MLX errors instead of swallowing
Replace checkError() log+swallow with lastError() that returns real MLX error messages. Add Eval/EvalAsync as error-returning variants of Materialize. Generate loop now propagates GPU errors via model.Err(). LoadAllSafetensors returns (map, error). Model loaders check lastError() after safetensors load. 180 tests passing. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ff01175a62
commit
754d6e2f93
11 changed files with 207 additions and 41 deletions
39
FINDINGS.md
39
FINDINGS.md
|
|
@ -415,3 +415,42 @@ Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses
|
|||
3. **SDPA efficient**: 32-head seq=512 attention at 1.8ms is practical for real-time inference.
|
||||
4. **Sampling overhead**: Full chain (TopP+MinP+TopK) adds ~560μs over greedy — acceptable per token.
|
||||
5. **Linear layer**: Single-token forward through 2048→2048 at 181μs suggests ~5500 layers/sec ceiling for per-token decode.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-19: Error Handling Audit — COMPLETED
|
||||
|
||||
### What changed
|
||||
|
||||
The old `checkError()` function logged errors via `slog.Error` and swallowed them. The mlx-c error handler (`mlx_go_error_handler`) stored the error string in a C static variable, but no Go code read it back as an error value.
|
||||
|
||||
### New error model
|
||||
|
||||
1. **`lastError() error`** — reads and clears the C-level error string. Returns `fmt.Errorf("mlx: %s", msg)` or nil. All callers now get real MLX error messages instead of generic "failed" strings.
|
||||
|
||||
2. **`Eval(...*Array) error`** — error-returning variant of Materialize. Checks `mlx_eval` return code and surfaces errors through `lastError()`. The key error surface for GPU computation failures (OOM, invalid graph, etc.).
|
||||
|
||||
3. **`EvalAsync(...*Array) error`** — same for async evaluation.
|
||||
|
||||
4. **`Materialize()` unchanged** — delegates to `Eval()`, logs errors. 237 existing callsites (mostly tests) unchanged.
|
||||
|
||||
### Error propagation paths
|
||||
|
||||
| Path | Before | After |
|
||||
|------|--------|-------|
|
||||
| Generate loop | Silent failure | `Eval()` → `m.lastErr` → `model.Err()` |
|
||||
| Safetensors load | Silent zero results | `LoadAllSafetensors` returns error; `LoadSafetensors` checks rc |
|
||||
| Model load (gemma3/qwen3) | Missing weights panic | Check `lastError()` after each safetensors file |
|
||||
| VJP/JVP/ValueAndGrad | Generic "vjp failed" | Real MLX error message via `lastError()` |
|
||||
| LoRA save | Generic "save failed" | Real MLX error message via `lastError()` |
|
||||
|
||||
### C-level changes
|
||||
|
||||
- Removed `fprintf(stderr, ...)` from error handler — errors now surface through Go, not stderr
|
||||
- Added `get_and_clear_last_error()` — reads and atomically clears the stored error string
|
||||
- Old `get_last_error()` removed (never cleared, so stale errors persisted across operations)
|
||||
|
||||
### Test results
|
||||
|
||||
- 180 tests passing (176 existing + 4 new error handling tests)
|
||||
- New tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file
|
||||
|
|
|
|||
2
TODO.md
2
TODO.md
|
|
@ -38,7 +38,7 @@ Implementation plan: `docs/plans/2026-02-19-backend-abstraction-plan.md`
|
|||
- [x] **All CGO moved to `internal/metal/`** — 19 source files, 10 test files, 148 tests passing.
|
||||
- [x] **Public API: `TextModel`, `Backend`, functional options** — Clean root package, compiles on all platforms.
|
||||
- [x] **Integration tests** — 7 tests for public API (backend registration, options, LoadModel paths).
|
||||
- [ ] **Error handling audit** — `checkError()` still logs + swallows. Needs conversion to error returns (return code 0/1 + stored error string pattern confirmed by CLion Claude). Low priority — existing behaviour, not a regression.
|
||||
- [x] **Error handling audit** — ✅ `checkError()` replaced with `lastError() error` (reads + clears C-level error string). Added `Eval(...*Array) error` and `EvalAsync(...*Array) error` as error-returning variants of Materialize. Generate loop propagates errors via `m.lastErr`. `LoadAllSafetensors` returns `(map, error)`. Model loaders (gemma3, qwen3) check `lastError()` after safetensors load. grad.go/lora.go now surface real MLX error messages. 4 new tests in error_test.go.
|
||||
- [ ] **Memory management — deterministic cleanup** — `Close()` stub in place. CLion Claude confirmed `mlx_array_free()` is safe on graph-referenced arrays (refcounted via shared_ptr). Double-free is UB. Can now implement per-step cleanup.
|
||||
- [ ] **Documentation** — Public API has godoc but needs examples for common workflows.
|
||||
|
||||
|
|
|
|||
46
internal/metal/error_test.go
Normal file
46
internal/metal/error_test.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package metal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEval_Success(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3}, 3)
|
||||
b := FromValues([]float32{4, 5, 6}, 3)
|
||||
c := Add(a, b)
|
||||
|
||||
if err := Eval(c); err != nil {
|
||||
t.Fatalf("Eval should succeed: %v", err)
|
||||
}
|
||||
|
||||
got := c.Floats()
|
||||
want := []float32{5, 7, 9}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("got[%d] = %f, want %f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEval_NilArray(t *testing.T) {
|
||||
// Eval should handle nil arrays gracefully.
|
||||
if err := Eval(nil); err != nil {
|
||||
t.Fatalf("Eval(nil) should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastError_NoError(t *testing.T) {
|
||||
// When no error has occurred, lastError should return nil.
|
||||
if err := lastError(); err != nil {
|
||||
t.Errorf("lastError should be nil when no error occurred, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAllSafetensors_MissingFile(t *testing.T) {
|
||||
_, err := LoadAllSafetensors("/nonexistent/path/model.safetensors")
|
||||
if err == nil {
|
||||
t.Fatal("LoadAllSafetensors should fail for missing file")
|
||||
}
|
||||
}
|
||||
|
|
@ -182,6 +182,9 @@ func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
|||
for name, arr := range LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
if err := lastError(); err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load weights %s: %w", filepath.Base(path), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to resolve weight with language_model. prefix fallback
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ package metal
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
)
|
||||
|
||||
|
|
@ -68,7 +69,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
input := FromValues(tokens, len(tokens))
|
||||
input = Reshape(input, 1, int32(len(tokens)))
|
||||
logits := m.model.Forward(input, caches)
|
||||
Materialize(logits)
|
||||
if err := Eval(logits); err != nil {
|
||||
m.lastErr = fmt.Errorf("prefill: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < cfg.MaxTokens; i++ {
|
||||
select {
|
||||
|
|
@ -82,7 +86,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
|
||||
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2)))
|
||||
next := sampler.Sample(lastPos)
|
||||
Materialize(next)
|
||||
if err := Eval(next); err != nil {
|
||||
m.lastErr = fmt.Errorf("sample step %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
id := int32(next.Int())
|
||||
|
||||
|
|
@ -105,7 +112,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
nextInput := FromValues([]int32{id}, 1)
|
||||
nextInput = Reshape(nextInput, 1, 1)
|
||||
logits = m.model.Forward(nextInput, caches)
|
||||
Materialize(logits)
|
||||
if err := Eval(logits); err != nil {
|
||||
m.lastErr = fmt.Errorf("decode step %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -101,8 +101,10 @@ func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (out
|
|||
|
||||
rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: vjp failed")
|
||||
if err := lastError(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nil, nil, fmt.Errorf("mlx: vjp failed (rc=%d)", rc)
|
||||
}
|
||||
|
||||
outputs = vectorToArrays(outVec)
|
||||
|
|
@ -143,8 +145,10 @@ func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outpu
|
|||
|
||||
rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: jvp failed")
|
||||
if err := lastError(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nil, nil, fmt.Errorf("mlx: jvp failed (rc=%d)", rc)
|
||||
}
|
||||
|
||||
outputs = vectorToArrays(outVec)
|
||||
|
|
@ -209,8 +213,10 @@ func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err e
|
|||
|
||||
rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed")
|
||||
if err := lastError(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed (rc=%d)", rc)
|
||||
}
|
||||
|
||||
values = vectorToArrays(valVec)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ package metal
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
|
@ -16,6 +17,7 @@ import (
|
|||
|
||||
// LoadSafetensors loads tensors from a .safetensors file, returning an iterator
|
||||
// over (name, array) pairs. Tensors are loaded lazily on the CPU stream.
|
||||
// Use [LoadAllSafetensors] for an error-returning variant.
|
||||
func LoadSafetensors(path string) iter.Seq2[string, *Array] {
|
||||
Init()
|
||||
return func(yield func(string, *Array) bool) {
|
||||
|
|
@ -31,7 +33,11 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
|
|||
cpu := C.mlx_default_cpu_stream_new()
|
||||
defer C.mlx_stream_free(cpu)
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
|
||||
rc := C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
|
||||
if rc != 0 {
|
||||
// Error will surface via lastError(); caller iterates zero tensors.
|
||||
return
|
||||
}
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
|
@ -54,10 +60,17 @@ func LoadSafetensors(path string) iter.Seq2[string, *Array] {
|
|||
}
|
||||
|
||||
// LoadAllSafetensors loads all tensors from a .safetensors file into a map.
|
||||
func LoadAllSafetensors(path string) map[string]*Array {
|
||||
// Returns an error if the file cannot be loaded.
|
||||
func LoadAllSafetensors(path string) (map[string]*Array, error) {
|
||||
tensors := make(map[string]*Array)
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
tensors[name] = arr
|
||||
}
|
||||
return tensors
|
||||
if len(tensors) == 0 {
|
||||
if err := lastError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("mlx: no tensors loaded from %s", path)
|
||||
}
|
||||
return tensors, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -230,8 +230,10 @@ func SaveSafetensors(path string, weights map[string]*Array) error {
|
|||
|
||||
rc := C.mlx_save_safetensors(cPath, cMap, cMeta)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return fmt.Errorf("mlx: save safetensors failed: %s", path)
|
||||
if err := lastError(); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("mlx: save safetensors failed: %s (rc=%d)", path, rc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -250,7 +250,10 @@ func TestSaveSafetensors(t *testing.T) {
|
|||
}
|
||||
|
||||
// Load it back
|
||||
loaded := LoadAllSafetensors(path)
|
||||
loaded, err := LoadAllSafetensors(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAllSafetensors: %v", err)
|
||||
}
|
||||
Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"])
|
||||
|
||||
aLoaded := loaded["layer.lora_a"].Floats()
|
||||
|
|
@ -290,7 +293,10 @@ func TestLoRAAdapter_Save(t *testing.T) {
|
|||
}
|
||||
|
||||
// Load and verify
|
||||
loaded := LoadAllSafetensors(path)
|
||||
loaded, err := LoadAllSafetensors(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAllSafetensors: %v", err)
|
||||
}
|
||||
aKey := "model.layers.0.self_attn.q_proj.lora_a"
|
||||
bKey := "model.layers.0.self_attn.q_proj.lora_b"
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package metal
|
|||
static const char *last_mlx_error = NULL;
|
||||
|
||||
static void mlx_go_error_handler(const char *msg, void *data) {
|
||||
fprintf(stderr, "MLX ERROR: %s\n", msg);
|
||||
last_mlx_error = msg;
|
||||
}
|
||||
|
||||
|
|
@ -26,13 +25,16 @@ static void set_error_handler() {
|
|||
mlx_set_error_handler(&mlx_go_error_handler, NULL, NULL);
|
||||
}
|
||||
|
||||
static const char* get_last_error() {
|
||||
return last_mlx_error;
|
||||
static const char* get_and_clear_last_error() {
|
||||
const char *msg = last_mlx_error;
|
||||
last_mlx_error = NULL;
|
||||
return msg;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
|
@ -43,29 +45,24 @@ var initOnce sync.Once
|
|||
func Init() {
|
||||
initOnce.Do(func() {
|
||||
C.set_error_handler()
|
||||
slog.Debug("mlx: initialized with Metal backend")
|
||||
slog.Debug("mlx: initialised with Metal backend")
|
||||
})
|
||||
}
|
||||
|
||||
// checkError logs the last MLX error if any occurred.
|
||||
func checkError() {
|
||||
if msg := C.get_last_error(); msg != nil {
|
||||
slog.Error("mlx", "error", C.GoString(msg))
|
||||
// lastError reads and clears the most recent MLX-C error.
|
||||
// Returns nil if no error occurred since the last call.
|
||||
func lastError() error {
|
||||
msg := C.get_and_clear_last_error()
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("mlx: %s", C.GoString(msg))
|
||||
}
|
||||
|
||||
// Materialize synchronously evaluates arrays, computing their values on the GPU.
|
||||
// This is the MLX equivalent of forcing lazy computation to complete.
|
||||
func Materialize(outputs ...*Array) {
|
||||
doMaterialize(outputs, false)
|
||||
}
|
||||
|
||||
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
|
||||
func MaterializeAsync(outputs ...*Array) {
|
||||
doMaterialize(outputs, true)
|
||||
}
|
||||
|
||||
func doMaterialize(outputs []*Array, async bool) {
|
||||
// Eval synchronously evaluates arrays on the GPU and returns any error.
|
||||
// This is the error-returning variant of Materialize. Use it in code paths
|
||||
// that need to propagate errors (e.g. the generate loop).
|
||||
func Eval(outputs ...*Array) error {
|
||||
Init()
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
|
@ -76,10 +73,51 @@ func doMaterialize(outputs []*Array, async bool) {
|
|||
}
|
||||
}
|
||||
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
} else {
|
||||
C.mlx_eval(vector)
|
||||
rc := C.mlx_eval(vector)
|
||||
if rc != 0 {
|
||||
if err := lastError(); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("mlx: eval failed (rc=%d)", rc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EvalAsync queues arrays for asynchronous GPU evaluation and returns any error.
|
||||
func EvalAsync(outputs ...*Array) error {
|
||||
Init()
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output != nil && output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
rc := C.mlx_async_eval(vector)
|
||||
if rc != 0 {
|
||||
if err := lastError(); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("mlx: async eval failed (rc=%d)", rc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Materialize synchronously evaluates arrays, computing their values on the GPU.
|
||||
// Errors are logged but not returned — use [Eval] when error propagation is needed.
|
||||
func Materialize(outputs ...*Array) {
|
||||
if err := Eval(outputs...); err != nil {
|
||||
slog.Error("mlx: materialize", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
|
||||
// Errors are logged but not returned — use [EvalAsync] when error propagation is needed.
|
||||
func MaterializeAsync(outputs ...*Array) {
|
||||
if err := EvalAsync(outputs...); err != nil {
|
||||
slog.Error("mlx: materialize async", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -122,6 +122,9 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
for name, arr := range LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
if err := lastError(); err != nil {
|
||||
return nil, fmt.Errorf("qwen3: load weights %s: %w", filepath.Base(path), err)
|
||||
}
|
||||
}
|
||||
|
||||
w := func(name string) *Array { return resolveWeight(weights, name) }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue