mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
sample: temporarily use grammars for constrained generation in new engine (#9586)
This commit is contained in:
parent
a1cda80bcb
commit
e093db92c4
10 changed files with 301 additions and 213 deletions
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
func TestWeighted(t *testing.T) {
|
||||
logits := []float32{-10, 3, -10, -10}
|
||||
sampler := NewSampler(0, 0, 0, 0, 0)
|
||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
||||
got, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -19,7 +19,7 @@ func TestWeighted(t *testing.T) {
|
|||
}
|
||||
|
||||
logits = []float32{-100, -10, 0, 10}
|
||||
sampler = NewSampler(0, 0, 0, 0, 0)
|
||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -31,94 +31,10 @@ func TestWeighted(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewSampler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
temperature float32
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
seed int
|
||||
wantGreedy bool // Instead of wantErr, check if we get greedy sampler
|
||||
}{
|
||||
{
|
||||
name: "temperature",
|
||||
temperature: 0.5,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "zero temperature - greedy",
|
||||
temperature: 0,
|
||||
wantGreedy: true,
|
||||
},
|
||||
{
|
||||
name: "top k",
|
||||
temperature: 0.1,
|
||||
topK: 10,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "top p",
|
||||
temperature: 0.1,
|
||||
topP: 0.9,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "min p",
|
||||
temperature: 0.1,
|
||||
minP: 0.2,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "seed - weighted",
|
||||
temperature: 0.1,
|
||||
seed: 42,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "default values",
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
topP: 0.9,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantGreedy: false,
|
||||
},
|
||||
{
|
||||
name: "all zeroes - greedy",
|
||||
temperature: 0.0,
|
||||
topK: 0,
|
||||
topP: 0.0,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantGreedy: true,
|
||||
},
|
||||
{
|
||||
name: "all transforms",
|
||||
temperature: 0.8,
|
||||
topK: 50,
|
||||
topP: 0.95,
|
||||
minP: 0.1,
|
||||
seed: 42,
|
||||
wantGreedy: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||
_, isGreedy := sampler.(*greedy)
|
||||
if isGreedy != tt.wantGreedy {
|
||||
t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": weighted,
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
||||
}
|
||||
|
||||
// Generate random logits for benchmarking
|
||||
|
@ -132,7 +48,7 @@ func BenchmarkSample(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
if _, err := s.Sample(logits); err != nil {
|
||||
b.Error(err)
|
||||
b.Fatalf("error sampling: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue