diff --git a/pkg/help/search.go b/pkg/help/search.go index 8f1593c..c718a2b 100644 --- a/pkg/help/search.go +++ b/pkg/help/search.go @@ -1,8 +1,9 @@ package help import ( + "cmp" "regexp" - "sort" + "slices" "strings" "unicode" ) @@ -158,11 +159,11 @@ func (i *searchIndex) Search(query string) []*SearchResult { } // Sort by score (highest first) - sort.Slice(results, func(a, b int) bool { - if results[a].Score != results[b].Score { - return results[a].Score > results[b].Score + slices.SortFunc(results, func(a, b *SearchResult) int { + if a.Score != b.Score { + return cmp.Compare(b.Score, a.Score) // descending } - return results[a].Topic.Title < results[b].Topic.Title + return cmp.Compare(a.Topic.Title, b.Topic.Title) }) return results @@ -357,11 +358,11 @@ func highlight(text string, res []*regexp.Regexp) string { } // Sort matches by start position - sort.Slice(matches, func(i, j int) bool { - if matches[i].start != matches[j].start { - return matches[i].start < matches[j].start + slices.SortFunc(matches, func(a, b match) int { + if a.start != b.start { + return cmp.Compare(a.start, b.start) } - return matches[i].end > matches[j].end + return cmp.Compare(b.end, a.end) // descending }) // Merge overlapping or adjacent matches diff --git a/pkg/i18n/internal/validate/main.go b/pkg/i18n/internal/validate/main.go index d295c57..1887489 100644 --- a/pkg/i18n/internal/validate/main.go +++ b/pkg/i18n/internal/validate/main.go @@ -18,6 +18,7 @@ package main import ( + "cmp" "encoding/json" "errors" "fmt" @@ -26,7 +27,7 @@ import ( "go/token" "os" "path/filepath" - "sort" + "slices" "strings" ) @@ -508,11 +509,11 @@ func printReport(result ValidationResult) { fmt.Printf("-------------\n") // Sort by file then line - sort.Slice(result.MissingKeys, func(i, j int) bool { - if result.MissingKeys[i].File != result.MissingKeys[j].File { - return result.MissingKeys[i].File < result.MissingKeys[j].File + slices.SortFunc(result.MissingKeys, func(a, b KeyUsage) int { + if a.File != b.File { + return cmp.Compare(a.File, b.File) } - return result.MissingKeys[i].Line < result.MissingKeys[j].Line + return cmp.Compare(a.Line, b.Line) }) for _, usage := range result.MissingKeys { diff --git a/pkg/io/datanode/client.go b/pkg/io/datanode/client.go index 8f77041..f423b96 100644 --- a/pkg/io/datanode/client.go +++ b/pkg/io/datanode/client.go @@ -7,11 +7,12 @@ package datanode import ( + "cmp" goio "io" "io/fs" "os" "path" - "sort" + "slices" "strings" "sync" "time" @@ -359,8 +360,8 @@ func (m *Medium) List(p string) ([]fs.DirEntry, error) { } } - sort.Slice(entries, func(i, j int) bool { - return entries[i].Name() < entries[j].Name() + slices.SortFunc(entries, func(a, b fs.DirEntry) int { + return cmp.Compare(a.Name(), b.Name()) }) return entries, nil diff --git a/pkg/io/node/node.go b/pkg/io/node/node.go index 184ccc0..69c59bf 100644 --- a/pkg/io/node/node.go +++ b/pkg/io/node/node.go @@ -6,11 +6,12 @@ package node import ( "archive/tar" "bytes" + "cmp" goio "io" "io/fs" "os" "path" - "sort" + "slices" "strings" "time" @@ -335,8 +336,8 @@ func (n *Node) ReadDir(name string) ([]fs.DirEntry, error) { } } - sort.Slice(entries, func(i, j int) bool { - return entries[i].Name() < entries[j].Name() + slices.SortFunc(entries, func(a, b fs.DirEntry) int { + return cmp.Compare(a.Name(), b.Name()) }) return entries, nil diff --git a/pkg/lab/collector/influxdb.go b/pkg/lab/collector/influxdb.go index c5d79aa..950c80c 100644 --- a/pkg/lab/collector/influxdb.go +++ b/pkg/lab/collector/influxdb.go @@ -1,11 +1,12 @@ package collector import ( + "cmp" "context" "encoding/json" "fmt" "net/http" - "sort" + "slices" "strings" "time" @@ -126,8 +127,11 @@ func (i *InfluxDB) Collect(ctx context.Context) error { for _, r := range runSet { data.Runs = append(data.Runs, r) } - sort.Slice(data.Runs, func(i, j int) bool { - return data.Runs[i].Model < data.Runs[j].Model || (data.Runs[i].Model == data.Runs[j].Model && data.Runs[i].RunID < data.Runs[j].RunID) + slices.SortFunc(data.Runs, func(a, b lab.BenchmarkRun) int { + if c := cmp.Compare(a.Model, b.Model); c != 0 { + return c + } + return cmp.Compare(a.RunID, b.RunID) }) i.store.SetBenchmarks(data) diff --git a/pkg/lab/handler/chart.go b/pkg/lab/handler/chart.go index adcfc07..5e179ab 100644 --- a/pkg/lab/handler/chart.go +++ b/pkg/lab/handler/chart.go @@ -1,9 +1,11 @@ package handler import ( + "cmp" "fmt" "html/template" "math" + "slices" "sort" "strings" @@ -118,7 +120,7 @@ func LossChart(points []lab.LossPoint) template.HTML { // Draw train loss line (dimmed). if len(trainPts) > 1 { - sort.Slice(trainPts, func(i, j int) bool { return trainPts[i].Iteration < trainPts[j].Iteration }) + slices.SortFunc(trainPts, func(a, b lab.LossPoint) int { return cmp.Compare(a.Iteration, b.Iteration) }) sb.WriteString(` 0 { @@ -232,7 +234,7 @@ func ContentChart(points []lab.ContentPoint) template.HTML { if !ok || len(pts) < 2 { continue } - sort.Slice(pts, func(i, j int) bool { return pts[i].Iteration < pts[j].Iteration }) + slices.SortFunc(pts, func(a, b lab.ContentPoint) int { return cmp.Compare(a.Iteration, b.Iteration) }) // Average duplicate iterations. averaged := averageByIteration(pts) @@ -285,7 +287,7 @@ func CapabilityChart(points []lab.CapabilityPoint) template.HTML { overall = append(overall, p) } } - sort.Slice(overall, func(i, j int) bool { return overall[i].Iteration < overall[j].Iteration }) + slices.SortFunc(overall, func(a, b lab.CapabilityPoint) int { return cmp.Compare(a.Iteration, b.Iteration) }) if len(overall) == 0 { return template.HTML(`
No overall capability data
`) diff --git a/pkg/lab/handler/web.go b/pkg/lab/handler/web.go index ed3bfc4..146c560 100644 --- a/pkg/lab/handler/web.go +++ b/pkg/lab/handler/web.go @@ -1,11 +1,12 @@ package handler import ( + "cmp" "embed" "fmt" "html/template" "net/http" - "sort" + "slices" "strings" "time" @@ -376,11 +377,14 @@ func buildModelGroups(runs []lab.TrainingRunStatus, benchmarks lab.BenchmarkData } result = append(result, *g) } - sort.Slice(result, func(i, j int) bool { - if result[i].HasTraining != result[j].HasTraining { - return result[i].HasTraining + slices.SortFunc(result, func(a, b ModelGroup) int { + if a.HasTraining != b.HasTraining { + if a.HasTraining { + return -1 + } + return 1 } - return result[i].Model < result[j].Model + return cmp.Compare(a.Model, b.Model) }) return result } diff --git a/pkg/plugin/registry.go b/pkg/plugin/registry.go index b2f0a85..7685a68 100644 --- a/pkg/plugin/registry.go +++ b/pkg/plugin/registry.go @@ -1,9 +1,10 @@ package plugin import ( + "cmp" "encoding/json" "path/filepath" - "sort" + "slices" core "forge.lthn.ai/core/go/pkg/framework/core" "forge.lthn.ai/core/go/pkg/io" @@ -34,8 +35,8 @@ func (r *Registry) List() []*PluginConfig { for _, cfg := range r.plugins { result = append(result, cfg) } - sort.Slice(result, func(i, j int) bool { - return result[i].Name < result[j].Name + slices.SortFunc(result, func(a, b *PluginConfig) int { + return cmp.Compare(a.Name, b.Name) }) return result } diff --git a/pkg/session/parser.go b/pkg/session/parser.go index 4c3dd87..a66bae8 100644 --- a/pkg/session/parser.go +++ b/pkg/session/parser.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "slices" "sort" "strings" "time" @@ -156,8 +157,8 @@ func ListSessions(projectsDir string) ([]Session, error) { sessions = append(sessions, s) } - sort.Slice(sessions, func(i, j int) bool { - return sessions[i].StartTime.After(sessions[j].StartTime) + slices.SortFunc(sessions, func(a, b Session) int { + return b.StartTime.Compare(a.StartTime) // descending }) return sessions, nil