sample: add numerical stability to temperature/softmax transform (#9631)

This commit is contained in:
Parth Sareen 2025-03-10 14:43:53 -07:00 committed by GitHub
parent fe776293f7
commit 7e34f4fbfa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 42 deletions

View file

@ -5,13 +5,25 @@ import (
"slices"
)
func softmax(ts []token) []token {
// temperature applies scaling and softmax to the logits
func temperature(ts []token, temp float32) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Apply temperature and compute exp(x - max)
temp = max(temp, 1e-7)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value)))
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
sum += ts[i].value
}
// Normalize
for i := range ts {
ts[i].value /= sum
}
@ -19,27 +31,6 @@ func softmax(ts []token) []token {
return ts
}
func temperature(ti []token, t float32) []token {
if t == 1 {
return ti
}
temp := max(t, 1e-7)
maxLogit := float32(math.Inf(-1))
for _, token := range ti {
if token.value > maxLogit {
maxLogit = token.value
}
}
// subtracting max logit to avoid under/overflow
for i := range ti {
ti[i].value = (ti[i].value - maxLogit) / temp
}
return ti
}
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
//
// The heap is represented as an array where for any node at index i:
@ -145,7 +136,8 @@ func minP(ts []token, p float32) []token {
}
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits
// sortLogits sorts implementation to sort tokens by logits using counting sort
// counting sort is faster than built-in sort for this use case
func sortLogits(tokens []token) {
if len(tokens) <= 1 {
return