diff --git a/go.mod b/go.mod index c45c9892c..cc5789005 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c golang.org/x/image v0.22.0 golang.org/x/tools v0.30.0 - gonum.org/v1/gonum v0.15.0 ) require ( @@ -45,6 +44,7 @@ require ( github.com/xtgo/set v1.0.0 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gonum.org/v1/gonum v0.15.0 // indirect gorgonia.org/vecf32 v0.9.0 // indirect gorgonia.org/vecf64 v0.9.0 // indirect ) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index e5189fa56..81e065624 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -589,11 +589,19 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + sampler := sample.NewSampler( + req.Temperature, + req.TopK, + req.TopP, + req.MinP, + req.Seed, + ) + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, numKeep: int32(req.NumKeep), - sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized + sampler: sampler, embedding: false, }) if err != nil { diff --git a/sample/samplers.go b/sample/samplers.go index 1b8a5edd9..a5a0507ca 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -2,76 +2,103 @@ package sample import ( "errors" - "math" - - "golang.org/x/exp/rand" - "gonum.org/v1/gonum/stat/sampleuv" + "math/rand/v2" + "slices" ) +// Sampler is not thread-safe. Each goroutine should have its own instance type Sampler interface { Sample([]float32) (int32, error) } +// logit represents information about a single token during sampling +type logit struct { + id int32 // The token's unique identifier + value float32 // The raw logit or probability from the model +} + type weighted struct { - src rand.Source - transforms []Transform + rng *rand.Rand + tokens []logit + topK int + topP float32 + minP float32 + temperature float32 } -// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279 -func Weighted(seed *uint64, transforms ...Transform) Sampler { - var src rand.Source - if seed != nil { - src = rand.NewSource(*seed) +func (s *weighted) Sample(logits []float32) (int32, error) { + if len(s.tokens) < len(logits) { + s.tokens = make([]logit, len(logits)) } - return weighted{src: src, transforms: transforms} -} -func (s weighted) Sample(logits []float32) (int32, error) { - logits64 := make([]float64, len(logits)) + tokens := s.tokens[:len(logits)] + for i, v := range logits { - logits64[i] = float64(v) + tokens[i].id = int32(i) + tokens[i].value = v } - for _, t := range s.transforms { - logits64 = t.Apply(logits64) + // Tokens are sorted by logits in TopK or SortTokens + if s.topK > 0 { + tokens = topK(tokens, s.topK) + } else { + sortLogits(tokens) } - logitsCopy := make([]float64, 0, len(logits)) - indices := make([]int, 0, len(logits)) - for i, logit := range logits64 { - if !math.IsInf(logit, -1) { - logitsCopy = append(logitsCopy, logit) - indices = append(indices, i) + tokens = temperature(tokens, s.temperature) + tokens = softmax(tokens) + + tokens = topP(tokens, s.topP) + tokens = minP(tokens, s.minP) + + if len(tokens) == 0 { + return -1, errors.New("no valid logits found for weighted sampling") + } + + var r float32 + if s.rng != nil { + r = s.rng.Float32() + } else { + r = rand.Float32() + } + + // Calculate cumulative sum of probabilities + var sum float32 + for i := range tokens { + sum += tokens[i].value + tokens[i].value = sum + } + r *= tokens[len(tokens)-1].value + + idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int { + // Compare cumulative probabilities + if token.value < target { + return -1 } + // First token that exceeds target + return 1 + }) + + if idx >= len(tokens) { + idx = len(tokens) - 1 } - if len(logitsCopy) == 0 { - return -1, errors.New("no valid logits found for weighed sampling") - } - - probs := softmax(logitsCopy) - w := sampleuv.NewWeighted(probs, s.src) - if idx, ok := w.Take(); ok { - return int32(indices[idx]), nil - } - return -1, errors.New("weighted sampler failed, no valid token found") + return tokens[idx].id, nil } type greedy struct{} -func Greedy() Sampler { - return greedy{} -} - -// Sample returns the index of the maximum value in logits. +// Greedy sample returns the index of the maximum value in logits. func (s greedy) Sample(logits []float32) (int32, error) { if len(logits) == 0 { return -1, errors.New("no logits provided for greedy sampling") } maxIdx := 0 - for i := range logits { - if logits[i] > logits[maxIdx] { + maxVal := logits[0] + for i := 1; i < len(logits); i++ { + if logits[i] > maxVal { + maxVal = logits[i] maxIdx = i } } @@ -80,41 +107,40 @@ func (s greedy) Sample(logits []float32) (int32, error) { } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 -func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) { +func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler { if temperature == 0 { - return Greedy(), nil + return &greedy{} } - if temperature < 0 || temperature > 2 { - return nil, errors.New("temperature must be between 0 and 2") + var rng *rand.Rand + if seed != -1 { + // PCG requires two parameters: sequence and stream + // Use original seed for sequence + sequence := uint64(seed) + // Use golden ratio hash to generate statistically independent seeds + rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9)) + } + temperature = max(temperature, 1) + + if topP < 0.0 { + topP = 0.0 + } + if topP >= 1.0 { + topP = 1.0 } - transforms := []Transform{Temperature(temperature)} - - if topK != 0 { - if topK <= 0 { - return nil, errors.New("topK must be greater than 0") - } - transforms = append(transforms, TopK(topK)) + if minP < 0.0 { + minP = 0.0 + } + if minP >= 1.0 { + minP = 1.0 } - if topP != 0 { - if topP < 0 || topP >= 1 { - return nil, errors.New("topP must be between 0 and 1") - } - transforms = append(transforms, TopP(topP)) + return &weighted{ + rng: rng, + topK: topK, + topP: topP, + minP: minP, + temperature: temperature, } - - if minP != 0 { - if minP < 0 || minP >= 1 { - return nil, errors.New("minP must be between 0 and 1") - } - transforms = append(transforms, MinP(minP)) - } - - if seed >= 0 { - seed64 := uint64(seed) - return Weighted(&seed64, transforms...), nil - } - return Weighted(nil, transforms...), nil } diff --git a/sample/samplers_benchmark_test.go b/sample/samplers_benchmark_test.go new file mode 100644 index 000000000..41c0b487f --- /dev/null +++ b/sample/samplers_benchmark_test.go @@ -0,0 +1,104 @@ +package sample + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkWeightedSampler(b *testing.B) { + sizes := []int{10, 100, 1000, 10000} + + for _, size := range sizes { + b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) { + logits := make([]float32, size) + for i := range logits { + logits[i] = float32(rand.Float64()*10 - 5) + } + + sampler := NewSampler(0.8, 0, 0, 0, 42) + b.ResetTimer() + for b.Loop() { + _, err := sampler.Sample(logits) + if err != nil { + b.Fatalf("Sampling failed: %v", err) + } + } + }) + } + + configs := []struct { + name string + temperature float32 + topK int + topP float32 + minP float32 + seed int + }{ + {"Greedy", 0, -1, 0, 0, -1}, + {"Temperature", 0.8, -1, 0, 0, -1}, + {"TopK", 0.8, 50, 0, 0, -1}, + {"TopP", 0.8, -1, 0.9, 0, -1}, + {"MinP", 0.8, -1, 0, 0.05, -1}, + {"WithSeed", 0.8, 50, 0, 0, 42}, + } + + // Fixed size for common vocab size + size := 128000 + logits := make([]float32, size) + for i := range logits { + logits[i] = float32(rand.Float64()*10 - 5) + } + + for _, tc := range configs { + b.Run("Config"+tc.name, func(b *testing.B) { + sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed) + sampler.Sample(logits) + + b.ResetTimer() + + for b.Loop() { + _, err := sampler.Sample(logits) + if err != nil { + b.Fatalf("Sampling failed: %v", err) + } + } + }) + } + + // Test with combined transforms separately - topK influences performance greatly + b.Run("TransformCombined", func(b *testing.B) { + sampler := NewSampler(0.8, 50, 0.9, 0.05, 42) + b.ResetTimer() + + for b.Loop() { + _, err := sampler.Sample(logits) + if err != nil { + b.Fatalf("Sampling failed: %v", err) + } + } + }) +} + +func BenchmarkGreedySampler(b *testing.B) { + sizes := []int{10, 100, 1000, 10000, 100000} + + for _, size := range sizes { + b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) { + logits := make([]float32, size) + for i := range logits { + logits[i] = float32(rand.Float64()*10 - 5) + } + + sampler := NewSampler(0, -1, 0, 0, -1) + b.ResetTimer() + + for b.Loop() { + _, err := sampler.Sample(logits) + if err != nil { + b.Fatalf("Sampling failed: %v", err) + } + } + }) + } +} diff --git a/sample/samplers_test.go b/sample/samplers_test.go index 32364a3b7..dbbee17bb 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -1,15 +1,14 @@ package sample import ( - "math" "math/rand/v2" "testing" - - "github.com/google/go-cmp/cmp" ) func TestWeighted(t *testing.T) { - got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))}) + logits := []float32{-10, 3, -10, -10} + sampler := NewSampler(0, 0, 0, 0, 0) + got, err := sampler.Sample(logits) if err != nil { t.Error(err) return @@ -19,64 +18,19 @@ func TestWeighted(t *testing.T) { t.Errorf("index mismatch: want %d, got %d", want, got) } - got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))}) - if err == nil { - t.Error("expected error for no valid tokens, got index", got) - } - - seed := uint64(42) - got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4}) + logits = []float32{-100, -10, 0, 10} + sampler = NewSampler(0, 0, 0, 0, 0) + got, err = sampler.Sample(logits) if err != nil { t.Error(err) return } - // With seed 42, we expect a consistent sample - want = int32(3) // This will be deterministic due to the seed + want = int32(3) // Should pick highest probability with this r value if want != got { t.Errorf("index mismatch: want %d, got %d", want, got) } } -type testTransform struct { - id int - callOrder *[]int -} - -func (ts *testTransform) Apply(logits []float64) []float64 { - if ts.callOrder != nil { - *ts.callOrder = append(*ts.callOrder, ts.id) - } - return logits -} - -func TestSample(t *testing.T) { - input := []float32{1, 2, 3, 4} - - var callOrder []int - mock1 := &testTransform{ - id: 1, - callOrder: &callOrder, - } - mock2 := &testTransform{ - id: 2, - callOrder: &callOrder, - } - mock3 := &testTransform{ - id: 3, - callOrder: &callOrder, - } - - _, err := Weighted(nil, mock1, mock2, mock3).Sample(input) - if err != nil { - t.Error(err) - return - } - wantOrder := []int{1, 2, 3} - if diff := cmp.Diff(wantOrder, callOrder); diff != "" { - t.Errorf("call order mismatch (-want +got):\n%s", diff) - } -} - func TestNewSampler(t *testing.T) { tests := []struct { name string @@ -85,75 +39,41 @@ func TestNewSampler(t *testing.T) { topP float32 minP float32 seed int - wantErr bool + wantGreedy bool // Instead of wantErr, check if we get greedy sampler }{ - { - name: "no transforms", - // temperature is 0, so greedy should be used - wantErr: false, - }, { name: "temperature", temperature: 0.5, - wantErr: false, + wantGreedy: false, }, { - name: "invalid temperature negative", - temperature: -1, - wantErr: true, - }, - { - name: "invalid temperature too high", - temperature: 2.1, - wantErr: true, + name: "zero temperature - greedy", + temperature: 0, + wantGreedy: true, }, { name: "top k", + temperature: 0.1, topK: 10, - temperature: 0.8, - wantErr: false, - }, - { - name: "invalid top k negative", - topK: -1, - temperature: 0.8, - wantErr: true, + wantGreedy: false, }, { name: "top p", + temperature: 0.1, topP: 0.9, - temperature: 0.8, - wantErr: false, - }, - { - name: "invalid top p negative", - topP: -0.1, - temperature: 0.8, - wantErr: true, - }, - { - name: "invalid top p one", - topP: 1.0, - temperature: 0.8, - wantErr: true, + wantGreedy: false, }, { name: "min p", + temperature: 0.1, minP: 0.2, - temperature: 0.8, - wantErr: false, + wantGreedy: false, }, { - name: "invalid min p negative", - minP: -0.1, - temperature: 0.8, - wantErr: true, - }, - { - name: "invalid min p one", - minP: 1.0, - temperature: 0.8, - wantErr: true, + name: "seed - weighted", + temperature: 0.1, + seed: 42, + wantGreedy: false, }, { name: "default values", @@ -162,16 +82,16 @@ func TestNewSampler(t *testing.T) { topP: 0.9, minP: 0.0, seed: 0, - wantErr: false, + wantGreedy: false, }, { - name: "all zeroes", + name: "all zeroes - greedy", temperature: 0.0, topK: 0, topP: 0.0, minP: 0.0, seed: 0, - wantErr: false, // all zeroes means no transforms + wantGreedy: true, }, { name: "all transforms", @@ -180,33 +100,28 @@ func TestNewSampler(t *testing.T) { topP: 0.95, minP: 0.1, seed: 42, - wantErr: false, + wantGreedy: false, }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed) - if (err != nil) != tt.wantErr { - t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr) + sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed) + _, isGreedy := sampler.(*greedy) + if isGreedy != tt.wantGreedy { + t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy) } }) } } func BenchmarkSample(b *testing.B) { - transforms := []Transform{ - Temperature(0.5), - TopK(10), - TopP(0.9), - MinP(0.2), - } - + weighted := NewSampler(0.5, 10, 0.9, 0.2, -1) samplers := map[string]Sampler{ - "Greedy": Greedy(), - "Weighted": Weighted(nil, transforms...), + "Greedy": NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy + "Weighted": weighted, } + // Generate random logits for benchmarking logits := make([]float32, 1<<16) for i := range logits { logits[i] = rand.Float32() @@ -215,7 +130,7 @@ func BenchmarkSample(b *testing.B) { for name, s := range samplers { b.Run(name, func(b *testing.B) { b.ResetTimer() - for range b.N { + for b.Loop() { if _, err := s.Sample(logits); err != nil { b.Error(err) } diff --git a/sample/transforms.go b/sample/transforms.go index 2dc6ebae1..f1f4f3b19 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -1,120 +1,203 @@ package sample import ( - "cmp" "math" "slices" - - pq "github.com/emirpasic/gods/v2/queues/priorityqueue" ) -type Transform interface { - Apply([]float64) []float64 -} - -// TODO(parthsareen): potentially cache softmax values -func softmax(logits []float64) []float64 { - var sum float64 - probs := make([]float64, len(logits)) - for i, v := range logits { - probs[i] = math.Exp(v) - sum += probs[i] +func softmax(ts []logit) []logit { + var sum float32 + for i, v := range ts { + ts[i].value = float32(math.Exp(float64(v.value))) + sum += ts[i].value } - for i := range probs { - probs[i] /= sum + for i := range ts { + ts[i].value /= sum } - return probs + return ts } -type Temperature float64 +func temperature(ti []logit, t float32) []logit { + if t == 1 { + return ti + } -func (t Temperature) Apply(logits []float64) []float64 { - temp := math.Max(float64(t), 1e-7) + temp := max(t, 1e-7) + maxLogit := float32(math.Inf(-1)) + for _, token := range ti { + if token.value > maxLogit { + maxLogit = token.value + } + } // subtracting max logit to avoid under/overflow - maxLogit := slices.Max(logits) - for i := range logits { - logits[i] = (logits[i] - maxLogit) / temp + for i := range ti { + ti[i].value = (ti[i].value - maxLogit) / temp } - return logits + return ti } -type logitMap struct { - index int - logit float64 -} - -type TopK int - -// TODO(parthsareen): avoid having to check all logits after this transform -func (k TopK) Apply(logits []float64) []float64 { - if int(k) >= len(logits) { - return logits - } - q := pq.NewWith(func(a, b logitMap) int { - return -cmp.Compare(a.logit, b.logit) - }) - - for i, logit := range logits { - q.Enqueue(logitMap{index: i, logit: logit}) - } - - validLogits := make(map[int]float64) - for range k { - logitMap, _ := q.Dequeue() - validLogits[logitMap.index] = logitMap.logit - } - - for i := range logits { - if _, ok := validLogits[i]; !ok { - logits[i] = math.Inf(-1) - } - } - - return logits -} - -type TopP float64 - -func (p TopP) Apply(logits []float64) []float64 { - probs := softmax(logits) - indices := make([]int, len(probs)) - for i := range indices { - indices[i] = i - } - - // sort in descending order - slices.SortFunc(indices, func(i, j int) int { - return cmp.Compare(probs[j], probs[i]) - }) - - var sum float64 - for i, idx := range indices { - sum += probs[idx] - if sum > float64(p) { - for _, idx := range indices[i+1:] { - logits[idx] = math.Inf(-1) - } +// siftDown maintains a min-heap property by recursively moving larger elements down the heap. +// +// The heap is represented as an array where for any node at index i: +// - Left child is at index 2i + 1 +// - Right child is at index 2i + 2 +// - Parent is at index (i-1)/2 +// +// The function compares a node with its children and: +// 1. Finds the smallest value between the node and its children +// 2. If the node is not the smallest, swaps it with its smallest child +// 3. Continues this process down the affected path until the min-heap property is restored +func siftDown(data []logit, start, end int) { + root := start + for { + child := 2*root + 1 + if child >= end { break } + // Find smaller child (we want min heap) + if child+1 < end && data[child+1].value < data[child].value { + child++ + } + // Exit if root is already smaller than children + if data[root].value <= data[child].value { + break + } + // Swap with smaller child and continue + data[root], data[child] = data[child], data[root] + root = child } - return logits } -type MinP float64 +// topK limits the number of tokens considered to the k highest logits +func topK(ts []logit, k int) []logit { + if k >= len(ts) { + return ts + } + // Heapify + siftDown - O(nlog(k)) + // Build min-heap of first k elements + heap := ts[:k] + for i := k/2 - 1; i >= 0; i-- { + siftDown(heap, i, k) + } -func (p MinP) Apply(logits []float64) []float64 { - probs := softmax(logits) - threshold := slices.Max(probs) * float64(p) - - for i, prob := range probs { - if prob < threshold { - logits[i] = math.Inf(-1) + // Process remaining elements - if larger than heap root, replace root + for i := k; i < len(ts); i++ { + if ts[i].value > heap[0].value { + heap[0] = ts[i] + siftDown(heap, 0, k) } } - return logits + slices.Reverse(heap) + + ts = heap + return ts +} + +// topP limits tokens to those with cumulative probability p +func topP(ts []logit, p float32) []logit { + if p == 1.0 { + return ts + } + + // Find cutoff index where cumulative sum exceeds p + var sum float32 + for i, t := range ts { + sum += t.value + if sum > float32(p) { + ts = ts[:i+1] + return ts + } + } + + return ts +} + +// minP limits tokens to those with cumulative probability p +func minP(ts []logit, p float32) []logit { + if p == 1.0 { + return ts + } + + maxProb := float32(math.Inf(-1)) + for _, token := range ts { + if token.value > maxProb { + maxProb = token.value + } + } + + threshold := maxProb * float32(p) + + // Filter tokens in-place + validTokens := ts[:0] + for i, token := range ts { + if token.value >= threshold { + validTokens = append(validTokens, ts[i]) + } + } + + ts = validTokens + return ts +} + +// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584 +// Conting sort implementation to sort tokens by logits +func sortLogits(tokens []logit) { + if len(tokens) <= 1 { + return + } + + // Find max/min in a single pass + minLogit, maxLogit := tokens[0].value, tokens[0].value + for _, t := range tokens[1:] { + if t.value < minLogit { + minLogit = t.value + } else if t.value > maxLogit { + maxLogit = t.value + } + } + + // Calculate scaling to map to uint32 range + logitRange := maxLogit - minLogit + if logitRange < 1e-6 { + return // All values effectively equal + } + + // Count frequencies directly from tokens + const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity + var counts [256]int // For first byte + + // First pass: count frequencies + for _, t := range tokens { + // Map to [0, maxInt] range + score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt) + counts[score>>16]++ + } + + // Calculate offsets + var offset int + for i := range counts { + count := counts[i] + counts[i] = offset + offset += count + } + + // Second pass: place elements in correct position + output := make([]logit, len(tokens)) + // Track current positions + countsCopy := counts + + for i, t := range tokens { + score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt) + + pos := countsCopy[score>>16] + countsCopy[score>>16]++ + output[len(tokens)-1-pos] = tokens[i] + } + + copy(tokens, output) } diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 05f76a274..950d79b35 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -4,77 +4,182 @@ import ( "math" "math/rand/v2" "testing" - - "github.com/google/go-cmp/cmp" ) -func TestTemperature(t *testing.T) { - got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0}) - want := []float64{-4, -10, 0, -14, -6, -12, -8} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) +// Helper to convert float64 slice to logit slice +func toLogits(values []float64) []logit { + tokens := make([]logit, len(values)) + for i, v := range values { + tokens[i] = logit{ + id: int32(i), + value: float32(v), + } + } + return tokens +} + +// Helper to compare logit slices +func compareLogits(t *testing.T, name string, want []float64, got []logit) { + t.Helper() + if len(want) != len(got) { + t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got)) + return + } + for i := range want { + if math.Abs(float64(got[i].value)-want[i]) > 1e-6 { + t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value) + } } } -func TestSoftmax(t *testing.T) { - got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4}) +func TestTemperature(t *testing.T) { + input := []float64{2, -1, 4, -3, 1, -2, 0} + want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp - want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("probs mismatch (-want +got):\n%s", diff) + got := temperature(toLogits(input), 0.5) + compareLogits(t, "Temperature", want, got) +} + +func TestSoftmax(t *testing.T) { + input := []float64{-3, -2, -1, 0, 1, 2, 4} + got := softmax(toLogits(input)) + + // Check probabilities sum to 1 + var sum float32 + for _, token := range got { + sum += token.value + } + if math.Abs(float64(sum)-1.0) > 1e-6 { + t.Errorf("probabilities don't sum to 1: got %f", sum) + } + + // Check relative ordering is preserved + for i := 1; i < len(got); i++ { + if got[i].value < got[i-1].value { + t.Errorf("probability ordering not preserved at index %d", i) + } } } func TestTopK(t *testing.T) { - got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) - want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) - } + input := []float64{-3, -2, -1, 0, 1, 2, 4} - got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) - - want = []float64{-3, -2, -1, 0, 1, 2, 4} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) + // Test k=3 + got := topK(toLogits(input), 3) + if len(got) != 3 { + t.Errorf("topK(3): wrong length: want 3, got %d", len(got)) } + // Should keep highest 3 values: 4, 2, 1 + want := []float64{4, 2, 1} + compareLogits(t, "topK(3)", want, got) + + // Test k > len + got = topK(toLogits(input), 10) + compareLogits(t, "topK(10)", input, got) } func TestTopP(t *testing.T) { - got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) - want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) + input := []float64{-3, -2, -1, 0, 1, 2, 4} + tokens := toLogits(input) + + // First apply temperature and softmax to get probabilities + tokens = temperature(tokens, 1) + tokens = softmax(tokens) + sortLogits(tokens) + + // Then apply topP + got := topP(tokens, 0.95) + + // Should keep tokens until cumsum > 0.95 + if len(got) > 3 { + t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) + t.Logf("got: %v", got) } } func TestMinP(t *testing.T) { - got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3}) - want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("logits mismatch (-want +got):\n%s", diff) + input := []float64{-3, -2, -1, 0, 1, 2, 4, 3} + tokens := toLogits(input) + + // First apply temperature and softmax + tokens = temperature(tokens, 1) + tokens = softmax(tokens) + + // Then apply minP + got := minP(tokens, 0.2) + + // Should keep tokens with prob >= 0.2 * max_prob + if len(got) > 3 { + t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) } } -func BenchmarkTransform(b *testing.B) { - transforms := map[string]Transform{ - "Temperature": Temperature(0.5), - "TopK": TopK(10), - "TopP": TopP(0.9), - "MinP": MinP(0.2), +func TestSortLogits(t *testing.T) { + input := []float64{3, 1, 4, 2, -1, 0, -2} + tokens := toLogits(input) + + sortLogits(tokens) + + for i := 1; i < len(tokens); i++ { + if tokens[i].value > tokens[i-1].value { + t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f", + i, tokens[i].value, tokens[i-1].value) + } } - logits := make([]float64, 1<<16) - for i := range logits { - logits[i] = rand.Float64() - } - - for name, transform := range transforms { - b.Run(name, func(b *testing.B) { - b.ResetTimer() - for range b.N { - transform.Apply(logits) - } - }) - } + want := []float64{4, 3, 2, 1, 0, -1, -2} + compareLogits(t, "sortLogits", want, tokens) +} + +func BenchmarkTransforms(b *testing.B) { + // Generate random logits + tokens := make([]logit, 1<<16) + for i := range tokens { + tokens[i] = logit{ + id: int32(i), + value: rand.Float32(), + } + } + + tokensCopy := make([]logit, len(tokens)) + + b.Run("Temperature", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + temperature(tokensCopy, 0.5) + } + }) + + b.Run("TopK", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + topK(tokensCopy, 10) + } + }) + + b.Run("TopP", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + topP(tokensCopy, 0.9) + } + }) + + b.Run("MinP", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + minP(tokensCopy, 0.2) + } + }) + + b.Run("SortTokens", func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + copy(tokensCopy, tokens) + sortLogits(tokensCopy) + } + }) }