- Add ForwardMasked to InternalModel, Gemma3 and Qwen3 architectures - Thread attention mask through decoder layers and SDPA calls - Use ScaledDotProductAttentionWithMask when explicit mask provided - Create batch.go with padded batching, mask construction, Classify (prefill-only) and BatchGenerate (autoregressive) implementations - Wire Classify/BatchGenerate through metalAdapter to go-inference - Tests: mask unit tests (shape, values, multi-batch), Classify with 4 prompts (152 prompts/s), WithLogits, BatchGenerate with 2 prompts Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
105 lines
2.5 KiB
Go
105 lines
2.5 KiB
Go
//go:build darwin && arm64
|
|
|
|
package metal
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
)
|
|
|
|
func TestBuildBatchMask_Shape(t *testing.T) {
|
|
// 2 prompts, max length 4, prompt lengths [3, 2].
|
|
mask := buildBatchMask(2, 4, []int32{3, 2})
|
|
if err := Eval(mask); err != nil {
|
|
t.Fatalf("Eval mask: %v", err)
|
|
}
|
|
|
|
shape := mask.Shape()
|
|
want := []int32{2, 1, 4, 4}
|
|
if len(shape) != 4 {
|
|
t.Fatalf("mask ndim = %d, want 4", len(shape))
|
|
}
|
|
for i, s := range shape {
|
|
if s != want[i] {
|
|
t.Errorf("mask shape[%d] = %d, want %d", i, s, want[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestBuildBatchMask_Values(t *testing.T) {
|
|
// Single prompt of length 3, padded to 4.
|
|
// Expected mask [1, 1, 4, 4]:
|
|
// row 0: [0, -inf, -inf, -inf] (can only attend to pos 0)
|
|
// row 1: [0, 0, -inf, -inf] (attend to pos 0,1)
|
|
// row 2: [0, 0, 0, -inf] (attend to pos 0,1,2)
|
|
// row 3: [0, 0, 0, -inf] (row 3 is padding — causal says j<=3 but j<3 caps it)
|
|
mask := buildBatchMask(1, 4, []int32{3})
|
|
if err := Eval(mask); err != nil {
|
|
t.Fatalf("Eval mask: %v", err)
|
|
}
|
|
|
|
// Flatten to get values.
|
|
flat := Reshape(mask, 16)
|
|
if err := Eval(flat); err != nil {
|
|
t.Fatalf("Eval flat: %v", err)
|
|
}
|
|
vals := flat.Floats()
|
|
|
|
negInf := float32(math.Inf(-1))
|
|
expected := []float32{
|
|
// row 0: attend j=0 only
|
|
0, negInf, negInf, negInf,
|
|
// row 1: attend j=0,1
|
|
0, 0, negInf, negInf,
|
|
// row 2: attend j=0,1,2
|
|
0, 0, 0, negInf,
|
|
// row 3: padding row — causal allows j<=3 but padding caps at j<3
|
|
0, 0, 0, negInf,
|
|
}
|
|
|
|
for i, v := range vals {
|
|
e := expected[i]
|
|
if math.IsInf(float64(e), -1) {
|
|
if !math.IsInf(float64(v), -1) {
|
|
t.Errorf("vals[%d] = %f, want -inf", i, v)
|
|
}
|
|
} else if v != e {
|
|
t.Errorf("vals[%d] = %f, want %f", i, v, e)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestBuildBatchMask_MultipleBatches(t *testing.T) {
|
|
// 2 prompts: lengths [2, 1], max length 2.
|
|
mask := buildBatchMask(2, 2, []int32{2, 1})
|
|
if err := Eval(mask); err != nil {
|
|
t.Fatalf("Eval mask: %v", err)
|
|
}
|
|
|
|
flat := Reshape(mask, 8)
|
|
if err := Eval(flat); err != nil {
|
|
t.Fatalf("Eval flat: %v", err)
|
|
}
|
|
vals := flat.Floats()
|
|
|
|
negInf := float32(math.Inf(-1))
|
|
expected := []float32{
|
|
// batch 0 (len=2): full causal, no padding
|
|
0, negInf,
|
|
0, 0,
|
|
// batch 1 (len=1): only first position is real
|
|
0, negInf,
|
|
0, negInf, // row 1: causal allows j<=1 but padding caps at j<1
|
|
}
|
|
|
|
for i, v := range vals {
|
|
e := expected[i]
|
|
if math.IsInf(float64(e), -1) {
|
|
if !math.IsInf(float64(v), -1) {
|
|
t.Errorf("batch vals[%d] = %f, want -inf", i, v)
|
|
}
|
|
} else if v != e {
|
|
t.Errorf("batch vals[%d] = %f, want %f", i, v, e)
|
|
}
|
|
}
|
|
}
|