From 5644857034835489b45c9a8dae16a54042812416 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 23:28:15 +0000 Subject: [PATCH] feat(metal): implement batch inference (Classify, BatchGenerate) - Add ForwardMasked to InternalModel, Gemma3 and Qwen3 architectures - Thread attention mask through decoder layers and SDPA calls - Use ScaledDotProductAttentionWithMask when explicit mask provided - Create batch.go with padded batching, mask construction, Classify (prefill-only) and BatchGenerate (autoregressive) implementations - Wire Classify/BatchGenerate through metalAdapter to go-inference - Tests: mask unit tests (shape, values, multi-batch), Classify with 4 prompts (152 prompts/s), WithLogits, BatchGenerate with 2 prompts Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- internal/metal/batch.go | 304 +++++++++++++++++++++++++++++++++++ internal/metal/batch_test.go | 105 ++++++++++++ internal/metal/error_test.go | 3 +- internal/metal/gemma3.go | 19 ++- internal/metal/model.go | 5 + internal/metal/qwen3.go | 22 ++- mlx_test.go | 110 +++++++++++++ register_metal.go | 45 ++++++ 8 files changed, 602 insertions(+), 11 deletions(-) create mode 100644 internal/metal/batch.go create mode 100644 internal/metal/batch_test.go diff --git a/internal/metal/batch.go b/internal/metal/batch.go new file mode 100644 index 0000000..05aa9da --- /dev/null +++ b/internal/metal/batch.go @@ -0,0 +1,304 @@ +//go:build darwin && arm64 + +package metal + +import ( + "context" + "fmt" + "math" + "sort" +) + +// ClassifyResult holds the output for a single prompt in batch classification. +type ClassifyResult struct { + Token Token + Logits []float32 +} + +// BatchResult holds the output for a single prompt in batch generation. +type BatchResult struct { + Tokens []Token + Err error +} + +// Classify runs batched prefill-only inference. Each prompt gets a single +// forward pass and the token at the last position is sampled. +func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConfig, returnLogits bool) ([]ClassifyResult, error) { + if len(prompts) == 0 { + return nil, nil + } + + // Tokenise all prompts. + encoded := make([][]int32, len(prompts)) + lengths := make([]int, len(prompts)) + for i, p := range prompts { + encoded[i] = m.tokenizer.Encode(p) + lengths[i] = len(encoded[i]) + } + + // Sort by length descending for minimal padding. Track original indices. + indices := make([]int, len(prompts)) + for i := range indices { + indices[i] = i + } + sort.Slice(indices, func(a, b int) bool { + return lengths[indices[a]] > lengths[indices[b]] + }) + + maxLen := lengths[indices[0]] + N := int32(len(prompts)) + L := int32(maxLen) + + // Build padded token matrix [N, maxLen] and prompt lengths (sorted order). + padded := make([]int32, int(N)*int(L)) + sortedLengths := make([]int32, N) + for si, origIdx := range indices { + sortedLengths[si] = int32(lengths[origIdx]) + copy(padded[si*int(L):], encoded[origIdx]) + // Remaining positions are already 0 (pad token) + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Build attention mask [N, 1, L, L]. + mask := buildBatchMask(N, L, sortedLengths) + + // Single forward pass. + tokens := FromValues(padded, int(N), int(L)) + logits := m.model.ForwardMasked(tokens, mask, m.newCachesN(int(N))) + if err := Eval(logits); err != nil { + return nil, fmt.Errorf("classify prefill: %w", err) + } + + // logits shape: [N, L, vocab] + sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK) + + // Gather logits at each prompt's last real token position and sample. + sortedResults := make([]ClassifyResult, N) + for si := int32(0); si < N; si++ { + lastPos := sortedLengths[si] - 1 + + // Extract [1, vocab] at position lastPos for this batch element. + batchLogits := SliceAxis(logits, 0, si, si+1) // [1, L, vocab] + posLogits := SliceAxis(batchLogits, 1, lastPos, lastPos+1) + posLogits = Reshape(posLogits, 1, int32(posLogits.Dim(2))) // [1, vocab] + + next := sampler.Sample(posLogits) + if err := Eval(next); err != nil { + return nil, fmt.Errorf("classify sample %d: %w", si, err) + } + + id := int32(next.Int()) + text := m.tokenizer.DecodeToken(id) + sortedResults[si].Token = Token{ID: id, Text: text} + + if returnLogits { + posLogits = Reshape(posLogits, int32(posLogits.Dim(1))) + if err := Eval(posLogits); err != nil { + return nil, fmt.Errorf("classify logits %d: %w", si, err) + } + sortedResults[si].Logits = posLogits.Floats() + } + } + + // Unsort results back to original prompt order. + results := make([]ClassifyResult, N) + for si, origIdx := range indices { + results[origIdx] = sortedResults[si] + } + + return results, nil +} + +// BatchGenerate runs batched autoregressive generation. +func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg GenerateConfig) ([]BatchResult, error) { + if len(prompts) == 0 { + return nil, nil + } + + // Tokenise all prompts. + encoded := make([][]int32, len(prompts)) + lengths := make([]int, len(prompts)) + for i, p := range prompts { + encoded[i] = m.tokenizer.Encode(p) + lengths[i] = len(encoded[i]) + } + + // Sort by length descending. + indices := make([]int, len(prompts)) + for i := range indices { + indices[i] = i + } + sort.Slice(indices, func(a, b int) bool { + return lengths[indices[a]] > lengths[indices[b]] + }) + + N := int32(len(prompts)) + maxLen := lengths[indices[0]] + L := int32(maxLen) + + // Build padded token matrix and lengths. + padded := make([]int32, int(N)*int(L)) + sortedLengths := make([]int32, N) + for si, origIdx := range indices { + sortedLengths[si] = int32(lengths[origIdx]) + copy(padded[si*int(L):], encoded[origIdx]) + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Prefill with mask. + mask := buildBatchMask(N, L, sortedLengths) + tokens := FromValues(padded, int(N), int(L)) + caches := m.newCachesN(int(N)) + logits := m.model.ForwardMasked(tokens, mask, caches) + if err := Eval(logits); err != nil { + return nil, fmt.Errorf("batch prefill: %w", err) + } + + sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK) + eosID := m.tokenizer.EOSToken() + + // Per-sequence state. + type seqState struct { + tokens []Token + finished bool + } + states := make([]seqState, N) + + maxTokens := cfg.MaxTokens + if maxTokens <= 0 { + maxTokens = 256 + } + + for step := 0; step < maxTokens; step++ { + select { + case <-ctx.Done(): + // Return partial results on cancellation. + sortedResults := make([]BatchResult, N) + for si := range states { + sortedResults[si] = BatchResult{Tokens: states[si].tokens, Err: ctx.Err()} + } + results := make([]BatchResult, N) + for si, origIdx := range indices { + results[origIdx] = sortedResults[si] + } + return results, nil + default: + } + + // Sample next token for each sequence from last position logits. + nextIDs := make([]int32, N) + allFinished := true + + for si := int32(0); si < N; si++ { + if states[si].finished { + nextIDs[si] = 0 // pad + continue + } + + var posLogits *Array + if step == 0 { + // First step: gather from prefill at each prompt's last real position. + lastPos := sortedLengths[si] - 1 + batchLogits := SliceAxis(logits, 0, si, si+1) + posLogits = SliceAxis(batchLogits, 1, lastPos, lastPos+1) + } else { + // Subsequent steps: logits shape [N, 1, vocab]. + posLogits = SliceAxis(logits, 0, si, si+1) + posLogits = SliceAxis(posLogits, 1, 0, 1) + } + posLogits = Reshape(posLogits, 1, int32(posLogits.Dim(posLogits.NumDims()-1))) + + next := sampler.Sample(posLogits) + if err := Eval(next); err != nil { + return nil, fmt.Errorf("batch sample step %d seq %d: %w", step, si, err) + } + + id := int32(next.Int()) + nextIDs[si] = id + + // Check stop conditions. + if id == eosID { + states[si].finished = true + continue + } + for _, stop := range cfg.StopTokens { + if id == stop { + states[si].finished = true + break + } + } + if !states[si].finished { + text := m.tokenizer.DecodeToken(id) + states[si].tokens = append(states[si].tokens, Token{ID: id, Text: text}) + allFinished = false + } + } + + if allFinished { + break + } + + // Feed next tokens [N, 1] through the model. + nextInput := FromValues(nextIDs, int(N), 1) + logits = m.model.Forward(nextInput, caches) + if err := Eval(logits); err != nil { + return nil, fmt.Errorf("batch decode step %d: %w", step, err) + } + } + + // Unsort results back to original order. + sortedResults := make([]BatchResult, N) + for si := range states { + sortedResults[si] = BatchResult{Tokens: states[si].tokens} + } + results := make([]BatchResult, N) + for si, origIdx := range indices { + results[origIdx] = sortedResults[si] + } + + return results, nil +} + +// buildBatchMask constructs a combined causal + padding attention mask. +// Shape: [N, 1, L, L]. Values: 0.0 = attend, -inf = ignore. +// mask[b, 0, i, j] = 0.0 if j <= i AND j < promptLen[b], else -inf. +func buildBatchMask(N, L int32, promptLens []int32) *Array { + negInf := float32(math.Inf(-1)) + data := make([]float32, int(N)*int(L)*int(L)) + + for b := int32(0); b < N; b++ { + pLen := promptLens[b] + base := int(b) * int(L) * int(L) + for i := int32(0); i < L; i++ { + for j := int32(0); j < L; j++ { + if j <= i && j < pLen { + data[base+int(i)*int(L)+int(j)] = 0 + } else { + data[base+int(i)*int(L)+int(j)] = negInf + } + } + } + } + + mask := FromValues(data, int(N), 1, int(L), int(L)) + return mask +} + +// newCachesN creates N independent sets of per-layer caches for batched inference. +// Since our KV cache implementation handles batch dimension internally, +// we create a single set of caches (same as non-batched). +func (m *Model) newCachesN(n int) []Cache { + // KV caches handle the batch dimension automatically via the key/value + // array shapes. A single cache set works for any batch size. + return m.newCaches() +} diff --git a/internal/metal/batch_test.go b/internal/metal/batch_test.go new file mode 100644 index 0000000..5f21e13 --- /dev/null +++ b/internal/metal/batch_test.go @@ -0,0 +1,105 @@ +//go:build darwin && arm64 + +package metal + +import ( + "math" + "testing" +) + +func TestBuildBatchMask_Shape(t *testing.T) { + // 2 prompts, max length 4, prompt lengths [3, 2]. + mask := buildBatchMask(2, 4, []int32{3, 2}) + if err := Eval(mask); err != nil { + t.Fatalf("Eval mask: %v", err) + } + + shape := mask.Shape() + want := []int32{2, 1, 4, 4} + if len(shape) != 4 { + t.Fatalf("mask ndim = %d, want 4", len(shape)) + } + for i, s := range shape { + if s != want[i] { + t.Errorf("mask shape[%d] = %d, want %d", i, s, want[i]) + } + } +} + +func TestBuildBatchMask_Values(t *testing.T) { + // Single prompt of length 3, padded to 4. + // Expected mask [1, 1, 4, 4]: + // row 0: [0, -inf, -inf, -inf] (can only attend to pos 0) + // row 1: [0, 0, -inf, -inf] (attend to pos 0,1) + // row 2: [0, 0, 0, -inf] (attend to pos 0,1,2) + // row 3: [0, 0, 0, -inf] (row 3 is padding — causal says j<=3 but j<3 caps it) + mask := buildBatchMask(1, 4, []int32{3}) + if err := Eval(mask); err != nil { + t.Fatalf("Eval mask: %v", err) + } + + // Flatten to get values. + flat := Reshape(mask, 16) + if err := Eval(flat); err != nil { + t.Fatalf("Eval flat: %v", err) + } + vals := flat.Floats() + + negInf := float32(math.Inf(-1)) + expected := []float32{ + // row 0: attend j=0 only + 0, negInf, negInf, negInf, + // row 1: attend j=0,1 + 0, 0, negInf, negInf, + // row 2: attend j=0,1,2 + 0, 0, 0, negInf, + // row 3: padding row — causal allows j<=3 but padding caps at j<3 + 0, 0, 0, negInf, + } + + for i, v := range vals { + e := expected[i] + if math.IsInf(float64(e), -1) { + if !math.IsInf(float64(v), -1) { + t.Errorf("vals[%d] = %f, want -inf", i, v) + } + } else if v != e { + t.Errorf("vals[%d] = %f, want %f", i, v, e) + } + } +} + +func TestBuildBatchMask_MultipleBatches(t *testing.T) { + // 2 prompts: lengths [2, 1], max length 2. + mask := buildBatchMask(2, 2, []int32{2, 1}) + if err := Eval(mask); err != nil { + t.Fatalf("Eval mask: %v", err) + } + + flat := Reshape(mask, 8) + if err := Eval(flat); err != nil { + t.Fatalf("Eval flat: %v", err) + } + vals := flat.Floats() + + negInf := float32(math.Inf(-1)) + expected := []float32{ + // batch 0 (len=2): full causal, no padding + 0, negInf, + 0, 0, + // batch 1 (len=1): only first position is real + 0, negInf, + 0, negInf, // row 1: causal allows j<=1 but padding caps at j<1 + } + + for i, v := range vals { + e := expected[i] + if math.IsInf(float64(e), -1) { + if !math.IsInf(float64(v), -1) { + t.Errorf("batch vals[%d] = %f, want -inf", i, v) + } + } else if v != e { + t.Errorf("batch vals[%d] = %f, want %f", i, v, e) + } + } +} diff --git a/internal/metal/error_test.go b/internal/metal/error_test.go index 9893491..7370aab 100644 --- a/internal/metal/error_test.go +++ b/internal/metal/error_test.go @@ -67,7 +67,8 @@ type fakeModel struct { numLayers int } -func (f *fakeModel) Forward(_ *Array, _ []Cache) *Array { return nil } +func (f *fakeModel) Forward(_ *Array, _ []Cache) *Array { return nil } +func (f *fakeModel) ForwardMasked(_ *Array, _ *Array, _ []Cache) *Array { return nil } func (f *fakeModel) NewCache() []Cache { caches := make([]Cache, f.numLayers) for i := range caches { diff --git a/internal/metal/gemma3.go b/internal/metal/gemma3.go index 4542200..407892c 100644 --- a/internal/metal/gemma3.go +++ b/internal/metal/gemma3.go @@ -323,6 +323,10 @@ func isLayerSliding(layerIdx, pattern int32) bool { // Forward runs the text model forward pass. func (m *GemmaModel) Forward(tokens *Array, caches []Cache) *Array { + return m.ForwardMasked(tokens, nil, caches) +} + +func (m *GemmaModel) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array { shape := tokens.Shape() B, L := shape[0], shape[1] @@ -330,15 +334,15 @@ func (m *GemmaModel) Forward(tokens *Array, caches []Cache) *Array { h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize)))) for i, layer := range m.Layers { - h = layer.forward(h, caches[i], B, L, m.Cfg) + h = layer.forward(h, caches[i], B, L, mask, m.Cfg) } return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)) } -func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *TextConfig) *Array { +func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *TextConfig) *Array { normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) - attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg) + attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, mask, cfg) attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) h := Add(x, attnOut) @@ -348,7 +352,7 @@ func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *TextConfig) * return Add(h, mlpOut) } -func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *TextConfig) *Array { +func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, mask *Array, cfg *TextConfig) *Array { q := a.QProj.Forward(x) k := a.KProj.Forward(x) v := a.VProj.Forward(x) @@ -386,7 +390,12 @@ func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg * } // Scaled dot-product attention - out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + var out *Array + if mask != nil { + out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale) + } else { + out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + } out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) return a.OProj.Forward(out) } diff --git a/internal/metal/model.go b/internal/metal/model.go index 325b3bb..f001356 100644 --- a/internal/metal/model.go +++ b/internal/metal/model.go @@ -14,6 +14,11 @@ type InternalModel interface { // Forward runs the model forward pass on token IDs with KV caches. Forward(tokens *Array, caches []Cache) *Array + // ForwardMasked runs the forward pass with an explicit attention mask. + // mask shape: [B, 1, L, L] — additive mask (0 = attend, -inf = ignore). + // Used for batched inference with padded sequences. + ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array + // NewCache creates per-layer KV caches for generation. NewCache() []Cache diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go index 5b4e7a3..ff0b233 100644 --- a/internal/metal/qwen3.go +++ b/internal/metal/qwen3.go @@ -249,22 +249,29 @@ func LoadQwen3(modelPath string) (*Qwen3Model, error) { // Forward runs the Qwen 3 forward pass. // Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size). func (m *Qwen3Model) Forward(tokens *Array, caches []Cache) *Array { + return m.ForwardMasked(tokens, nil, caches) +} + +// ForwardMasked runs the forward pass with an explicit attention mask. +// mask shape: [B, 1, L, L] — additive mask (0 = attend, -inf = ignore). +// When mask is nil, standard causal attention is used. +func (m *Qwen3Model) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *Array { shape := tokens.Shape() B, L := shape[0], shape[1] h := m.EmbedTokens.Forward(tokens) for i, layer := range m.Layers { - h = layer.forward(h, caches[i], B, L, m.Cfg) + h = layer.forward(h, caches[i], B, L, mask, m.Cfg) } return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps)) } -func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array { +func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array { // Pre-attention norm → attention → residual add normed := l.InputNorm.Forward(x, cfg.RMSNormEps) - attnOut := l.Attention.forward(normed, c, B, L, cfg) + attnOut := l.Attention.forward(normed, c, B, L, mask, cfg) h := Add(x, attnOut) // Pre-MLP norm → MLP → residual add @@ -273,7 +280,7 @@ func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Con return Add(h, mlpOut) } -func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config) *Array { +func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array { q := a.QProj.Forward(x) k := a.KProj.Forward(x) v := a.VProj.Forward(x) @@ -309,7 +316,12 @@ func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, cfg *Qwen3Config } // Scaled dot-product attention - out := ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + var out *Array + if mask != nil { + out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale) + } else { + out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + } out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) return a.OProj.Forward(out) } diff --git a/mlx_test.go b/mlx_test.go index 4b85d36..e5063f5 100644 --- a/mlx_test.go +++ b/mlx_test.go @@ -443,6 +443,116 @@ func TestLlama_Inference(t *testing.T) { } } +// --- Batch Inference tests (Gemma3-1B) --- + +// TestClassify_Batch validates batched prefill-only classification. +func TestClassify_Batch(t *testing.T) { + modelPath := gemma3ModelPath(t) + + m, err := inference.LoadModel(modelPath) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer func() { m.Close(); mlx.ClearCache() }() + + ctx := context.Background() + prompts := []string{ + "The capital of France is", + "2 + 2 =", + "The colour of the sky is", + "Go is a programming", + } + + start := time.Now() + results, err := m.Classify(ctx, prompts) + dur := time.Since(start) + if err != nil { + t.Fatalf("Classify: %v", err) + } + + if len(results) != len(prompts) { + t.Fatalf("Classify returned %d results, want %d", len(results), len(prompts)) + } + + for i, r := range results { + if r.Token.ID == 0 { + t.Errorf("prompt %d: got pad token (id=0)", i) + } + t.Logf("prompt %d %q → token %d %q", i, prompts[i], r.Token.ID, r.Token.Text) + } + t.Logf("Classified %d prompts in %s (%.1f prompts/s)", len(prompts), dur, float64(len(prompts))/dur.Seconds()) +} + +// TestClassify_WithLogits validates that logits are returned when requested. +func TestClassify_WithLogits(t *testing.T) { + modelPath := gemma3ModelPath(t) + + m, err := inference.LoadModel(modelPath) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer func() { m.Close(); mlx.ClearCache() }() + + ctx := context.Background() + results, err := m.Classify(ctx, []string{"Hello world"}, inference.WithLogits()) + if err != nil { + t.Fatalf("Classify: %v", err) + } + + if len(results) != 1 { + t.Fatalf("got %d results, want 1", len(results)) + } + if len(results[0].Logits) == 0 { + t.Fatal("expected non-empty logits with WithLogits()") + } + t.Logf("Logits length: %d (vocab size)", len(results[0].Logits)) +} + +// TestBatchGenerate validates batched autoregressive generation. +func TestBatchGenerate(t *testing.T) { + modelPath := gemma3ModelPath(t) + + m, err := inference.LoadModel(modelPath) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer func() { m.Close(); mlx.ClearCache() }() + + ctx := context.Background() + prompts := []string{ + "The capital of France is", + "2 + 2 =", + } + + start := time.Now() + results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(16)) + dur := time.Since(start) + if err != nil { + t.Fatalf("BatchGenerate: %v", err) + } + + if len(results) != len(prompts) { + t.Fatalf("BatchGenerate returned %d results, want %d", len(results), len(prompts)) + } + + for i, r := range results { + if r.Err != nil { + t.Errorf("prompt %d error: %v", i, r.Err) + continue + } + if len(r.Tokens) == 0 { + t.Errorf("prompt %d: no tokens generated", i) + continue + } + var output strings.Builder + for _, tok := range r.Tokens { + output.WriteString(tok.Text) + } + t.Logf("prompt %d %q → %d tokens: %s", i, prompts[i], len(r.Tokens), output.String()) + } + t.Logf("Batch generated in %s", dur) +} + // TestLlama_Chat validates chat template for Llama 3 models. func TestLlama_Chat(t *testing.T) { modelPath := llamaModelPath(t) diff --git a/register_metal.go b/register_metal.go index 7f0ffe6..cc9dc19 100644 --- a/register_metal.go +++ b/register_metal.go @@ -119,6 +119,51 @@ func (a *metalAdapter) Chat(ctx context.Context, messages []inference.Message, o } } +func (a *metalAdapter) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + cfg := inference.ApplyGenerateOpts(opts) + mcfg := metal.GenerateConfig{ + Temperature: cfg.Temperature, + TopK: cfg.TopK, + } + results, err := a.m.Classify(ctx, prompts, mcfg, cfg.ReturnLogits) + if err != nil { + return nil, err + } + out := make([]inference.ClassifyResult, len(results)) + for i, r := range results { + out[i] = inference.ClassifyResult{ + Token: inference.Token{ID: r.Token.ID, Text: r.Token.Text}, + Logits: r.Logits, + } + } + return out, nil +} + +func (a *metalAdapter) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) { + cfg := inference.ApplyGenerateOpts(opts) + mcfg := metal.GenerateConfig{ + MaxTokens: cfg.MaxTokens, + Temperature: cfg.Temperature, + TopK: cfg.TopK, + TopP: cfg.TopP, + StopTokens: cfg.StopTokens, + RepeatPenalty: cfg.RepeatPenalty, + } + results, err := a.m.BatchGenerate(ctx, prompts, mcfg) + if err != nil { + return nil, err + } + out := make([]inference.BatchResult, len(results)) + for i, r := range results { + tokens := make([]inference.Token, len(r.Tokens)) + for j, t := range r.Tokens { + tokens[j] = inference.Token{ID: t.ID, Text: t.Text} + } + out[i] = inference.BatchResult{Tokens: tokens, Err: r.Err} + } + return out, nil +} + func (a *metalAdapter) ModelType() string { return a.m.ModelType() } func (a *metalAdapter) Err() error { return a.m.Err() } func (a *metalAdapter) Close() error { return a.m.Close() }