mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
llama: remove model loading for grammar (#10096)
This commit is contained in:
parent
40b10eee6d
commit
a53d744b01
13 changed files with 521 additions and 107 deletions
49
llama/llama.cpp/src/llama-grammar.cpp
vendored
49
llama/llama.cpp/src/llama-grammar.cpp
vendored
|
@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
|
const struct ollama_vocab * ollama_vocab,
|
||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index) {
|
size_t start_rule_index) {
|
||||||
|
@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar {
|
return new llama_grammar {
|
||||||
vocab,
|
vocab,
|
||||||
|
ollama_vocab,
|
||||||
std::move(vec_rules),
|
std::move(vec_rules),
|
||||||
std::move(stacks),
|
std::move(stacks),
|
||||||
/* .partial_utf8 = */ {},
|
/* .partial_utf8 = */ {},
|
||||||
|
@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
|
const struct ollama_vocab * ollama_vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root,
|
const char * grammar_root,
|
||||||
bool lazy,
|
bool lazy,
|
||||||
|
@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar {
|
return new llama_grammar {
|
||||||
vocab,
|
vocab,
|
||||||
|
ollama_vocab,
|
||||||
std::move(vec_rules),
|
std::move(vec_rules),
|
||||||
std::move(stacks),
|
std::move(stacks),
|
||||||
/* .partial_utf8 = */ {},
|
/* .partial_utf8 = */ {},
|
||||||
|
@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||||
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||||
auto * result = new llama_grammar {
|
auto * result = new llama_grammar {
|
||||||
grammar.vocab,
|
grammar.vocab,
|
||||||
|
grammar.o_vocab,
|
||||||
grammar.rules,
|
grammar.rules,
|
||||||
grammar.stacks,
|
grammar.stacks,
|
||||||
grammar.partial_utf8,
|
grammar.partial_utf8,
|
||||||
|
@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
||||||
GGML_ASSERT(grammar.vocab != nullptr);
|
|
||||||
|
|
||||||
if (grammar.awaiting_trigger) {
|
if (grammar.awaiting_trigger) {
|
||||||
return;
|
return;
|
||||||
|
@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||||
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
const llama_token id = cur_p->data[i].id;
|
const llama_token id = cur_p->data[i].id;
|
||||||
const std::string & piece = grammar.vocab->token_to_piece(id);
|
const std::string piece = grammar.o_vocab ?
|
||||||
|
grammar.o_vocab->token_to_piece(id) :
|
||||||
|
grammar.vocab->token_to_piece(id);
|
||||||
|
|
||||||
if (grammar.vocab->is_eog(id)) {
|
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
|
||||||
|
|
||||||
|
if (is_eog) {
|
||||||
if (!allow_eog) {
|
if (!allow_eog) {
|
||||||
cur_p->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
|
@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||||
GGML_ASSERT(grammar.vocab != nullptr);
|
|
||||||
|
|
||||||
const auto & piece = grammar.vocab->token_to_piece(token);
|
const std::string piece = grammar.o_vocab ?
|
||||||
|
grammar.o_vocab->token_to_piece(token) :
|
||||||
|
grammar.vocab->token_to_piece(token);
|
||||||
|
|
||||||
if (grammar.awaiting_trigger) {
|
if (grammar.awaiting_trigger) {
|
||||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||||
|
@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (grammar.vocab->is_eog(token)) {
|
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
|
||||||
|
if (is_eog) {
|
||||||
for (const auto & stack : grammar.stacks) {
|
for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_accept_str(grammar, piece);
|
llama_grammar_accept_str(grammar, piece);
|
||||||
|
@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
||||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
|
||||||
|
try {
|
||||||
|
return token_to_piece_map.at(token);
|
||||||
|
} catch (const std::out_of_range&) {
|
||||||
|
throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
|
||||||
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
|
token_to_piece_map[tokens[i]] = pieces[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ollama_vocab::is_eog(const uint32_t token) const {
|
||||||
|
return special_eog_ids.count(token) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
|
||||||
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
|
special_eog_ids.insert(tokens[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
14
llama/llama.cpp/src/llama-grammar.h
vendored
14
llama/llama.cpp/src/llama-grammar.h
vendored
|
@ -6,8 +6,19 @@
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
|
struct ollama_vocab {
|
||||||
|
std::map<uint32_t, std::string> token_to_piece_map;
|
||||||
|
std::set<uint32_t> special_eog_ids;
|
||||||
|
|
||||||
|
const std::string & token_to_piece(const uint32_t token) const;
|
||||||
|
void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
|
||||||
|
void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
|
||||||
|
bool is_eog(const uint32_t token) const;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
// grammar element type
|
// grammar element type
|
||||||
enum llama_gretype {
|
enum llama_gretype {
|
||||||
|
@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern {
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
// note: allow null vocab for testing (not great)
|
// note: allow null vocab for testing (not great)
|
||||||
const llama_vocab * vocab;
|
const llama_vocab * vocab;
|
||||||
|
const ollama_vocab * o_vocab;
|
||||||
|
|
||||||
const llama_grammar_rules rules; // TODO: shared ptr
|
const llama_grammar_rules rules; // TODO: shared ptr
|
||||||
llama_grammar_stacks stacks;
|
llama_grammar_stacks stacks;
|
||||||
|
@ -141,12 +153,14 @@ struct llama_grammar {
|
||||||
// note: needed for tests (not great)
|
// note: needed for tests (not great)
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
|
const struct ollama_vocab * ollama_vocab,
|
||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index);
|
size_t start_rule_index);
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
|
const struct ollama_vocab * ollama_vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root,
|
const char * grammar_root,
|
||||||
bool lazy,
|
bool lazy,
|
||||||
|
|
4
llama/llama.cpp/src/llama-sampling.cpp
vendored
4
llama/llama.cpp/src/llama-sampling.cpp
vendored
|
@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||||
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||||
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||||
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||||
|
|
||||||
|
@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||||
/* .vocab = */ vocab,
|
/* .vocab = */ vocab,
|
||||||
/* .grammar_str = */ grammar_str,
|
/* .grammar_str = */ grammar_str,
|
||||||
/* .grammar_root = */ grammar_root,
|
/* .grammar_root = */ grammar_root,
|
||||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
/* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||||
};
|
};
|
||||||
if (!ctx->grammar) {
|
if (!ctx->grammar) {
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
108
llama/llama.go
108
llama/llama.go
|
@ -35,6 +35,7 @@ import (
|
||||||
"runtime/cgo"
|
"runtime/cgo"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
||||||
|
@ -249,20 +250,6 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadVocabFromFile(path string) (*Vocab, error) {
|
|
||||||
mp := C.CString(path)
|
|
||||||
defer C.free(unsafe.Pointer(mp))
|
|
||||||
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
|
|
||||||
if v.c == nil {
|
|
||||||
return nil, fmt.Errorf("unable to load vocab: %s", path)
|
|
||||||
}
|
|
||||||
return &v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func FreeVocab(vocab *Vocab) {
|
|
||||||
C.llama_free_vocab(vocab.c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func FreeModel(model *Model) {
|
func FreeModel(model *Model) {
|
||||||
C.llama_model_free(model.c)
|
C.llama_model_free(model.c)
|
||||||
}
|
}
|
||||||
|
@ -311,10 +298,6 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vocab struct {
|
|
||||||
c *C.struct_llama_vocab
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Vocab() *C.struct_llama_vocab {
|
func (m *Model) Vocab() *C.struct_llama_vocab {
|
||||||
return C.llama_model_get_vocab(m.c)
|
return C.llama_model_get_vocab(m.c)
|
||||||
}
|
}
|
||||||
|
@ -692,35 +675,65 @@ func SchemaToGrammar(schema []byte) []byte {
|
||||||
return buf[:n]
|
return buf[:n]
|
||||||
}
|
}
|
||||||
|
|
||||||
type Sampler struct {
|
|
||||||
c *C.struct_llama_sampler
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
|
|
||||||
cGrammar := C.CString(grammar)
|
|
||||||
cRoot := C.CString("root")
|
|
||||||
defer C.free(unsafe.Pointer(cGrammar))
|
|
||||||
defer C.free(unsafe.Pointer(cRoot))
|
|
||||||
|
|
||||||
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
|
|
||||||
|
|
||||||
return sampler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sampler) Accept(token int32) {
|
|
||||||
C.llama_sampler_accept(s.c, C.llama_token(token))
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenData struct {
|
type TokenData struct {
|
||||||
Id int32
|
ID int32
|
||||||
Logit float32
|
Logit float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Apply(tokens []TokenData) {
|
type Grammar struct {
|
||||||
|
c *C.struct_llama_grammar
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar {
|
||||||
|
cGrammar := C.CString(grammar)
|
||||||
|
defer C.free(unsafe.Pointer(cGrammar))
|
||||||
|
|
||||||
|
cTokens := make([]C.uint32_t, len(vocabIds))
|
||||||
|
for i, token := range vocabIds {
|
||||||
|
cTokens[i] = C.uint32_t(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
cPieces := make([]*C.char, len(vocabValues))
|
||||||
|
for i, piece := range vocabValues {
|
||||||
|
cPieces[i] = C.CString(piece)
|
||||||
|
defer C.free(unsafe.Pointer(cPieces[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
cEogTokens := make([]C.uint32_t, len(eogTokens))
|
||||||
|
for i, token := range eogTokens {
|
||||||
|
cEogTokens[i] = C.uint32_t(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens)))
|
||||||
|
if g == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Grammar{c: g}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Grammar) Free() {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
if g.c != nil {
|
||||||
|
C.grammar_free(g.c)
|
||||||
|
g.c = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Grammar) Apply(tokens []TokenData) {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
|
if g.c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tds := make([]C.struct_llama_token_data, len(tokens))
|
tds := make([]C.struct_llama_token_data, len(tokens))
|
||||||
for i, token := range tokens {
|
for i, token := range tokens {
|
||||||
tds[i] = C.struct_llama_token_data{
|
tds[i] = C.struct_llama_token_data{
|
||||||
id: C.int32_t(token.Id),
|
id: C.int32_t(token.ID),
|
||||||
logit: C.float(token.Logit),
|
logit: C.float(token.Logit),
|
||||||
p: C.float(0.0),
|
p: C.float(0.0),
|
||||||
}
|
}
|
||||||
|
@ -731,13 +744,24 @@ func (s *Sampler) Apply(tokens []TokenData) {
|
||||||
selected: C.int64_t(-1),
|
selected: C.int64_t(-1),
|
||||||
sorted: C.bool(false),
|
sorted: C.bool(false),
|
||||||
}
|
}
|
||||||
|
|
||||||
var pinner runtime.Pinner
|
var pinner runtime.Pinner
|
||||||
pinner.Pin(&tds[0])
|
pinner.Pin(&tds[0])
|
||||||
defer pinner.Unpin()
|
defer pinner.Unpin()
|
||||||
|
|
||||||
C.llama_sampler_apply(s.c, tda)
|
C.grammar_apply(g.c, tda)
|
||||||
for i := range tokens {
|
for i := range tokens {
|
||||||
tokens[i].Logit = float32(tds[i].logit)
|
tokens[i].Logit = float32(tds[i].logit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *Grammar) Accept(token int32) {
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if grammar was freed
|
||||||
|
if g.c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
C.grammar_accept(g.c, C.llama_token(token))
|
||||||
|
}
|
||||||
|
|
207
llama/patches/0021-add-ollama-vocab-for-grammar-support.patch
Normal file
207
llama/patches/0021-add-ollama-vocab-for-grammar-support.patch
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: ParthSareen <parth.sareen@ollama.com>
|
||||||
|
Date: Mon, 21 Apr 2025 13:30:31 -0700
|
||||||
|
Subject: [PATCH] add ollama vocab for grammar support
|
||||||
|
|
||||||
|
---
|
||||||
|
src/llama-grammar.cpp | 49 ++++++++++++++++++++++++++++++++++++------
|
||||||
|
src/llama-grammar.h | 14 ++++++++++++
|
||||||
|
src/llama-sampling.cpp | 4 ++--
|
||||||
|
3 files changed, 58 insertions(+), 9 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
|
||||||
|
index 973b47ae..60d58236 100644
|
||||||
|
--- a/src/llama-grammar.cpp
|
||||||
|
+++ b/src/llama-grammar.cpp
|
||||||
|
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
+ const struct ollama_vocab * ollama_vocab,
|
||||||
|
const llama_grammar_element ** rules,
|
||||||
|
size_t n_rules,
|
||||||
|
size_t start_rule_index) {
|
||||||
|
@@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
|
return new llama_grammar {
|
||||||
|
vocab,
|
||||||
|
+ ollama_vocab,
|
||||||
|
std::move(vec_rules),
|
||||||
|
std::move(stacks),
|
||||||
|
/* .partial_utf8 = */ {},
|
||||||
|
@@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
+ const struct ollama_vocab * ollama_vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root,
|
||||||
|
bool lazy,
|
||||||
|
@@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
|
return new llama_grammar {
|
||||||
|
vocab,
|
||||||
|
+ ollama_vocab,
|
||||||
|
std::move(vec_rules),
|
||||||
|
std::move(stacks),
|
||||||
|
/* .partial_utf8 = */ {},
|
||||||
|
@@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||||
|
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||||
|
auto * result = new llama_grammar {
|
||||||
|
grammar.vocab,
|
||||||
|
+ grammar.o_vocab,
|
||||||
|
grammar.rules,
|
||||||
|
grammar.stacks,
|
||||||
|
grammar.partial_utf8,
|
||||||
|
@@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
||||||
|
- GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
|
|
||||||
|
if (grammar.awaiting_trigger) {
|
||||||
|
return;
|
||||||
|
@@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
const llama_token id = cur_p->data[i].id;
|
||||||
|
- const std::string & piece = grammar.vocab->token_to_piece(id);
|
||||||
|
+ const std::string piece = grammar.o_vocab ?
|
||||||
|
+ grammar.o_vocab->token_to_piece(id) :
|
||||||
|
+ grammar.vocab->token_to_piece(id);
|
||||||
|
|
||||||
|
- if (grammar.vocab->is_eog(id)) {
|
||||||
|
+ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
|
||||||
|
+
|
||||||
|
+ if (is_eog) {
|
||||||
|
if (!allow_eog) {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
@@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||||
|
- GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
|
|
||||||
|
- const auto & piece = grammar.vocab->token_to_piece(token);
|
||||||
|
+ const std::string piece = grammar.o_vocab ?
|
||||||
|
+ grammar.o_vocab->token_to_piece(token) :
|
||||||
|
+ grammar.vocab->token_to_piece(token);
|
||||||
|
|
||||||
|
if (grammar.awaiting_trigger) {
|
||||||
|
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||||
|
@@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- if (grammar.vocab->is_eog(token)) {
|
||||||
|
+ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
|
||||||
|
+ if (is_eog) {
|
||||||
|
for (const auto & stack : grammar.stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
- GGML_ABORT("fatal error");
|
||||||
|
+ GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_accept_str(grammar, piece);
|
||||||
|
@@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
||||||
|
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
|
||||||
|
+ try {
|
||||||
|
+ return token_to_piece_map.at(token);
|
||||||
|
+ } catch (const std::out_of_range&) {
|
||||||
|
+ throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
|
||||||
|
+ for (size_t i = 0; i < n_tokens; i++) {
|
||||||
|
+ token_to_piece_map[tokens[i]] = pieces[i];
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+bool ollama_vocab::is_eog(const uint32_t token) const {
|
||||||
|
+ return special_eog_ids.count(token) > 0;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
|
||||||
|
+ for (size_t i = 0; i < n_tokens; i++) {
|
||||||
|
+ special_eog_ids.insert(tokens[i]);
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
diff --git a/src/llama-grammar.h b/src/llama-grammar.h
|
||||||
|
index f8c291de..2a3a62db 100644
|
||||||
|
--- a/src/llama-grammar.h
|
||||||
|
+++ b/src/llama-grammar.h
|
||||||
|
@@ -6,8 +6,19 @@
|
||||||
|
#include <regex>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
+#include <set>
|
||||||
|
|
||||||
|
struct llama_vocab;
|
||||||
|
+struct ollama_vocab {
|
||||||
|
+ std::map<uint32_t, std::string> token_to_piece_map;
|
||||||
|
+ std::set<uint32_t> special_eog_ids;
|
||||||
|
+
|
||||||
|
+ const std::string & token_to_piece(const uint32_t token) const;
|
||||||
|
+ void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
|
||||||
|
+ void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
|
||||||
|
+ bool is_eog(const uint32_t token) const;
|
||||||
|
+
|
||||||
|
+};
|
||||||
|
|
||||||
|
// grammar element type
|
||||||
|
enum llama_gretype {
|
||||||
|
@@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern {
|
||||||
|
struct llama_grammar {
|
||||||
|
// note: allow null vocab for testing (not great)
|
||||||
|
const llama_vocab * vocab;
|
||||||
|
+ const ollama_vocab * o_vocab;
|
||||||
|
|
||||||
|
const llama_grammar_rules rules; // TODO: shared ptr
|
||||||
|
llama_grammar_stacks stacks;
|
||||||
|
@@ -141,12 +153,14 @@ struct llama_grammar {
|
||||||
|
// note: needed for tests (not great)
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
+ const struct ollama_vocab * ollama_vocab,
|
||||||
|
const llama_grammar_element ** rules,
|
||||||
|
size_t n_rules,
|
||||||
|
size_t start_rule_index);
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
+ const struct ollama_vocab * ollama_vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root,
|
||||||
|
bool lazy,
|
||||||
|
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
|
||||||
|
index d1497985..b1a9dca3 100644
|
||||||
|
--- a/src/llama-sampling.cpp
|
||||||
|
+++ b/src/llama-sampling.cpp
|
||||||
|
@@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||||
|
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
- auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||||
|
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||||
|
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||||
|
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||||
|
|
||||||
|
@@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||||
|
/* .vocab = */ vocab,
|
||||||
|
/* .grammar_str = */ grammar_str,
|
||||||
|
/* .grammar_root = */ grammar_root,
|
||||||
|
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||||
|
+ /* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||||
|
};
|
||||||
|
if (!ctx->grammar) {
|
||||||
|
delete ctx;
|
47
llama/sampling_ext.cpp
vendored
47
llama/sampling_ext.cpp
vendored
|
@ -5,6 +5,7 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
#include "llama-model-loader.h"
|
#include "llama-model-loader.h"
|
||||||
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
|
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
|
||||||
try {
|
try {
|
||||||
|
@ -86,3 +87,49 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
|
||||||
void llama_free_vocab(struct llama_vocab * vocab) {
|
void llama_free_vocab(struct llama_vocab * vocab) {
|
||||||
delete vocab;
|
delete vocab;
|
||||||
}
|
}
|
||||||
|
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens) {
|
||||||
|
try {
|
||||||
|
if (grammar == nullptr) {
|
||||||
|
LLAMA_LOG_ERROR("%s: null grammar input\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama_vocab *vocab = new ollama_vocab();
|
||||||
|
vocab->set_eog_tokens(eog_tokens, n_eog_tokens);
|
||||||
|
vocab->add_token_pieces(tokens, n_tokens, pieces);
|
||||||
|
|
||||||
|
struct llama_grammar *g = llama_grammar_init_impl(nullptr, vocab, grammar, "root", false, nullptr, 0, nullptr, 0);
|
||||||
|
if (g == nullptr) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to initialize grammar\n", __func__);
|
||||||
|
delete vocab;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return g;
|
||||||
|
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LLAMA_LOG_ERROR("%s: exception during initialization: %s\n", __func__, e.what());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void grammar_free(struct llama_grammar *g) {
|
||||||
|
if (g != nullptr) {
|
||||||
|
if (g->vocab != nullptr) {
|
||||||
|
delete g->vocab;
|
||||||
|
}
|
||||||
|
llama_grammar_free_impl(g);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens) {
|
||||||
|
if (g == nullptr || tokens == nullptr) {
|
||||||
|
LLAMA_LOG_ERROR("%s: null grammar or tokens input\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
llama_grammar_apply_impl(*g, tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void grammar_accept(struct llama_grammar *g, llama_token id) {
|
||||||
|
llama_grammar_accept_impl(*g, id);
|
||||||
|
}
|
||||||
|
|
8
llama/sampling_ext.h
vendored
8
llama/sampling_ext.h
vendored
|
@ -35,8 +35,12 @@ extern "C"
|
||||||
|
|
||||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||||
|
|
||||||
struct llama_vocab * llama_load_vocab_from_file(const char * fname);
|
|
||||||
void llama_free_vocab(struct llama_vocab * vocab);
|
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens);
|
||||||
|
void grammar_free(struct llama_grammar *g);
|
||||||
|
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens);
|
||||||
|
void grammar_accept(struct llama_grammar *g, llama_token id);
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,9 @@ type Model struct {
|
||||||
// Implement MultimodalProcessor interface
|
// Implement MultimodalProcessor interface
|
||||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||||
|
|
||||||
|
// Implement TextProcessor interface
|
||||||
|
var _ model.TextProcessor = (*Model)(nil)
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
textModel, err := NewTextModel(c)
|
textModel, err := NewTextModel(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -32,6 +32,7 @@ type TextProcessor interface {
|
||||||
Encode(s string, addSpecial bool) ([]int32, error)
|
Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
Is(int32, Special) bool
|
Is(int32, Special) bool
|
||||||
|
Vocabulary() *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vocabulary struct {
|
type Vocabulary struct {
|
||||||
|
@ -117,6 +118,8 @@ type BytePairEncoding struct {
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||||
|
|
||||||
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
||||||
return BytePairEncoding{
|
return BytePairEncoding{
|
||||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||||
|
@ -124,6 +127,10 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||||
|
return bpe.vocab
|
||||||
|
}
|
||||||
|
|
||||||
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||||
return bpe.vocab.Is(id, special)
|
return bpe.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,10 @@ type SentencePieceModel struct {
|
||||||
|
|
||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||||
|
|
||||||
|
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
||||||
|
return spm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
|
|
|
@ -298,12 +298,6 @@ type Server struct {
|
||||||
// multimodalHash generates hashes for comparing equality
|
// multimodalHash generates hashes for comparing equality
|
||||||
// of non-text data
|
// of non-text data
|
||||||
multimodalHash maphash.Hash
|
multimodalHash maphash.Hash
|
||||||
|
|
||||||
// vocab is a llama.cpp vocab required for gammar-based
|
|
||||||
// constrained generation (json mode, structured outputs)
|
|
||||||
// TODO: this is temporary until Ollama sampling supports
|
|
||||||
// constrained generation
|
|
||||||
vocab *sample.Vocab
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) allNil() bool {
|
func (s *Server) allNil() bool {
|
||||||
|
@ -606,14 +600,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var grammar *sample.Grammar
|
var grammar *sample.GrammarSampler
|
||||||
var err error
|
var err error
|
||||||
if req.Grammar != "" {
|
if req.Grammar != "" {
|
||||||
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
|
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer grammar.Free()
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(
|
||||||
|
@ -789,8 +784,6 @@ func (s *Server) loadModel(
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.vocab = sample.NewVocab(mpath)
|
|
||||||
|
|
||||||
// TODO(jessegross): LoRA loading
|
// TODO(jessegross): LoRA loading
|
||||||
if lpath.String() != "" {
|
if lpath.String() != "" {
|
||||||
panic("loras are not yet implemented")
|
panic("loras are not yet implemented")
|
||||||
|
|
|
@ -5,9 +5,9 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// token represents information about a single token during sampling
|
// token represents information about a single token during sampling
|
||||||
|
@ -22,7 +22,7 @@ type Sampler struct {
|
||||||
topP float32
|
topP float32
|
||||||
minP float32
|
minP float32
|
||||||
temperature float32
|
temperature float32
|
||||||
grammar *Grammar
|
grammar *GrammarSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
|
@ -127,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
|
@ -164,63 +164,43 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Grammar struct {
|
type GrammarSampler struct {
|
||||||
vocab *Vocab
|
grammar *llama.Grammar
|
||||||
grammar string
|
|
||||||
sampler *llama.Sampler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
|
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
|
||||||
v, err := vocab.Load()
|
vocabIds := make([]uint32, len(model.Vocabulary().Values))
|
||||||
if err != nil {
|
pieces := make([]string, len(model.Vocabulary().Values))
|
||||||
return nil, err
|
for i := range model.Vocabulary().Values {
|
||||||
|
pieces[i], _ = model.Decode([]int32{int32(i)})
|
||||||
|
vocabIds[i] = uint32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Grammar{
|
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
|
||||||
vocab: vocab,
|
if grammar == nil {
|
||||||
grammar: grammar,
|
return nil, errors.New("sample: failed to initialize grammar")
|
||||||
sampler: llama.NewGrammarSampler(v, grammar),
|
}
|
||||||
}, nil
|
|
||||||
|
return &GrammarSampler{grammar: grammar}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Grammar) Apply(tokens []token) {
|
func (g *GrammarSampler) Apply(tokens []token) {
|
||||||
tds := make([]llama.TokenData, len(tokens))
|
tds := make([]llama.TokenData, len(tokens))
|
||||||
for i, token := range tokens {
|
for i, token := range tokens {
|
||||||
tds[i].Id = token.id
|
tds[i].ID = token.id
|
||||||
tds[i].Logit = token.value
|
tds[i].Logit = token.value
|
||||||
}
|
}
|
||||||
|
g.grammar.Apply(tds)
|
||||||
g.sampler.Apply(tds)
|
|
||||||
|
|
||||||
for i := range tokens {
|
for i := range tokens {
|
||||||
tokens[i].value = tds[i].Logit
|
tokens[i].value = tds[i].Logit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Grammar) Accept(token int32) {
|
func (g *GrammarSampler) Accept(token int32) {
|
||||||
g.sampler.Accept(token)
|
g.grammar.Accept(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vocab struct {
|
func (g *GrammarSampler) Free() {
|
||||||
once sync.Once
|
g.grammar.Free()
|
||||||
vocab *llama.Vocab
|
|
||||||
err error
|
|
||||||
path string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVocab(path string) *Vocab {
|
|
||||||
return &Vocab{path: path}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load returns the lazily-loaded vocabulary
|
|
||||||
func (v *Vocab) Load() (*llama.Vocab, error) {
|
|
||||||
v.once.Do(func() {
|
|
||||||
vocab, err := llama.LoadVocabFromFile(v.path)
|
|
||||||
if err != nil {
|
|
||||||
v.err = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
v.vocab = vocab
|
|
||||||
})
|
|
||||||
return v.vocab, v.err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
|
@ -55,6 +60,97 @@ func TestWeighted(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
vocab := make(map[string]int32)
|
||||||
|
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
types := make([]uint32, len(vocab))
|
||||||
|
tokens := make([]string, len(vocab))
|
||||||
|
for token, id := range vocab {
|
||||||
|
tokens[id] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
merges := make([]string, 0, 1)
|
||||||
|
// Only need vocab for Grammar Test
|
||||||
|
return model.NewBytePairEncoding(
|
||||||
|
``,
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: tokens,
|
||||||
|
Types: types,
|
||||||
|
Merges: merges,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrammar(t *testing.T) {
|
||||||
|
tokenizer := modelHelper(t)
|
||||||
|
|
||||||
|
grammarJSON := `
|
||||||
|
root ::= object
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\\x7F\x00-\x1F] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
||||||
|
`
|
||||||
|
grammar, err := NewGrammarSampler(tokenizer, grammarJSON)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer grammar.Free()
|
||||||
|
|
||||||
|
logits := make([]float32, len(tokenizer.Vocabulary().Values))
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = rand.Float32()
|
||||||
|
}
|
||||||
|
tokens := make([]token, len(logits))
|
||||||
|
for i := range tokens {
|
||||||
|
tokens[i].id = int32(i)
|
||||||
|
tokens[i].value = logits[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
grammar.Apply(tokens)
|
||||||
|
nonInfCount := 0
|
||||||
|
infCount := 0
|
||||||
|
for _, tok := range tokens {
|
||||||
|
if math.IsInf(float64(tok.value), -1) {
|
||||||
|
infCount++
|
||||||
|
} else {
|
||||||
|
nonInfCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nonInfCount == 0 {
|
||||||
|
t.Error("expected at least one non -inf token after grammar application, got none")
|
||||||
|
}
|
||||||
|
if infCount == 0 {
|
||||||
|
t.Error("expected some -inf tokens after grammar application, got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
samplers := map[string]Sampler{
|
samplers := map[string]Sampler{
|
||||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue