mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
sample: temporarily use grammars for constrained generation in new engine (#9586)
This commit is contained in:
parent
a1cda80bcb
commit
e093db92c4
10 changed files with 301 additions and 213 deletions
|
@ -5,7 +5,7 @@ import (
|
|||
"slices"
|
||||
)
|
||||
|
||||
func softmax(ts []logit) []logit {
|
||||
func softmax(ts []token) []token {
|
||||
var sum float32
|
||||
for i, v := range ts {
|
||||
ts[i].value = float32(math.Exp(float64(v.value)))
|
||||
|
@ -19,7 +19,7 @@ func softmax(ts []logit) []logit {
|
|||
return ts
|
||||
}
|
||||
|
||||
func temperature(ti []logit, t float32) []logit {
|
||||
func temperature(ti []token, t float32) []token {
|
||||
if t == 1 {
|
||||
return ti
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ func temperature(ti []logit, t float32) []logit {
|
|||
// 1. Finds the smallest value between the node and its children
|
||||
// 2. If the node is not the smallest, swaps it with its smallest child
|
||||
// 3. Continues this process down the affected path until the min-heap property is restored
|
||||
func siftDown(data []logit, start, end int) {
|
||||
func siftDown(data []token, start, end int) {
|
||||
root := start
|
||||
for {
|
||||
child := 2*root + 1
|
||||
|
@ -73,7 +73,7 @@ func siftDown(data []logit, start, end int) {
|
|||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
func topK(ts []logit, k int) []logit {
|
||||
func topK(ts []token, k int) []token {
|
||||
if k >= len(ts) {
|
||||
return ts
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ func topK(ts []logit, k int) []logit {
|
|||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
func topP(ts []logit, p float32) []logit {
|
||||
func topP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ func topP(ts []logit, p float32) []logit {
|
|||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
func minP(ts []logit, p float32) []logit {
|
||||
func minP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func minP(ts []logit, p float32) []logit {
|
|||
|
||||
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
|
||||
// Conting sort implementation to sort tokens by logits
|
||||
func sortLogits(tokens []logit) {
|
||||
func sortLogits(tokens []token) {
|
||||
if len(tokens) <= 1 {
|
||||
return
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ func sortLogits(tokens []logit) {
|
|||
}
|
||||
|
||||
// Second pass: place elements in correct position
|
||||
output := make([]logit, len(tokens))
|
||||
output := make([]token, len(tokens))
|
||||
// Track current positions
|
||||
countsCopy := counts
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue