fix: correct SDPA mask mode and slice logits to last position
This commit is contained in:
parent
70c32135d0
commit
b6fbb88bfb
3 changed files with 18 additions and 12 deletions
|
|
@ -87,6 +87,8 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
|||
}
|
||||
|
||||
logits := b.model.Forward(input, b.caches)
|
||||
// Take last position: [B, L, V] → [B, V]
|
||||
logits = lastPosition(logits)
|
||||
next := sampler.Sample(logits)
|
||||
mlx.Materialize(next)
|
||||
|
||||
|
|
@ -101,6 +103,19 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
|||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
||||
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
|
||||
func lastPosition(logits *mlx.Array) *mlx.Array {
|
||||
shape := logits.Shape()
|
||||
if len(shape) == 3 && shape[1] > 1 {
|
||||
L := shape[1]
|
||||
logits = mlx.Slice(logits, []int32{0, L - 1, 0}, []int32{shape[0], L, shape[2]})
|
||||
logits = mlx.Reshape(logits, shape[0], shape[2])
|
||||
} else if len(shape) == 3 && shape[1] == 1 {
|
||||
logits = mlx.Reshape(logits, shape[0], shape[2])
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
// Chat formats messages and generates a response.
|
||||
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
||||
// Format as Gemma chat
|
||||
|
|
@ -148,6 +163,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
|||
}
|
||||
|
||||
logits := b.model.Forward(input, b.caches)
|
||||
logits = lastPosition(logits)
|
||||
next := sampler.Sample(logits)
|
||||
mlx.Materialize(next)
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, off
|
|||
|
||||
// ScaledDotProductAttention computes attention using a fused Metal kernel.
|
||||
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
|
||||
mode := "none"
|
||||
mode := ""
|
||||
if causal {
|
||||
mode = "causal"
|
||||
}
|
||||
|
|
@ -67,7 +67,7 @@ func ScaledDotProductAttention(query, key, value *Array, scale float32, causal b
|
|||
|
||||
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
|
||||
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
|
||||
cMode := C.CString("none")
|
||||
cMode := C.CString("array")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
sinksArr := C.mlx_array_new()
|
||||
|
|
|
|||
|
|
@ -95,9 +95,6 @@ type MLP struct {
|
|||
DownProj *mlx.Linear
|
||||
}
|
||||
|
||||
// debugOnce is a temporary flag for shape debugging (remove after fix).
|
||||
var debugOnce = true
|
||||
|
||||
// compiledGELU is a singleton for the compiled GELU function.
|
||||
var compiledGELU *mlx.CompiledFunc
|
||||
|
||||
|
|
@ -372,13 +369,6 @@ func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
if debugOnce {
|
||||
slog.Info("mlx: debug", "x_shape", x.Shape(), "q_shape", q.Shape(), "k_shape", k.Shape(), "v_shape", v.Shape(),
|
||||
"B", B, "L", L, "num_heads", cfg.NumAttentionHeads, "num_kv_heads", cfg.NumKeyValueHeads, "head_dim", cfg.HeadDim,
|
||||
"q_scales", a.QProj.Scales != nil)
|
||||
debugOnce = false
|
||||
}
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue