diff --git a/FINDINGS.md b/FINDINGS.md index f83ad9e..25d9979 100644 --- a/FINDINGS.md +++ b/FINDINGS.md @@ -281,7 +281,7 @@ m, _ := inference.LoadModel(path) for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(128)) { ... } ``` -### New go-inference features available +### newArray go-inference features available - `inference.List()` — returns all registered backend names - `inference.Backend.Available()` — hardware availability check @@ -321,7 +321,7 @@ Was a stub that passed logits through unchanged. Now fully implemented: Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses MaxAxis to find the peak probability per position. -### New Bindings +### newArray Bindings | Function | Header | Purpose | |----------|--------|---------| @@ -424,7 +424,7 @@ Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses The old `checkError()` function logged errors via `slog.Error` and swallowed them. The mlx-c error handler (`mlx_go_error_handler`) stored the error string in a C static variable, but no Go code read it back as an error value. -### New error model +### newArray error model 1. **`lastError() error`** — reads and clears the C-level error string. Returns `fmt.Errorf("mlx: %s", msg)` or nil. All callers now get real MLX error messages instead of generic "failed" strings. @@ -453,4 +453,4 @@ The old `checkError()` function logged errors via `slog.Error` and swallowed the ### Test results - 180 tests passing (176 existing + 4 new error handling tests) -- New tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file +- newArray tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file diff --git a/TODO.md b/TODO.md index ff5497c..575c1ed 100644 --- a/TODO.md +++ b/TODO.md @@ -87,7 +87,7 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe - [x] **`-mmacosx-version-min=26.0` is wrong** — ✅ Changed to `13.3` (MLX's own minimum). No longer locks out macOS 14/15 users. -- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). New test: `TestNewCaches_ContextLen`. +- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). newArray test: `TestNewCaches_ContextLen`. ### Important — Should Fix @@ -103,13 +103,13 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe ### Minor — Nice to Have -- [ ] **Rename `New()` → `newArray()`** — `array.go:32`: `New("INPUT")` is exported but only used internally as a pre-allocation pattern before C fills in `ctx`. Since this is `internal/metal`, exporting is harmless, but `newArray` better signals intent. +- [x] **Rename `New()` → `newArray()`** — ✅ Renamed via IDE refactoring (112 usages updated). Unexported, signals internal-only intent. -- [ ] **`Collect()` is unused** — `metal.go:125-133` defines `Collect()` to gather valid arrays for batch Materialize, but nothing in the codebase calls it. Dead code. +- [x] **`Collect()` is unused** — ✅ Removed function and its test. Dead code eliminated. -- [ ] **`qwen3.go` / `gemma3.go` — second `json.Unmarshal` error discarded** — Both model configs parse quantization with a second unmarshal whose error is silently dropped. Probably fine (same data already parsed successfully), but inconsistent with the first unmarshal which returns errors. +- [x] **`qwen3.go` — second `json.Unmarshal` error discarded** — ✅ Now checks and returns the error. gemma3.go already handled it correctly. -- [ ] **Document `AsStrided` stride formula** — `gemma3.go:354-359` uses stride manipulation for `[B,L,H*D]` → `[B,H,L,D]` virtual transpose. The formula is correct but non-obvious. A one-line comment explaining the stride derivation prevents future confusion. +- [x] **Document `AsStrided` stride formula** — ✅ Added comment explaining the stride derivation for the `[B,L,H*D]` → `[B,H,L,D]` virtual transpose. ### Questions for You to Consider @@ -124,4 +124,4 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe 1. Virgil in core/go writes tasks here after research 2. This repo's session picks up tasks in phase order 3. Mark `[x]` when done, note commit hash -4. New discoveries → add tasks, flag in FINDINGS.md +4. newArray discoveries → add tasks, flag in FINDINGS.md diff --git a/cpp/TODO.md b/cpp/TODO.md index da73c5e..e9b9924 100644 --- a/cpp/TODO.md +++ b/cpp/TODO.md @@ -42,5 +42,5 @@ Tasks for the CLion Claude session. Written by GoLand Claude or Virgil. 1. GoLand Claude or Virgil writes tasks here 2. Pick up in order, mark `[x]` when done -3. New findings → `cpp/FINDINGS.md` +3. newArray findings → `cpp/FINDINGS.md` 4. If Go changes needed → note in FINDINGS.md for GoLand Claude diff --git a/docs/plans/2026-02-19-backend-abstraction-design.md b/docs/plans/2026-02-19-backend-abstraction-design.md index ddf8f6f..f976372 100644 --- a/docs/plans/2026-02-19-backend-abstraction-design.md +++ b/docs/plans/2026-02-19-backend-abstraction-design.md @@ -60,7 +60,7 @@ go-mlx/ │ ├── lora.go LoRA adapters │ ├── optim.go AdamW │ ├── generate.go NEW: autoregressive generation loop -│ └── backend.go Implements mlx.TextModel, exports New() +│ └── backend.go Implements mlx.TextModel, exports newArray() │ ├── mlxlm/ Future: Python subprocess backend │ └── backend.go Implements mlx.Backend via core/go/pkg/process @@ -174,7 +174,7 @@ func Get(name string) (Backend, bool) func Default() (Backend, error) // register_metal.go (//go:build darwin && arm64) -func init() { Register(metal.New()) } +func init() { Register(metal.newArray()) } func MetalAvailable() bool { return true } ``` @@ -294,9 +294,9 @@ Thin subprocess wrapper using `core/go/pkg/process`. Implements `mlx.Backend`. R ## Testing - All 148 existing tests move into `internal/metal/` and must pass -- New `generate_test.go` — autoregressive loop with Gemma3-1B from `/Volumes/Data/lem/safetensors/` -- New `backend_test.go` — end-to-end LoadModel + Generate -- New `memory_test.go` — 1000-token generation, assert GetPeakMemory() bounded +- newArray `generate_test.go` — autoregressive loop with Gemma3-1B from `/Volumes/Data/lem/safetensors/` +- newArray `backend_test.go` — end-to-end LoadModel + Generate +- newArray `memory_test.go` — 1000-token generation, assert GetPeakMemory() bounded - Root `mlx_test.go` — integration via public mlx.LoadModel() API ## Communication diff --git a/docs/plans/2026-02-19-backend-abstraction-plan.md b/docs/plans/2026-02-19-backend-abstraction-plan.md index a1a8343..281077c 100644 --- a/docs/plans/2026-02-19-backend-abstraction-plan.md +++ b/docs/plans/2026-02-19-backend-abstraction-plan.md @@ -312,7 +312,7 @@ Copy core CGO bridge from `mlx.go` to `internal/metal/metal.go`. Change: Copy `array.go` to `internal/metal/array.go`. Change: - `package mlx` → `package metal` -- All types (`Array`, `New`, `FromValue`, `FromValues`, `Zeros`, `Free`) stay exported within the internal package +- All types (`Array`, `newArray`, `FromValue`, `FromValues`, `Zeros`, `Free`) stay exported within the internal package **Step 5: Move stream.go → internal/metal/stream.go** @@ -456,7 +456,7 @@ Check for conflicts when flattening: - `tokenizer.Load` → rename to `loadTokenizer` - `tokenizer.Tokenizer` struct → stays `Tokenizer` - `sample.Sampler` interface → stays `Sampler` -- `sample.New` → rename to `newSampler` +- `sample.newArray` → rename to `newSampler` - `cache.Cache` interface → stays `Cache` **Step 3: Delete old sub-package directories** diff --git a/internal/metal/array.go b/internal/metal/array.go index f178543..b357fc3 100644 --- a/internal/metal/array.go +++ b/internal/metal/array.go @@ -26,10 +26,10 @@ type Array struct { name string // debug label } -// New creates a named Array and registers a GC finalizer. +// newArray creates a named Array and registers a GC finalizer. // The inputs parameter is accepted for API compatibility but not stored — // MLX-C tracks inter-array references via its own refcounting. -func New(name string, inputs ...*Array) *Array { +func newArray(name string, inputs ...*Array) *Array { t := &Array{name: name} runtime.SetFinalizer(t, finalizeArray) return t @@ -50,7 +50,7 @@ type scalarTypes interface { // FromValue creates a scalar Array from a Go value. func FromValue[T scalarTypes](t T) *Array { Init() - tt := New("") + tt := newArray("") switch v := any(t).(type) { case bool: tt.ctx = C.mlx_array_new_bool(C.bool(v)) @@ -122,7 +122,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array { panic(err) } - tt := New("") + tt := newArray("") tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype)) return tt } @@ -134,7 +134,7 @@ func Zeros(shape []int32, dtype DType) *Array { for i, s := range shape { cShape[i] = C.int(s) } - tt := New("ZEROS") + tt := newArray("ZEROS") C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx) return tt } @@ -146,7 +146,7 @@ func (t *Array) Set(other *Array) { // Clone creates a new Go wrapper sharing the same C handle (increments C refcount). func (t *Array) Clone() *Array { - tt := New(t.name) + tt := newArray(t.name) C.mlx_array_set(&tt.ctx, t.ctx) return tt } @@ -231,7 +231,7 @@ func (t Array) IsRowContiguous() bool { // Contiguous returns a row-major contiguous copy of the array. // If the array is already row-contiguous, this is a no-op. func Contiguous(a *Array) *Array { - out := New("CONTIGUOUS", a) + out := newArray("CONTIGUOUS", a) C.mlx_contiguous(&out.ctx, a.ctx, C._Bool(false), DefaultStream().ctx) return out } diff --git a/internal/metal/array_test.go b/internal/metal/array_test.go index 4871e56..15ee6f6 100644 --- a/internal/metal/array_test.go +++ b/internal/metal/array_test.go @@ -452,15 +452,3 @@ func TestArray_Float_DTypeFloat64(t *testing.T) { } } -// --- Collect --- - -func TestCollect_FiltersNilAndInvalid(t *testing.T) { - a := FromValue(float32(1.0)) - b := FromValue(float32(2.0)) - Materialize(a, b) - - result := Collect(a, nil, b) - if len(result) != 2 { - t.Errorf("Collect returned %d arrays, want 2", len(result)) - } -} diff --git a/internal/metal/fast.go b/internal/metal/fast.go index c2140fd..643d79d 100644 --- a/internal/metal/fast.go +++ b/internal/metal/fast.go @@ -12,21 +12,21 @@ import "unsafe" // RMSNorm applies Root Mean Square normalization using a fused Metal kernel. func RMSNorm(x, weight *Array, eps float32) *Array { - out := New("FAST_RMSNORM", x) + out := newArray("FAST_RMSNORM", x) C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx) return out } // LayerNorm applies Layer normalization using a fused Metal kernel. func LayerNorm(x, weight, bias *Array, eps float32) *Array { - out := New("FAST_LAYERNORM", x) + out := newArray("FAST_LAYERNORM", x) C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx) return out } // RoPE applies Rotary Position Embeddings using a fused Metal kernel. func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array { - out := New("FAST_ROPE", x) + out := newArray("FAST_ROPE", x) freqs := C.mlx_array_new() defer C.mlx_array_free(freqs) C.mlx_fast_rope( @@ -60,7 +60,7 @@ func ScaledDotProductAttention(query, key, value *Array, scale float32, causal b sinksArr := C.mlx_array_new() defer C.mlx_array_free(sinksArr) - out := New("FAST_SDPA", query, key, value) + out := newArray("FAST_SDPA", query, key, value) C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx) return out } @@ -73,7 +73,7 @@ func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale flo sinksArr := C.mlx_array_new() defer C.mlx_array_free(sinksArr) - out := New("FAST_SDPA", query, key, value, mask) + out := newArray("FAST_SDPA", query, key, value, mask) C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx) return out } diff --git a/internal/metal/gemma3.go b/internal/metal/gemma3.go index 9da21fc..e03102d 100644 --- a/internal/metal/gemma3.go +++ b/internal/metal/gemma3.go @@ -350,7 +350,9 @@ func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg * k := a.KProj.Forward(x) v := a.VProj.Forward(x) - // Reshape to [B, num_heads, L, head_dim] + // Virtual transpose [B,L,H*D] → [B,H,L,D] via stride manipulation. + // Strides: batch = L*H*D (full sequence), head = D (adjacent heads in memory), + // seq = H*D (jump one full row of heads), elem = 1 (contiguous within head). q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, diff --git a/internal/metal/grad.go b/internal/metal/grad.go index 0199cce..485a432 100644 --- a/internal/metal/grad.go +++ b/internal/metal/grad.go @@ -40,7 +40,7 @@ func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload nInputs := int(C.mlx_vector_array_size(inputs)) goInputs := make([]*Array, nInputs) for i := 0; i < nInputs; i++ { - a := New("GRAD_INPUT") + a := newArray("GRAD_INPUT") C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) goInputs[i] = a } @@ -325,7 +325,7 @@ func vectorToArrays(vec C.mlx_vector_array) []*Array { n := int(C.mlx_vector_array_size(vec)) out := make([]*Array, n) for i := 0; i < n; i++ { - a := New("VEC_OUT") + a := newArray("VEC_OUT") C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i)) out[i] = a } @@ -334,7 +334,7 @@ func vectorToArrays(vec C.mlx_vector_array) []*Array { // Log returns element-wise natural logarithm. func Log(a *Array) *Array { - out := New("LOG", a) + out := newArray("LOG", a) C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) return out } @@ -353,7 +353,7 @@ func MeanAll(a *Array) *Array { // OnesLike creates an array of ones with the same shape and type as the input. func OnesLike(a *Array) *Array { - out := New("ONES_LIKE", a) + out := newArray("ONES_LIKE", a) C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx) return out } diff --git a/internal/metal/lora.go b/internal/metal/lora.go index 9f649ee..73d4d33 100644 --- a/internal/metal/lora.go +++ b/internal/metal/lora.go @@ -109,8 +109,8 @@ func (l *LoRALinear) ParamCount() int { // LoRAConfig specifies which layers to apply LoRA to and with what parameters. type LoRAConfig struct { - Rank int // Decomposition rank (default 8) - Alpha float32 // Scaling factor (default 16) + Rank int // Decomposition rank (default 8) + Alpha float32 // Scaling factor (default 16) TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj) } @@ -187,7 +187,7 @@ func (a *LoRAAdapter) Save(path string) error { // RandomNormal generates normal (Gaussian) random values with given mean and stddev. func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array { Init() - out := New("RANDOM_NORMAL") + out := newArray("RANDOM_NORMAL") cShape := make([]C.int, len(shape)) for i, s := range shape { cShape[i] = C.int(s) diff --git a/internal/metal/metal.go b/internal/metal/metal.go index 4361814..b24980c 100644 --- a/internal/metal/metal.go +++ b/internal/metal/metal.go @@ -120,17 +120,6 @@ func MaterializeAsync(outputs ...*Array) { } } -// Collect gathers all valid arrays from a variadic list for batch Materialize. -func Collect(arrays ...*Array) []*Array { - var out []*Array - for _, a := range arrays { - if a != nil && a.Valid() { - out = append(out, a) - } - } - return out -} - // MetalAvailable reports whether Metal GPU is available. func MetalAvailable() bool { Init() diff --git a/internal/metal/ops.go b/internal/metal/ops.go index 2bf554c..70470b5 100644 --- a/internal/metal/ops.go +++ b/internal/metal/ops.go @@ -14,7 +14,7 @@ import "unsafe" // Add returns element-wise a + b. func Add(a, b *Array) *Array { - out := New("ADD", a, b) + out := newArray("ADD", a, b) C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } @@ -27,7 +27,7 @@ func AddScalar(a *Array, s float32) *Array { // Mul returns element-wise a * b. func Mul(a, b *Array) *Array { - out := New("MUL", a, b) + out := newArray("MUL", a, b) C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } @@ -40,21 +40,21 @@ func MulScalar(a *Array, s float32) *Array { // Divide returns element-wise a / b. func Divide(a, b *Array) *Array { - out := New("DIV", a, b) + out := newArray("DIV", a, b) C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } // Subtract returns element-wise a - b. func Subtract(a, b *Array) *Array { - out := New("SUB", a, b) + out := newArray("SUB", a, b) C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } // Negative returns element-wise -a. func Negative(a *Array) *Array { - out := New("NEG", a) + out := newArray("NEG", a) C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx) return out } @@ -63,14 +63,14 @@ func Negative(a *Array) *Array { // Exp returns element-wise exp(a). func Exp(a *Array) *Array { - out := New("EXP", a) + out := newArray("EXP", a) C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx) return out } // Sigmoid returns element-wise 1/(1+exp(-a)). func Sigmoid(a *Array) *Array { - out := New("SIGMOID", a) + out := newArray("SIGMOID", a) C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx) return out } @@ -82,56 +82,56 @@ func SiLU(a *Array) *Array { // Tanh returns element-wise tanh(a). func Tanh(a *Array) *Array { - out := New("TANH", a) + out := newArray("TANH", a) C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx) return out } // Sqrt returns element-wise sqrt(a). func Sqrt(a *Array) *Array { - out := New("SQRT", a) + out := newArray("SQRT", a) C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx) return out } // Rsqrt returns element-wise 1/sqrt(a). func Rsqrt(a *Array) *Array { - out := New("RSQRT", a) + out := newArray("RSQRT", a) C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx) return out } // Reciprocal returns element-wise 1/a. func Reciprocal(a *Array) *Array { - out := New("RECIPROCAL", a) + out := newArray("RECIPROCAL", a) C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx) return out } // Square returns element-wise a^2. func Square(a *Array) *Array { - out := New("SQUARE", a) + out := newArray("SQUARE", a) C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx) return out } // Power returns element-wise a^b. func Power(a, b *Array) *Array { - out := New("POWER", a, b) + out := newArray("POWER", a, b) C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } // Maximum returns element-wise max(a, b). func Maximum(a, b *Array) *Array { - out := New("MAX", a, b) + out := newArray("MAX", a, b) C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } // Minimum returns element-wise min(a, b). func Minimum(a, b *Array) *Array { - out := New("MIN", a, b) + out := newArray("MIN", a, b) C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } @@ -140,14 +140,14 @@ func Minimum(a, b *Array) *Array { // Matmul returns the matrix product of a and b. func Matmul(a, b *Array) *Array { - out := New("MATMUL", a, b) + out := newArray("MATMUL", a, b) C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } // QuantizedMatmul performs quantized matrix multiplication. func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array { - out := New("QMATMUL", x, w, scales, biases) + out := newArray("QMATMUL", x, w, scales, biases) gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} mode := C.CString("affine") @@ -164,7 +164,7 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit // Softmax returns softmax along the last axis. func Softmax(a *Array) *Array { - out := New("SOFTMAX", a) + out := newArray("SOFTMAX", a) axis := []C.int{C.int(-1)} C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx) return out @@ -172,21 +172,21 @@ func Softmax(a *Array) *Array { // Argmax returns the index of the maximum value along an axis. func Argmax(a *Array, axis int, keepDims bool) *Array { - out := New("ARGMAX", a) + out := newArray("ARGMAX", a) C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) return out } // TopK returns the top k values along the last axis. func TopK(a *Array, k int) *Array { - out := New("TOPK", a) + out := newArray("TOPK", a) C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx) return out } // Sum reduces by summation along the given axis. func Sum(a *Array, axis int, keepDims bool) *Array { - out := New("SUM", a) + out := newArray("SUM", a) axes := []C.int{C.int(axis)} C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) return out @@ -194,7 +194,7 @@ func Sum(a *Array, axis int, keepDims bool) *Array { // Mean reduces by averaging along the given axis. func Mean(a *Array, axis int, keepDims bool) *Array { - out := New("MEAN", a) + out := newArray("MEAN", a) axes := []C.int{C.int(axis)} C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) return out @@ -204,7 +204,7 @@ func Mean(a *Array, axis int, keepDims bool) *Array { // Reshape changes the shape of an array. func Reshape(a *Array, shape ...int32) *Array { - out := New("RESHAPE", a) + out := newArray("RESHAPE", a) cShape := make([]C.int, len(shape)) for i, s := range shape { cShape[i] = C.int(s) @@ -215,7 +215,7 @@ func Reshape(a *Array, shape ...int32) *Array { // Transpose permutes dimensions. If no axes given, reverses all dims. func Transpose(a *Array, axes ...int) *Array { - out := New("TRANSPOSE", a) + out := newArray("TRANSPOSE", a) if len(axes) == 0 { C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx) } else { @@ -230,14 +230,14 @@ func Transpose(a *Array, axes ...int) *Array { // ExpandDims inserts a new axis at the given position. func ExpandDims(a *Array, axis int) *Array { - out := New("EXPAND_DIMS", a) + out := newArray("EXPAND_DIMS", a) C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) return out } // Squeeze removes dimensions of size 1. func Squeeze(a *Array, axes ...int) *Array { - out := New("SQUEEZE", a) + out := newArray("SQUEEZE", a) cAxes := make([]C.int, len(axes)) for i, ax := range axes { cAxes[i] = C.int(ax) @@ -257,14 +257,14 @@ func Concatenate(arrays []*Array, axis int) *Array { inputs[i] = a } - out := New("CONCAT", inputs...) + out := newArray("CONCAT", inputs...) C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) return out } // BroadcastTo broadcasts an array to the given shape. func BroadcastTo(a *Array, shape []int32) *Array { - out := New("BROADCAST", a) + out := newArray("BROADCAST", a) cShape := make([]C.int, len(shape)) for i, s := range shape { cShape[i] = C.int(s) @@ -275,14 +275,14 @@ func BroadcastTo(a *Array, shape []int32) *Array { // AsType casts an array to a different dtype. func AsType(a *Array, dtype DType) *Array { - out := New("ASTYPE", a) + out := newArray("ASTYPE", a) C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx) return out } // AsStrided creates a view with custom strides. func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array { - out := New("AS_STRIDED", a) + out := newArray("AS_STRIDED", a) cShape := make([]C.int, len(shape)) for i, s := range shape { cShape[i] = C.int(s) @@ -297,28 +297,28 @@ func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array { // Take gathers elements from a along axis using indices. func Take(a, indices *Array, axis int) *Array { - out := New("TAKE", a, indices) + out := newArray("TAKE", a, indices) C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) return out } // Where selects elements from a or b based on condition. func Where(condition, a, b *Array) *Array { - out := New("WHERE", condition, a, b) + out := newArray("WHERE", condition, a, b) C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } // Argpartition partially sorts and returns indices for top-k selection. func Argpartition(a *Array, kth, axis int) *Array { - out := New("ARGPARTITION", a) + out := newArray("ARGPARTITION", a) C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx) return out } // Dequantize restores a quantized array to full precision. func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array { - out := New("DEQUANTIZE", w, scales, biases) + out := newArray("DEQUANTIZE", w, scales, biases) gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} mode := C.CString("affine") @@ -330,7 +330,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array { // PutAlongAxis places values into array at indices along axis. func PutAlongAxis(a, indices, values *Array, axis int) *Array { - out := New("PUT_ALONG_AXIS", a, indices, values) + out := newArray("PUT_ALONG_AXIS", a, indices, values) // Use scatter approach: src[indices] = values C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx) return out @@ -339,7 +339,7 @@ func PutAlongAxis(a, indices, values *Array, axis int) *Array { // TakeAlongAxis gathers elements from a along axis using indices. // Unlike Take, this uses the same number of dimensions for indices and input. func TakeAlongAxis(a, indices *Array, axis int) *Array { - out := New("TAKE_ALONG_AXIS", a, indices) + out := newArray("TAKE_ALONG_AXIS", a, indices) C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) return out } @@ -347,7 +347,7 @@ func TakeAlongAxis(a, indices *Array, axis int) *Array { // LogSumExp computes log(sum(exp(a))) along the given axis. // Numerically stable reduction for cross-entropy loss. func LogSumExp(a *Array, axis int, keepDims bool) *Array { - out := New("LOGSUMEXP", a) + out := newArray("LOGSUMEXP", a) C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) return out } @@ -355,35 +355,35 @@ func LogSumExp(a *Array, axis int, keepDims bool) *Array { // CumSum returns the cumulative sum along the given axis. // reverse=false for forward, inclusive=true to include the current element. func CumSum(a *Array, axis int, reverse, inclusive bool) *Array { - out := New("CUMSUM", a) + out := newArray("CUMSUM", a) C.mlx_cumsum(&out.ctx, a.ctx, C.int(axis), C._Bool(reverse), C._Bool(inclusive), DefaultStream().ctx) return out } // Sort returns the array sorted along the given axis. func Sort(a *Array, axis int) *Array { - out := New("SORT", a) + out := newArray("SORT", a) C.mlx_sort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) return out } // Argsort returns the indices that would sort the array along the given axis. func Argsort(a *Array, axis int) *Array { - out := New("ARGSORT", a) + out := newArray("ARGSORT", a) C.mlx_argsort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) return out } // Greater returns element-wise a > b as a bool array. func Greater(a, b *Array) *Array { - out := New("GREATER", a, b) + out := newArray("GREATER", a, b) C.mlx_greater(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } // MaxAxis returns the maximum value along the given axis. func MaxAxis(a *Array, axis int, keepDims bool) *Array { - out := New("MAX_AXIS", a) + out := newArray("MAX_AXIS", a) C.mlx_max_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) return out } diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go index 33f32c5..70c9ba9 100644 --- a/internal/metal/qwen3.go +++ b/internal/metal/qwen3.go @@ -75,7 +75,9 @@ func parseQwen3Config(data []byte) (*Qwen3Config, error) { var wrapper struct { Quantization *QuantizationConfig `json:"quantization"` } - json.Unmarshal(data, &wrapper) + if err := json.Unmarshal(data, &wrapper); err != nil { + return nil, fmt.Errorf("qwen3: parse quantization: %w", err) + } cfg.Quantization = wrapper.Quantization // Compute scale diff --git a/internal/metal/random.go b/internal/metal/random.go index bd6e90b..a7e792c 100644 --- a/internal/metal/random.go +++ b/internal/metal/random.go @@ -10,7 +10,7 @@ import "C" // RandomCategorical samples from a categorical distribution defined by logprobs. // Returns indices sampled according to the log-probability distribution along the last axis. func RandomCategorical(logprobs *Array) *Array { - out := New("RANDOM_CATEGORICAL", logprobs) + out := newArray("RANDOM_CATEGORICAL", logprobs) key := C.mlx_array_new() defer C.mlx_array_free(key) C.mlx_random_categorical( @@ -25,7 +25,7 @@ func RandomCategorical(logprobs *Array) *Array { // RandomUniform generates uniform random values in [low, high). func RandomUniform(low, high float32, shape []int32, dtype DType) *Array { - out := New("RANDOM_UNIFORM") + out := newArray("RANDOM_UNIFORM") cShape := make([]C.int, len(shape)) for i, s := range shape { cShape[i] = C.int(s) diff --git a/internal/metal/slice.go b/internal/metal/slice.go index 494435f..4e8b4af 100644 --- a/internal/metal/slice.go +++ b/internal/metal/slice.go @@ -10,7 +10,7 @@ import "C" // Slice extracts a sub-array using start and end indices for each dimension. // starts and ends must have the same length as the array's dimensions. func Slice(a *Array, starts, ends []int32) *Array { - out := New("SLICE", a) + out := newArray("SLICE", a) cStarts := make([]C.int, len(starts)) cEnds := make([]C.int, len(ends)) for i := range starts { @@ -47,7 +47,7 @@ func SliceAxis(a *Array, axis int, start, end int32) *Array { // SliceUpdateInplace updates a slice of the array in-place. // This is critical for KV cache updates. func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array { - out := New("SLICE_UPDATE", a, update) + out := newArray("SLICE_UPDATE", a, update) cStarts := make([]C.int, len(starts)) cEnds := make([]C.int, len(ends)) for i := range starts {