feat(metal): bind CumSum, implement TopP and MinP sampling
New ops: CumSum, Sort, Argsort, Greater, MaxAxis — all bound to mlx-c. TopP (nucleus) sampling now fully implemented: sorts probabilities descending, computes cumulative sum, masks tokens beyond the threshold, and scatters the mask back to original positions via argsort. MinP sampling now fully implemented: computes softmax, finds max probability, masks tokens below min_p * max_prob. Both were previously stubs that passed through logits unchanged. 10 new tests (CumSum variants, Sort, Argsort, Greater, MaxAxis, TopP, MinP). 176 total tests passing. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
df0b300b1a
commit
f39126f6bd
4 changed files with 212 additions and 16 deletions
|
|
@ -351,3 +351,39 @@ func LogSumExp(a *Array, axis int, keepDims bool) *Array {
|
|||
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// CumSum returns the cumulative sum along the given axis.
|
||||
// reverse=false for forward, inclusive=true to include the current element.
|
||||
func CumSum(a *Array, axis int, reverse, inclusive bool) *Array {
|
||||
out := New("CUMSUM", a)
|
||||
C.mlx_cumsum(&out.ctx, a.ctx, C.int(axis), C._Bool(reverse), C._Bool(inclusive), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sort returns the array sorted along the given axis.
|
||||
func Sort(a *Array, axis int) *Array {
|
||||
out := New("SORT", a)
|
||||
C.mlx_sort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argsort returns the indices that would sort the array along the given axis.
|
||||
func Argsort(a *Array, axis int) *Array {
|
||||
out := New("ARGSORT", a)
|
||||
C.mlx_argsort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Greater returns element-wise a > b as a bool array.
|
||||
func Greater(a, b *Array) *Array {
|
||||
out := New("GREATER", a, b)
|
||||
C.mlx_greater(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// MaxAxis returns the maximum value along the given axis.
|
||||
func MaxAxis(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MAX_AXIS", a)
|
||||
C.mlx_max_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -515,6 +515,87 @@ func TestAdd_Broadcasting(t *testing.T) {
|
|||
|
||||
// --- Random ---
|
||||
|
||||
// --- Cumulative and sorting ops ---
|
||||
|
||||
func TestCumSum(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
c := CumSum(a, -1, false, true)
|
||||
Materialize(c)
|
||||
floatSliceApprox(t, c.Floats(), []float32{1, 3, 6, 10})
|
||||
}
|
||||
|
||||
func TestCumSum_Exclusive(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
c := CumSum(a, -1, false, false) // exclusive
|
||||
Materialize(c)
|
||||
floatSliceApprox(t, c.Floats(), []float32{0, 1, 3, 6})
|
||||
}
|
||||
|
||||
func TestCumSum_Reverse(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
c := CumSum(a, -1, true, true) // reverse
|
||||
Materialize(c)
|
||||
floatSliceApprox(t, c.Floats(), []float32{10, 9, 7, 4})
|
||||
}
|
||||
|
||||
func TestSort(t *testing.T) {
|
||||
a := FromValues([]float32{3, 1, 4, 1, 5}, 1, 5)
|
||||
c := Sort(a, -1)
|
||||
Materialize(c)
|
||||
floatSliceApprox(t, c.Floats(), []float32{1, 1, 3, 4, 5})
|
||||
}
|
||||
|
||||
func TestArgsort(t *testing.T) {
|
||||
a := FromValues([]float32{3, 1, 4, 1, 5}, 1, 5)
|
||||
c := Argsort(a, -1)
|
||||
Materialize(c)
|
||||
// indices of sorted order: [1, 3, 0, 2, 4]
|
||||
got := c.Ints()
|
||||
want := []int{1, 3, 0, 2, 4}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("Argsort[%d] = %d, want %d", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGreater(t *testing.T) {
|
||||
a := FromValues([]float32{1, 5, 3}, 3)
|
||||
b := FromValues([]float32{2, 2, 3}, 3)
|
||||
c := Greater(a, b)
|
||||
// Greater returns bool dtype — cast to int32 for data extraction
|
||||
c = AsType(c, DTypeInt32)
|
||||
Materialize(c)
|
||||
// 1>2=false, 5>2=true, 3>3=false
|
||||
got := c.DataInt32()
|
||||
want := []int32{0, 1, 0}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("Greater[%d] = %d, want %d", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxAxis(t *testing.T) {
|
||||
a := FromValues([]float32{1, 5, 3, 4, 2, 6}, 2, 3)
|
||||
c := MaxAxis(a, -1, false) // max per row
|
||||
Materialize(c)
|
||||
floatSliceApprox(t, c.Floats(), []float32{5, 6})
|
||||
}
|
||||
|
||||
func TestMaxAxis_KeepDims(t *testing.T) {
|
||||
a := FromValues([]float32{1, 5, 3, 4, 2, 6}, 2, 3)
|
||||
c := MaxAxis(a, -1, true)
|
||||
Materialize(c)
|
||||
|
||||
shape := c.Shape()
|
||||
if shape[0] != 2 || shape[1] != 1 {
|
||||
t.Errorf("shape = %v, want [2 1]", shape)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Random ---
|
||||
|
||||
func TestRandomCategorical(t *testing.T) {
|
||||
// Heavily weighted towards index 2
|
||||
logprobs := FromValues([]float32{-100, -100, 0}, 1, 3)
|
||||
|
|
|
|||
|
|
@ -68,20 +68,65 @@ func (k TopKSampler) Sample(logits *Array) *Array {
|
|||
return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
// TopP implements nucleus sampling (cumulative probability threshold).
|
||||
// TopP implements nucleus (top-p) sampling.
|
||||
// Keeps the smallest set of tokens whose cumulative probability exceeds p.
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logits *Array) *Array {
|
||||
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
|
||||
// For now, pass through. TopK + Temperature covers most use cases.
|
||||
return logits
|
||||
// Convert logits to probabilities
|
||||
probs := Softmax(logits)
|
||||
|
||||
// Sort descending via argsort of negated probs
|
||||
neg := Negative(probs)
|
||||
sortIdx := Argsort(neg, -1)
|
||||
sortedProbs := TakeAlongAxis(probs, sortIdx, -1)
|
||||
|
||||
// Cumulative sum of sorted probabilities
|
||||
cumProbs := CumSum(sortedProbs, -1, false, true)
|
||||
|
||||
// Mask in sorted space: keep tokens where cumprob (excluding current) <= threshold
|
||||
shiftedCum := Subtract(cumProbs, sortedProbs)
|
||||
threshold := FromValue(float32(p))
|
||||
sortedMask := Where(
|
||||
Greater(shiftedCum, threshold),
|
||||
FromValue(float32(math.Inf(-1))),
|
||||
FromValue(float32(0)),
|
||||
)
|
||||
|
||||
// Scatter mask back to original positions
|
||||
mask := PutAlongAxis(
|
||||
Zeros(logits.Shape(), DTypeFloat32),
|
||||
sortIdx,
|
||||
sortedMask,
|
||||
-1,
|
||||
)
|
||||
|
||||
// Apply mask: -inf where excluded, original logit where kept
|
||||
return Where(
|
||||
Greater(FromValue(float32(0)), mask),
|
||||
FromValue(float32(math.Inf(-1))),
|
||||
logits,
|
||||
)
|
||||
}
|
||||
|
||||
// MinPSampler masks tokens below min_p * max_prob.
|
||||
// MinPSampler masks tokens whose probability is below min_p * max_prob.
|
||||
type MinPSampler float32
|
||||
|
||||
func (p MinPSampler) Sample(logits *Array) *Array {
|
||||
// For now, pass through — MinP is an optimization over TopP.
|
||||
// Full implementation requires finding max prob and masking below threshold.
|
||||
return logits
|
||||
// Convert logits to probabilities
|
||||
probs := Softmax(logits)
|
||||
|
||||
// Find the maximum probability
|
||||
maxProb := MaxAxis(probs, -1, true)
|
||||
|
||||
// Threshold = min_p * max_prob
|
||||
threshold := MulScalar(maxProb, float32(p))
|
||||
|
||||
// Mask tokens below threshold
|
||||
mask := Where(
|
||||
Greater(threshold, probs),
|
||||
FromValue(float32(math.Inf(-1))),
|
||||
logits,
|
||||
)
|
||||
return mask
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,26 +89,60 @@ func TestNew_Chain(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTopP_PassThrough(t *testing.T) {
|
||||
// TopP is currently a stub — verify it doesn't break the chain
|
||||
func TestTopP_DominantLogit(t *testing.T) {
|
||||
// With one dominant logit, TopP should always pick it
|
||||
logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := newSampler(0.5, 0.9, 0, 0) // topP=0.9 (stub), temp=0.5
|
||||
s := newSampler(0.5, 0.9, 0, 0) // topP=0.9, temp=0.5
|
||||
token := s.Sample(logits)
|
||||
Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("topP stub + temp sample = %d, want 2", token.Int())
|
||||
t.Errorf("topP dominant sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP_PassThrough(t *testing.T) {
|
||||
// MinP is currently a stub — verify it doesn't break the chain
|
||||
func TestTopP_RestrictsOptions(t *testing.T) {
|
||||
// Two equal high logits, two low. TopP=0.5 should mostly restrict to top tokens.
|
||||
logits := FromValues([]float32{10, 10, -100, -100}, 1, 4)
|
||||
s := newSampler(1.0, 0.5, 0, 0) // topP=0.5, temp=1.0
|
||||
|
||||
seen := map[int]bool{}
|
||||
for range 30 {
|
||||
token := s.Sample(logits)
|
||||
Materialize(token)
|
||||
seen[token.Int()] = true
|
||||
}
|
||||
|
||||
// Should only pick indices 0 or 1 (the two high-probability tokens)
|
||||
for idx := range seen {
|
||||
if idx != 0 && idx != 1 {
|
||||
t.Errorf("topP=0.5 sampled index %d, expected only 0 or 1", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP_DominantLogit(t *testing.T) {
|
||||
// With one dominant logit, MinP should always pick it
|
||||
logits := FromValues([]float32{-10, -10, 100, -10}, 1, 4)
|
||||
s := newSampler(0.5, 0, 0.1, 0) // minP=0.1 (stub), temp=0.5
|
||||
s := newSampler(0.5, 0, 0.1, 0) // minP=0.1, temp=0.5
|
||||
token := s.Sample(logits)
|
||||
Materialize(token)
|
||||
|
||||
if token.Int() != 2 {
|
||||
t.Errorf("minP stub + temp sample = %d, want 2", token.Int())
|
||||
t.Errorf("minP dominant sample = %d, want 2", token.Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP_RestrictsOptions(t *testing.T) {
|
||||
// One very high logit, rest are low. MinP=0.1 should mask the low tokens.
|
||||
logits := FromValues([]float32{-100, 50, -100, -100}, 1, 4)
|
||||
s := newSampler(1.0, 0, 0.1, 0) // minP=0.1, temp=1.0
|
||||
|
||||
for range 20 {
|
||||
token := s.Sample(logits)
|
||||
Materialize(token)
|
||||
if token.Int() != 1 {
|
||||
t.Errorf("minP with dominant logit sampled %d, want 1", token.Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue