From a53d744b01c65de77afb77aed4a576b317a90912 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 24 Apr 2025 11:51:19 -0700 Subject: [PATCH] llama: remove model loading for grammar (#10096) --- llama/llama.cpp/src/llama-grammar.cpp | 49 ++++- llama/llama.cpp/src/llama-grammar.h | 14 ++ llama/llama.cpp/src/llama-sampling.cpp | 4 +- llama/llama.go | 108 +++++---- ...add-ollama-vocab-for-grammar-support.patch | 207 ++++++++++++++++++ llama/sampling_ext.cpp | 47 ++++ llama/sampling_ext.h | 8 +- model/models/mistral3/model.go | 3 + model/process_text.go | 7 + model/process_text_spm.go | 4 + runner/ollamarunner/runner.go | 13 +- sample/samplers.go | 68 ++---- sample/samplers_test.go | 96 ++++++++ 13 files changed, 521 insertions(+), 107 deletions(-) create mode 100644 llama/patches/0021-add-ollama-vocab-for-grammar-support.patch diff --git a/llama/llama.cpp/src/llama-grammar.cpp b/llama/llama.cpp/src/llama-grammar.cpp index 973b47ae0..60d582362 100644 --- a/llama/llama.cpp/src/llama-grammar.cpp +++ b/llama/llama.cpp/src/llama-grammar.cpp @@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, + const struct ollama_vocab * ollama_vocab, const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { @@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl( // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar { vocab, + ollama_vocab, std::move(vec_rules), std::move(stacks), /* .partial_utf8 = */ {}, @@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl( struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, + const struct ollama_vocab * ollama_vocab, const char * grammar_str, const char * grammar_root, bool lazy, @@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl( // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar { vocab, + ollama_vocab, std::move(vec_rules), std::move(stacks), /* .partial_utf8 = */ {}, @@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { auto * result = new llama_grammar { grammar.vocab, + grammar.o_vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, @@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra } void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { - GGML_ASSERT(grammar.vocab != nullptr); if (grammar.awaiting_trigger) { return; @@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; - const std::string & piece = grammar.vocab->token_to_piece(id); + const std::string piece = grammar.o_vocab ? + grammar.o_vocab->token_to_piece(id) : + grammar.vocab->token_to_piece(id); - if (grammar.vocab->is_eog(id)) { + const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id); + + if (is_eog) { if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } @@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ } void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { - GGML_ASSERT(grammar.vocab != nullptr); - const auto & piece = grammar.vocab->token_to_piece(token); + const std::string piece = grammar.o_vocab ? + grammar.o_vocab->token_to_piece(token) : + grammar.vocab->token_to_piece(token); if (grammar.awaiting_trigger) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { @@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token } } - if (grammar.vocab->is_eog(token)) { + const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token); + if (is_eog) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } } - GGML_ABORT("fatal error"); + GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty"); } llama_grammar_accept_str(grammar, piece); @@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } } + + +const std::string & ollama_vocab::token_to_piece(const uint32_t token) const { + try { + return token_to_piece_map.at(token); + } catch (const std::out_of_range&) { + throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token)); + } +} + +void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) { + for (size_t i = 0; i < n_tokens; i++) { + token_to_piece_map[tokens[i]] = pieces[i]; + } +} + +bool ollama_vocab::is_eog(const uint32_t token) const { + return special_eog_ids.count(token) > 0; +} + +void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) { + for (size_t i = 0; i < n_tokens; i++) { + special_eog_ids.insert(tokens[i]); + } +} diff --git a/llama/llama.cpp/src/llama-grammar.h b/llama/llama.cpp/src/llama-grammar.h index f8c291de9..2a3a62db3 100644 --- a/llama/llama.cpp/src/llama-grammar.h +++ b/llama/llama.cpp/src/llama-grammar.h @@ -6,8 +6,19 @@ #include #include #include +#include struct llama_vocab; +struct ollama_vocab { + std::map token_to_piece_map; + std::set special_eog_ids; + + const std::string & token_to_piece(const uint32_t token) const; + void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces); + void set_eog_tokens(const uint32_t* tokens, size_t n_tokens); + bool is_eog(const uint32_t token) const; + +}; // grammar element type enum llama_gretype { @@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern { struct llama_grammar { // note: allow null vocab for testing (not great) const llama_vocab * vocab; + const ollama_vocab * o_vocab; const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; @@ -141,12 +153,14 @@ struct llama_grammar { // note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, + const struct ollama_vocab * ollama_vocab, const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index); struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, + const struct ollama_vocab * ollama_vocab, const char * grammar_str, const char * grammar_root, bool lazy, diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampling.cpp index d14979850..b1a9dca3c 100644 --- a/llama/llama.cpp/src/llama-sampling.cpp +++ b/llama/llama.cpp/src/llama-sampling.cpp @@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } - auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); @@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), + /* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), }; if (!ctx->grammar) { delete ctx; diff --git a/llama/llama.go b/llama/llama.go index 1c2329ce4..5fce0a622 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -35,6 +35,7 @@ import ( "runtime/cgo" "slices" "strings" + "sync" "unsafe" _ "github.com/ollama/ollama/llama/llama.cpp/common" @@ -249,20 +250,6 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) { return &m, nil } -func LoadVocabFromFile(path string) (*Vocab, error) { - mp := C.CString(path) - defer C.free(unsafe.Pointer(mp)) - v := Vocab{c: C.llama_load_vocab_from_file(mp)} - if v.c == nil { - return nil, fmt.Errorf("unable to load vocab: %s", path) - } - return &v, nil -} - -func FreeVocab(vocab *Vocab) { - C.llama_free_vocab(vocab.c) -} - func FreeModel(model *Model) { C.llama_model_free(model.c) } @@ -311,10 +298,6 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float return nil } -type Vocab struct { - c *C.struct_llama_vocab -} - func (m *Model) Vocab() *C.struct_llama_vocab { return C.llama_model_get_vocab(m.c) } @@ -692,35 +675,65 @@ func SchemaToGrammar(schema []byte) []byte { return buf[:n] } -type Sampler struct { - c *C.struct_llama_sampler -} - -func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler { - cGrammar := C.CString(grammar) - cRoot := C.CString("root") - defer C.free(unsafe.Pointer(cGrammar)) - defer C.free(unsafe.Pointer(cRoot)) - - sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)} - - return sampler -} - -func (s *Sampler) Accept(token int32) { - C.llama_sampler_accept(s.c, C.llama_token(token)) -} - type TokenData struct { - Id int32 + ID int32 Logit float32 } -func (s *Sampler) Apply(tokens []TokenData) { +type Grammar struct { + c *C.struct_llama_grammar + mu sync.Mutex +} + +func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar { + cGrammar := C.CString(grammar) + defer C.free(unsafe.Pointer(cGrammar)) + + cTokens := make([]C.uint32_t, len(vocabIds)) + for i, token := range vocabIds { + cTokens[i] = C.uint32_t(token) + } + + cPieces := make([]*C.char, len(vocabValues)) + for i, piece := range vocabValues { + cPieces[i] = C.CString(piece) + defer C.free(unsafe.Pointer(cPieces[i])) + } + + cEogTokens := make([]C.uint32_t, len(eogTokens)) + for i, token := range eogTokens { + cEogTokens[i] = C.uint32_t(token) + } + + g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens))) + if g == nil { + return nil + } + + return &Grammar{c: g} +} + +func (g *Grammar) Free() { + g.mu.Lock() + defer g.mu.Unlock() + if g.c != nil { + C.grammar_free(g.c) + g.c = nil + } +} + +func (g *Grammar) Apply(tokens []TokenData) { + g.mu.Lock() + defer g.mu.Unlock() + + if g.c == nil { + return + } + tds := make([]C.struct_llama_token_data, len(tokens)) for i, token := range tokens { tds[i] = C.struct_llama_token_data{ - id: C.int32_t(token.Id), + id: C.int32_t(token.ID), logit: C.float(token.Logit), p: C.float(0.0), } @@ -731,13 +744,24 @@ func (s *Sampler) Apply(tokens []TokenData) { selected: C.int64_t(-1), sorted: C.bool(false), } - var pinner runtime.Pinner pinner.Pin(&tds[0]) defer pinner.Unpin() - C.llama_sampler_apply(s.c, tda) + C.grammar_apply(g.c, tda) for i := range tokens { tokens[i].Logit = float32(tds[i].logit) } } + +func (g *Grammar) Accept(token int32) { + g.mu.Lock() + defer g.mu.Unlock() + + // Check if grammar was freed + if g.c == nil { + return + } + + C.grammar_accept(g.c, C.llama_token(token)) +} diff --git a/llama/patches/0021-add-ollama-vocab-for-grammar-support.patch b/llama/patches/0021-add-ollama-vocab-for-grammar-support.patch new file mode 100644 index 000000000..6193b755f --- /dev/null +++ b/llama/patches/0021-add-ollama-vocab-for-grammar-support.patch @@ -0,0 +1,207 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: ParthSareen +Date: Mon, 21 Apr 2025 13:30:31 -0700 +Subject: [PATCH] add ollama vocab for grammar support + +--- + src/llama-grammar.cpp | 49 ++++++++++++++++++++++++++++++++++++------ + src/llama-grammar.h | 14 ++++++++++++ + src/llama-sampling.cpp | 4 ++-- + 3 files changed, 58 insertions(+), 9 deletions(-) + +diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp +index 973b47ae..60d58236 100644 +--- a/src/llama-grammar.cpp ++++ b/src/llama-grammar.cpp +@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( + + struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, ++ const struct ollama_vocab * ollama_vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { +@@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl( + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar { + vocab, ++ ollama_vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, +@@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl( + + struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, ++ const struct ollama_vocab * ollama_vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, +@@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl( + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar { + vocab, ++ ollama_vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, +@@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { + struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { + auto * result = new llama_grammar { + grammar.vocab, ++ grammar.o_vocab, + grammar.rules, + grammar.stacks, + grammar.partial_utf8, +@@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra + } + + void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { +- GGML_ASSERT(grammar.vocab != nullptr); + + if (grammar.awaiting_trigger) { + return; +@@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ + + for (size_t i = 0; i < cur_p->size; ++i) { + const llama_token id = cur_p->data[i].id; +- const std::string & piece = grammar.vocab->token_to_piece(id); ++ const std::string piece = grammar.o_vocab ? ++ grammar.o_vocab->token_to_piece(id) : ++ grammar.vocab->token_to_piece(id); + +- if (grammar.vocab->is_eog(id)) { ++ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id); ++ ++ if (is_eog) { + if (!allow_eog) { + cur_p->data[i].logit = -INFINITY; + } +@@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ + } + + void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { +- GGML_ASSERT(grammar.vocab != nullptr); + +- const auto & piece = grammar.vocab->token_to_piece(token); ++ const std::string piece = grammar.o_vocab ? ++ grammar.o_vocab->token_to_piece(token) : ++ grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { +@@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token + } + } + +- if (grammar.vocab->is_eog(token)) { ++ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token); ++ if (is_eog) { + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + return; + } + } +- GGML_ABORT("fatal error"); ++ GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty"); + } + + llama_grammar_accept_str(grammar, piece); +@@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); + } + } ++ ++ ++const std::string & ollama_vocab::token_to_piece(const uint32_t token) const { ++ try { ++ return token_to_piece_map.at(token); ++ } catch (const std::out_of_range&) { ++ throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token)); ++ } ++} ++ ++void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) { ++ for (size_t i = 0; i < n_tokens; i++) { ++ token_to_piece_map[tokens[i]] = pieces[i]; ++ } ++} ++ ++bool ollama_vocab::is_eog(const uint32_t token) const { ++ return special_eog_ids.count(token) > 0; ++} ++ ++void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) { ++ for (size_t i = 0; i < n_tokens; i++) { ++ special_eog_ids.insert(tokens[i]); ++ } ++} +diff --git a/src/llama-grammar.h b/src/llama-grammar.h +index f8c291de..2a3a62db 100644 +--- a/src/llama-grammar.h ++++ b/src/llama-grammar.h +@@ -6,8 +6,19 @@ + #include + #include + #include ++#include + + struct llama_vocab; ++struct ollama_vocab { ++ std::map token_to_piece_map; ++ std::set special_eog_ids; ++ ++ const std::string & token_to_piece(const uint32_t token) const; ++ void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces); ++ void set_eog_tokens(const uint32_t* tokens, size_t n_tokens); ++ bool is_eog(const uint32_t token) const; ++ ++}; + + // grammar element type + enum llama_gretype { +@@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern { + struct llama_grammar { + // note: allow null vocab for testing (not great) + const llama_vocab * vocab; ++ const ollama_vocab * o_vocab; + + const llama_grammar_rules rules; // TODO: shared ptr + llama_grammar_stacks stacks; +@@ -141,12 +153,14 @@ struct llama_grammar { + // note: needed for tests (not great) + struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, ++ const struct ollama_vocab * ollama_vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + + struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, ++ const struct ollama_vocab * ollama_vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, +diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp +index d1497985..b1a9dca3 100644 +--- a/src/llama-sampling.cpp ++++ b/src/llama-sampling.cpp +@@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { + trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); + } + +- auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), ++ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); + +@@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( + /* .vocab = */ vocab, + /* .grammar_str = */ grammar_str, + /* .grammar_root = */ grammar_root, +- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), ++ /* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), + }; + if (!ctx->grammar) { + delete ctx; diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index a87e23e5c..6a025c906 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -5,6 +5,7 @@ #include "llama.h" #include "llama-model.h" #include "llama-model-loader.h" +#include "llama-grammar.h" struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) { try { @@ -86,3 +87,49 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) { void llama_free_vocab(struct llama_vocab * vocab) { delete vocab; } +struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens) { + try { + if (grammar == nullptr) { + LLAMA_LOG_ERROR("%s: null grammar input\n", __func__); + return nullptr; + } + + ollama_vocab *vocab = new ollama_vocab(); + vocab->set_eog_tokens(eog_tokens, n_eog_tokens); + vocab->add_token_pieces(tokens, n_tokens, pieces); + + struct llama_grammar *g = llama_grammar_init_impl(nullptr, vocab, grammar, "root", false, nullptr, 0, nullptr, 0); + if (g == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize grammar\n", __func__); + delete vocab; + return nullptr; + } + return g; + + } catch (const std::exception& e) { + LLAMA_LOG_ERROR("%s: exception during initialization: %s\n", __func__, e.what()); + return nullptr; + } +} + +void grammar_free(struct llama_grammar *g) { + if (g != nullptr) { + if (g->vocab != nullptr) { + delete g->vocab; + } + llama_grammar_free_impl(g); + } +} + +void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens) { + if (g == nullptr || tokens == nullptr) { + LLAMA_LOG_ERROR("%s: null grammar or tokens input\n", __func__); + return; + } + llama_grammar_apply_impl(*g, tokens); +} + + +void grammar_accept(struct llama_grammar *g, llama_token id) { + llama_grammar_accept_impl(*g, id); +} diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index 9be7c100e..a9e610ba2 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -35,8 +35,12 @@ extern "C" int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len); - struct llama_vocab * llama_load_vocab_from_file(const char * fname); - void llama_free_vocab(struct llama_vocab * vocab); + + struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens); + void grammar_free(struct llama_grammar *g); + void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens); + void grammar_accept(struct llama_grammar *g, llama_token id); + #ifdef __cplusplus } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index fca3896c3..f749fdcd2 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -26,6 +26,9 @@ type Model struct { // Implement MultimodalProcessor interface var _ model.MultimodalProcessor = (*Model)(nil) +// Implement TextProcessor interface +var _ model.TextProcessor = (*Model)(nil) + func New(c fs.Config) (model.Model, error) { textModel, err := NewTextModel(c) if err != nil { diff --git a/model/process_text.go b/model/process_text.go index f0fb77872..ce0b2d98a 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -32,6 +32,7 @@ type TextProcessor interface { Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool + Vocabulary() *Vocabulary } type Vocabulary struct { @@ -117,6 +118,8 @@ type BytePairEncoding struct { vocab *Vocabulary } +var _ TextProcessor = (*BytePairEncoding)(nil) + func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { return BytePairEncoding{ pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), @@ -124,6 +127,10 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { } } +func (bpe BytePairEncoding) Vocabulary() *Vocabulary { + return bpe.vocab +} + func (bpe BytePairEncoding) Is(id int32, special Special) bool { return bpe.vocab.Is(id, special) } diff --git a/model/process_text_spm.go b/model/process_text_spm.go index c6e08dbd4..446d5d604 100644 --- a/model/process_text_spm.go +++ b/model/process_text_spm.go @@ -17,6 +17,10 @@ type SentencePieceModel struct { var _ TextProcessor = (*SentencePieceModel)(nil) +func (spm SentencePieceModel) Vocabulary() *Vocabulary { + return spm.vocab +} + func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index fee052805..0ac543888 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -298,12 +298,6 @@ type Server struct { // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash - - // vocab is a llama.cpp vocab required for gammar-based - // constrained generation (json mode, structured outputs) - // TODO: this is temporary until Ollama sampling supports - // constrained generation - vocab *sample.Vocab } func (s *Server) allNil() bool { @@ -606,14 +600,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var grammar *sample.Grammar + var grammar *sample.GrammarSampler var err error if req.Grammar != "" { - grammar, err = sample.NewGrammar(s.vocab, req.Grammar) + grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar) if err != nil { http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError) return } + defer grammar.Free() } sampler := sample.NewSampler( @@ -789,8 +784,6 @@ func (s *Server) loadModel( panic(err) } - s.vocab = sample.NewVocab(mpath) - // TODO(jessegross): LoRA loading if lpath.String() != "" { panic("loras are not yet implemented") diff --git a/sample/samplers.go b/sample/samplers.go index ef8033691..f0846c8dd 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -5,9 +5,9 @@ import ( "math" "math/rand/v2" "slices" - "sync" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/model" ) // token represents information about a single token during sampling @@ -22,7 +22,7 @@ type Sampler struct { topP float32 minP float32 temperature float32 - grammar *Grammar + grammar *GrammarSampler } func (s *Sampler) Sample(logits []float32) (int32, error) { @@ -127,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) { } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 -func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { +func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler { var rng *rand.Rand if seed != -1 { // PCG requires two parameters: sequence and stream @@ -164,63 +164,43 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed } } -type Grammar struct { - vocab *Vocab - grammar string - sampler *llama.Sampler +type GrammarSampler struct { + grammar *llama.Grammar } -func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) { - v, err := vocab.Load() - if err != nil { - return nil, err +func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) { + vocabIds := make([]uint32, len(model.Vocabulary().Values)) + pieces := make([]string, len(model.Vocabulary().Values)) + for i := range model.Vocabulary().Values { + pieces[i], _ = model.Decode([]int32{int32(i)}) + vocabIds[i] = uint32(i) } - return &Grammar{ - vocab: vocab, - grammar: grammar, - sampler: llama.NewGrammarSampler(v, grammar), - }, nil + grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)}) + if grammar == nil { + return nil, errors.New("sample: failed to initialize grammar") + } + + return &GrammarSampler{grammar: grammar}, nil } -func (g *Grammar) Apply(tokens []token) { +func (g *GrammarSampler) Apply(tokens []token) { tds := make([]llama.TokenData, len(tokens)) for i, token := range tokens { - tds[i].Id = token.id + tds[i].ID = token.id tds[i].Logit = token.value } - - g.sampler.Apply(tds) + g.grammar.Apply(tds) for i := range tokens { tokens[i].value = tds[i].Logit } } -func (g *Grammar) Accept(token int32) { - g.sampler.Accept(token) +func (g *GrammarSampler) Accept(token int32) { + g.grammar.Accept(token) } -type Vocab struct { - once sync.Once - vocab *llama.Vocab - err error - path string -} - -func NewVocab(path string) *Vocab { - return &Vocab{path: path} -} - -// Load returns the lazily-loaded vocabulary -func (v *Vocab) Load() (*llama.Vocab, error) { - v.once.Do(func() { - vocab, err := llama.LoadVocabFromFile(v.path) - if err != nil { - v.err = err - return - } - v.vocab = vocab - }) - return v.vocab, v.err +func (g *GrammarSampler) Free() { + g.grammar.Free() } diff --git a/sample/samplers_test.go b/sample/samplers_test.go index d79dce474..cdcd55d43 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -1,9 +1,14 @@ package sample import ( + "encoding/json" "math" "math/rand/v2" + "os" + "path/filepath" "testing" + + "github.com/ollama/ollama/model" ) func TestWeighted(t *testing.T) { @@ -55,6 +60,97 @@ func TestWeighted(t *testing.T) { } } +func modelHelper(t testing.TB) model.BytePairEncoding { + t.Helper() + + f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + vocab := make(map[string]int32) + if err := json.NewDecoder(f).Decode(&vocab); err != nil { + t.Fatal(err) + } + + types := make([]uint32, len(vocab)) + tokens := make([]string, len(vocab)) + for token, id := range vocab { + tokens[id] = token + } + + merges := make([]string, 0, 1) + // Only need vocab for Grammar Test + return model.NewBytePairEncoding( + ``, + &model.Vocabulary{ + Values: tokens, + Types: types, + Merges: merges, + }, + ) +} + +func TestGrammar(t *testing.T) { + tokenizer := modelHelper(t) + + grammarJSON := ` + root ::= object + value ::= object | array | string | number | ("true" | "false" | "null") ws + object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + string ::= + "\"" ( + [^"\\\x7F\x00-\x1F] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + # Optional space: by convention, applied in this grammar after literal chars when allowed + ws ::= ([ \t\n] ws)? + ` + grammar, err := NewGrammarSampler(tokenizer, grammarJSON) + if err != nil { + t.Fatal(err) + } + defer grammar.Free() + + logits := make([]float32, len(tokenizer.Vocabulary().Values)) + for i := range logits { + logits[i] = rand.Float32() + } + tokens := make([]token, len(logits)) + for i := range tokens { + tokens[i].id = int32(i) + tokens[i].value = logits[i] + } + + grammar.Apply(tokens) + nonInfCount := 0 + infCount := 0 + for _, tok := range tokens { + if math.IsInf(float64(tok.value), -1) { + infCount++ + } else { + nonInfCount++ + } + } + if nonInfCount == 0 { + t.Error("expected at least one non -inf token after grammar application, got none") + } + if infCount == 0 { + t.Error("expected some -inf tokens after grammar application, got none") + } +} + func BenchmarkSample(b *testing.B) { samplers := map[string]Sampler{ "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy