llama: remove model loading for grammar (#10096)

This commit is contained in:
Parth Sareen 2025-04-24 11:51:19 -07:00 committed by GitHub
parent 40b10eee6d
commit a53d744b01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 521 additions and 107 deletions

View file

@ -1,9 +1,14 @@
package sample
import (
"encoding/json"
"math"
"math/rand/v2"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
)
func TestWeighted(t *testing.T) {
@ -55,6 +60,97 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
t.Fatal(err)
}
types := make([]uint32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
}
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
``,
&model.Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
)
}
func TestGrammar(t *testing.T) {
tokenizer := modelHelper(t)
grammarJSON := `
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
`
grammar, err := NewGrammarSampler(tokenizer, grammarJSON)
if err != nil {
t.Fatal(err)
}
defer grammar.Free()
logits := make([]float32, len(tokenizer.Vocabulary().Values))
for i := range logits {
logits[i] = rand.Float32()
}
tokens := make([]token, len(logits))
for i := range tokens {
tokens[i].id = int32(i)
tokens[i].value = logits[i]
}
grammar.Apply(tokens)
nonInfCount := 0
infCount := 0
for _, tok := range tokens {
if math.IsInf(float64(tok.value), -1) {
infCount++
} else {
nonInfCount++
}
}
if nonInfCount == 0 {
t.Error("expected at least one non -inf token after grammar application, got none")
}
if infCount == 0 {
t.Error("expected some -inf tokens after grammar application, got none")
}
}
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy