mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
sample: add numerical stability to temperature/softmax transform (#9631)
This commit is contained in:
parent
fe776293f7
commit
7e34f4fbfa
3 changed files with 28 additions and 42 deletions
|
@ -32,17 +32,9 @@ func compareLogits(t *testing.T, name string, want []float64, got []token) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
input := []float64{2, -1, 4, -3, 1, -2, 0}
|
||||
want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
|
||||
|
||||
func TestTemperatureAndSoftmax(t *testing.T) {
|
||||
input := []float64{1, 4, -2, 0}
|
||||
got := temperature(toTokens(input), 0.5)
|
||||
compareLogits(t, "Temperature", want, got)
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
got := softmax(toTokens(input))
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
|
@ -53,11 +45,14 @@ func TestSoftmax(t *testing.T) {
|
|||
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
||||
}
|
||||
|
||||
// Check relative ordering is preserved
|
||||
for i := 1; i < len(got); i++ {
|
||||
if got[i].value < got[i-1].value {
|
||||
t.Errorf("probability ordering not preserved at index %d", i)
|
||||
}
|
||||
got = temperature(toTokens(input), 1)
|
||||
// Check probabilities sum to 1
|
||||
sum = 0.0
|
||||
for _, token := range got {
|
||||
sum += token.value
|
||||
}
|
||||
if math.Abs(float64(sum)-1.0) > 1e-6 {
|
||||
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,7 +79,6 @@ func TestTopP(t *testing.T) {
|
|||
|
||||
// First apply temperature and softmax to get probabilities
|
||||
tokens = temperature(tokens, 1)
|
||||
tokens = softmax(tokens)
|
||||
sortLogits(tokens)
|
||||
|
||||
// Then apply topP
|
||||
|
@ -103,7 +97,6 @@ func TestMinP(t *testing.T) {
|
|||
|
||||
// First apply temperature and softmax
|
||||
tokens = temperature(tokens, 1)
|
||||
tokens = softmax(tokens)
|
||||
|
||||
// Then apply minP
|
||||
got := minP(tokens, 0.2)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue