Merge pull request 'chore: Go 1.26 modernization' (#2) from chore/go-1.26-modernization into main
All checks were successful
Security Scan / security (push) Successful in 15s
Test / Vet & Build (push) Successful in 38s

This commit is contained in:
Charon 2026-02-24 18:01:47 +00:00
commit c1baeb9254
6 changed files with 17 additions and 20 deletions

View file

@ -3,11 +3,11 @@
package metal
import (
"cmp"
"context"
"fmt"
"math"
"slices"
"sort"
"time"
)
@ -49,8 +49,8 @@ func (m *Model) Classify(ctx context.Context, prompts []string, cfg GenerateConf
for i := range indices {
indices[i] = i
}
sort.Slice(indices, func(a, b int) bool {
return lengths[indices[a]] > lengths[indices[b]]
slices.SortFunc(indices, func(a, b int) int {
return cmp.Compare(lengths[b], lengths[a])
})
maxLen := lengths[indices[0]]
@ -160,8 +160,8 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
for i := range indices {
indices[i] = i
}
sort.Slice(indices, func(a, b int) bool {
return lengths[indices[a]] > lengths[indices[b]]
slices.SortFunc(indices, func(a, b int) int {
return cmp.Compare(lengths[b], lengths[a])
})
N := int32(len(prompts))
@ -208,7 +208,7 @@ func (m *Model) BatchGenerate(ctx context.Context, prompts []string, cfg Generat
maxTokens = 256
}
for step := 0; step < maxTokens; step++ {
for step := range maxTokens {
select {
case <-ctx.Done():
// Return partial results on cancellation.

View file

@ -4,6 +4,7 @@ package metal
import (
"context"
"errors"
"fmt"
"iter"
"slices"
@ -185,7 +186,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
Free(logits)
}()
for i := 0; i < cfg.MaxTokens; i++ {
for i := range cfg.MaxTokens {
select {
case <-ctx.Done():
m.lastErr = ctx.Err()
@ -256,7 +257,7 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
func (m *Model) InspectAttention(ctx context.Context, prompt string) (*AttentionResult, error) {
tokens := m.tokenizer.Encode(prompt)
if len(tokens) == 0 {
return nil, fmt.Errorf("empty prompt after tokenisation")
return nil, errors.New("empty prompt after tokenisation")
}
caches := m.newCaches()
@ -304,7 +305,7 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
keys[i] = make([][]float32, numHeads)
stride := validLen * headDim
for h := 0; h < numHeads; h++ {
for h := range numHeads {
start := h * stride
end := start + stride
if end > len(flat) {

View file

@ -39,7 +39,7 @@ func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
for i := range nInputs {
a := newArray("GRAD_INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
@ -324,7 +324,7 @@ func MSELoss(predictions, targets *Array) *Array {
func vectorToArrays(vec C.mlx_vector_array) []*Array {
n := int(C.mlx_vector_array_size(vec))
out := make([]*Array, n)
for i := 0; i < n; i++ {
for i := range n {
a := newArray("VEC_OUT")
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
out[i] = a

View file

@ -12,10 +12,11 @@ import (
"encoding/json"
"fmt"
"log/slog"
"maps"
"math"
"os"
"path/filepath"
"sort"
"slices"
"strings"
"unsafe"
)
@ -152,12 +153,7 @@ func (a *LoRAAdapter) TotalParams() int {
// SortedNames returns layer names in deterministic sorted order.
func (a *LoRAAdapter) SortedNames() []string {
names := make([]string, 0, len(a.Layers))
for name := range a.Layers {
names = append(names, name)
}
sort.Strings(names)
return names
return slices.Sorted(maps.Keys(a.Layers))
}
// AllTrainableParams returns all trainable arrays (A and B from every layer),

View file

@ -31,7 +31,7 @@ func SliceAxis(a *Array, axis int, start, end int32) *Array {
ndim := a.NumDims()
starts := make([]int32, ndim)
ends := make([]int32, ndim)
for i := 0; i < ndim; i++ {
for i := range ndim {
starts[i] = 0
ends[i] = int32(a.Dim(i))
}

View file

@ -191,7 +191,7 @@ func (t *Tokenizer) bpeMerge(symbols []string) []string {
// Find the pair with the lowest merge rank.
bestRank := -1
bestIdx := -1
for i := 0; i < len(symbols)-1; i++ {
for i := range len(symbols) - 1 {
key := symbols[i] + " " + symbols[i+1]
if rank, ok := t.mergeRanks[key]; ok {
if bestRank < 0 || rank < bestRank {