diff --git a/calibrate.go b/calibrate.go index 0ed9091..9babeca 100644 --- a/calibrate.go +++ b/calibrate.go @@ -2,6 +2,7 @@ package i18n import ( "context" + "errors" "fmt" "time" @@ -48,7 +49,7 @@ func CalibrateDomains(ctx context.Context, modelA, modelB inference.TextModel, samples []CalibrationSample, opts ...ClassifyOption) (*CalibrationStats, error) { if len(samples) == 0 { - return nil, fmt.Errorf("calibrate: empty sample set") + return nil, errors.New("calibrate: empty sample set") } cfg := defaultClassifyConfig() diff --git a/reversal/reference.go b/reversal/reference.go index e5e4d03..3baf514 100644 --- a/reversal/reference.go +++ b/reversal/reference.go @@ -1,9 +1,10 @@ package reversal import ( - "fmt" + "errors" + "maps" "math" - "sort" + "slices" ) // ClassifiedText is a text sample with a domain label (from 1B model or ground truth). @@ -45,7 +46,7 @@ type ImprintClassification struct { // per unique domain label. func BuildReferences(tokeniser *Tokeniser, samples []ClassifiedText) (*ReferenceSet, error) { if len(samples) == 0 { - return nil, fmt.Errorf("empty sample set") + return nil, errors.New("empty sample set") } // Group imprints by domain. @@ -60,7 +61,7 @@ func BuildReferences(tokeniser *Tokeniser, samples []ClassifiedText) (*Reference } if len(grouped) == 0 { - return nil, fmt.Errorf("no samples with domain labels") + return nil, errors.New("no samples with domain labels") } rs := &ReferenceSet{Domains: make(map[string]*ReferenceDistribution)} @@ -105,7 +106,15 @@ func (rs *ReferenceSet) Classify(imprint GrammarImprint) ImprintClassification { for d, m := range distances { ranked = append(ranked, scored{d, m.CosineSimilarity}) } - sort.Slice(ranked, func(i, j int) bool { return ranked[i].sim > ranked[j].sim }) + slices.SortFunc(ranked, func(a, b scored) int { + if a.sim > b.sim { + return -1 + } + if a.sim < b.sim { + return 1 + } + return 0 + }) result := ImprintClassification{Distances: distances} if len(ranked) > 0 { @@ -121,12 +130,7 @@ func (rs *ReferenceSet) Classify(imprint GrammarImprint) ImprintClassification { // DomainNames returns sorted domain names in the reference set. func (rs *ReferenceSet) DomainNames() []string { - names := make([]string, 0, len(rs.Domains)) - for d := range rs.Domains { - names = append(names, d) - } - sort.Strings(names) - return names + return slices.Sorted(maps.Keys(rs.Domains)) } // computeCentroid averages imprints into a single centroid. diff --git a/service.go b/service.go index 17df471..67adde9 100644 --- a/service.go +++ b/service.go @@ -3,6 +3,7 @@ package i18n import ( "embed" "encoding/json" + "errors" "fmt" "io/fs" "maps" @@ -97,7 +98,7 @@ func NewWithLoader(loader Loader, opts ...Option) (*Service, error) { langs := loader.Languages() if len(langs) == 0 { - return nil, fmt.Errorf("no languages available from loader") + return nil, errors.New("no languages available from loader") } for _, lang := range langs { @@ -181,7 +182,7 @@ func (s *Service) SetLanguage(lang string) error { return fmt.Errorf("invalid language tag %q: %w", lang, err) } if len(s.availableLangs) == 0 { - return fmt.Errorf("no languages available") + return errors.New("no languages available") } matcher := language.NewMatcher(s.availableLangs) bestMatch, _, confidence := matcher.Match(requestedLang)