sample: add numerical stability to temperature/softmax transform (#9631)

This commit is contained in:
Parth Sareen 2025-03-10 14:43:53 -07:00 committed by GitHub
parent fe776293f7
commit 7e34f4fbfa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 42 deletions

View file

@ -32,17 +32,9 @@ func compareLogits(t *testing.T, name string, want []float64, got []token) {
}
}
func TestTemperature(t *testing.T) {
input := []float64{2, -1, 4, -3, 1, -2, 0}
want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
func TestTemperatureAndSoftmax(t *testing.T) {
input := []float64{1, 4, -2, 0}
got := temperature(toTokens(input), 0.5)
compareLogits(t, "Temperature", want, got)
}
func TestSoftmax(t *testing.T) {
input := []float64{-3, -2, -1, 0, 1, 2, 4}
got := softmax(toTokens(input))
// Check probabilities sum to 1
var sum float32
@ -53,11 +45,14 @@ func TestSoftmax(t *testing.T) {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
// Check relative ordering is preserved
for i := 1; i < len(got); i++ {
if got[i].value < got[i-1].value {
t.Errorf("probability ordering not preserved at index %d", i)
}
got = temperature(toTokens(input), 1)
// Check probabilities sum to 1
sum = 0.0
for _, token := range got {
sum += token.value
}
if math.Abs(float64(sum)-1.0) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
}
@ -84,7 +79,6 @@ func TestTopP(t *testing.T) {
// First apply temperature and softmax to get probabilities
tokens = temperature(tokens, 1)
tokens = softmax(tokens)
sortLogits(tokens)
// Then apply topP
@ -103,7 +97,6 @@ func TestMinP(t *testing.T) {
// First apply temperature and softmax
tokens = temperature(tokens, 1)
tokens = softmax(tokens)
// Then apply minP
got := minP(tokens, 0.2)