llama: remove model loading for grammar (#10096)

This commit is contained in:
Parth Sareen 2025-04-24 11:51:19 -07:00 committed by GitHub
parent 40b10eee6d
commit a53d744b01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 521 additions and 107 deletions

View file

@ -5,9 +5,9 @@ import (
"math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
// token represents information about a single token during sampling
@ -22,7 +22,7 @@ type Sampler struct {
topP float32
minP float32
temperature float32
grammar *Grammar
grammar *GrammarSampler
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
@ -127,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@ -164,63 +164,43 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
}
type Grammar struct {
vocab *Vocab
grammar string
sampler *llama.Sampler
type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
v, err := vocab.Load()
if err != nil {
return nil, err
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
return &Grammar{
vocab: vocab,
grammar: grammar,
sampler: llama.NewGrammarSampler(v, grammar),
}, nil
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}
return &GrammarSampler{grammar: grammar}, nil
}
func (g *Grammar) Apply(tokens []token) {
func (g *GrammarSampler) Apply(tokens []token) {
tds := make([]llama.TokenData, len(tokens))
for i, token := range tokens {
tds[i].Id = token.id
tds[i].ID = token.id
tds[i].Logit = token.value
}
g.sampler.Apply(tds)
g.grammar.Apply(tds)
for i := range tokens {
tokens[i].value = tds[i].Logit
}
}
func (g *Grammar) Accept(token int32) {
g.sampler.Accept(token)
func (g *GrammarSampler) Accept(token int32) {
g.grammar.Accept(token)
}
type Vocab struct {
once sync.Once
vocab *llama.Vocab
err error
path string
}
func NewVocab(path string) *Vocab {
return &Vocab{path: path}
}
// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
v.once.Do(func() {
vocab, err := llama.LoadVocabFromFile(v.path)
if err != nil {
v.err = err
return
}
v.vocab = vocab
})
return v.vocab, v.err
func (g *GrammarSampler) Free() {
g.grammar.Free()
}