mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
llama: remove model loading for grammar (#10096)
This commit is contained in:
parent
40b10eee6d
commit
a53d744b01
13 changed files with 521 additions and 107 deletions
49
llama/llama.cpp/src/llama-grammar.cpp
vendored
49
llama/llama.cpp/src/llama-grammar.cpp
vendored
|
@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const struct ollama_vocab * ollama_vocab,
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index) {
|
||||
|
@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
ollama_vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
|
@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const struct ollama_vocab * ollama_vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
|
@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
ollama_vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
|
@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
|||
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||
auto * result = new llama_grammar {
|
||||
grammar.vocab,
|
||||
grammar.o_vocab,
|
||||
grammar.rules,
|
||||
grammar.stacks,
|
||||
grammar.partial_utf8,
|
||||
|
@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
|||
}
|
||||
|
||||
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
return;
|
||||
|
@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
|||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const llama_token id = cur_p->data[i].id;
|
||||
const std::string & piece = grammar.vocab->token_to_piece(id);
|
||||
const std::string piece = grammar.o_vocab ?
|
||||
grammar.o_vocab->token_to_piece(id) :
|
||||
grammar.vocab->token_to_piece(id);
|
||||
|
||||
if (grammar.vocab->is_eog(id)) {
|
||||
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
|
||||
|
||||
if (is_eog) {
|
||||
if (!allow_eog) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
|
@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
|||
}
|
||||
|
||||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||
GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
const auto & piece = grammar.vocab->token_to_piece(token);
|
||||
const std::string piece = grammar.o_vocab ?
|
||||
grammar.o_vocab->token_to_piece(token) :
|
||||
grammar.vocab->token_to_piece(token);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
|
@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
}
|
||||
}
|
||||
|
||||
if (grammar.vocab->is_eog(token)) {
|
||||
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
|
||||
if (is_eog) {
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
|
||||
}
|
||||
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
|
@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
|||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
|
||||
try {
|
||||
return token_to_piece_map.at(token);
|
||||
} catch (const std::out_of_range&) {
|
||||
throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
|
||||
}
|
||||
}
|
||||
|
||||
void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
token_to_piece_map[tokens[i]] = pieces[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool ollama_vocab::is_eog(const uint32_t token) const {
|
||||
return special_eog_ids.count(token) > 0;
|
||||
}
|
||||
|
||||
void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
special_eog_ids.insert(tokens[i]);
|
||||
}
|
||||
}
|
||||
|
|
14
llama/llama.cpp/src/llama-grammar.h
vendored
14
llama/llama.cpp/src/llama-grammar.h
vendored
|
@ -6,8 +6,19 @@
|
|||
#include <regex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
struct llama_vocab;
|
||||
struct ollama_vocab {
|
||||
std::map<uint32_t, std::string> token_to_piece_map;
|
||||
std::set<uint32_t> special_eog_ids;
|
||||
|
||||
const std::string & token_to_piece(const uint32_t token) const;
|
||||
void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
|
||||
void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
|
||||
bool is_eog(const uint32_t token) const;
|
||||
|
||||
};
|
||||
|
||||
// grammar element type
|
||||
enum llama_gretype {
|
||||
|
@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern {
|
|||
struct llama_grammar {
|
||||
// note: allow null vocab for testing (not great)
|
||||
const llama_vocab * vocab;
|
||||
const ollama_vocab * o_vocab;
|
||||
|
||||
const llama_grammar_rules rules; // TODO: shared ptr
|
||||
llama_grammar_stacks stacks;
|
||||
|
@ -141,12 +153,14 @@ struct llama_grammar {
|
|||
// note: needed for tests (not great)
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const struct ollama_vocab * ollama_vocab,
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const struct ollama_vocab * ollama_vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
|
|
4
llama/llama.cpp/src/llama-sampling.cpp
vendored
4
llama/llama.cpp/src/llama-sampling.cpp
vendored
|
@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|||
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
||||
}
|
||||
|
||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||
|
||||
|
@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
};
|
||||
if (!ctx->grammar) {
|
||||
delete ctx;
|
||||
|
|
108
llama/llama.go
108
llama/llama.go
|
@ -35,6 +35,7 @@ import (
|
|||
"runtime/cgo"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
||||
|
@ -249,20 +250,6 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
|||
return &m, nil
|
||||
}
|
||||
|
||||
func LoadVocabFromFile(path string) (*Vocab, error) {
|
||||
mp := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(mp))
|
||||
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
|
||||
if v.c == nil {
|
||||
return nil, fmt.Errorf("unable to load vocab: %s", path)
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
func FreeVocab(vocab *Vocab) {
|
||||
C.llama_free_vocab(vocab.c)
|
||||
}
|
||||
|
||||
func FreeModel(model *Model) {
|
||||
C.llama_model_free(model.c)
|
||||
}
|
||||
|
@ -311,10 +298,6 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
|
|||
return nil
|
||||
}
|
||||
|
||||
type Vocab struct {
|
||||
c *C.struct_llama_vocab
|
||||
}
|
||||
|
||||
func (m *Model) Vocab() *C.struct_llama_vocab {
|
||||
return C.llama_model_get_vocab(m.c)
|
||||
}
|
||||
|
@ -692,35 +675,65 @@ func SchemaToGrammar(schema []byte) []byte {
|
|||
return buf[:n]
|
||||
}
|
||||
|
||||
type Sampler struct {
|
||||
c *C.struct_llama_sampler
|
||||
}
|
||||
|
||||
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
|
||||
cGrammar := C.CString(grammar)
|
||||
cRoot := C.CString("root")
|
||||
defer C.free(unsafe.Pointer(cGrammar))
|
||||
defer C.free(unsafe.Pointer(cRoot))
|
||||
|
||||
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
|
||||
|
||||
return sampler
|
||||
}
|
||||
|
||||
func (s *Sampler) Accept(token int32) {
|
||||
C.llama_sampler_accept(s.c, C.llama_token(token))
|
||||
}
|
||||
|
||||
type TokenData struct {
|
||||
Id int32
|
||||
ID int32
|
||||
Logit float32
|
||||
}
|
||||
|
||||
func (s *Sampler) Apply(tokens []TokenData) {
|
||||
type Grammar struct {
|
||||
c *C.struct_llama_grammar
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar {
|
||||
cGrammar := C.CString(grammar)
|
||||
defer C.free(unsafe.Pointer(cGrammar))
|
||||
|
||||
cTokens := make([]C.uint32_t, len(vocabIds))
|
||||
for i, token := range vocabIds {
|
||||
cTokens[i] = C.uint32_t(token)
|
||||
}
|
||||
|
||||
cPieces := make([]*C.char, len(vocabValues))
|
||||
for i, piece := range vocabValues {
|
||||
cPieces[i] = C.CString(piece)
|
||||
defer C.free(unsafe.Pointer(cPieces[i]))
|
||||
}
|
||||
|
||||
cEogTokens := make([]C.uint32_t, len(eogTokens))
|
||||
for i, token := range eogTokens {
|
||||
cEogTokens[i] = C.uint32_t(token)
|
||||
}
|
||||
|
||||
g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens)))
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Grammar{c: g}
|
||||
}
|
||||
|
||||
func (g *Grammar) Free() {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if g.c != nil {
|
||||
C.grammar_free(g.c)
|
||||
g.c = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Grammar) Apply(tokens []TokenData) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if g.c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tds := make([]C.struct_llama_token_data, len(tokens))
|
||||
for i, token := range tokens {
|
||||
tds[i] = C.struct_llama_token_data{
|
||||
id: C.int32_t(token.Id),
|
||||
id: C.int32_t(token.ID),
|
||||
logit: C.float(token.Logit),
|
||||
p: C.float(0.0),
|
||||
}
|
||||
|
@ -731,13 +744,24 @@ func (s *Sampler) Apply(tokens []TokenData) {
|
|||
selected: C.int64_t(-1),
|
||||
sorted: C.bool(false),
|
||||
}
|
||||
|
||||
var pinner runtime.Pinner
|
||||
pinner.Pin(&tds[0])
|
||||
defer pinner.Unpin()
|
||||
|
||||
C.llama_sampler_apply(s.c, tda)
|
||||
C.grammar_apply(g.c, tda)
|
||||
for i := range tokens {
|
||||
tokens[i].Logit = float32(tds[i].logit)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Grammar) Accept(token int32) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
// Check if grammar was freed
|
||||
if g.c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
C.grammar_accept(g.c, C.llama_token(token))
|
||||
}
|
||||
|
|
207
llama/patches/0021-add-ollama-vocab-for-grammar-support.patch
Normal file
207
llama/patches/0021-add-ollama-vocab-for-grammar-support.patch
Normal file
|
@ -0,0 +1,207 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: ParthSareen <parth.sareen@ollama.com>
|
||||
Date: Mon, 21 Apr 2025 13:30:31 -0700
|
||||
Subject: [PATCH] add ollama vocab for grammar support
|
||||
|
||||
---
|
||||
src/llama-grammar.cpp | 49 ++++++++++++++++++++++++++++++++++++------
|
||||
src/llama-grammar.h | 14 ++++++++++++
|
||||
src/llama-sampling.cpp | 4 ++--
|
||||
3 files changed, 58 insertions(+), 9 deletions(-)
|
||||
|
||||
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
|
||||
index 973b47ae..60d58236 100644
|
||||
--- a/src/llama-grammar.cpp
|
||||
+++ b/src/llama-grammar.cpp
|
||||
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
+ const struct ollama_vocab * ollama_vocab,
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index) {
|
||||
@@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
+ ollama_vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
@@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
+ const struct ollama_vocab * ollama_vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
@@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
+ ollama_vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
@@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||
auto * result = new llama_grammar {
|
||||
grammar.vocab,
|
||||
+ grammar.o_vocab,
|
||||
grammar.rules,
|
||||
grammar.stacks,
|
||||
grammar.partial_utf8,
|
||||
@@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||
}
|
||||
|
||||
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
||||
- GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
return;
|
||||
@@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const llama_token id = cur_p->data[i].id;
|
||||
- const std::string & piece = grammar.vocab->token_to_piece(id);
|
||||
+ const std::string piece = grammar.o_vocab ?
|
||||
+ grammar.o_vocab->token_to_piece(id) :
|
||||
+ grammar.vocab->token_to_piece(id);
|
||||
|
||||
- if (grammar.vocab->is_eog(id)) {
|
||||
+ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
|
||||
+
|
||||
+ if (is_eog) {
|
||||
if (!allow_eog) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
@@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||
}
|
||||
|
||||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||
- GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
- const auto & piece = grammar.vocab->token_to_piece(token);
|
||||
+ const std::string piece = grammar.o_vocab ?
|
||||
+ grammar.o_vocab->token_to_piece(token) :
|
||||
+ grammar.vocab->token_to_piece(token);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
@@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
}
|
||||
}
|
||||
|
||||
- if (grammar.vocab->is_eog(token)) {
|
||||
+ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
|
||||
+ if (is_eog) {
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
- GGML_ABORT("fatal error");
|
||||
+ GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
|
||||
}
|
||||
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
@@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||
}
|
||||
}
|
||||
+
|
||||
+
|
||||
+const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
|
||||
+ try {
|
||||
+ return token_to_piece_map.at(token);
|
||||
+ } catch (const std::out_of_range&) {
|
||||
+ throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
|
||||
+ for (size_t i = 0; i < n_tokens; i++) {
|
||||
+ token_to_piece_map[tokens[i]] = pieces[i];
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+bool ollama_vocab::is_eog(const uint32_t token) const {
|
||||
+ return special_eog_ids.count(token) > 0;
|
||||
+}
|
||||
+
|
||||
+void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
|
||||
+ for (size_t i = 0; i < n_tokens; i++) {
|
||||
+ special_eog_ids.insert(tokens[i]);
|
||||
+ }
|
||||
+}
|
||||
diff --git a/src/llama-grammar.h b/src/llama-grammar.h
|
||||
index f8c291de..2a3a62db 100644
|
||||
--- a/src/llama-grammar.h
|
||||
+++ b/src/llama-grammar.h
|
||||
@@ -6,8 +6,19 @@
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
+#include <set>
|
||||
|
||||
struct llama_vocab;
|
||||
+struct ollama_vocab {
|
||||
+ std::map<uint32_t, std::string> token_to_piece_map;
|
||||
+ std::set<uint32_t> special_eog_ids;
|
||||
+
|
||||
+ const std::string & token_to_piece(const uint32_t token) const;
|
||||
+ void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
|
||||
+ void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
|
||||
+ bool is_eog(const uint32_t token) const;
|
||||
+
|
||||
+};
|
||||
|
||||
// grammar element type
|
||||
enum llama_gretype {
|
||||
@@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern {
|
||||
struct llama_grammar {
|
||||
// note: allow null vocab for testing (not great)
|
||||
const llama_vocab * vocab;
|
||||
+ const ollama_vocab * o_vocab;
|
||||
|
||||
const llama_grammar_rules rules; // TODO: shared ptr
|
||||
llama_grammar_stacks stacks;
|
||||
@@ -141,12 +153,14 @@ struct llama_grammar {
|
||||
// note: needed for tests (not great)
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
+ const struct ollama_vocab * ollama_vocab,
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
+ const struct ollama_vocab * ollama_vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
|
||||
index d1497985..b1a9dca3 100644
|
||||
--- a/src/llama-sampling.cpp
|
||||
+++ b/src/llama-sampling.cpp
|
||||
@@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
||||
}
|
||||
|
||||
- auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||
|
||||
@@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
+ /* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
};
|
||||
if (!ctx->grammar) {
|
||||
delete ctx;
|
47
llama/sampling_ext.cpp
vendored
47
llama/sampling_ext.cpp
vendored
|
@ -5,6 +5,7 @@
|
|||
#include "llama.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
|
||||
try {
|
||||
|
@ -86,3 +87,49 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
|
|||
void llama_free_vocab(struct llama_vocab * vocab) {
|
||||
delete vocab;
|
||||
}
|
||||
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens) {
|
||||
try {
|
||||
if (grammar == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: null grammar input\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ollama_vocab *vocab = new ollama_vocab();
|
||||
vocab->set_eog_tokens(eog_tokens, n_eog_tokens);
|
||||
vocab->add_token_pieces(tokens, n_tokens, pieces);
|
||||
|
||||
struct llama_grammar *g = llama_grammar_init_impl(nullptr, vocab, grammar, "root", false, nullptr, 0, nullptr, 0);
|
||||
if (g == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize grammar\n", __func__);
|
||||
delete vocab;
|
||||
return nullptr;
|
||||
}
|
||||
return g;
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
LLAMA_LOG_ERROR("%s: exception during initialization: %s\n", __func__, e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void grammar_free(struct llama_grammar *g) {
|
||||
if (g != nullptr) {
|
||||
if (g->vocab != nullptr) {
|
||||
delete g->vocab;
|
||||
}
|
||||
llama_grammar_free_impl(g);
|
||||
}
|
||||
}
|
||||
|
||||
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens) {
|
||||
if (g == nullptr || tokens == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: null grammar or tokens input\n", __func__);
|
||||
return;
|
||||
}
|
||||
llama_grammar_apply_impl(*g, tokens);
|
||||
}
|
||||
|
||||
|
||||
void grammar_accept(struct llama_grammar *g, llama_token id) {
|
||||
llama_grammar_accept_impl(*g, id);
|
||||
}
|
||||
|
|
8
llama/sampling_ext.h
vendored
8
llama/sampling_ext.h
vendored
|
@ -35,8 +35,12 @@ extern "C"
|
|||
|
||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||
|
||||
struct llama_vocab * llama_load_vocab_from_file(const char * fname);
|
||||
void llama_free_vocab(struct llama_vocab * vocab);
|
||||
|
||||
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens);
|
||||
void grammar_free(struct llama_grammar *g);
|
||||
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens);
|
||||
void grammar_accept(struct llama_grammar *g, llama_token id);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -26,6 +26,9 @@ type Model struct {
|
|||
// Implement MultimodalProcessor interface
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
// Implement TextProcessor interface
|
||||
var _ model.TextProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
textModel, err := NewTextModel(c)
|
||||
if err != nil {
|
||||
|
|
|
@ -32,6 +32,7 @@ type TextProcessor interface {
|
|||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(int32, Special) bool
|
||||
Vocabulary() *Vocabulary
|
||||
}
|
||||
|
||||
type Vocabulary struct {
|
||||
|
@ -117,6 +118,8 @@ type BytePairEncoding struct {
|
|||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
||||
return BytePairEncoding{
|
||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||
|
@ -124,6 +127,10 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
|||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
return bpe.vocab
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
|
|
@ -17,6 +17,10 @@ type SentencePieceModel struct {
|
|||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
|
||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
|
|
|
@ -298,12 +298,6 @@ type Server struct {
|
|||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
|
||||
// vocab is a llama.cpp vocab required for gammar-based
|
||||
// constrained generation (json mode, structured outputs)
|
||||
// TODO: this is temporary until Ollama sampling supports
|
||||
// constrained generation
|
||||
vocab *sample.Vocab
|
||||
}
|
||||
|
||||
func (s *Server) allNil() bool {
|
||||
|
@ -606,14 +600,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var grammar *sample.Grammar
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer grammar.Free()
|
||||
}
|
||||
|
||||
sampler := sample.NewSampler(
|
||||
|
@ -789,8 +784,6 @@ func (s *Server) loadModel(
|
|||
panic(err)
|
||||
}
|
||||
|
||||
s.vocab = sample.NewVocab(mpath)
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
panic("loras are not yet implemented")
|
||||
|
|
|
@ -5,9 +5,9 @@ import (
|
|||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
|
@ -22,7 +22,7 @@ type Sampler struct {
|
|||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
grammar *Grammar
|
||||
grammar *GrammarSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
|
@ -127,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
|
@ -164,63 +164,43 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|||
}
|
||||
}
|
||||
|
||||
type Grammar struct {
|
||||
vocab *Vocab
|
||||
grammar string
|
||||
sampler *llama.Sampler
|
||||
type GrammarSampler struct {
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
|
||||
v, err := vocab.Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||
pieces := make([]string, len(model.Vocabulary().Values))
|
||||
for i := range model.Vocabulary().Values {
|
||||
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||
vocabIds[i] = uint32(i)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
vocab: vocab,
|
||||
grammar: grammar,
|
||||
sampler: llama.NewGrammarSampler(v, grammar),
|
||||
}, nil
|
||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
|
||||
if grammar == nil {
|
||||
return nil, errors.New("sample: failed to initialize grammar")
|
||||
}
|
||||
|
||||
return &GrammarSampler{grammar: grammar}, nil
|
||||
}
|
||||
|
||||
func (g *Grammar) Apply(tokens []token) {
|
||||
func (g *GrammarSampler) Apply(tokens []token) {
|
||||
tds := make([]llama.TokenData, len(tokens))
|
||||
for i, token := range tokens {
|
||||
tds[i].Id = token.id
|
||||
tds[i].ID = token.id
|
||||
tds[i].Logit = token.value
|
||||
}
|
||||
|
||||
g.sampler.Apply(tds)
|
||||
g.grammar.Apply(tds)
|
||||
|
||||
for i := range tokens {
|
||||
tokens[i].value = tds[i].Logit
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Grammar) Accept(token int32) {
|
||||
g.sampler.Accept(token)
|
||||
func (g *GrammarSampler) Accept(token int32) {
|
||||
g.grammar.Accept(token)
|
||||
}
|
||||
|
||||
type Vocab struct {
|
||||
once sync.Once
|
||||
vocab *llama.Vocab
|
||||
err error
|
||||
path string
|
||||
}
|
||||
|
||||
func NewVocab(path string) *Vocab {
|
||||
return &Vocab{path: path}
|
||||
}
|
||||
|
||||
// Load returns the lazily-loaded vocabulary
|
||||
func (v *Vocab) Load() (*llama.Vocab, error) {
|
||||
v.once.Do(func() {
|
||||
vocab, err := llama.LoadVocabFromFile(v.path)
|
||||
if err != nil {
|
||||
v.err = err
|
||||
return
|
||||
}
|
||||
v.vocab = vocab
|
||||
})
|
||||
return v.vocab, v.err
|
||||
func (g *GrammarSampler) Free() {
|
||||
g.grammar.Free()
|
||||
}
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
package sample
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
|
@ -55,6 +60,97 @@ func TestWeighted(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
types := make([]uint32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
}
|
||||
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
``,
|
||||
&model.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestGrammar(t *testing.T) {
|
||||
tokenizer := modelHelper(t)
|
||||
|
||||
grammarJSON := `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
grammar, err := NewGrammarSampler(tokenizer, grammarJSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer grammar.Free()
|
||||
|
||||
logits := make([]float32, len(tokenizer.Vocabulary().Values))
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float32()
|
||||
}
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range tokens {
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = logits[i]
|
||||
}
|
||||
|
||||
grammar.Apply(tokens)
|
||||
nonInfCount := 0
|
||||
infCount := 0
|
||||
for _, tok := range tokens {
|
||||
if math.IsInf(float64(tok.value), -1) {
|
||||
infCount++
|
||||
} else {
|
||||
nonInfCount++
|
||||
}
|
||||
}
|
||||
if nonInfCount == 0 {
|
||||
t.Error("expected at least one non -inf token after grammar application, got none")
|
||||
}
|
||||
if infCount == 0 {
|
||||
t.Error("expected some -inf tokens after grammar application, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue