debug: add shape logging and stderr error handler for inference debugging
This commit is contained in:
parent
56c6e2fa8d
commit
d3c31aa5a6
2 changed files with 12 additions and 0 deletions
|
|
@ -23,12 +23,14 @@ package mlx
|
|||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
static const char *last_mlx_error = NULL;
|
||||
|
||||
static void mlx_go_error_handler(const char *msg, void *data) {
|
||||
fprintf(stderr, "MLX ERROR: %s\n", msg);
|
||||
last_mlx_error = msg;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -95,6 +95,9 @@ type MLP struct {
|
|||
DownProj *mlx.Linear
|
||||
}
|
||||
|
||||
// debugOnce is a temporary flag for shape debugging (remove after fix).
|
||||
var debugOnce = true
|
||||
|
||||
// compiledGELU is a singleton for the compiled GELU function.
|
||||
var compiledGELU *mlx.CompiledFunc
|
||||
|
||||
|
|
@ -356,6 +359,13 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
if debugOnce {
|
||||
slog.Info("mlx: debug", "x_shape", x.Shape(), "q_shape", q.Shape(), "k_shape", k.Shape(), "v_shape", v.Shape(),
|
||||
"B", B, "L", L, "num_heads", cfg.NumAttentionHeads, "num_kv_heads", cfg.NumKeyValueHeads, "head_dim", cfg.HeadDim,
|
||||
"q_scales", a.QProj.Scales != nil)
|
||||
debugOnce = false
|
||||
}
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue