sample: temporarily use grammars for constrained generation in new engine (#9586)

This commit is contained in:
Jeffrey Morgan 2025-03-10 16:17:39 +01:00 committed by GitHub
parent a1cda80bcb
commit e093db92c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 301 additions and 213 deletions

View file

@ -5,7 +5,7 @@ import (
"slices"
)
func softmax(ts []logit) []logit {
func softmax(ts []token) []token {
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value)))
@ -19,7 +19,7 @@ func softmax(ts []logit) []logit {
return ts
}
func temperature(ti []logit, t float32) []logit {
func temperature(ti []token, t float32) []token {
if t == 1 {
return ti
}
@ -51,7 +51,7 @@ func temperature(ti []logit, t float32) []logit {
// 1. Finds the smallest value between the node and its children
// 2. If the node is not the smallest, swaps it with its smallest child
// 3. Continues this process down the affected path until the min-heap property is restored
func siftDown(data []logit, start, end int) {
func siftDown(data []token, start, end int) {
root := start
for {
child := 2*root + 1
@ -73,7 +73,7 @@ func siftDown(data []logit, start, end int) {
}
// topK limits the number of tokens considered to the k highest logits
func topK(ts []logit, k int) []logit {
func topK(ts []token, k int) []token {
if k >= len(ts) {
return ts
}
@ -99,7 +99,7 @@ func topK(ts []logit, k int) []logit {
}
// topP limits tokens to those with cumulative probability p
func topP(ts []logit, p float32) []logit {
func topP(ts []token, p float32) []token {
if p == 1.0 {
return ts
}
@ -118,7 +118,7 @@ func topP(ts []logit, p float32) []logit {
}
// minP limits tokens to those with cumulative probability p
func minP(ts []logit, p float32) []logit {
func minP(ts []token, p float32) []token {
if p == 1.0 {
return ts
}
@ -146,7 +146,7 @@ func minP(ts []logit, p float32) []logit {
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits
func sortLogits(tokens []logit) {
func sortLogits(tokens []token) {
if len(tokens) <= 1 {
return
}
@ -187,7 +187,7 @@ func sortLogits(tokens []logit) {
}
// Second pass: place elements in correct position
output := make([]logit, len(tokens))
output := make([]token, len(tokens))
// Track current positions
countsCopy := counts