debug: add shape logging and stderr error handler for inference debugging
This commit is contained in:
parent
d92d097a7f
commit
f416f7e529
2 changed files with 12 additions and 0 deletions
|
|
@ -23,12 +23,14 @@ package mlx
|
||||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||||
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib
|
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include "mlx/c/mlx.h"
|
#include "mlx/c/mlx.h"
|
||||||
|
|
||||||
static const char *last_mlx_error = NULL;
|
static const char *last_mlx_error = NULL;
|
||||||
|
|
||||||
static void mlx_go_error_handler(const char *msg, void *data) {
|
static void mlx_go_error_handler(const char *msg, void *data) {
|
||||||
|
fprintf(stderr, "MLX ERROR: %s\n", msg);
|
||||||
last_mlx_error = msg;
|
last_mlx_error = msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,9 @@ type MLP struct {
|
||||||
DownProj *mlx.Linear
|
DownProj *mlx.Linear
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// debugOnce is a temporary flag for shape debugging (remove after fix).
|
||||||
|
var debugOnce = true
|
||||||
|
|
||||||
// compiledGELU is a singleton for the compiled GELU function.
|
// compiledGELU is a singleton for the compiled GELU function.
|
||||||
var compiledGELU *mlx.CompiledFunc
|
var compiledGELU *mlx.CompiledFunc
|
||||||
|
|
||||||
|
|
@ -356,6 +359,13 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||||
k := a.KProj.Forward(x)
|
k := a.KProj.Forward(x)
|
||||||
v := a.VProj.Forward(x)
|
v := a.VProj.Forward(x)
|
||||||
|
|
||||||
|
if debugOnce {
|
||||||
|
slog.Info("mlx: debug", "x_shape", x.Shape(), "q_shape", q.Shape(), "k_shape", k.Shape(), "v_shape", v.Shape(),
|
||||||
|
"B", B, "L", L, "num_heads", cfg.NumAttentionHeads, "num_kv_heads", cfg.NumKeyValueHeads, "head_dim", cfg.HeadDim,
|
||||||
|
"q_scales", a.QProj.Scales != nil)
|
||||||
|
debugOnce = false
|
||||||
|
}
|
||||||
|
|
||||||
// Reshape to [B, num_heads, L, head_dim]
|
// Reshape to [B, num_heads, L, head_dim]
|
||||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue