mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
sample: do all sorting in topK
This commit is contained in:
parent
3ba91634c1
commit
4aeb67ef4c
3 changed files with 35 additions and 25 deletions
|
@ -53,8 +53,17 @@ func temperature(ts []token, temp float32) []token {
|
|||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
func topK(ts []token, k int) []token {
|
||||
if k >= len(ts) {
|
||||
sortLogits(ts)
|
||||
if k >= len(ts) || k <= 0 {
|
||||
slices.SortFunc(ts, func(a, b token) int {
|
||||
switch {
|
||||
case a.value < b.value:
|
||||
return 1
|
||||
case a.value > b.value:
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return ts
|
||||
}
|
||||
|
||||
|
@ -125,17 +134,3 @@ func minP(ts []token, p float32) []token {
|
|||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
// sortLogits sorts the tokens in descending order of logits
|
||||
func sortLogits(ts []token) {
|
||||
slices.SortFunc(ts, func(a, b token) int {
|
||||
switch {
|
||||
case a.value < b.value:
|
||||
return 1
|
||||
case a.value > b.value:
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue