llama: remove model loading for grammar (#10096)

This commit is contained in:
Parth Sareen 2025-04-24 11:51:19 -07:00 committed by GitHub
parent 40b10eee6d
commit a53d744b01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 521 additions and 107 deletions

View file

@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
auto * result = new llama_grammar {
grammar.vocab,
grammar.o_vocab,
grammar.rules,
grammar.stacks,
grammar.partial_utf8,
@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
}
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
GGML_ASSERT(grammar.vocab != nullptr);
if (grammar.awaiting_trigger) {
return;
@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
for (size_t i = 0; i < cur_p->size; ++i) {
const llama_token id = cur_p->data[i].id;
const std::string & piece = grammar.vocab->token_to_piece(id);
const std::string piece = grammar.o_vocab ?
grammar.o_vocab->token_to_piece(id) :
grammar.vocab->token_to_piece(id);
if (grammar.vocab->is_eog(id)) {
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
if (is_eog) {
if (!allow_eog) {
cur_p->data[i].logit = -INFINITY;
}
@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
}
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr);
const auto & piece = grammar.vocab->token_to_piece(token);
const std::string piece = grammar.o_vocab ?
grammar.o_vocab->token_to_piece(token) :
grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
}
}
if (grammar.vocab->is_eog(token)) {
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
if (is_eog) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
}
}
GGML_ABORT("fatal error");
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
}
llama_grammar_accept_str(grammar, piece);
@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
try {
return token_to_piece_map.at(token);
} catch (const std::out_of_range&) {
throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
}
}
void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
for (size_t i = 0; i < n_tokens; i++) {
token_to_piece_map[tokens[i]] = pieces[i];
}
}
bool ollama_vocab::is_eog(const uint32_t token) const {
return special_eog_ids.count(token) > 0;
}
void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
for (size_t i = 0; i < n_tokens; i++) {
special_eog_ids.insert(tokens[i]);
}
}

View file

@ -6,8 +6,19 @@
#include <regex>
#include <string>
#include <vector>
#include <set>
struct llama_vocab;
struct ollama_vocab {
std::map<uint32_t, std::string> token_to_piece_map;
std::set<uint32_t> special_eog_ids;
const std::string & token_to_piece(const uint32_t token) const;
void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
bool is_eog(const uint32_t token) const;
};
// grammar element type
enum llama_gretype {
@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern {
struct llama_grammar {
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;
const ollama_vocab * o_vocab;
const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks;
@ -141,12 +153,14 @@ struct llama_grammar {
// note: needed for tests (not great)
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,

View file

@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
}
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
/* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
};
if (!ctx->grammar) {
delete ctx;

View file

@ -35,6 +35,7 @@ import (
"runtime/cgo"
"slices"
"strings"
"sync"
"unsafe"
_ "github.com/ollama/ollama/llama/llama.cpp/common"
@ -249,20 +250,6 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
return &m, nil
}
func LoadVocabFromFile(path string) (*Vocab, error) {
mp := C.CString(path)
defer C.free(unsafe.Pointer(mp))
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
if v.c == nil {
return nil, fmt.Errorf("unable to load vocab: %s", path)
}
return &v, nil
}
func FreeVocab(vocab *Vocab) {
C.llama_free_vocab(vocab.c)
}
func FreeModel(model *Model) {
C.llama_model_free(model.c)
}
@ -311,10 +298,6 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
return nil
}
type Vocab struct {
c *C.struct_llama_vocab
}
func (m *Model) Vocab() *C.struct_llama_vocab {
return C.llama_model_get_vocab(m.c)
}
@ -692,35 +675,65 @@ func SchemaToGrammar(schema []byte) []byte {
return buf[:n]
}
type Sampler struct {
c *C.struct_llama_sampler
}
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
cGrammar := C.CString(grammar)
cRoot := C.CString("root")
defer C.free(unsafe.Pointer(cGrammar))
defer C.free(unsafe.Pointer(cRoot))
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
return sampler
}
func (s *Sampler) Accept(token int32) {
C.llama_sampler_accept(s.c, C.llama_token(token))
}
type TokenData struct {
Id int32
ID int32
Logit float32
}
func (s *Sampler) Apply(tokens []TokenData) {
type Grammar struct {
c *C.struct_llama_grammar
mu sync.Mutex
}
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar {
cGrammar := C.CString(grammar)
defer C.free(unsafe.Pointer(cGrammar))
cTokens := make([]C.uint32_t, len(vocabIds))
for i, token := range vocabIds {
cTokens[i] = C.uint32_t(token)
}
cPieces := make([]*C.char, len(vocabValues))
for i, piece := range vocabValues {
cPieces[i] = C.CString(piece)
defer C.free(unsafe.Pointer(cPieces[i]))
}
cEogTokens := make([]C.uint32_t, len(eogTokens))
for i, token := range eogTokens {
cEogTokens[i] = C.uint32_t(token)
}
g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens)))
if g == nil {
return nil
}
return &Grammar{c: g}
}
func (g *Grammar) Free() {
g.mu.Lock()
defer g.mu.Unlock()
if g.c != nil {
C.grammar_free(g.c)
g.c = nil
}
}
func (g *Grammar) Apply(tokens []TokenData) {
g.mu.Lock()
defer g.mu.Unlock()
if g.c == nil {
return
}
tds := make([]C.struct_llama_token_data, len(tokens))
for i, token := range tokens {
tds[i] = C.struct_llama_token_data{
id: C.int32_t(token.Id),
id: C.int32_t(token.ID),
logit: C.float(token.Logit),
p: C.float(0.0),
}
@ -731,13 +744,24 @@ func (s *Sampler) Apply(tokens []TokenData) {
selected: C.int64_t(-1),
sorted: C.bool(false),
}
var pinner runtime.Pinner
pinner.Pin(&tds[0])
defer pinner.Unpin()
C.llama_sampler_apply(s.c, tda)
C.grammar_apply(g.c, tda)
for i := range tokens {
tokens[i].Logit = float32(tds[i].logit)
}
}
func (g *Grammar) Accept(token int32) {
g.mu.Lock()
defer g.mu.Unlock()
// Check if grammar was freed
if g.c == nil {
return
}
C.grammar_accept(g.c, C.llama_token(token))
}

View file

@ -0,0 +1,207 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: ParthSareen <parth.sareen@ollama.com>
Date: Mon, 21 Apr 2025 13:30:31 -0700
Subject: [PATCH] add ollama vocab for grammar support
---
src/llama-grammar.cpp | 49 ++++++++++++++++++++++++++++++++++++------
src/llama-grammar.h | 14 ++++++++++++
src/llama-sampling.cpp | 4 ++--
3 files changed, 58 insertions(+), 9 deletions(-)
diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
index 973b47ae..60d58236 100644
--- a/src/llama-grammar.cpp
+++ b/src/llama-grammar.cpp
@@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
+ const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
@@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
+ ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
+ const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
@@ -1067,6 +1070,7 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
+ ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@@ -1089,6 +1093,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
auto * result = new llama_grammar {
grammar.vocab,
+ grammar.o_vocab,
grammar.rules,
grammar.stacks,
grammar.partial_utf8,
@@ -1116,7 +1121,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
}
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
- GGML_ASSERT(grammar.vocab != nullptr);
if (grammar.awaiting_trigger) {
return;
@@ -1138,9 +1142,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
for (size_t i = 0; i < cur_p->size; ++i) {
const llama_token id = cur_p->data[i].id;
- const std::string & piece = grammar.vocab->token_to_piece(id);
+ const std::string piece = grammar.o_vocab ?
+ grammar.o_vocab->token_to_piece(id) :
+ grammar.vocab->token_to_piece(id);
- if (grammar.vocab->is_eog(id)) {
+ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
+
+ if (is_eog) {
if (!allow_eog) {
cur_p->data[i].logit = -INFINITY;
}
@@ -1159,9 +1167,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
}
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
- GGML_ASSERT(grammar.vocab != nullptr);
- const auto & piece = grammar.vocab->token_to_piece(token);
+ const std::string piece = grammar.o_vocab ?
+ grammar.o_vocab->token_to_piece(token) :
+ grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
@@ -1191,13 +1200,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
}
}
- if (grammar.vocab->is_eog(token)) {
+ const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
+ if (is_eog) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
}
}
- GGML_ABORT("fatal error");
+ GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
}
llama_grammar_accept_str(grammar, piece);
@@ -1217,3 +1227,28 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}
+
+
+const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
+ try {
+ return token_to_piece_map.at(token);
+ } catch (const std::out_of_range&) {
+ throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
+ }
+}
+
+void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
+ for (size_t i = 0; i < n_tokens; i++) {
+ token_to_piece_map[tokens[i]] = pieces[i];
+ }
+}
+
+bool ollama_vocab::is_eog(const uint32_t token) const {
+ return special_eog_ids.count(token) > 0;
+}
+
+void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
+ for (size_t i = 0; i < n_tokens; i++) {
+ special_eog_ids.insert(tokens[i]);
+ }
+}
diff --git a/src/llama-grammar.h b/src/llama-grammar.h
index f8c291de..2a3a62db 100644
--- a/src/llama-grammar.h
+++ b/src/llama-grammar.h
@@ -6,8 +6,19 @@
#include <regex>
#include <string>
#include <vector>
+#include <set>
struct llama_vocab;
+struct ollama_vocab {
+ std::map<uint32_t, std::string> token_to_piece_map;
+ std::set<uint32_t> special_eog_ids;
+
+ const std::string & token_to_piece(const uint32_t token) const;
+ void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
+ void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
+ bool is_eog(const uint32_t token) const;
+
+};
// grammar element type
enum llama_gretype {
@@ -114,6 +125,7 @@ struct llama_grammar_trigger_pattern {
struct llama_grammar {
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;
+ const ollama_vocab * o_vocab;
const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks;
@@ -141,12 +153,14 @@ struct llama_grammar {
// note: needed for tests (not great)
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
+ const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
+ const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index d1497985..b1a9dca3 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -1465,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
}
- auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
@@ -1547,7 +1547,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
+ /* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
};
if (!ctx->grammar) {
delete ctx;

View file

@ -5,6 +5,7 @@
#include "llama.h"
#include "llama-model.h"
#include "llama-model-loader.h"
#include "llama-grammar.h"
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
try {
@ -86,3 +87,49 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
void llama_free_vocab(struct llama_vocab * vocab) {
delete vocab;
}
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens) {
try {
if (grammar == nullptr) {
LLAMA_LOG_ERROR("%s: null grammar input\n", __func__);
return nullptr;
}
ollama_vocab *vocab = new ollama_vocab();
vocab->set_eog_tokens(eog_tokens, n_eog_tokens);
vocab->add_token_pieces(tokens, n_tokens, pieces);
struct llama_grammar *g = llama_grammar_init_impl(nullptr, vocab, grammar, "root", false, nullptr, 0, nullptr, 0);
if (g == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize grammar\n", __func__);
delete vocab;
return nullptr;
}
return g;
} catch (const std::exception& e) {
LLAMA_LOG_ERROR("%s: exception during initialization: %s\n", __func__, e.what());
return nullptr;
}
}
void grammar_free(struct llama_grammar *g) {
if (g != nullptr) {
if (g->vocab != nullptr) {
delete g->vocab;
}
llama_grammar_free_impl(g);
}
}
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens) {
if (g == nullptr || tokens == nullptr) {
LLAMA_LOG_ERROR("%s: null grammar or tokens input\n", __func__);
return;
}
llama_grammar_apply_impl(*g, tokens);
}
void grammar_accept(struct llama_grammar *g, llama_token id) {
llama_grammar_accept_impl(*g, id);
}

View file

@ -35,8 +35,12 @@ extern "C"
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
struct llama_vocab * llama_load_vocab_from_file(const char * fname);
void llama_free_vocab(struct llama_vocab * vocab);
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens);
void grammar_free(struct llama_grammar *g);
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens);
void grammar_accept(struct llama_grammar *g, llama_token id);
#ifdef __cplusplus
}

View file

@ -26,6 +26,9 @@ type Model struct {
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
textModel, err := NewTextModel(c)
if err != nil {

View file

@ -32,6 +32,7 @@ type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
type Vocabulary struct {
@ -117,6 +118,8 @@ type BytePairEncoding struct {
vocab *Vocabulary
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
return BytePairEncoding{
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
@ -124,6 +127,10 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}

View file

@ -17,6 +17,10 @@ type SentencePieceModel struct {
var _ TextProcessor = (*SentencePieceModel)(nil)
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])

View file

@ -298,12 +298,6 @@ type Server struct {
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
// vocab is a llama.cpp vocab required for gammar-based
// constrained generation (json mode, structured outputs)
// TODO: this is temporary until Ollama sampling supports
// constrained generation
vocab *sample.Vocab
}
func (s *Server) allNil() bool {
@ -606,14 +600,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var grammar *sample.Grammar
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return
}
defer grammar.Free()
}
sampler := sample.NewSampler(
@ -789,8 +784,6 @@ func (s *Server) loadModel(
panic(err)
}
s.vocab = sample.NewVocab(mpath)
// TODO(jessegross): LoRA loading
if lpath.String() != "" {
panic("loras are not yet implemented")

View file

@ -5,9 +5,9 @@ import (
"math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
// token represents information about a single token during sampling
@ -22,7 +22,7 @@ type Sampler struct {
topP float32
minP float32
temperature float32
grammar *Grammar
grammar *GrammarSampler
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
@ -127,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@ -164,63 +164,43 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
}
type Grammar struct {
vocab *Vocab
grammar string
sampler *llama.Sampler
type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
v, err := vocab.Load()
if err != nil {
return nil, err
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
return &Grammar{
vocab: vocab,
grammar: grammar,
sampler: llama.NewGrammarSampler(v, grammar),
}, nil
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}
return &GrammarSampler{grammar: grammar}, nil
}
func (g *Grammar) Apply(tokens []token) {
func (g *GrammarSampler) Apply(tokens []token) {
tds := make([]llama.TokenData, len(tokens))
for i, token := range tokens {
tds[i].Id = token.id
tds[i].ID = token.id
tds[i].Logit = token.value
}
g.sampler.Apply(tds)
g.grammar.Apply(tds)
for i := range tokens {
tokens[i].value = tds[i].Logit
}
}
func (g *Grammar) Accept(token int32) {
g.sampler.Accept(token)
func (g *GrammarSampler) Accept(token int32) {
g.grammar.Accept(token)
}
type Vocab struct {
once sync.Once
vocab *llama.Vocab
err error
path string
}
func NewVocab(path string) *Vocab {
return &Vocab{path: path}
}
// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
v.once.Do(func() {
vocab, err := llama.LoadVocabFromFile(v.path)
if err != nil {
v.err = err
return
}
v.vocab = vocab
})
return v.vocab, v.err
func (g *GrammarSampler) Free() {
g.grammar.Free()
}

View file

@ -1,9 +1,14 @@
package sample
import (
"encoding/json"
"math"
"math/rand/v2"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/model"
)
func TestWeighted(t *testing.T) {
@ -55,6 +60,97 @@ func TestWeighted(t *testing.T) {
}
}
func modelHelper(t testing.TB) model.BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
t.Fatal(err)
}
types := make([]uint32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
}
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
``,
&model.Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
)
}
func TestGrammar(t *testing.T) {
tokenizer := modelHelper(t)
grammarJSON := `
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
`
grammar, err := NewGrammarSampler(tokenizer, grammarJSON)
if err != nil {
t.Fatal(err)
}
defer grammar.Free()
logits := make([]float32, len(tokenizer.Vocabulary().Values))
for i := range logits {
logits[i] = rand.Float32()
}
tokens := make([]token, len(logits))
for i := range tokens {
tokens[i].id = int32(i)
tokens[i].value = logits[i]
}
grammar.Apply(tokens)
nonInfCount := 0
infCount := 0
for _, tok := range tokens {
if math.IsInf(float64(tok.value), -1) {
infCount++
} else {
nonInfCount++
}
}
if nonInfCount == 0 {
t.Error("expected at least one non -inf token after grammar application, got none")
}
if infCount == 0 {
t.Error("expected some -inf tokens after grammar application, got none")
}
}
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy