feat(metal): implement batch inference (Classify, BatchGenerate)
- Add ForwardMasked to InternalModel, Gemma3 and Qwen3 architectures - Thread attention mask through decoder layers and SDPA calls - Use ScaledDotProductAttentionWithMask when explicit mask provided - Create batch.go with padded batching, mask construction, Classify (prefill-only) and BatchGenerate (autoregressive) implementations - Wire Classify/BatchGenerate through metalAdapter to go-inference - Tests: mask unit tests (shape, values, multi-batch), Classify with 4 prompts (152 prompts/s), WithLogits, BatchGenerate with 2 prompts Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ce1acef462
commit
5644857034
8 changed files with 602 additions and 11 deletions
304
internal/metal/batch.go
Normal file
304
internal/metal/batch.go
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package metal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// ClassifyResult holds the output for a single prompt in batch classification.
|
||||
type ClassifyResult struct {
|
||||
Token Token
|
||||
Logits []float32
|
||||
}
|
||||
|
||||
// BatchResult holds the output for a single prompt in batch generation.
|
||||
type BatchResult struct {
|
||||
Tokens []Token
|
||||
Err error
|
||||
}
|
||||
|
||||
// Classify runs batched prefill-only inference. Each prompt gets a single
|
||||
// forward pass and the token at the last position is sampled.
|
||||
func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConfig, returnLogits bool) ([]ClassifyResult, error) {
|
||||
if len(prompts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Tokenise all prompts.
|
||||
encoded := make([][]int32, len(prompts))
|
||||
lengths := make([]int, len(prompts))
|
||||
for i, p := range prompts {
|
||||
encoded[i] = m.tokenizer.Encode(p)
|
||||
lengths[i] = len(encoded[i])
|
||||
}
|
||||
|
||||
// Sort by length descending for minimal padding. Track original indices.
|
||||
indices := make([]int, len(prompts))
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
sort.Slice(indices, func(a, b int) bool {
|
||||
return lengths[indices[a]] > lengths[indices[b]]
|
||||
})
|
||||
|
||||
maxLen := lengths[indices[0]]
|
||||
N := int32(len(prompts))
|
||||
L := int32(maxLen)
|
||||
|
||||
// Build padded token matrix [N, maxLen] and prompt lengths (sorted order).
|
||||
padded := make([]int32, int(N)*int(L))
|
||||
sortedLengths := make([]int32, N)
|
||||
for si, origIdx := range indices {
|
||||
sortedLengths[si] = int32(lengths[origIdx])
|
||||
copy(padded[si*int(L):], encoded[origIdx])
|
||||
// Remaining positions are already 0 (pad token)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Build attention mask [N, 1, L, L].
|
||||
mask := buildBatchMask(N, L, sortedLengths)
|
||||
|
||||
// Single forward pass.
|
||||
tokens := FromValues(padded, int(N), int(L))
|
||||
logits := m.model.ForwardMasked(tokens, mask, m.newCachesN(int(N)))
|
||||
if err := Eval(logits); err != nil {
|
||||
return nil, fmt.Errorf("classify prefill: %w", err)
|
||||
}
|
||||
|
||||
// logits shape: [N, L, vocab]
|
||||
sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK)
|
||||
|
||||
// Gather logits at each prompt's last real token position and sample.
|
||||
sortedResults := make([]ClassifyResult, N)
|
||||
for si := int32(0); si < N; si++ {
|
||||
lastPos := sortedLengths[si] - 1
|
||||
|
||||
// Extract [1, vocab] at position lastPos for this batch element.
|
||||
batchLogits := SliceAxis(logits, 0, si, si+1) // [1, L, vocab]
|
||||
posLogits := SliceAxis(batchLogits, 1, lastPos, lastPos+1)
|
||||
posLogits = Reshape(posLogits, 1, int32(posLogits.Dim(2))) // [1, vocab]
|
||||
|
||||
next := sampler.Sample(posLogits)
|
||||
if err := Eval(next); err != nil {
|
||||
return nil, fmt.Errorf("classify sample %d: %w", si, err)
|
||||
}
|
||||
|
||||
id := int32(next.Int())
|
||||
text := m.tokenizer.DecodeToken(id)
|
||||
sortedResults[si].Token = Token{ID: id, Text: text}
|
||||
|
||||
if returnLogits {
|
||||
posLogits = Reshape(posLogits, int32(posLogits.Dim(1)))
|
||||
if err := Eval(posLogits); err != nil {
|
||||
return nil, fmt.Errorf("classify logits %d: %w", si, err)
|
||||
}
|
||||
sortedResults[si].Logits = posLogits.Floats()
|
||||
}
|
||||
}
|
||||
|
||||
// Unsort results back to original prompt order.
|
||||
results := make([]ClassifyResult, N)
|
||||
for si, origIdx := range indices {
|
||||
results[origIdx] = sortedResults[si]
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// BatchGenerate runs batched autoregressive generation.
|
||||
func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg GenerateConfig) ([]BatchResult, error) {
|
||||
if len(prompts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Tokenise all prompts.
|
||||
encoded := make([][]int32, len(prompts))
|
||||
lengths := make([]int, len(prompts))
|
||||
for i, p := range prompts {
|
||||
encoded[i] = m.tokenizer.Encode(p)
|
||||
lengths[i] = len(encoded[i])
|
||||
}
|
||||
|
||||
// Sort by length descending.
|
||||
indices := make([]int, len(prompts))
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
sort.Slice(indices, func(a, b int) bool {
|
||||
return lengths[indices[a]] > lengths[indices[b]]
|
||||
})
|
||||
|
||||
N := int32(len(prompts))
|
||||
maxLen := lengths[indices[0]]
|
||||
L := int32(maxLen)
|
||||
|
||||
// Build padded token matrix and lengths.
|
||||
padded := make([]int32, int(N)*int(L))
|
||||
sortedLengths := make([]int32, N)
|
||||
for si, origIdx := range indices {
|
||||
sortedLengths[si] = int32(lengths[origIdx])
|
||||
copy(padded[si*int(L):], encoded[origIdx])
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Prefill with mask.
|
||||
mask := buildBatchMask(N, L, sortedLengths)
|
||||
tokens := FromValues(padded, int(N), int(L))
|
||||
caches := m.newCachesN(int(N))
|
||||
logits := m.model.ForwardMasked(tokens, mask, caches)
|
||||
if err := Eval(logits); err != nil {
|
||||
return nil, fmt.Errorf("batch prefill: %w", err)
|
||||
}
|
||||
|
||||
sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK)
|
||||
eosID := m.tokenizer.EOSToken()
|
||||
|
||||
// Per-sequence state.
|
||||
type seqState struct {
|
||||
tokens []Token
|
||||
finished bool
|
||||
}
|
||||
states := make([]seqState, N)
|
||||
|
||||
maxTokens := cfg.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 256
|
||||
}
|
||||
|
||||
for step := 0; step < maxTokens; step++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Return partial results on cancellation.
|
||||
sortedResults := make([]BatchResult, N)
|
||||
for si := range states {
|
||||
sortedResults[si] = BatchResult{Tokens: states[si].tokens, Err: ctx.Err()}
|
||||
}
|
||||
results := make([]BatchResult, N)
|
||||
for si, origIdx := range indices {
|
||||
results[origIdx] = sortedResults[si]
|
||||
}
|
||||
return results, nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Sample next token for each sequence from last position logits.
|
||||
nextIDs := make([]int32, N)
|
||||
allFinished := true
|
||||
|
||||
for si := int32(0); si < N; si++ {
|
||||
if states[si].finished {
|
||||
nextIDs[si] = 0 // pad
|
||||
continue
|
||||
}
|
||||
|
||||
var posLogits *Array
|
||||
if step == 0 {
|
||||
// First step: gather from prefill at each prompt's last real position.
|
||||
lastPos := sortedLengths[si] - 1
|
||||
batchLogits := SliceAxis(logits, 0, si, si+1)
|
||||
posLogits = SliceAxis(batchLogits, 1, lastPos, lastPos+1)
|
||||
} else {
|
||||
// Subsequent steps: logits shape [N, 1, vocab].
|
||||
posLogits = SliceAxis(logits, 0, si, si+1)
|
||||
posLogits = SliceAxis(posLogits, 1, 0, 1)
|
||||
}
|
||||
posLogits = Reshape(posLogits, 1, int32(posLogits.Dim(posLogits.NumDims()-1)))
|
||||
|
||||
next := sampler.Sample(posLogits)
|
||||
if err := Eval(next); err != nil {
|
||||
return nil, fmt.Errorf("batch sample step %d seq %d: %w", step, si, err)
|
||||
}
|
||||
|
||||
id := int32(next.Int())
|
||||
nextIDs[si] = id
|
||||
|
||||
// Check stop conditions.
|
||||
if id == eosID {
|
||||
states[si].finished = true
|
||||
continue
|
||||
}
|
||||
for _, stop := range cfg.StopTokens {
|
||||
if id == stop {
|
||||
states[si].finished = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !states[si].finished {
|
||||
text := m.tokenizer.DecodeToken(id)
|
||||
states[si].tokens = append(states[si].tokens, Token{ID: id, Text: text})
|
||||
allFinished = false
|
||||
}
|
||||
}
|
||||
|
||||
if allFinished {
|
||||
break
|
||||
}
|
||||
|
||||
// Feed next tokens [N, 1] through the model.
|
||||
nextInput := FromValues(nextIDs, int(N), 1)
|
||||
logits = m.model.Forward(nextInput, caches)
|
||||
if err := Eval(logits); err != nil {
|
||||
return nil, fmt.Errorf("batch decode step %d: %w", step, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Unsort results back to original order.
|
||||
sortedResults := make([]BatchResult, N)
|
||||
for si := range states {
|
||||
sortedResults[si] = BatchResult{Tokens: states[si].tokens}
|
||||
}
|
||||
results := make([]BatchResult, N)
|
||||
for si, origIdx := range indices {
|
||||
results[origIdx] = sortedResults[si]
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// buildBatchMask constructs a combined causal + padding attention mask.
|
||||
// Shape: [N, 1, L, L]. Values: 0.0 = attend, -inf = ignore.
|
||||
// mask[b, 0, i, j] = 0.0 if j <= i AND j < promptLen[b], else -inf.
|
||||
func buildBatchMask(N, L int32, promptLens []int32) *Array {
|
||||
negInf := float32(math.Inf(-1))
|
||||
data := make([]float32, int(N)*int(L)*int(L))
|
||||
|
||||
for b := int32(0); b < N; b++ {
|
||||
pLen := promptLens[b]
|
||||
base := int(b) * int(L) * int(L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
if j <= i && j < pLen {
|
||||
data[base+int(i)*int(L)+int(j)] = 0
|
||||
} else {
|
||||
data[base+int(i)*int(L)+int(j)] = negInf
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mask := FromValues(data, int(N), 1, int(L), int(L))
|
||||
return mask
|
||||
}
|
||||
|
||||
// newCachesN creates N independent sets of per-layer caches for batched inference.
|
||||
// Since our KV cache implementation handles batch dimension internally,
|
||||
// we create a single set of caches (same as non-batched).
|
||||
func (m *Model) newCachesN(n int) []Cache {
|
||||
// KV caches handle the batch dimension automatically via the key/value
|
||||
// array shapes. A single cache set works for any batch size.
|
||||
return m.newCaches()
|
||||
}
|
||||
105
internal/metal/batch_test.go
Normal file
105
internal/metal/batch_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package metal
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildBatchMask_Shape(t *testing.T) {
|
||||
// 2 prompts, max length 4, prompt lengths [3, 2].
|
||||
mask := buildBatchMask(2, 4, []int32{3, 2})
|
||||
if err := Eval(mask); err != nil {
|
||||
t.Fatalf("Eval mask: %v", err)
|
||||
}
|
||||
|
||||
shape := mask.Shape()
|
||||
want := []int32{2, 1, 4, 4}
|
||||
if len(shape) != 4 {
|
||||
t.Fatalf("mask ndim = %d, want 4", len(shape))
|
||||
}
|
||||
for i, s := range shape {
|
||||
if s != want[i] {
|
||||
t.Errorf("mask shape[%d] = %d, want %d", i, s, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBatchMask_Values(t *testing.T) {
|
||||
// Single prompt of length 3, padded to 4.
|
||||
// Expected mask [1, 1, 4, 4]:
|
||||
// row 0: [0, -inf, -inf, -inf] (can only attend to pos 0)
|
||||
// row 1: [0, 0, -inf, -inf] (attend to pos 0,1)
|
||||
// row 2: [0, 0, 0, -inf] (attend to pos 0,1,2)
|
||||
// row 3: [0, 0, 0, -inf] (row 3 is padding — causal says j<=3 but j<3 caps it)
|
||||
mask := buildBatchMask(1, 4, []int32{3})
|
||||
if err := Eval(mask); err != nil {
|
||||
t.Fatalf("Eval mask: %v", err)
|
||||
}
|
||||
|
||||
// Flatten to get values.
|
||||
flat := Reshape(mask, 16)
|
||||
if err := Eval(flat); err != nil {
|
||||
t.Fatalf("Eval flat: %v", err)
|
||||
}
|
||||
vals := flat.Floats()
|
||||
|
||||
negInf := float32(math.Inf(-1))
|
||||
expected := []float32{
|
||||
// row 0: attend j=0 only
|
||||
0, negInf, negInf, negInf,
|
||||
// row 1: attend j=0,1
|
||||
0, 0, negInf, negInf,
|
||||
// row 2: attend j=0,1,2
|
||||
0, 0, 0, negInf,
|
||||
// row 3: padding row — causal allows j<=3 but padding caps at j<3
|
||||
0, 0, 0, negInf,
|
||||
}
|
||||
|
||||
for i, v := range vals {
|
||||
e := expected[i]
|
||||
if math.IsInf(float64(e), -1) {
|
||||
if !math.IsInf(float64(v), -1) {
|
||||
t.Errorf("vals[%d] = %f, want -inf", i, v)
|
||||
}
|
||||
} else if v != e {
|
||||
t.Errorf("vals[%d] = %f, want %f", i, v, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBatchMask_MultipleBatches(t *testing.T) {
|
||||
// 2 prompts: lengths [2, 1], max length 2.
|
||||
mask := buildBatchMask(2, 2, []int32{2, 1})
|
||||
if err := Eval(mask); err != nil {
|
||||
t.Fatalf("Eval mask: %v", err)
|
||||
}
|
||||
|
||||
flat := Reshape(mask, 8)
|
||||
if err := Eval(flat); err != nil {
|
||||
t.Fatalf("Eval flat: %v", err)
|
||||
}
|
||||
vals := flat.Floats()
|
||||
|
||||
negInf := float32(math.Inf(-1))
|
||||
expected := []float32{
|
||||
// batch 0 (len=2): full causal, no padding
|
||||
0, negInf,
|
||||
0, 0,
|
||||
// batch 1 (len=1): only first position is real
|
||||
0, negInf,
|
||||
0, negInf, // row 1: causal allows j<=1 but padding caps at j<1
|
||||
}
|
||||
|
||||
for i, v := range vals {
|
||||
e := expected[i]
|
||||
if math.IsInf(float64(e), -1) {
|
||||
if !math.IsInf(float64(v), -1) {
|
||||
t.Errorf("batch vals[%d] = %f, want -inf", i, v)
|
||||
}
|
||||
} else if v != e {
|
||||
t.Errorf("batch vals[%d] = %f, want %f", i, v, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -67,7 +67,8 @@ type fakeModel struct {
|
|||
numLayers int
|
||||
}
|
||||
|
||||
func (f *fakeModel) Forward(_ *Array, _ []Cache) *Array { return nil }
|
||||
func (f *fakeModel) Forward(_ *Array, _ []Cache) *Array { return nil }
|
||||
func (f *fakeModel) ForwardMasked(_ *Array, _ *Array, _ []Cache) *Array { return nil }
|
||||
func (f *fakeModel) NewCache() []Cache {
|
||||
caches := make([]Cache, f.numLayers)
|
||||
for i := range caches {
|
||||
|
|
|
|||
|
|
@ -323,6 +323,10 @@ func isLayerSliding(layerIdx, pattern int32) bool {
|
|||
|
||||
// Forward runs the text model forward pass.
|
||||
func (m *GemmaModel) Forward(tokens *Array, caches []Cache) *Array {
|
||||
return m.ForwardMasked(tokens, nil, caches)
|
||||
}
|
||||
|
||||
func (m *GemmaModel) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
|
|
@ -330,15 +334,15 @@ func (m *GemmaModel) Forward(tokens *Array, caches []Cache) *Array {
|
|||
h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||
h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
|
||||
}
|
||||
|
||||
return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *TextConfig) *Array {
|
||||
func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *TextConfig) *Array {
|
||||
normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, mask, cfg)
|
||||
attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := Add(x, attnOut)
|
||||
|
||||
|
|
@ -348,7 +352,7 @@ func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *TextConfig) *
|
|||
return Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *TextConfig) *Array {
|
||||
func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, mask *Array, cfg *TextConfig) *Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
|
@ -386,7 +390,12 @@ func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *
|
|||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
var out *Array
|
||||
if mask != nil {
|
||||
out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
|
||||
} else {
|
||||
out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
}
|
||||
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@ type InternalModel interface {
|
|||
// Forward runs the model forward pass on token IDs with KV caches.
|
||||
Forward(tokens *Array, caches []Cache) *Array
|
||||
|
||||
// ForwardMasked runs the forward pass with an explicit attention mask.
|
||||
// mask shape: [B, 1, L, L] — additive mask (0 = attend, -inf = ignore).
|
||||
// Used for batched inference with padded sequences.
|
||||
ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array
|
||||
|
||||
// NewCache creates per-layer KV caches for generation.
|
||||
NewCache() []Cache
|
||||
|
||||
|
|
|
|||
|
|
@ -249,22 +249,29 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
|||
// Forward runs the Qwen 3 forward pass.
|
||||
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
|
||||
func (m *Qwen3Model) Forward(tokens *Array, caches []Cache) *Array {
|
||||
return m.ForwardMasked(tokens, nil, caches)
|
||||
}
|
||||
|
||||
// ForwardMasked runs the forward pass with an explicit attention mask.
|
||||
// mask shape: [B, 1, L, L] — additive mask (0 = attend, -inf = ignore).
|
||||
// When mask is nil, standard causal attention is used.
|
||||
func (m *Qwen3Model) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||
h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
|
||||
}
|
||||
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array {
|
||||
func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
|
||||
// Pre-attention norm → attention → residual add
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, cfg)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, mask, cfg)
|
||||
h := Add(x, attnOut)
|
||||
|
||||
// Pre-MLP norm → MLP → residual add
|
||||
|
|
@ -273,7 +280,7 @@ func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Con
|
|||
return Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array {
|
||||
func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
|
@ -309,7 +316,12 @@ func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config
|
|||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
var out *Array
|
||||
if mask != nil {
|
||||
out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
|
||||
} else {
|
||||
out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
}
|
||||
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
|
|
|||
110
mlx_test.go
110
mlx_test.go
|
|
@ -443,6 +443,116 @@ func TestLlama_Inference(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Batch Inference tests (Gemma3-1B) ---
|
||||
|
||||
// TestClassify_Batch validates batched prefill-only classification.
|
||||
func TestClassify_Batch(t *testing.T) {
|
||||
modelPath := gemma3ModelPath(t)
|
||||
|
||||
m, err := inference.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
defer func() { m.Close(); mlx.ClearCache() }()
|
||||
|
||||
ctx := context.Background()
|
||||
prompts := []string{
|
||||
"The capital of France is",
|
||||
"2 + 2 =",
|
||||
"The colour of the sky is",
|
||||
"Go is a programming",
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
results, err := m.Classify(ctx, prompts)
|
||||
dur := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatalf("Classify: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != len(prompts) {
|
||||
t.Fatalf("Classify returned %d results, want %d", len(results), len(prompts))
|
||||
}
|
||||
|
||||
for i, r := range results {
|
||||
if r.Token.ID == 0 {
|
||||
t.Errorf("prompt %d: got pad token (id=0)", i)
|
||||
}
|
||||
t.Logf("prompt %d %q → token %d %q", i, prompts[i], r.Token.ID, r.Token.Text)
|
||||
}
|
||||
t.Logf("Classified %d prompts in %s (%.1f prompts/s)", len(prompts), dur, float64(len(prompts))/dur.Seconds())
|
||||
}
|
||||
|
||||
// TestClassify_WithLogits validates that logits are returned when requested.
|
||||
func TestClassify_WithLogits(t *testing.T) {
|
||||
modelPath := gemma3ModelPath(t)
|
||||
|
||||
m, err := inference.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
defer func() { m.Close(); mlx.ClearCache() }()
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := m.Classify(ctx, []string{"Hello world"}, inference.WithLogits())
|
||||
if err != nil {
|
||||
t.Fatalf("Classify: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("got %d results, want 1", len(results))
|
||||
}
|
||||
if len(results[0].Logits) == 0 {
|
||||
t.Fatal("expected non-empty logits with WithLogits()")
|
||||
}
|
||||
t.Logf("Logits length: %d (vocab size)", len(results[0].Logits))
|
||||
}
|
||||
|
||||
// TestBatchGenerate validates batched autoregressive generation.
|
||||
func TestBatchGenerate(t *testing.T) {
|
||||
modelPath := gemma3ModelPath(t)
|
||||
|
||||
m, err := inference.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
defer func() { m.Close(); mlx.ClearCache() }()
|
||||
|
||||
ctx := context.Background()
|
||||
prompts := []string{
|
||||
"The capital of France is",
|
||||
"2 + 2 =",
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(16))
|
||||
dur := time.Since(start)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchGenerate: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != len(prompts) {
|
||||
t.Fatalf("BatchGenerate returned %d results, want %d", len(results), len(prompts))
|
||||
}
|
||||
|
||||
for i, r := range results {
|
||||
if r.Err != nil {
|
||||
t.Errorf("prompt %d error: %v", i, r.Err)
|
||||
continue
|
||||
}
|
||||
if len(r.Tokens) == 0 {
|
||||
t.Errorf("prompt %d: no tokens generated", i)
|
||||
continue
|
||||
}
|
||||
var output strings.Builder
|
||||
for _, tok := range r.Tokens {
|
||||
output.WriteString(tok.Text)
|
||||
}
|
||||
t.Logf("prompt %d %q → %d tokens: %s", i, prompts[i], len(r.Tokens), output.String())
|
||||
}
|
||||
t.Logf("Batch generated in %s", dur)
|
||||
}
|
||||
|
||||
// TestLlama_Chat validates chat template for Llama 3 models.
|
||||
func TestLlama_Chat(t *testing.T) {
|
||||
modelPath := llamaModelPath(t)
|
||||
|
|
|
|||
|
|
@ -119,6 +119,51 @@ func (a *metalAdapter) Chat(ctx context.Context, messages []inference.Message, o
|
|||
}
|
||||
}
|
||||
|
||||
func (a *metalAdapter) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
mcfg := metal.GenerateConfig{
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
}
|
||||
results, err := a.m.Classify(ctx, prompts, mcfg, cfg.ReturnLogits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]inference.ClassifyResult, len(results))
|
||||
for i, r := range results {
|
||||
out[i] = inference.ClassifyResult{
|
||||
Token: inference.Token{ID: r.Token.ID, Text: r.Token.Text},
|
||||
Logits: r.Logits,
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *metalAdapter) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
mcfg := metal.GenerateConfig{
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
StopTokens: cfg.StopTokens,
|
||||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
}
|
||||
results, err := a.m.BatchGenerate(ctx, prompts, mcfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]inference.BatchResult, len(results))
|
||||
for i, r := range results {
|
||||
tokens := make([]inference.Token, len(r.Tokens))
|
||||
for j, t := range r.Tokens {
|
||||
tokens[j] = inference.Token{ID: t.ID, Text: t.Text}
|
||||
}
|
||||
out[i] = inference.BatchResult{Tokens: tokens, Err: r.Err}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *metalAdapter) ModelType() string { return a.m.ModelType() }
|
||||
func (a *metalAdapter) Err() error { return a.m.Err() }
|
||||
func (a *metalAdapter) Close() error { return a.m.Close() }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue