fix(metal): address 4 minor code review items

- Rename New() → newArray() to signal internal-only intent (112 usages)
- Remove unused Collect() function and its test
- Fix discarded json.Unmarshal error in qwen3.go
- Document AsStrided stride formula in gemma3.go

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-19 21:36:40 +00:00
parent fb95cde30c
commit 443347a2f8
16 changed files with 89 additions and 108 deletions

View file

@ -281,7 +281,7 @@ m, _ := inference.LoadModel(path)
for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(128)) { ... }
```
### New go-inference features available
### newArray go-inference features available
- `inference.List()` — returns all registered backend names
- `inference.Backend.Available()` — hardware availability check
@ -321,7 +321,7 @@ Was a stub that passed logits through unchanged. Now fully implemented:
Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses MaxAxis to find the peak probability per position.
### New Bindings
### newArray Bindings
| Function | Header | Purpose |
|----------|--------|---------|
@ -424,7 +424,7 @@ Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses
The old `checkError()` function logged errors via `slog.Error` and swallowed them. The mlx-c error handler (`mlx_go_error_handler`) stored the error string in a C static variable, but no Go code read it back as an error value.
### New error model
### newArray error model
1. **`lastError() error`** — reads and clears the C-level error string. Returns `fmt.Errorf("mlx: %s", msg)` or nil. All callers now get real MLX error messages instead of generic "failed" strings.
@ -453,4 +453,4 @@ The old `checkError()` function logged errors via `slog.Error` and swallowed the
### Test results
- 180 tests passing (176 existing + 4 new error handling tests)
- New tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file
- newArray tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file

12
TODO.md
View file

@ -87,7 +87,7 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
- [x] **`-mmacosx-version-min=26.0` is wrong** — ✅ Changed to `13.3` (MLX's own minimum). No longer locks out macOS 14/15 users.
- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). New test: `TestNewCaches_ContextLen`.
- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). newArray test: `TestNewCaches_ContextLen`.
### Important — Should Fix
@ -103,13 +103,13 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
### Minor — Nice to Have
- [ ] **Rename `New()` → `newArray()`**`array.go:32`: `New("INPUT")` is exported but only used internally as a pre-allocation pattern before C fills in `ctx`. Since this is `internal/metal`, exporting is harmless, but `newArray` better signals intent.
- [x] **Rename `New()` → `newArray()`** — ✅ Renamed via IDE refactoring (112 usages updated). Unexported, signals internal-only intent.
- [ ] **`Collect()` is unused** — `metal.go:125-133` defines `Collect()` to gather valid arrays for batch Materialize, but nothing in the codebase calls it. Dead code.
- [x] **`Collect()` is unused** — ✅ Removed function and its test. Dead code eliminated.
- [ ] **`qwen3.go` / `gemma3.go` — second `json.Unmarshal` error discarded** — Both model configs parse quantization with a second unmarshal whose error is silently dropped. Probably fine (same data already parsed successfully), but inconsistent with the first unmarshal which returns errors.
- [x] **`qwen3.go` — second `json.Unmarshal` error discarded** — ✅ Now checks and returns the error. gemma3.go already handled it correctly.
- [ ] **Document `AsStrided` stride formula**`gemma3.go:354-359` uses stride manipulation for `[B,L,H*D]``[B,H,L,D]` virtual transpose. The formula is correct but non-obvious. A one-line comment explaining the stride derivation prevents future confusion.
- [x] **Document `AsStrided` stride formula** — ✅ Added comment explaining the stride derivation for the `[B,L,H*D]``[B,H,L,D]` virtual transpose.
### Questions for You to Consider
@ -124,4 +124,4 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
1. Virgil in core/go writes tasks here after research
2. This repo's session picks up tasks in phase order
3. Mark `[x]` when done, note commit hash
4. New discoveries → add tasks, flag in FINDINGS.md
4. newArray discoveries → add tasks, flag in FINDINGS.md

View file

@ -42,5 +42,5 @@ Tasks for the CLion Claude session. Written by GoLand Claude or Virgil.
1. GoLand Claude or Virgil writes tasks here
2. Pick up in order, mark `[x]` when done
3. New findings → `cpp/FINDINGS.md`
3. newArray findings → `cpp/FINDINGS.md`
4. If Go changes needed → note in FINDINGS.md for GoLand Claude

View file

@ -60,7 +60,7 @@ go-mlx/
│ ├── lora.go LoRA adapters
│ ├── optim.go AdamW
│ ├── generate.go NEW: autoregressive generation loop
│ └── backend.go Implements mlx.TextModel, exports New()
│ └── backend.go Implements mlx.TextModel, exports newArray()
├── mlxlm/ Future: Python subprocess backend
│ └── backend.go Implements mlx.Backend via core/go/pkg/process
@ -174,7 +174,7 @@ func Get(name string) (Backend, bool)
func Default() (Backend, error)
// register_metal.go (//go:build darwin && arm64)
func init() { Register(metal.New()) }
func init() { Register(metal.newArray()) }
func MetalAvailable() bool { return true }
```
@ -294,9 +294,9 @@ Thin subprocess wrapper using `core/go/pkg/process`. Implements `mlx.Backend`. R
## Testing
- All 148 existing tests move into `internal/metal/` and must pass
- New `generate_test.go` — autoregressive loop with Gemma3-1B from `/Volumes/Data/lem/safetensors/`
- New `backend_test.go` — end-to-end LoadModel + Generate
- New `memory_test.go` — 1000-token generation, assert GetPeakMemory() bounded
- newArray `generate_test.go` — autoregressive loop with Gemma3-1B from `/Volumes/Data/lem/safetensors/`
- newArray `backend_test.go` — end-to-end LoadModel + Generate
- newArray `memory_test.go` — 1000-token generation, assert GetPeakMemory() bounded
- Root `mlx_test.go` — integration via public mlx.LoadModel() API
## Communication

View file

@ -312,7 +312,7 @@ Copy core CGO bridge from `mlx.go` to `internal/metal/metal.go`. Change:
Copy `array.go` to `internal/metal/array.go`. Change:
- `package mlx``package metal`
- All types (`Array`, `New`, `FromValue`, `FromValues`, `Zeros`, `Free`) stay exported within the internal package
- All types (`Array`, `newArray`, `FromValue`, `FromValues`, `Zeros`, `Free`) stay exported within the internal package
**Step 5: Move stream.go → internal/metal/stream.go**
@ -456,7 +456,7 @@ Check for conflicts when flattening:
- `tokenizer.Load` → rename to `loadTokenizer`
- `tokenizer.Tokenizer` struct → stays `Tokenizer`
- `sample.Sampler` interface → stays `Sampler`
- `sample.New` → rename to `newSampler`
- `sample.newArray` → rename to `newSampler`
- `cache.Cache` interface → stays `Cache`
**Step 3: Delete old sub-package directories**

View file

@ -26,10 +26,10 @@ type Array struct {
name string // debug label
}
// New creates a named Array and registers a GC finalizer.
// newArray creates a named Array and registers a GC finalizer.
// The inputs parameter is accepted for API compatibility but not stored —
// MLX-C tracks inter-array references via its own refcounting.
func New(name string, inputs ...*Array) *Array {
func newArray(name string, inputs ...*Array) *Array {
t := &Array{name: name}
runtime.SetFinalizer(t, finalizeArray)
return t
@ -50,7 +50,7 @@ type scalarTypes interface {
// FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array {
Init()
tt := New("")
tt := newArray("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
@ -122,7 +122,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
panic(err)
}
tt := New("")
tt := newArray("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
@ -134,7 +134,7 @@ func Zeros(shape []int32, dtype DType) *Array {
for i, s := range shape {
cShape[i] = C.int(s)
}
tt := New("ZEROS")
tt := newArray("ZEROS")
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
return tt
}
@ -146,7 +146,7 @@ func (t *Array) Set(other *Array) {
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
func (t *Array) Clone() *Array {
tt := New(t.name)
tt := newArray(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
@ -231,7 +231,7 @@ func (t Array) IsRowContiguous() bool {
// Contiguous returns a row-major contiguous copy of the array.
// If the array is already row-contiguous, this is a no-op.
func Contiguous(a *Array) *Array {
out := New("CONTIGUOUS", a)
out := newArray("CONTIGUOUS", a)
C.mlx_contiguous(&out.ctx, a.ctx, C._Bool(false), DefaultStream().ctx)
return out
}

View file

@ -452,15 +452,3 @@ func TestArray_Float_DTypeFloat64(t *testing.T) {
}
}
// --- Collect ---
func TestCollect_FiltersNilAndInvalid(t *testing.T) {
a := FromValue(float32(1.0))
b := FromValue(float32(2.0))
Materialize(a, b)
result := Collect(a, nil, b)
if len(result) != 2 {
t.Errorf("Collect returned %d arrays, want 2", len(result))
}
}

View file

@ -12,21 +12,21 @@ import "unsafe"
// RMSNorm applies Root Mean Square normalization using a fused Metal kernel.
func RMSNorm(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
out := newArray("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
// LayerNorm applies Layer normalization using a fused Metal kernel.
func LayerNorm(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
out := newArray("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
out := New("FAST_ROPE", x)
out := newArray("FAST_ROPE", x)
freqs := C.mlx_array_new()
defer C.mlx_array_free(freqs)
C.mlx_fast_rope(
@ -60,7 +60,7 @@ func ScaledDotProductAttention(query, key, value *Array, scale float32, causal b
sinksArr := C.mlx_array_new()
defer C.mlx_array_free(sinksArr)
out := New("FAST_SDPA", query, key, value)
out := newArray("FAST_SDPA", query, key, value)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx)
return out
}
@ -73,7 +73,7 @@ func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale flo
sinksArr := C.mlx_array_new()
defer C.mlx_array_free(sinksArr)
out := New("FAST_SDPA", query, key, value, mask)
out := newArray("FAST_SDPA", query, key, value, mask)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx)
return out
}

View file

@ -350,7 +350,9 @@ func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
// Virtual transpose [B,L,H*D] → [B,H,L,D] via stride manipulation.
// Strides: batch = L*H*D (full sequence), head = D (adjacent heads in memory),
// seq = H*D (jump one full row of heads), elem = 1 (contiguous within head).
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},

View file

@ -40,7 +40,7 @@ func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("GRAD_INPUT")
a := newArray("GRAD_INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
@ -325,7 +325,7 @@ func vectorToArrays(vec C.mlx_vector_array) []*Array {
n := int(C.mlx_vector_array_size(vec))
out := make([]*Array, n)
for i := 0; i < n; i++ {
a := New("VEC_OUT")
a := newArray("VEC_OUT")
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
out[i] = a
}
@ -334,7 +334,7 @@ func vectorToArrays(vec C.mlx_vector_array) []*Array {
// Log returns element-wise natural logarithm.
func Log(a *Array) *Array {
out := New("LOG", a)
out := newArray("LOG", a)
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
@ -353,7 +353,7 @@ func MeanAll(a *Array) *Array {
// OnesLike creates an array of ones with the same shape and type as the input.
func OnesLike(a *Array) *Array {
out := New("ONES_LIKE", a)
out := newArray("ONES_LIKE", a)
C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}

View file

@ -109,8 +109,8 @@ func (l *LoRALinear) ParamCount() int {
// LoRAConfig specifies which layers to apply LoRA to and with what parameters.
type LoRAConfig struct {
Rank int // Decomposition rank (default 8)
Alpha float32 // Scaling factor (default 16)
Rank int // Decomposition rank (default 8)
Alpha float32 // Scaling factor (default 16)
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
}
@ -187,7 +187,7 @@ func (a *LoRAAdapter) Save(path string) error {
// RandomNormal generates normal (Gaussian) random values with given mean and stddev.
func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
Init()
out := New("RANDOM_NORMAL")
out := newArray("RANDOM_NORMAL")
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)

View file

@ -120,17 +120,6 @@ func MaterializeAsync(outputs ...*Array) {
}
}
// Collect gathers all valid arrays from a variadic list for batch Materialize.
func Collect(arrays ...*Array) []*Array {
var out []*Array
for _, a := range arrays {
if a != nil && a.Valid() {
out = append(out, a)
}
}
return out
}
// MetalAvailable reports whether Metal GPU is available.
func MetalAvailable() bool {
Init()

View file

@ -14,7 +14,7 @@ import "unsafe"
// Add returns element-wise a + b.
func Add(a, b *Array) *Array {
out := New("ADD", a, b)
out := newArray("ADD", a, b)
C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
@ -27,7 +27,7 @@ func AddScalar(a *Array, s float32) *Array {
// Mul returns element-wise a * b.
func Mul(a, b *Array) *Array {
out := New("MUL", a, b)
out := newArray("MUL", a, b)
C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
@ -40,21 +40,21 @@ func MulScalar(a *Array, s float32) *Array {
// Divide returns element-wise a / b.
func Divide(a, b *Array) *Array {
out := New("DIV", a, b)
out := newArray("DIV", a, b)
C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Subtract returns element-wise a - b.
func Subtract(a, b *Array) *Array {
out := New("SUB", a, b)
out := newArray("SUB", a, b)
C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Negative returns element-wise -a.
func Negative(a *Array) *Array {
out := New("NEG", a)
out := newArray("NEG", a)
C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
@ -63,14 +63,14 @@ func Negative(a *Array) *Array {
// Exp returns element-wise exp(a).
func Exp(a *Array) *Array {
out := New("EXP", a)
out := newArray("EXP", a)
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Sigmoid returns element-wise 1/(1+exp(-a)).
func Sigmoid(a *Array) *Array {
out := New("SIGMOID", a)
out := newArray("SIGMOID", a)
C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
@ -82,56 +82,56 @@ func SiLU(a *Array) *Array {
// Tanh returns element-wise tanh(a).
func Tanh(a *Array) *Array {
out := New("TANH", a)
out := newArray("TANH", a)
C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Sqrt returns element-wise sqrt(a).
func Sqrt(a *Array) *Array {
out := New("SQRT", a)
out := newArray("SQRT", a)
C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Rsqrt returns element-wise 1/sqrt(a).
func Rsqrt(a *Array) *Array {
out := New("RSQRT", a)
out := newArray("RSQRT", a)
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Reciprocal returns element-wise 1/a.
func Reciprocal(a *Array) *Array {
out := New("RECIPROCAL", a)
out := newArray("RECIPROCAL", a)
C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Square returns element-wise a^2.
func Square(a *Array) *Array {
out := New("SQUARE", a)
out := newArray("SQUARE", a)
C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Power returns element-wise a^b.
func Power(a, b *Array) *Array {
out := New("POWER", a, b)
out := newArray("POWER", a, b)
C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Maximum returns element-wise max(a, b).
func Maximum(a, b *Array) *Array {
out := New("MAX", a, b)
out := newArray("MAX", a, b)
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise min(a, b).
func Minimum(a, b *Array) *Array {
out := New("MIN", a, b)
out := newArray("MIN", a, b)
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
@ -140,14 +140,14 @@ func Minimum(a, b *Array) *Array {
// Matmul returns the matrix product of a and b.
func Matmul(a, b *Array) *Array {
out := New("MATMUL", a, b)
out := newArray("MATMUL", a, b)
C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// QuantizedMatmul performs quantized matrix multiplication.
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
out := New("QMATMUL", x, w, scales, biases)
out := newArray("QMATMUL", x, w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("affine")
@ -164,7 +164,7 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
// Softmax returns softmax along the last axis.
func Softmax(a *Array) *Array {
out := New("SOFTMAX", a)
out := newArray("SOFTMAX", a)
axis := []C.int{C.int(-1)}
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
return out
@ -172,21 +172,21 @@ func Softmax(a *Array) *Array {
// Argmax returns the index of the maximum value along an axis.
func Argmax(a *Array, axis int, keepDims bool) *Array {
out := New("ARGMAX", a)
out := newArray("ARGMAX", a)
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// TopK returns the top k values along the last axis.
func TopK(a *Array, k int) *Array {
out := New("TOPK", a)
out := newArray("TOPK", a)
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
return out
}
// Sum reduces by summation along the given axis.
func Sum(a *Array, axis int, keepDims bool) *Array {
out := New("SUM", a)
out := newArray("SUM", a)
axes := []C.int{C.int(axis)}
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
return out
@ -194,7 +194,7 @@ func Sum(a *Array, axis int, keepDims bool) *Array {
// Mean reduces by averaging along the given axis.
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN", a)
out := newArray("MEAN", a)
axes := []C.int{C.int(axis)}
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
return out
@ -204,7 +204,7 @@ func Mean(a *Array, axis int, keepDims bool) *Array {
// Reshape changes the shape of an array.
func Reshape(a *Array, shape ...int32) *Array {
out := New("RESHAPE", a)
out := newArray("RESHAPE", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
@ -215,7 +215,7 @@ func Reshape(a *Array, shape ...int32) *Array {
// Transpose permutes dimensions. If no axes given, reverses all dims.
func Transpose(a *Array, axes ...int) *Array {
out := New("TRANSPOSE", a)
out := newArray("TRANSPOSE", a)
if len(axes) == 0 {
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
} else {
@ -230,14 +230,14 @@ func Transpose(a *Array, axes ...int) *Array {
// ExpandDims inserts a new axis at the given position.
func ExpandDims(a *Array, axis int) *Array {
out := New("EXPAND_DIMS", a)
out := newArray("EXPAND_DIMS", a)
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Squeeze removes dimensions of size 1.
func Squeeze(a *Array, axes ...int) *Array {
out := New("SQUEEZE", a)
out := newArray("SQUEEZE", a)
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
@ -257,14 +257,14 @@ func Concatenate(arrays []*Array, axis int) *Array {
inputs[i] = a
}
out := New("CONCAT", inputs...)
out := newArray("CONCAT", inputs...)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
// BroadcastTo broadcasts an array to the given shape.
func BroadcastTo(a *Array, shape []int32) *Array {
out := New("BROADCAST", a)
out := newArray("BROADCAST", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
@ -275,14 +275,14 @@ func BroadcastTo(a *Array, shape []int32) *Array {
// AsType casts an array to a different dtype.
func AsType(a *Array, dtype DType) *Array {
out := New("ASTYPE", a)
out := newArray("ASTYPE", a)
C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
// AsStrided creates a view with custom strides.
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
out := New("AS_STRIDED", a)
out := newArray("AS_STRIDED", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
@ -297,28 +297,28 @@ func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
// Take gathers elements from a along axis using indices.
func Take(a, indices *Array, axis int) *Array {
out := New("TAKE", a, indices)
out := newArray("TAKE", a, indices)
C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Where selects elements from a or b based on condition.
func Where(condition, a, b *Array) *Array {
out := New("WHERE", condition, a, b)
out := newArray("WHERE", condition, a, b)
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Argpartition partially sorts and returns indices for top-k selection.
func Argpartition(a *Array, kth, axis int) *Array {
out := New("ARGPARTITION", a)
out := newArray("ARGPARTITION", a)
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
// Dequantize restores a quantized array to full precision.
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
out := New("DEQUANTIZE", w, scales, biases)
out := newArray("DEQUANTIZE", w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("affine")
@ -330,7 +330,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
// PutAlongAxis places values into array at indices along axis.
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", a, indices, values)
out := newArray("PUT_ALONG_AXIS", a, indices, values)
// Use scatter approach: src[indices] = values
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
@ -339,7 +339,7 @@ func PutAlongAxis(a, indices, values *Array, axis int) *Array {
// TakeAlongAxis gathers elements from a along axis using indices.
// Unlike Take, this uses the same number of dimensions for indices and input.
func TakeAlongAxis(a, indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", a, indices)
out := newArray("TAKE_ALONG_AXIS", a, indices)
C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@ -347,7 +347,7 @@ func TakeAlongAxis(a, indices *Array, axis int) *Array {
// LogSumExp computes log(sum(exp(a))) along the given axis.
// Numerically stable reduction for cross-entropy loss.
func LogSumExp(a *Array, axis int, keepDims bool) *Array {
out := New("LOGSUMEXP", a)
out := newArray("LOGSUMEXP", a)
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}
@ -355,35 +355,35 @@ func LogSumExp(a *Array, axis int, keepDims bool) *Array {
// CumSum returns the cumulative sum along the given axis.
// reverse=false for forward, inclusive=true to include the current element.
func CumSum(a *Array, axis int, reverse, inclusive bool) *Array {
out := New("CUMSUM", a)
out := newArray("CUMSUM", a)
C.mlx_cumsum(&out.ctx, a.ctx, C.int(axis), C._Bool(reverse), C._Bool(inclusive), DefaultStream().ctx)
return out
}
// Sort returns the array sorted along the given axis.
func Sort(a *Array, axis int) *Array {
out := New("SORT", a)
out := newArray("SORT", a)
C.mlx_sort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Argsort returns the indices that would sort the array along the given axis.
func Argsort(a *Array, axis int) *Array {
out := New("ARGSORT", a)
out := newArray("ARGSORT", a)
C.mlx_argsort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Greater returns element-wise a > b as a bool array.
func Greater(a, b *Array) *Array {
out := New("GREATER", a, b)
out := newArray("GREATER", a, b)
C.mlx_greater(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// MaxAxis returns the maximum value along the given axis.
func MaxAxis(a *Array, axis int, keepDims bool) *Array {
out := New("MAX_AXIS", a)
out := newArray("MAX_AXIS", a)
C.mlx_max_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}

View file

@ -75,7 +75,9 @@ func parseQwen3Config(data []byte) (*Qwen3Config, error) {
var wrapper struct {
Quantization *QuantizationConfig `json:"quantization"`
}
json.Unmarshal(data, &wrapper)
if err := json.Unmarshal(data, &wrapper); err != nil {
return nil, fmt.Errorf("qwen3: parse quantization: %w", err)
}
cfg.Quantization = wrapper.Quantization
// Compute scale

View file

@ -10,7 +10,7 @@ import "C"
// RandomCategorical samples from a categorical distribution defined by logprobs.
// Returns indices sampled according to the log-probability distribution along the last axis.
func RandomCategorical(logprobs *Array) *Array {
out := New("RANDOM_CATEGORICAL", logprobs)
out := newArray("RANDOM_CATEGORICAL", logprobs)
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_categorical(
@ -25,7 +25,7 @@ func RandomCategorical(logprobs *Array) *Array {
// RandomUniform generates uniform random values in [low, high).
func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
out := New("RANDOM_UNIFORM")
out := newArray("RANDOM_UNIFORM")
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)

View file

@ -10,7 +10,7 @@ import "C"
// Slice extracts a sub-array using start and end indices for each dimension.
// starts and ends must have the same length as the array's dimensions.
func Slice(a *Array, starts, ends []int32) *Array {
out := New("SLICE", a)
out := newArray("SLICE", a)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
@ -47,7 +47,7 @@ func SliceAxis(a *Array, axis int, start, end int32) *Array {
// SliceUpdateInplace updates a slice of the array in-place.
// This is critical for KV cache updates.
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
out := New("SLICE_UPDATE", a, update)
out := newArray("SLICE_UPDATE", a, update)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {