llama: sync

This commit is contained in:
jmorganca 2025-04-25 16:38:05 -07:00
parent 4892872c18
commit f4ab82f0b4
3 changed files with 9 additions and 58 deletions

View file

@ -907,7 +907,6 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
@ -963,7 +962,6 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -977,7 +975,6 @@ struct llama_grammar * llama_grammar_init_impl(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
@ -1070,7 +1067,6 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -1093,7 +1089,6 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
auto * result = new llama_grammar {
grammar.vocab,
grammar.o_vocab,
grammar.rules,
grammar.stacks,
grammar.partial_utf8,
@ -1121,6 +1116,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
}
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
GGML_ASSERT(grammar.vocab != nullptr);
if (grammar.awaiting_trigger) {
return;
@ -1142,13 +1138,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
for (size_t i = 0; i < cur_p->size; ++i) {
const llama_token id = cur_p->data[i].id;
const std::string piece = grammar.o_vocab ?
grammar.o_vocab->token_to_piece(id) :
grammar.vocab->token_to_piece(id);
const std::string & piece = grammar.vocab->token_to_piece(id);
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
if (is_eog) {
if (grammar.vocab->is_eog(id)) {
if (!allow_eog) {
cur_p->data[i].logit = -INFINITY;
}
@ -1167,10 +1159,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
}
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr);
const std::string piece = grammar.o_vocab ?
grammar.o_vocab->token_to_piece(token) :
grammar.vocab->token_to_piece(token);
const auto & piece = grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
@ -1200,14 +1191,13 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
}
}
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
if (is_eog) {
if (grammar.vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
}
}
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
GGML_ABORT("fatal error");
}
llama_grammar_accept_str(grammar, piece);
@ -1227,28 +1217,3 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
try {
return token_to_piece_map.at(token);
} catch (const std::out_of_range&) {
throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
}
}
void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
for (size_t i = 0; i < n_tokens; i++) {
token_to_piece_map[tokens[i]] = pieces[i];
}
}
bool ollama_vocab::is_eog(const uint32_t token) const {
return special_eog_ids.count(token) > 0;
}
void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
for (size_t i = 0; i < n_tokens; i++) {
special_eog_ids.insert(tokens[i]);
}
}

View file

@ -6,19 +6,8 @@
#include <regex>
#include <string>
#include <vector>
#include <set>
struct llama_vocab;
struct ollama_vocab {
std::map<uint32_t, std::string> token_to_piece_map;
std::set<uint32_t> special_eog_ids;
const std::string & token_to_piece(const uint32_t token) const;
void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
bool is_eog(const uint32_t token) const;
};
// grammar element type
enum llama_gretype {
@ -125,7 +114,6 @@ struct llama_grammar_trigger_pattern {
struct llama_grammar {
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;
const ollama_vocab * o_vocab;
const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks;
@ -153,14 +141,12 @@ struct llama_grammar {
// note: needed for tests (not great)
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,

View file

@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
}
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
};
if (!ctx->grammar) {
delete ctx;