From f416f7e529f41c2ed4dce47c7fb01325511088fb Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 02:16:30 +0000 Subject: [PATCH] debug: add shape logging and stderr error handler for inference debugging --- pkg/mlx/mlx.go | 2 ++ pkg/mlx/model/gemma3.go | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/pkg/mlx/mlx.go b/pkg/mlx/mlx.go index 3067bcb0..31445dd4 100644 --- a/pkg/mlx/mlx.go +++ b/pkg/mlx/mlx.go @@ -23,12 +23,14 @@ package mlx #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate #cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib +#include #include #include "mlx/c/mlx.h" static const char *last_mlx_error = NULL; static void mlx_go_error_handler(const char *msg, void *data) { + fprintf(stderr, "MLX ERROR: %s\n", msg); last_mlx_error = msg; } diff --git a/pkg/mlx/model/gemma3.go b/pkg/mlx/model/gemma3.go index 4892218c..0bb1e4eb 100644 --- a/pkg/mlx/model/gemma3.go +++ b/pkg/mlx/model/gemma3.go @@ -95,6 +95,9 @@ type MLP struct { DownProj *mlx.Linear } +// debugOnce is a temporary flag for shape debugging (remove after fix). +var debugOnce = true + // compiledGELU is a singleton for the compiled GELU function. var compiledGELU *mlx.CompiledFunc @@ -356,6 +359,13 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b k := a.KProj.Forward(x) v := a.VProj.Forward(x) + if debugOnce { + slog.Info("mlx: debug", "x_shape", x.Shape(), "q_shape", q.Shape(), "k_shape", k.Shape(), "v_shape", v.Shape(), + "B", B, "L", L, "num_heads", cfg.NumAttentionHeads, "num_kv_heads", cfg.NumKeyValueHeads, "head_dim", cfg.HeadDim, + "q_scales", a.QProj.Scales != nil) + debugOnce = false + } + // Reshape to [B, num_heads, L, head_dim] q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)