sample: do all sorting in topK

This commit is contained in:
ParthSareen 2025-03-12 13:40:25 -04:00 committed by Parth Sareen
parent 3ba91634c1
commit 4aeb67ef4c
3 changed files with 35 additions and 25 deletions

View file

@ -53,8 +53,17 @@ func temperature(ts []token, temp float32) []token {
// topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token {
if k >= len(ts) {
sortLogits(ts)
if k >= len(ts) || k <= 0 {
slices.SortFunc(ts, func(a, b token) int {
switch {
case a.value < b.value:
return 1
case a.value > b.value:
return -1
default:
return 0
}
})
return ts
}
@ -125,17 +134,3 @@ func minP(ts []token, p float32) []token {
ts = validTokens
return ts
}
// sortLogits sorts the tokens in descending order of logits
func sortLogits(ts []token) {
slices.SortFunc(ts, func(a, b token) int {
switch {
case a.value < b.value:
return 1
case a.value > b.value:
return -1
default:
return 0
}
})
}