sample: add numerical stability to temperature/softmax transform (#9631)

This commit is contained in:
Parth Sareen 2025-03-10 14:43:53 -07:00 committed by GitHub
parent fe776293f7
commit 7e34f4fbfa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 42 deletions

View file

@ -90,8 +90,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
sortLogits(tokens) sortLogits(tokens)
} }
// token logit values are updated to probabilities
tokens = temperature(tokens, s.temperature) tokens = temperature(tokens, s.temperature)
tokens = softmax(tokens)
tokens = topP(tokens, s.topP) tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP) tokens = minP(tokens, s.minP)

View file

@ -5,13 +5,25 @@ import (
"slices" "slices"
) )
func softmax(ts []token) []token { // temperature applies scaling and softmax to the logits
func temperature(ts []token, temp float32) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Apply temperature and compute exp(x - max)
temp = max(temp, 1e-7)
var sum float32 var sum float32
for i, v := range ts { for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value))) ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
sum += ts[i].value sum += ts[i].value
} }
// Normalize
for i := range ts { for i := range ts {
ts[i].value /= sum ts[i].value /= sum
} }
@ -19,27 +31,6 @@ func softmax(ts []token) []token {
return ts return ts
} }
func temperature(ti []token, t float32) []token {
if t == 1 {
return ti
}
temp := max(t, 1e-7)
maxLogit := float32(math.Inf(-1))
for _, token := range ti {
if token.value > maxLogit {
maxLogit = token.value
}
}
// subtracting max logit to avoid under/overflow
for i := range ti {
ti[i].value = (ti[i].value - maxLogit) / temp
}
return ti
}
// siftDown maintains a min-heap property by recursively moving larger elements down the heap. // siftDown maintains a min-heap property by recursively moving larger elements down the heap.
// //
// The heap is represented as an array where for any node at index i: // The heap is represented as an array where for any node at index i:
@ -145,7 +136,8 @@ func minP(ts []token, p float32) []token {
} }
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584 // TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits // sortLogits sorts implementation to sort tokens by logits using counting sort
// counting sort is faster than built-in sort for this use case
func sortLogits(tokens []token) { func sortLogits(tokens []token) {
if len(tokens) <= 1 { if len(tokens) <= 1 {
return return

View file

@ -32,17 +32,9 @@ func compareLogits(t *testing.T, name string, want []float64, got []token) {
} }
} }
func TestTemperature(t *testing.T) { func TestTemperatureAndSoftmax(t *testing.T) {
input := []float64{2, -1, 4, -3, 1, -2, 0} input := []float64{1, 4, -2, 0}
want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
got := temperature(toTokens(input), 0.5) got := temperature(toTokens(input), 0.5)
compareLogits(t, "Temperature", want, got)
}
func TestSoftmax(t *testing.T) {
input := []float64{-3, -2, -1, 0, 1, 2, 4}
got := softmax(toTokens(input))
// Check probabilities sum to 1 // Check probabilities sum to 1
var sum float32 var sum float32
@ -53,11 +45,14 @@ func TestSoftmax(t *testing.T) {
t.Errorf("probabilities don't sum to 1: got %f", sum) t.Errorf("probabilities don't sum to 1: got %f", sum)
} }
// Check relative ordering is preserved got = temperature(toTokens(input), 1)
for i := 1; i < len(got); i++ { // Check probabilities sum to 1
if got[i].value < got[i-1].value { sum = 0.0
t.Errorf("probability ordering not preserved at index %d", i) for _, token := range got {
} sum += token.value
}
if math.Abs(float64(sum)-1.0) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
} }
} }
@ -84,7 +79,6 @@ func TestTopP(t *testing.T) {
// First apply temperature and softmax to get probabilities // First apply temperature and softmax to get probabilities
tokens = temperature(tokens, 1) tokens = temperature(tokens, 1)
tokens = softmax(tokens)
sortLogits(tokens) sortLogits(tokens)
// Then apply topP // Then apply topP
@ -103,7 +97,6 @@ func TestMinP(t *testing.T) {
// First apply temperature and softmax // First apply temperature and softmax
tokens = temperature(tokens, 1) tokens = temperature(tokens, 1)
tokens = softmax(tokens)
// Then apply minP // Then apply minP
got := minP(tokens, 0.2) got := minP(tokens, 0.2)