fix(metal): address 4 minor code review items
- Rename New() → newArray() to signal internal-only intent (112 usages) - Remove unused Collect() function and its test - Fix discarded json.Unmarshal error in qwen3.go - Document AsStrided stride formula in gemma3.go Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
fb95cde30c
commit
443347a2f8
16 changed files with 89 additions and 108 deletions
|
|
@ -281,7 +281,7 @@ m, _ := inference.LoadModel(path)
|
|||
for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(128)) { ... }
|
||||
```
|
||||
|
||||
### New go-inference features available
|
||||
### newArray go-inference features available
|
||||
|
||||
- `inference.List()` — returns all registered backend names
|
||||
- `inference.Backend.Available()` — hardware availability check
|
||||
|
|
@ -321,7 +321,7 @@ Was a stub that passed logits through unchanged. Now fully implemented:
|
|||
|
||||
Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses MaxAxis to find the peak probability per position.
|
||||
|
||||
### New Bindings
|
||||
### newArray Bindings
|
||||
|
||||
| Function | Header | Purpose |
|
||||
|----------|--------|---------|
|
||||
|
|
@ -424,7 +424,7 @@ Was a stub. Now masks tokens whose probability is below `min_p * max_prob`. Uses
|
|||
|
||||
The old `checkError()` function logged errors via `slog.Error` and swallowed them. The mlx-c error handler (`mlx_go_error_handler`) stored the error string in a C static variable, but no Go code read it back as an error value.
|
||||
|
||||
### New error model
|
||||
### newArray error model
|
||||
|
||||
1. **`lastError() error`** — reads and clears the C-level error string. Returns `fmt.Errorf("mlx: %s", msg)` or nil. All callers now get real MLX error messages instead of generic "failed" strings.
|
||||
|
||||
|
|
@ -453,4 +453,4 @@ The old `checkError()` function logged errors via `slog.Error` and swallowed the
|
|||
### Test results
|
||||
|
||||
- 180 tests passing (176 existing + 4 new error handling tests)
|
||||
- New tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file
|
||||
- newArray tests: Eval success, Eval nil safety, lastError no-error, LoadAllSafetensors missing file
|
||||
|
|
|
|||
12
TODO.md
12
TODO.md
|
|
@ -87,7 +87,7 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
|
|||
|
||||
- [x] **`-mmacosx-version-min=26.0` is wrong** — ✅ Changed to `13.3` (MLX's own minimum). No longer locks out macOS 14/15 users.
|
||||
|
||||
- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). New test: `TestNewCaches_ContextLen`.
|
||||
- [x] **`LoadOption` is ignored in `metalBackend.LoadModel()`** — ✅ Now calls `inference.ApplyLoadOpts()`. `ContextLen` passed through to `metal.LoadConfig` → stored on `Model` → replaces unbounded `KVCache` with `RotatingKVCache(contextLen)` in generate loop. `GPULayers=0` logs a warning (Metal always uses full GPU offload). newArray test: `TestNewCaches_ContextLen`.
|
||||
|
||||
### Important — Should Fix
|
||||
|
||||
|
|
@ -103,13 +103,13 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
|
|||
|
||||
### Minor — Nice to Have
|
||||
|
||||
- [ ] **Rename `New()` → `newArray()`** — `array.go:32`: `New("INPUT")` is exported but only used internally as a pre-allocation pattern before C fills in `ctx`. Since this is `internal/metal`, exporting is harmless, but `newArray` better signals intent.
|
||||
- [x] **Rename `New()` → `newArray()`** — ✅ Renamed via IDE refactoring (112 usages updated). Unexported, signals internal-only intent.
|
||||
|
||||
- [ ] **`Collect()` is unused** — `metal.go:125-133` defines `Collect()` to gather valid arrays for batch Materialize, but nothing in the codebase calls it. Dead code.
|
||||
- [x] **`Collect()` is unused** — ✅ Removed function and its test. Dead code eliminated.
|
||||
|
||||
- [ ] **`qwen3.go` / `gemma3.go` — second `json.Unmarshal` error discarded** — Both model configs parse quantization with a second unmarshal whose error is silently dropped. Probably fine (same data already parsed successfully), but inconsistent with the first unmarshal which returns errors.
|
||||
- [x] **`qwen3.go` — second `json.Unmarshal` error discarded** — ✅ Now checks and returns the error. gemma3.go already handled it correctly.
|
||||
|
||||
- [ ] **Document `AsStrided` stride formula** — `gemma3.go:354-359` uses stride manipulation for `[B,L,H*D]` → `[B,H,L,D]` virtual transpose. The formula is correct but non-obvious. A one-line comment explaining the stride derivation prevents future confusion.
|
||||
- [x] **Document `AsStrided` stride formula** — ✅ Added comment explaining the stride derivation for the `[B,L,H*D]` → `[B,H,L,D]` virtual transpose.
|
||||
|
||||
### Questions for You to Consider
|
||||
|
||||
|
|
@ -124,4 +124,4 @@ Full codebase review after Phase 4 completion + go-inference integration. Groupe
|
|||
1. Virgil in core/go writes tasks here after research
|
||||
2. This repo's session picks up tasks in phase order
|
||||
3. Mark `[x]` when done, note commit hash
|
||||
4. New discoveries → add tasks, flag in FINDINGS.md
|
||||
4. newArray discoveries → add tasks, flag in FINDINGS.md
|
||||
|
|
|
|||
|
|
@ -42,5 +42,5 @@ Tasks for the CLion Claude session. Written by GoLand Claude or Virgil.
|
|||
|
||||
1. GoLand Claude or Virgil writes tasks here
|
||||
2. Pick up in order, mark `[x]` when done
|
||||
3. New findings → `cpp/FINDINGS.md`
|
||||
3. newArray findings → `cpp/FINDINGS.md`
|
||||
4. If Go changes needed → note in FINDINGS.md for GoLand Claude
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ go-mlx/
|
|||
│ ├── lora.go LoRA adapters
|
||||
│ ├── optim.go AdamW
|
||||
│ ├── generate.go NEW: autoregressive generation loop
|
||||
│ └── backend.go Implements mlx.TextModel, exports New()
|
||||
│ └── backend.go Implements mlx.TextModel, exports newArray()
|
||||
│
|
||||
├── mlxlm/ Future: Python subprocess backend
|
||||
│ └── backend.go Implements mlx.Backend via core/go/pkg/process
|
||||
|
|
@ -174,7 +174,7 @@ func Get(name string) (Backend, bool)
|
|||
func Default() (Backend, error)
|
||||
|
||||
// register_metal.go (//go:build darwin && arm64)
|
||||
func init() { Register(metal.New()) }
|
||||
func init() { Register(metal.newArray()) }
|
||||
func MetalAvailable() bool { return true }
|
||||
```
|
||||
|
||||
|
|
@ -294,9 +294,9 @@ Thin subprocess wrapper using `core/go/pkg/process`. Implements `mlx.Backend`. R
|
|||
## Testing
|
||||
|
||||
- All 148 existing tests move into `internal/metal/` and must pass
|
||||
- New `generate_test.go` — autoregressive loop with Gemma3-1B from `/Volumes/Data/lem/safetensors/`
|
||||
- New `backend_test.go` — end-to-end LoadModel + Generate
|
||||
- New `memory_test.go` — 1000-token generation, assert GetPeakMemory() bounded
|
||||
- newArray `generate_test.go` — autoregressive loop with Gemma3-1B from `/Volumes/Data/lem/safetensors/`
|
||||
- newArray `backend_test.go` — end-to-end LoadModel + Generate
|
||||
- newArray `memory_test.go` — 1000-token generation, assert GetPeakMemory() bounded
|
||||
- Root `mlx_test.go` — integration via public mlx.LoadModel() API
|
||||
|
||||
## Communication
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ Copy core CGO bridge from `mlx.go` to `internal/metal/metal.go`. Change:
|
|||
|
||||
Copy `array.go` to `internal/metal/array.go`. Change:
|
||||
- `package mlx` → `package metal`
|
||||
- All types (`Array`, `New`, `FromValue`, `FromValues`, `Zeros`, `Free`) stay exported within the internal package
|
||||
- All types (`Array`, `newArray`, `FromValue`, `FromValues`, `Zeros`, `Free`) stay exported within the internal package
|
||||
|
||||
**Step 5: Move stream.go → internal/metal/stream.go**
|
||||
|
||||
|
|
@ -456,7 +456,7 @@ Check for conflicts when flattening:
|
|||
- `tokenizer.Load` → rename to `loadTokenizer`
|
||||
- `tokenizer.Tokenizer` struct → stays `Tokenizer`
|
||||
- `sample.Sampler` interface → stays `Sampler`
|
||||
- `sample.New` → rename to `newSampler`
|
||||
- `sample.newArray` → rename to `newSampler`
|
||||
- `cache.Cache` interface → stays `Cache`
|
||||
|
||||
**Step 3: Delete old sub-package directories**
|
||||
|
|
|
|||
|
|
@ -26,10 +26,10 @@ type Array struct {
|
|||
name string // debug label
|
||||
}
|
||||
|
||||
// New creates a named Array and registers a GC finalizer.
|
||||
// newArray creates a named Array and registers a GC finalizer.
|
||||
// The inputs parameter is accepted for API compatibility but not stored —
|
||||
// MLX-C tracks inter-array references via its own refcounting.
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
func newArray(name string, inputs ...*Array) *Array {
|
||||
t := &Array{name: name}
|
||||
runtime.SetFinalizer(t, finalizeArray)
|
||||
return t
|
||||
|
|
@ -50,7 +50,7 @@ type scalarTypes interface {
|
|||
// FromValue creates a scalar Array from a Go value.
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
Init()
|
||||
tt := New("")
|
||||
tt := newArray("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
|
|
@ -122,7 +122,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt := newArray("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
|
@ -134,7 +134,7 @@ func Zeros(shape []int32, dtype DType) *Array {
|
|||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
tt := New("ZEROS")
|
||||
tt := newArray("ZEROS")
|
||||
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return tt
|
||||
}
|
||||
|
|
@ -146,7 +146,7 @@ func (t *Array) Set(other *Array) {
|
|||
|
||||
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.name)
|
||||
tt := newArray(t.name)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
|
@ -231,7 +231,7 @@ func (t Array) IsRowContiguous() bool {
|
|||
// Contiguous returns a row-major contiguous copy of the array.
|
||||
// If the array is already row-contiguous, this is a no-op.
|
||||
func Contiguous(a *Array) *Array {
|
||||
out := New("CONTIGUOUS", a)
|
||||
out := newArray("CONTIGUOUS", a)
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C._Bool(false), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -452,15 +452,3 @@ func TestArray_Float_DTypeFloat64(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Collect ---
|
||||
|
||||
func TestCollect_FiltersNilAndInvalid(t *testing.T) {
|
||||
a := FromValue(float32(1.0))
|
||||
b := FromValue(float32(2.0))
|
||||
Materialize(a, b)
|
||||
|
||||
result := Collect(a, nil, b)
|
||||
if len(result) != 2 {
|
||||
t.Errorf("Collect returned %d arrays, want 2", len(result))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,21 +12,21 @@ import "unsafe"
|
|||
|
||||
// RMSNorm applies Root Mean Square normalization using a fused Metal kernel.
|
||||
func RMSNorm(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
out := newArray("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// LayerNorm applies Layer normalization using a fused Metal kernel.
|
||||
func LayerNorm(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
out := newArray("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
|
||||
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
|
||||
out := New("FAST_ROPE", x)
|
||||
out := newArray("FAST_ROPE", x)
|
||||
freqs := C.mlx_array_new()
|
||||
defer C.mlx_array_free(freqs)
|
||||
C.mlx_fast_rope(
|
||||
|
|
@ -60,7 +60,7 @@ func ScaledDotProductAttention(query, key, value *Array, scale float32, causal b
|
|||
sinksArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(sinksArr)
|
||||
|
||||
out := New("FAST_SDPA", query, key, value)
|
||||
out := newArray("FAST_SDPA", query, key, value)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -73,7 +73,7 @@ func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale flo
|
|||
sinksArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(sinksArr)
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask)
|
||||
out := newArray("FAST_SDPA", query, key, value, mask)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -350,7 +350,9 @@ func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, cfg *
|
|||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
// Virtual transpose [B,L,H*D] → [B,H,L,D] via stride manipulation.
|
||||
// Strides: batch = L*H*D (full sequence), head = D (adjacent heads in memory),
|
||||
// seq = H*D (jump one full row of heads), elem = 1 (contiguous within head).
|
||||
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload
|
|||
nInputs := int(C.mlx_vector_array_size(inputs))
|
||||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
a := New("GRAD_INPUT")
|
||||
a := newArray("GRAD_INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||
goInputs[i] = a
|
||||
}
|
||||
|
|
@ -325,7 +325,7 @@ func vectorToArrays(vec C.mlx_vector_array) []*Array {
|
|||
n := int(C.mlx_vector_array_size(vec))
|
||||
out := make([]*Array, n)
|
||||
for i := 0; i < n; i++ {
|
||||
a := New("VEC_OUT")
|
||||
a := newArray("VEC_OUT")
|
||||
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
|
||||
out[i] = a
|
||||
}
|
||||
|
|
@ -334,7 +334,7 @@ func vectorToArrays(vec C.mlx_vector_array) []*Array {
|
|||
|
||||
// Log returns element-wise natural logarithm.
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG", a)
|
||||
out := newArray("LOG", a)
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -353,7 +353,7 @@ func MeanAll(a *Array) *Array {
|
|||
|
||||
// OnesLike creates an array of ones with the same shape and type as the input.
|
||||
func OnesLike(a *Array) *Array {
|
||||
out := New("ONES_LIKE", a)
|
||||
out := newArray("ONES_LIKE", a)
|
||||
C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -109,8 +109,8 @@ func (l *LoRALinear) ParamCount() int {
|
|||
|
||||
// LoRAConfig specifies which layers to apply LoRA to and with what parameters.
|
||||
type LoRAConfig struct {
|
||||
Rank int // Decomposition rank (default 8)
|
||||
Alpha float32 // Scaling factor (default 16)
|
||||
Rank int // Decomposition rank (default 8)
|
||||
Alpha float32 // Scaling factor (default 16)
|
||||
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
|
||||
}
|
||||
|
||||
|
|
@ -187,7 +187,7 @@ func (a *LoRAAdapter) Save(path string) error {
|
|||
// RandomNormal generates normal (Gaussian) random values with given mean and stddev.
|
||||
func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
|
||||
Init()
|
||||
out := New("RANDOM_NORMAL")
|
||||
out := newArray("RANDOM_NORMAL")
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
|
|
|
|||
|
|
@ -120,17 +120,6 @@ func MaterializeAsync(outputs ...*Array) {
|
|||
}
|
||||
}
|
||||
|
||||
// Collect gathers all valid arrays from a variadic list for batch Materialize.
|
||||
func Collect(arrays ...*Array) []*Array {
|
||||
var out []*Array
|
||||
for _, a := range arrays {
|
||||
if a != nil && a.Valid() {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MetalAvailable reports whether Metal GPU is available.
|
||||
func MetalAvailable() bool {
|
||||
Init()
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import "unsafe"
|
|||
|
||||
// Add returns element-wise a + b.
|
||||
func Add(a, b *Array) *Array {
|
||||
out := New("ADD", a, b)
|
||||
out := newArray("ADD", a, b)
|
||||
C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -27,7 +27,7 @@ func AddScalar(a *Array, s float32) *Array {
|
|||
|
||||
// Mul returns element-wise a * b.
|
||||
func Mul(a, b *Array) *Array {
|
||||
out := New("MUL", a, b)
|
||||
out := newArray("MUL", a, b)
|
||||
C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -40,21 +40,21 @@ func MulScalar(a *Array, s float32) *Array {
|
|||
|
||||
// Divide returns element-wise a / b.
|
||||
func Divide(a, b *Array) *Array {
|
||||
out := New("DIV", a, b)
|
||||
out := newArray("DIV", a, b)
|
||||
C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Subtract returns element-wise a - b.
|
||||
func Subtract(a, b *Array) *Array {
|
||||
out := New("SUB", a, b)
|
||||
out := newArray("SUB", a, b)
|
||||
C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Negative returns element-wise -a.
|
||||
func Negative(a *Array) *Array {
|
||||
out := New("NEG", a)
|
||||
out := newArray("NEG", a)
|
||||
C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -63,14 +63,14 @@ func Negative(a *Array) *Array {
|
|||
|
||||
// Exp returns element-wise exp(a).
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP", a)
|
||||
out := newArray("EXP", a)
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sigmoid returns element-wise 1/(1+exp(-a)).
|
||||
func Sigmoid(a *Array) *Array {
|
||||
out := New("SIGMOID", a)
|
||||
out := newArray("SIGMOID", a)
|
||||
C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -82,56 +82,56 @@ func SiLU(a *Array) *Array {
|
|||
|
||||
// Tanh returns element-wise tanh(a).
|
||||
func Tanh(a *Array) *Array {
|
||||
out := New("TANH", a)
|
||||
out := newArray("TANH", a)
|
||||
C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sqrt returns element-wise sqrt(a).
|
||||
func Sqrt(a *Array) *Array {
|
||||
out := New("SQRT", a)
|
||||
out := newArray("SQRT", a)
|
||||
C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Rsqrt returns element-wise 1/sqrt(a).
|
||||
func Rsqrt(a *Array) *Array {
|
||||
out := New("RSQRT", a)
|
||||
out := newArray("RSQRT", a)
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Reciprocal returns element-wise 1/a.
|
||||
func Reciprocal(a *Array) *Array {
|
||||
out := New("RECIPROCAL", a)
|
||||
out := newArray("RECIPROCAL", a)
|
||||
C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Square returns element-wise a^2.
|
||||
func Square(a *Array) *Array {
|
||||
out := New("SQUARE", a)
|
||||
out := newArray("SQUARE", a)
|
||||
C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Power returns element-wise a^b.
|
||||
func Power(a, b *Array) *Array {
|
||||
out := New("POWER", a, b)
|
||||
out := newArray("POWER", a, b)
|
||||
C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Maximum returns element-wise max(a, b).
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAX", a, b)
|
||||
out := newArray("MAX", a, b)
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise min(a, b).
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MIN", a, b)
|
||||
out := newArray("MIN", a, b)
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -140,14 +140,14 @@ func Minimum(a, b *Array) *Array {
|
|||
|
||||
// Matmul returns the matrix product of a and b.
|
||||
func Matmul(a, b *Array) *Array {
|
||||
out := New("MATMUL", a, b)
|
||||
out := newArray("MATMUL", a, b)
|
||||
C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// QuantizedMatmul performs quantized matrix multiplication.
|
||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
|
||||
out := New("QMATMUL", x, w, scales, biases)
|
||||
out := newArray("QMATMUL", x, w, scales, biases)
|
||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||
mode := C.CString("affine")
|
||||
|
|
@ -164,7 +164,7 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
|
|||
|
||||
// Softmax returns softmax along the last axis.
|
||||
func Softmax(a *Array) *Array {
|
||||
out := New("SOFTMAX", a)
|
||||
out := newArray("SOFTMAX", a)
|
||||
axis := []C.int{C.int(-1)}
|
||||
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
|
||||
return out
|
||||
|
|
@ -172,21 +172,21 @@ func Softmax(a *Array) *Array {
|
|||
|
||||
// Argmax returns the index of the maximum value along an axis.
|
||||
func Argmax(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", a)
|
||||
out := newArray("ARGMAX", a)
|
||||
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// TopK returns the top k values along the last axis.
|
||||
func TopK(a *Array, k int) *Array {
|
||||
out := New("TOPK", a)
|
||||
out := newArray("TOPK", a)
|
||||
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sum reduces by summation along the given axis.
|
||||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("SUM", a)
|
||||
out := newArray("SUM", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
|
|
@ -194,7 +194,7 @@ func Sum(a *Array, axis int, keepDims bool) *Array {
|
|||
|
||||
// Mean reduces by averaging along the given axis.
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN", a)
|
||||
out := newArray("MEAN", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
|
|
@ -204,7 +204,7 @@ func Mean(a *Array, axis int, keepDims bool) *Array {
|
|||
|
||||
// Reshape changes the shape of an array.
|
||||
func Reshape(a *Array, shape ...int32) *Array {
|
||||
out := New("RESHAPE", a)
|
||||
out := newArray("RESHAPE", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
|
|
@ -215,7 +215,7 @@ func Reshape(a *Array, shape ...int32) *Array {
|
|||
|
||||
// Transpose permutes dimensions. If no axes given, reverses all dims.
|
||||
func Transpose(a *Array, axes ...int) *Array {
|
||||
out := New("TRANSPOSE", a)
|
||||
out := newArray("TRANSPOSE", a)
|
||||
if len(axes) == 0 {
|
||||
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
} else {
|
||||
|
|
@ -230,14 +230,14 @@ func Transpose(a *Array, axes ...int) *Array {
|
|||
|
||||
// ExpandDims inserts a new axis at the given position.
|
||||
func ExpandDims(a *Array, axis int) *Array {
|
||||
out := New("EXPAND_DIMS", a)
|
||||
out := newArray("EXPAND_DIMS", a)
|
||||
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Squeeze removes dimensions of size 1.
|
||||
func Squeeze(a *Array, axes ...int) *Array {
|
||||
out := New("SQUEEZE", a)
|
||||
out := newArray("SQUEEZE", a)
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, ax := range axes {
|
||||
cAxes[i] = C.int(ax)
|
||||
|
|
@ -257,14 +257,14 @@ func Concatenate(arrays []*Array, axis int) *Array {
|
|||
inputs[i] = a
|
||||
}
|
||||
|
||||
out := New("CONCAT", inputs...)
|
||||
out := newArray("CONCAT", inputs...)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// BroadcastTo broadcasts an array to the given shape.
|
||||
func BroadcastTo(a *Array, shape []int32) *Array {
|
||||
out := New("BROADCAST", a)
|
||||
out := newArray("BROADCAST", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
|
|
@ -275,14 +275,14 @@ func BroadcastTo(a *Array, shape []int32) *Array {
|
|||
|
||||
// AsType casts an array to a different dtype.
|
||||
func AsType(a *Array, dtype DType) *Array {
|
||||
out := New("ASTYPE", a)
|
||||
out := newArray("ASTYPE", a)
|
||||
C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// AsStrided creates a view with custom strides.
|
||||
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
||||
out := New("AS_STRIDED", a)
|
||||
out := newArray("AS_STRIDED", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
|
|
@ -297,28 +297,28 @@ func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
|||
|
||||
// Take gathers elements from a along axis using indices.
|
||||
func Take(a, indices *Array, axis int) *Array {
|
||||
out := New("TAKE", a, indices)
|
||||
out := newArray("TAKE", a, indices)
|
||||
C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Where selects elements from a or b based on condition.
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE", condition, a, b)
|
||||
out := newArray("WHERE", condition, a, b)
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argpartition partially sorts and returns indices for top-k selection.
|
||||
func Argpartition(a *Array, kth, axis int) *Array {
|
||||
out := New("ARGPARTITION", a)
|
||||
out := newArray("ARGPARTITION", a)
|
||||
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Dequantize restores a quantized array to full precision.
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
|
||||
out := New("DEQUANTIZE", w, scales, biases)
|
||||
out := newArray("DEQUANTIZE", w, scales, biases)
|
||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||
mode := C.CString("affine")
|
||||
|
|
@ -330,7 +330,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
|
|||
|
||||
// PutAlongAxis places values into array at indices along axis.
|
||||
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", a, indices, values)
|
||||
out := newArray("PUT_ALONG_AXIS", a, indices, values)
|
||||
// Use scatter approach: src[indices] = values
|
||||
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
|
|
@ -339,7 +339,7 @@ func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
|||
// TakeAlongAxis gathers elements from a along axis using indices.
|
||||
// Unlike Take, this uses the same number of dimensions for indices and input.
|
||||
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS", a, indices)
|
||||
out := newArray("TAKE_ALONG_AXIS", a, indices)
|
||||
C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -347,7 +347,7 @@ func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
|||
// LogSumExp computes log(sum(exp(a))) along the given axis.
|
||||
// Numerically stable reduction for cross-entropy loss.
|
||||
func LogSumExp(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", a)
|
||||
out := newArray("LOGSUMEXP", a)
|
||||
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -355,35 +355,35 @@ func LogSumExp(a *Array, axis int, keepDims bool) *Array {
|
|||
// CumSum returns the cumulative sum along the given axis.
|
||||
// reverse=false for forward, inclusive=true to include the current element.
|
||||
func CumSum(a *Array, axis int, reverse, inclusive bool) *Array {
|
||||
out := New("CUMSUM", a)
|
||||
out := newArray("CUMSUM", a)
|
||||
C.mlx_cumsum(&out.ctx, a.ctx, C.int(axis), C._Bool(reverse), C._Bool(inclusive), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sort returns the array sorted along the given axis.
|
||||
func Sort(a *Array, axis int) *Array {
|
||||
out := New("SORT", a)
|
||||
out := newArray("SORT", a)
|
||||
C.mlx_sort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argsort returns the indices that would sort the array along the given axis.
|
||||
func Argsort(a *Array, axis int) *Array {
|
||||
out := New("ARGSORT", a)
|
||||
out := newArray("ARGSORT", a)
|
||||
C.mlx_argsort_axis(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Greater returns element-wise a > b as a bool array.
|
||||
func Greater(a, b *Array) *Array {
|
||||
out := New("GREATER", a, b)
|
||||
out := newArray("GREATER", a, b)
|
||||
C.mlx_greater(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// MaxAxis returns the maximum value along the given axis.
|
||||
func MaxAxis(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MAX_AXIS", a)
|
||||
out := newArray("MAX_AXIS", a)
|
||||
C.mlx_max_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -75,7 +75,9 @@ func parseQwen3Config(data []byte) (*Qwen3Config, error) {
|
|||
var wrapper struct {
|
||||
Quantization *QuantizationConfig `json:"quantization"`
|
||||
}
|
||||
json.Unmarshal(data, &wrapper)
|
||||
if err := json.Unmarshal(data, &wrapper); err != nil {
|
||||
return nil, fmt.Errorf("qwen3: parse quantization: %w", err)
|
||||
}
|
||||
cfg.Quantization = wrapper.Quantization
|
||||
|
||||
// Compute scale
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import "C"
|
|||
// RandomCategorical samples from a categorical distribution defined by logprobs.
|
||||
// Returns indices sampled according to the log-probability distribution along the last axis.
|
||||
func RandomCategorical(logprobs *Array) *Array {
|
||||
out := New("RANDOM_CATEGORICAL", logprobs)
|
||||
out := newArray("RANDOM_CATEGORICAL", logprobs)
|
||||
key := C.mlx_array_new()
|
||||
defer C.mlx_array_free(key)
|
||||
C.mlx_random_categorical(
|
||||
|
|
@ -25,7 +25,7 @@ func RandomCategorical(logprobs *Array) *Array {
|
|||
|
||||
// RandomUniform generates uniform random values in [low, high).
|
||||
func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
|
||||
out := New("RANDOM_UNIFORM")
|
||||
out := newArray("RANDOM_UNIFORM")
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import "C"
|
|||
// Slice extracts a sub-array using start and end indices for each dimension.
|
||||
// starts and ends must have the same length as the array's dimensions.
|
||||
func Slice(a *Array, starts, ends []int32) *Array {
|
||||
out := New("SLICE", a)
|
||||
out := newArray("SLICE", a)
|
||||
cStarts := make([]C.int, len(starts))
|
||||
cEnds := make([]C.int, len(ends))
|
||||
for i := range starts {
|
||||
|
|
@ -47,7 +47,7 @@ func SliceAxis(a *Array, axis int, start, end int32) *Array {
|
|||
// SliceUpdateInplace updates a slice of the array in-place.
|
||||
// This is critical for KV cache updates.
|
||||
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
|
||||
out := New("SLICE_UPDATE", a, update)
|
||||
out := newArray("SLICE_UPDATE", a, update)
|
||||
cStarts := make([]C.int, len(starts))
|
||||
cEnds := make([]C.int, len(ends))
|
||||
for i := range starts {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue