From f39126f6bd31525fd75dac6f7604f865fd01a7b5 Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 20:39:44 +0000 Subject: [PATCH] feat(metal): bind CumSum, implement TopP and MinP sampling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New ops: CumSum, Sort, Argsort, Greater, MaxAxis — all bound to mlx-c. TopP (nucleus) sampling now fully implemented: sorts probabilities descending, computes cumulative sum, masks tokens beyond the threshold, and scatters the mask back to original positions via argsort. MinP sampling now fully implemented: computes softmax, finds max probability, masks tokens below min_p * max_prob. Both were previously stubs that passed through logits unchanged. 10 new tests (CumSum variants, Sort, Argsort, Greater, MaxAxis, TopP, MinP). 176 total tests passing. Co-Authored-By: Virgil --- internal/metal/ops.go | 36 ++++++++++++++++ internal/metal/ops_test.go | 81 +++++++++++++++++++++++++++++++++++ internal/metal/sample.go | 61 ++++++++++++++++++++++---- internal/metal/sample_test.go | 50 +++++++++++++++++---- 4 files changed, 212 insertions(+), 16 deletions(-) diff --git a/internal/metal/ops.go b/internal/metal/ops.go index e027985..2bf554c 100644 --- a/internal/metal/ops.go +++ b/internal/metal/ops.go @@ -351,3 +351,39 @@ func LogSumExp(a *Array, axis int, keepDims bool) *Array { C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) return out } + +// CumSum returns the cumulative sum along the given axis. +// reverse=false for forward, inclusive=true to include the current element. +func CumSum(a *Array, axis int, reverse, inclusive bool) *Array { + out := New("CUMSUM", a) + C.mlx_cumsum(&out.ctx, a.ctx, C.int(axis), C._Bool(reverse), C._Bool(inclusive), DefaultStream().ctx) + return out +} + +// Sort returns the array sorted along the given axis. +func Sort(a *Array, axis int) *Array { + out := New("SORT", a) + C.mlx_sort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +// Argsort returns the indices that would sort the array along the given axis. +func Argsort(a *Array, axis int) *Array { + out := New("ARGSORT", a) + C.mlx_argsort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +// Greater returns element-wise a > b as a bool array. +func Greater(a, b *Array) *Array { + out := New("GREATER", a, b) + C.mlx_greater(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// MaxAxis returns the maximum value along the given axis. +func MaxAxis(a *Array, axis int, keepDims bool) *Array { + out := New("MAX_AXIS", a) + C.mlx_max_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) + return out +} diff --git a/internal/metal/ops_test.go b/internal/metal/ops_test.go index 85a154d..e7eade1 100644 --- a/internal/metal/ops_test.go +++ b/internal/metal/ops_test.go @@ -515,6 +515,87 @@ func TestAdd_Broadcasting(t *testing.T) { // --- Random --- +// --- Cumulative and sorting ops --- + +func TestCumSum(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4}, 1, 4) + c := CumSum(a, -1, false, true) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1, 3, 6, 10}) +} + +func TestCumSum_Exclusive(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4}, 1, 4) + c := CumSum(a, -1, false, false) // exclusive + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{0, 1, 3, 6}) +} + +func TestCumSum_Reverse(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4}, 1, 4) + c := CumSum(a, -1, true, true) // reverse + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{10, 9, 7, 4}) +} + +func TestSort(t *testing.T) { + a := FromValues([]float32{3, 1, 4, 1, 5}, 1, 5) + c := Sort(a, -1) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1, 1, 3, 4, 5}) +} + +func TestArgsort(t *testing.T) { + a := FromValues([]float32{3, 1, 4, 1, 5}, 1, 5) + c := Argsort(a, -1) + Materialize(c) + // indices of sorted order: [1, 3, 0, 2, 4] + got := c.Ints() + want := []int{1, 3, 0, 2, 4} + for i := range got { + if got[i] != want[i] { + t.Errorf("Argsort[%d] = %d, want %d", i, got[i], want[i]) + } + } +} + +func TestGreater(t *testing.T) { + a := FromValues([]float32{1, 5, 3}, 3) + b := FromValues([]float32{2, 2, 3}, 3) + c := Greater(a, b) + // Greater returns bool dtype — cast to int32 for data extraction + c = AsType(c, DTypeInt32) + Materialize(c) + // 1>2=false, 5>2=true, 3>3=false + got := c.DataInt32() + want := []int32{0, 1, 0} + for i := range got { + if got[i] != want[i] { + t.Errorf("Greater[%d] = %d, want %d", i, got[i], want[i]) + } + } +} + +func TestMaxAxis(t *testing.T) { + a := FromValues([]float32{1, 5, 3, 4, 2, 6}, 2, 3) + c := MaxAxis(a, -1, false) // max per row + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{5, 6}) +} + +func TestMaxAxis_KeepDims(t *testing.T) { + a := FromValues([]float32{1, 5, 3, 4, 2, 6}, 2, 3) + c := MaxAxis(a, -1, true) + Materialize(c) + + shape := c.Shape() + if shape[0] != 2 || shape[1] != 1 { + t.Errorf("shape = %v, want [2 1]", shape) + } +} + +// --- Random --- + func TestRandomCategorical(t *testing.T) { // Heavily weighted towards index 2 logprobs := FromValues([]float32{-100, -100, 0}, 1, 3) diff --git a/internal/metal/sample.go b/internal/metal/sample.go index 7ef7f97..1d29160 100644 --- a/internal/metal/sample.go +++ b/internal/metal/sample.go @@ -68,20 +68,65 @@ func (k TopKSampler) Sample(logits *Array) *Array { return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1) } -// TopP implements nucleus sampling (cumulative probability threshold). +// TopP implements nucleus (top-p) sampling. +// Keeps the smallest set of tokens whose cumulative probability exceeds p. type TopP float32 func (p TopP) Sample(logits *Array) *Array { - // TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly. - // For now, pass through. TopK + Temperature covers most use cases. - return logits + // Convert logits to probabilities + probs := Softmax(logits) + + // Sort descending via argsort of negated probs + neg := Negative(probs) + sortIdx := Argsort(neg, -1) + sortedProbs := TakeAlongAxis(probs, sortIdx, -1) + + // Cumulative sum of sorted probabilities + cumProbs := CumSum(sortedProbs, -1, false, true) + + // Mask in sorted space: keep tokens where cumprob (excluding current) <= threshold + shiftedCum := Subtract(cumProbs, sortedProbs) + threshold := FromValue(float32(p)) + sortedMask := Where( + Greater(shiftedCum, threshold), + FromValue(float32(math.Inf(-1))), + FromValue(float32(0)), + ) + + // Scatter mask back to original positions + mask := PutAlongAxis( + Zeros(logits.Shape(), DTypeFloat32), + sortIdx, + sortedMask, + -1, + ) + + // Apply mask: -inf where excluded, original logit where kept + return Where( + Greater(FromValue(float32(0)), mask), + FromValue(float32(math.Inf(-1))), + logits, + ) } -// MinPSampler masks tokens below min_p * max_prob. +// MinPSampler masks tokens whose probability is below min_p * max_prob. type MinPSampler float32 func (p MinPSampler) Sample(logits *Array) *Array { - // For now, pass through — MinP is an optimization over TopP. - // Full implementation requires finding max prob and masking below threshold. - return logits + // Convert logits to probabilities + probs := Softmax(logits) + + // Find the maximum probability + maxProb := MaxAxis(probs, -1, true) + + // Threshold = min_p * max_prob + threshold := MulScalar(maxProb, float32(p)) + + // Mask tokens below threshold + mask := Where( + Greater(threshold, probs), + FromValue(float32(math.Inf(-1))), + logits, + ) + return mask } diff --git a/internal/metal/sample_test.go b/internal/metal/sample_test.go index 9cead08..2e98752 100644 --- a/internal/metal/sample_test.go +++ b/internal/metal/sample_test.go @@ -89,26 +89,60 @@ func TestNew_Chain(t *testing.T) { } } -func TestTopP_PassThrough(t *testing.T) { - // TopP is currently a stub — verify it doesn't break the chain +func TestTopP_DominantLogit(t *testing.T) { + // With one dominant logit, TopP should always pick it logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4) - s := newSampler(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5 + s := newSampler(0.5, 0.9, 0, 0) // topP=0.9, temp=0.5 token := s.Sample(logits) Materialize(token) if token.Int() != 2 { - t.Errorf("topP stub + temp sample = %d, want 2", token.Int()) + t.Errorf("topP dominant sample = %d, want 2", token.Int()) } } -func TestMinP_PassThrough(t *testing.T) { - // MinP is currently a stub — verify it doesn't break the chain +func TestTopP_RestrictsOptions(t *testing.T) { + // Two equal high logits, two low. TopP=0.5 should mostly restrict to top tokens. + logits := FromValues([]float32{10, 10, -100, -100}, 1, 4) + s := newSampler(1.0, 0.5, 0, 0) // topP=0.5, temp=1.0 + + seen := map[int]bool{} + for range 30 { + token := s.Sample(logits) + Materialize(token) + seen[token.Int()] = true + } + + // Should only pick indices 0 or 1 (the two high-probability tokens) + for idx := range seen { + if idx != 0 && idx != 1 { + t.Errorf("topP=0.5 sampled index %d, expected only 0 or 1", idx) + } + } +} + +func TestMinP_DominantLogit(t *testing.T) { + // With one dominant logit, MinP should always pick it logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4) - s := newSampler(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5 + s := newSampler(0.5, 0, 0.1, 0) // minP=0.1, temp=0.5 token := s.Sample(logits) Materialize(token) if token.Int() != 2 { - t.Errorf("minP stub + temp sample = %d, want 2", token.Int()) + t.Errorf("minP dominant sample = %d, want 2", token.Int()) + } +} + +func TestMinP_RestrictsOptions(t *testing.T) { + // One very high logit, rest are low. MinP=0.1 should mask the low tokens. + logits := FromValues([]float32{-100, 50, -100, -100}, 1, 4) + s := newSampler(1.0, 0, 0.1, 0) // minP=0.1, temp=1.0 + + for range 20 { + token := s.Sample(logits) + Materialize(token) + if token.Int() != 1 { + t.Errorf("minP with dominant logit sampled %d, want 1", token.Int()) + } } }