mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
sample: add numerical stability to temperature/softmax transform (#9631)
This commit is contained in:
parent
fe776293f7
commit
7e34f4fbfa
3 changed files with 28 additions and 42 deletions
|
@ -5,13 +5,25 @@ import (
|
|||
"slices"
|
||||
)
|
||||
|
||||
func softmax(ts []token) []token {
|
||||
// temperature applies scaling and softmax to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
if t.value > maxLogit {
|
||||
maxLogit = t.value
|
||||
}
|
||||
}
|
||||
|
||||
// Apply temperature and compute exp(x - max)
|
||||
temp = max(temp, 1e-7)
|
||||
var sum float32
|
||||
for i, v := range ts {
|
||||
ts[i].value = float32(math.Exp(float64(v.value)))
|
||||
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
|
||||
sum += ts[i].value
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
@ -19,27 +31,6 @@ func softmax(ts []token) []token {
|
|||
return ts
|
||||
}
|
||||
|
||||
func temperature(ti []token, t float32) []token {
|
||||
if t == 1 {
|
||||
return ti
|
||||
}
|
||||
|
||||
temp := max(t, 1e-7)
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, token := range ti {
|
||||
if token.value > maxLogit {
|
||||
maxLogit = token.value
|
||||
}
|
||||
}
|
||||
|
||||
// subtracting max logit to avoid under/overflow
|
||||
for i := range ti {
|
||||
ti[i].value = (ti[i].value - maxLogit) / temp
|
||||
}
|
||||
|
||||
return ti
|
||||
}
|
||||
|
||||
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
|
||||
//
|
||||
// The heap is represented as an array where for any node at index i:
|
||||
|
@ -145,7 +136,8 @@ func minP(ts []token, p float32) []token {
|
|||
}
|
||||
|
||||
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
|
||||
// Conting sort implementation to sort tokens by logits
|
||||
// sortLogits sorts implementation to sort tokens by logits using counting sort
|
||||
// counting sort is faster than built-in sort for this use case
|
||||
func sortLogits(tokens []token) {
|
||||
if len(tokens) <= 1 {
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue